Address comments
diff --git a/src/analysis/bits-bounds-lattice.h b/src/analysis/bits-bounds-lattice.h
index f2f334d..cebddfc 100644
--- a/src/analysis/bits-bounds-lattice.h
+++ b/src/analysis/bits-bounds-lattice.h
@@ -1,7 +1,24 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#ifndef wasm_analysis_bits_bounds_lattice_h
#define wasm_analysis_bits_bounds_lattice_h
#include <optional>
+#include <variant>
#include "lattice.h"
#include "wasm.h"
@@ -9,67 +26,129 @@
namespace wasm::analysis {
struct MaxBitsLattice {
- enum LatticeState { BOTTOM, REGULAR_VALUE, TOP };
+ enum LatticeState { BOTTOM, TOP };
- struct Element {
- Index upperBound = 0;
- std::optional<Literal> constVal;
- LatticeState state = BOTTOM;
+ class Element {
+ // If this holds a LatticeState, it indicates if the element is a top or
+ // bottom element. An Index will represent the maximum number of bits a
+ // value has. A Literal contains the actual value if we know what it is.
+ std::variant<Index, Literal, LatticeState> value;
- bool isTop() const { return state == TOP; }
+ public:
+ bool isTop() const {
+ if (std::holds_alternative<LatticeState>(value)) {
+ return std::get<LatticeState>(value) == LatticeState::TOP;
+ }
+ return false;
+ }
- bool isBottom() const { return state == BOTTOM; }
+ void setTop() { value = LatticeState::TOP; }
- void setUpperBound(Index upperBound) {
- if (state != TOP) {
- state = REGULAR_VALUE;
- this->upperBound = upperBound;
- constVal.reset();
+ bool isBottom() const {
+ if (std::holds_alternative<LatticeState>(value)) {
+ return std::get<LatticeState>(value) == LatticeState::BOTTOM;
+ }
+ return false;
+ }
+
+ void setBottom() { value = LatticeState::BOTTOM; }
+
+ // Returns an optional which contains a literal (i.e. constant)
+ // value if it exists.
+ std::optional<Literal> getLiteral() const {
+ std::optional<Literal> result;
+ if (std::holds_alternative<Literal>(value)) {
+ result = std::get<Literal>(value);
+ }
+ return result;
+ }
+
+ // Returns the maximum number of bits. The optional is empty
+ // for the top and bottom elements, or if the literal happens
+ // not to be i32 or i64.
+ std::optional<Index> getUpperBound() const {
+ std::optional<Index> result;
+ if (std::holds_alternative<Index>(value)) {
+ result = std::get<Index>(value);
+ } else if (std::holds_alternative<Literal>(value)) {
+ Literal val = std::get<Literal>(value);
+ if (val.type == Type::i32) {
+ result = 32 - val.countLeadingZeroes().geti32();
+ } else if (val.type == Type::i64) {
+ result = 64 - val.countLeadingZeroes().geti64();
+ }
+ }
+ return result;
+ }
+
+ // Returns an upper bound approximated for i32s. If Top, returns 32. If
+ // bottom, or has no upper bound, returns 0.
+ Index geti32ApproxUpperBound() const {
+ std::optional<Index> result = getUpperBound();
+ if (result.has_value()) {
+ return result.value();
+ } else if (isTop()) {
+ return Index(32);
+ } else {
+ return Index(0);
}
}
- void setUpperBound(Index upperBound, Literal constVal) {
- if (state != TOP) {
- state = REGULAR_VALUE;
- this->upperBound = upperBound;
- this->constVal = constVal;
+ // Returns an upper bound approximated of i64s. If Top, returns 32. If
+ // bottom, or has no upper bound, returns 0.
+ Index geti64ApproxUpperBound() const {
+ std::optional<Index> result = getUpperBound();
+ if (result.has_value()) {
+ return result.value();
+ } else if (isTop()) {
+ return Index(64);
+ } else {
+ return Index(0);
}
}
- void setTop() {
- state = TOP;
- upperBound = UINT32_MAX;
- constVal.reset();
+ void setUpperBound(Index upperBound) { value = upperBound; }
+
+ // If we don't have a desired literal type, we switch to
+ // bottom (i.e. the literal isn't supposed to have a max bits value).
+ void setLiteralValue(Literal val) {
+ if (val.type == Type::i32 || val.type == Type::i64) {
+ value = val;
+ } else {
+ value = LatticeState::BOTTOM;
+ }
}
bool makeLeastUpperBound(const Element& other) {
if (other.isBottom() || isTop()) {
return false;
} else if (other.isTop()) {
- setTop();
+ value = LatticeState::TOP;
return true;
} else if (isBottom()) {
- if (other.constVal.has_value()) {
- setUpperBound(other.upperBound, other.constVal.value());
+ if (std::holds_alternative<Literal>(other.value)) {
+ value = std::get<Literal>(other.value);
} else {
- setUpperBound(other.upperBound);
+ value = std::get<Index>(other.value);
}
return true;
}
- if (upperBound < other.upperBound) {
- if (other.constVal.has_value()) {
- setUpperBound(other.upperBound, other.constVal.value());
+ Index currMaxBits = getUpperBound().value();
+ Index otherMaxBits = other.getUpperBound().value();
+
+ if (currMaxBits < otherMaxBits) {
+ if (std::holds_alternative<Literal>(other.value)) {
+ value = std::get<Literal>(other.value);
} else {
- setUpperBound(other.upperBound);
+ value = otherMaxBits;
}
return true;
- } else if (upperBound == other.upperBound && other.constVal.has_value()) {
- if (constVal.has_value()) {
- constVal.reset();
- } else {
- constVal = other.constVal.value();
- }
+ } else if (currMaxBits == otherMaxBits &&
+ std::holds_alternative<Literal>(value) &&
+ (!std::holds_alternative<Literal>(other.value) ||
+ std::get<Literal>(value) != std::get<Literal>(other.value))) {
+ value = currMaxBits;
return true;
}
@@ -77,15 +156,16 @@
}
void print(std::ostream& os) {
- if (state == TOP) {
- os << "TOP";
- } else if (state == BOTTOM) {
- os << "BOTTOM";
- } else {
- os << upperBound;
- if (constVal.has_value()) {
- os << " " << constVal.value();
+ if (LatticeState* val = std::get_if<LatticeState>(&value)) {
+ if (*val == LatticeState::BOTTOM) {
+ os << "BOTTOM";
+ } else {
+ os << "TOP";
}
+ } else if (Index* val = std::get_if<Index>(&value)) {
+ os << *val;
+ } else {
+ os << "Literal " << std::get<Literal>(value);
}
}
@@ -93,32 +173,38 @@
};
LatticeComparison compare(const Element& left, const Element& right) {
- if (left.isTop()) {
- if (right.isTop()) {
- return LatticeComparison::EQUAL;
+ if (std::holds_alternative<LatticeState>(left.value)) {
+ LatticeState leftVal = std::get<LatticeState>(left.value);
+ if (std::holds_alternative<LatticeState>(right.value)) {
+ LatticeState rightVal = std::get<LatticeState>(right.value);
+ if (leftVal < rightVal) {
+ return LatticeComparison::LESS;
+ } else if (leftVal > rightVal) {
+ return LatticeComparison::GREATER;
+ } else {
+ return LatticeComparison::EQUAL;
+ }
} else {
- return LatticeComparison::GREATER;
+ return leftVal == LatticeState::TOP ? LatticeComparison::GREATER
+ : LatticeComparison::LESS;
}
- } else if (right.isTop()) {
- return LatticeComparison::LESS;
- } else if (left.isBottom()) {
- if (right.isBottom()) {
- return LatticeComparison::EQUAL;
- } else {
- return LatticeComparison::LESS;
- }
- } else if (right.isBottom()) {
- return LatticeComparison::GREATER;
+ } else if (std::holds_alternative<LatticeState>(right.value)) {
+ return std::get<LatticeState>(right.value) == LatticeState::TOP
+ ? LatticeComparison::LESS
+ : LatticeComparison::GREATER;
}
- if (left.upperBound > right.upperBound) {
- return LatticeComparison::GREATER;
- } else if (left.upperBound < right.upperBound) {
+ Index leftMaxBits = left.getUpperBound().value();
+ Index rightMaxBits = right.getUpperBound().value();
+
+ if (leftMaxBits < rightMaxBits) {
return LatticeComparison::LESS;
+ } else if (leftMaxBits > rightMaxBits) {
+ return LatticeComparison::GREATER;
} else {
- if (left.constVal.has_value()) {
- if (right.constVal.has_value()) {
- if (left.constVal.value() == right.constVal.value()) {
+ if (std::holds_alternative<Literal>(left.value)) {
+ if (std::holds_alternative<Literal>(right.value)) {
+ if (std::get<Literal>(left.value) == std::get<Literal>(right.value)) {
return LatticeComparison::EQUAL;
} else {
return LatticeComparison::NO_RELATION;
@@ -127,7 +213,7 @@
return LatticeComparison::LESS;
}
} else {
- if (right.constVal.has_value()) {
+ if (std::holds_alternative<Literal>(right.value)) {
return LatticeComparison::GREATER;
} else {
return LatticeComparison::EQUAL;
diff --git a/src/analysis/bits-bounds-transfer-function.h b/src/analysis/bits-bounds-transfer-function.h
index c5e53cf..49a5d0d 100644
--- a/src/analysis/bits-bounds-transfer-function.h
+++ b/src/analysis/bits-bounds-transfer-function.h
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
#ifndef wasm_analysis_bits_bounds_transfer_function_h
#define wasm_analysis_bits_bounds_transfer_function_h
@@ -26,25 +42,22 @@
void visitConst(Const* curr) {
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
- bool addInformation = true;
switch (curr->type.getBasic()) {
case Type::i32:
- currElement.setUpperBound(
- 32 - curr->value.countLeadingZeroes().geti32(), curr->value);
+ currElement.setLiteralValue(curr->value);
+ if (collectingResults) {
+ exprMaxBounds[curr] = currElement.getUpperBound().value();
+ }
break;
case Type::i64:
- currElement.setUpperBound(
- 64 - curr->value.countLeadingZeroes().geti64(), curr->value);
+ currElement.setLiteralValue(curr->value);
+ if (collectingResults) {
+ exprMaxBounds[curr] = currElement.getUpperBound().value();
+ }
break;
default: {
- addInformation = false;
}
}
-
- if (collectingResults && addInformation) {
- exprMaxBounds[curr] = currElement.upperBound;
- }
-
currState->push(std::move(currElement));
}
@@ -59,22 +72,27 @@
case RotLInt32:
case RotRInt32:
case SubInt32: {
+ // TODO: Use a more precise estimate for these cases.
currElement.setUpperBound(32);
break;
}
case AddInt32: {
currElement.setUpperBound(
- std::min(Index(32), std::max(left.upperBound, right.upperBound) + 1));
+ std::min(Index(32),
+ std::max(left.geti32ApproxUpperBound(),
+ right.geti32ApproxUpperBound()) +
+ 1));
break;
}
case MulInt32: {
- currElement.setUpperBound(
- std::min(Index(32), left.upperBound + right.upperBound));
+ currElement.setUpperBound(std::min(Index(32),
+ left.geti32ApproxUpperBound() +
+ right.geti32ApproxUpperBound()));
break;
}
case DivSInt32: {
- int32_t maxBitsLeft = left.upperBound;
- int32_t maxBitsRight = right.upperBound;
+ int32_t maxBitsLeft = left.geti32ApproxUpperBound();
+ int32_t maxBitsRight = right.geti32ApproxUpperBound();
if (maxBitsLeft == 32 || maxBitsRight == 32) {
currElement.setUpperBound(32);
} else {
@@ -84,19 +102,21 @@
break;
}
case DivUInt32: {
- int32_t maxBitsLeft = left.upperBound;
- int32_t maxBitsRight = right.upperBound;
+ int32_t maxBitsLeft = left.geti32ApproxUpperBound();
+ int32_t maxBitsRight = right.geti32ApproxUpperBound();
currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
break;
}
case RemSInt32: {
- if (right.constVal.has_value()) {
- if (left.upperBound == 32) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
+ Index leftUpperBound = left.geti32ApproxUpperBound();
+ if (leftUpperBound == 32) {
currElement.setUpperBound(32);
} else {
auto bitsRight =
- Index(wasm::Bits::ceilLog2(right.constVal.value().geti32()));
- currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+ Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
+ currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
}
} else {
currElement.setUpperBound(32);
@@ -104,56 +124,67 @@
break;
}
case RemUInt32: {
- if (right.constVal.has_value()) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
auto bitsRight =
- Index(wasm::Bits::ceilLog2(right.constVal.value().geti32()));
- currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+ Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
+ currElement.setUpperBound(
+ std::min(left.geti32ApproxUpperBound(), bitsRight));
} else {
currElement.setUpperBound(32);
}
break;
}
case AndInt32: {
- currElement.setUpperBound(std::min(left.upperBound, right.upperBound));
+ currElement.setUpperBound(std::min(left.geti32ApproxUpperBound(),
+ right.geti32ApproxUpperBound()));
break;
}
case OrInt32:
case XorInt32: {
- currElement.setUpperBound(std::max(left.upperBound, right.upperBound));
+ currElement.setUpperBound(std::max(left.geti32ApproxUpperBound(),
+ right.geti32ApproxUpperBound()));
break;
}
case ShlInt32: {
- if (right.constVal.has_value()) {
- currElement.setUpperBound(std::min(
- Index(32),
- left.upperBound + Bits::getEffectiveShifts(
- right.constVal.value().geti32(), Type::i32)));
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
+ currElement.setUpperBound(
+ std::min(Index(32),
+ left.geti32ApproxUpperBound() +
+ Bits::getEffectiveShifts(
+ constRightValue.value().geti32(), Type::i32)));
}
break;
}
case ShrUInt32: {
- if (right.constVal.has_value()) {
- auto shifts = std::min(Index(Bits::getEffectiveShifts(
- right.constVal.value().geti32(), Type::i32)),
- left.upperBound);
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
+ Index leftUpperBound = left.geti32ApproxUpperBound();
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(
+ constRightValue.value().geti32(), Type::i32)),
+ leftUpperBound);
currElement.setUpperBound(
- std::max(Index(0), left.upperBound - shifts));
+ std::max(Index(0), leftUpperBound - shifts));
} else {
currElement.setUpperBound(32);
}
break;
}
case ShrSInt32: {
- if (right.constVal.has_value()) {
- if (left.upperBound == 32) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ Index leftUpperBound = left.geti32ApproxUpperBound();
+ if (constRightValue.has_value()) {
+ if (leftUpperBound == 32) {
currElement.setUpperBound(32);
} else {
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
- right.constVal.value().geti32(), Type::i32)),
- left.upperBound);
+ constRightValue.value().geti32(), Type::i32)),
+ leftUpperBound);
currElement.setUpperBound(
- std::max(Index(0), left.upperBound - shifts));
+ std::max(Index(0), leftUpperBound - shifts));
}
} else {
currElement.setUpperBound(32);
@@ -168,17 +199,20 @@
}
case AddInt64: {
currElement.setUpperBound(
- std::min(Index(64), std::max(left.upperBound, right.upperBound)));
+ std::min(Index(64),
+ std::max(left.geti64ApproxUpperBound(),
+ right.geti64ApproxUpperBound())));
break;
}
case MulInt64: {
- currElement.setUpperBound(
- std::min(Index(64), left.upperBound + right.upperBound));
+ currElement.setUpperBound(std::min(Index(64),
+ left.geti64ApproxUpperBound() +
+ right.geti64ApproxUpperBound()));
break;
}
case DivSInt64: {
- int32_t maxBitsLeft = left.upperBound;
- int32_t maxBitsRight = right.upperBound;
+ int32_t maxBitsLeft = left.geti64ApproxUpperBound();
+ int32_t maxBitsRight = right.geti64ApproxUpperBound();
if (maxBitsLeft == 64 || maxBitsRight == 64) {
currElement.setUpperBound(64);
} else {
@@ -188,19 +222,21 @@
break;
}
case DivUInt64: {
- int32_t maxBitsLeft = left.upperBound;
- int32_t maxBitsRight = right.upperBound;
+ int32_t maxBitsLeft = left.geti64ApproxUpperBound();
+ int32_t maxBitsRight = right.geti64ApproxUpperBound();
currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
break;
}
case RemSInt64: {
- if (right.constVal.has_value()) {
- if (left.upperBound == 64) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ Index leftUpperBound = left.geti64ApproxUpperBound();
+ if (constRightValue.has_value()) {
+ if (leftUpperBound == 64) {
currElement.setUpperBound(64);
} else {
auto bitsRight =
- Index(wasm::Bits::ceilLog2(right.constVal.value().geti64()));
- currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+ Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
+ currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
}
} else {
currElement.setUpperBound(64);
@@ -208,59 +244,69 @@
break;
}
case RemUInt64: {
- if (right.constVal.has_value()) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
auto bitsRight =
- Index(wasm::Bits::ceilLog2(right.constVal.value().geti64()));
- currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+ Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
+ currElement.setUpperBound(
+ std::min(left.geti64ApproxUpperBound(), bitsRight));
} else {
currElement.setUpperBound(64);
}
break;
}
case AndInt64: {
- currElement.setUpperBound(std::min(left.upperBound, right.upperBound));
+ currElement.setUpperBound(std::min(left.geti64ApproxUpperBound(),
+ right.geti64ApproxUpperBound()));
break;
}
case OrInt64:
case XorInt64: {
- currElement.setUpperBound(std::max(left.upperBound, right.upperBound));
+ currElement.setUpperBound(std::max(left.geti64ApproxUpperBound(),
+ right.geti64ApproxUpperBound()));
break;
}
case ShlInt64: {
- if (right.constVal.has_value()) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
currElement.setUpperBound(
std::min(Index(64),
- Bits::getEffectiveShifts(right.constVal.value().geti64(),
+ Bits::getEffectiveShifts(constRightValue.value().geti64(),
Type::i64) +
- left.upperBound));
+ left.geti64ApproxUpperBound()));
} else {
currElement.setUpperBound(64);
}
break;
}
case ShrUInt64: {
- if (right.constVal.has_value()) {
- auto shifts = std::min(Index(Bits::getEffectiveShifts(
- right.constVal.value().geti64(), Type::i64)),
- left.upperBound);
+ std::optional<Literal> constRightValue = right.getLiteral();
+ if (constRightValue.has_value()) {
+ Index leftUpperBound = left.geti64ApproxUpperBound();
+ auto shifts =
+ std::min(Index(Bits::getEffectiveShifts(
+ constRightValue.value().geti64(), Type::i64)),
+ leftUpperBound);
currElement.setUpperBound(
- std::max(Index(0), left.upperBound - shifts));
+ std::max(Index(0), leftUpperBound - shifts));
} else {
currElement.setUpperBound(64);
}
break;
}
case ShrSInt64: {
- if (right.constVal.has_value()) {
- if (left.upperBound == 64) {
+ std::optional<Literal> constRightValue = right.getLiteral();
+ Index leftUpperBound = left.geti64ApproxUpperBound();
+ if (constRightValue.has_value()) {
+ if (leftUpperBound == 64) {
currElement.setUpperBound(64);
} else {
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
- right.constVal.value().geti64(), Type::i64)),
- left.upperBound);
+ constRightValue.value().geti64(), Type::i64)),
+ leftUpperBound);
currElement.setUpperBound(
- std::max(Index(0), left.upperBound - shifts));
+ std::max(Index(0), leftUpperBound - shifts));
}
} else {
currElement.setUpperBound(64);
@@ -273,14 +319,14 @@
}
if (collectingResults && addInformation) {
- exprMaxBounds[curr] = currElement.upperBound;
+ exprMaxBounds[curr] = currElement.getUpperBound().value();
}
currState->push(std::move(currElement));
}
void visitUnary(Unary* curr) {
- MaxBitsLattice::Element unaryVal = currState->pop();
+ MaxBitsLattice::Element val = currState->pop();
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
bool addInformation = true;
@@ -297,35 +343,43 @@
currElement.setUpperBound(7);
break;
}
- case WrapInt64:
+ case WrapInt64: {
+ currElement.setUpperBound(val.geti64ApproxUpperBound());
+ break;
+ }
case ExtendUInt32: {
- currElement.setUpperBound(unaryVal.upperBound);
+ currElement.setUpperBound(val.geti32ApproxUpperBound());
break;
}
case ExtendS8Int32: {
- currElement.setUpperBound(
- unaryVal.upperBound >= 8 ? Index(32) : unaryVal.upperBound);
+ Index upperBound = val.geti32ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 8 ? Index(32) : upperBound);
break;
}
case ExtendS16Int32: {
- currElement.setUpperBound(
- unaryVal.upperBound >= 16 ? Index(32) : unaryVal.upperBound);
+ Index upperBound = val.geti32ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 16 ? Index(32) : upperBound);
break;
}
case ExtendS8Int64: {
- currElement.setUpperBound(
- unaryVal.upperBound >= 8 ? Index(64) : unaryVal.upperBound);
+ Index upperBound = val.geti64ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 8 ? Index(64) : upperBound);
break;
}
case ExtendS16Int64: {
- currElement.setUpperBound(
- unaryVal.upperBound >= 16 ? Index(64) : unaryVal.upperBound);
+ Index upperBound = val.geti64ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 16 ? Index(64) : upperBound);
break;
}
- case ExtendS32Int64:
+ case ExtendS32Int64: {
+ Index upperBound = val.geti64ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
+ break;
+ }
+ // TODO: What's the difference of this with the above?
case ExtendSInt32: {
- currElement.setUpperBound(
- unaryVal.upperBound >= 32 ? Index(64) : unaryVal.upperBound);
+ Index upperBound = val.geti32ApproxUpperBound();
+ currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
break;
}
default: {
@@ -334,7 +388,7 @@
}
if (collectingResults && addInformation) {
- exprMaxBounds[curr] = currElement.upperBound;
+ exprMaxBounds[curr] = currElement.getUpperBound().value();
}
currState->push(std::move(currElement));
@@ -343,8 +397,11 @@
void visitLocalSet(LocalSet* curr) {
MaxBitsLattice::Element val = currState->pop();
- if (collectingResults && !val.isTop()) {
- exprMaxBounds[curr] = val.upperBound;
+ if (collectingResults && curr->isTee()) {
+ std::optional<Index> upperBound = val.getUpperBound();
+ if (upperBound.has_value()) {
+ exprMaxBounds[curr] = upperBound.value();
+ }
}
}
diff --git a/src/ir/bits.h b/src/ir/bits.h
index 25d80fb..ec7a1d3 100644
--- a/src/ir/bits.h
+++ b/src/ir/bits.h
@@ -163,6 +163,9 @@
return 32;
}
int32_t bitsRight = getMaxBits(c);
+ // TODO: If the bitsRight is equal to maxBitsLeft, because we are
+ // making an estimate, it is possible that the right could be 1,
+ // meaning that the actual division result is maxBitsLeft.
return std::max(0, maxBitsLeft - bitsRight + 1);
}
return 32;
diff --git a/test/gtest/cfg.cpp b/test/gtest/cfg.cpp
index 1e85834..a5a699c 100644
--- a/test/gtest/cfg.cpp
+++ b/test/gtest/cfg.cpp
@@ -680,6 +680,7 @@
LatticeComparison::EQUAL);
}
+// TODO: Add more thorough test cases for max bits analysis.
TEST_F(CFGTest, MaxBitsAnalysis) {
auto moduleText = R"wasm(
(module