Address comments
diff --git a/src/analysis/bits-bounds-lattice.h b/src/analysis/bits-bounds-lattice.h
index f2f334d..cebddfc 100644
--- a/src/analysis/bits-bounds-lattice.h
+++ b/src/analysis/bits-bounds-lattice.h
@@ -1,7 +1,24 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 #ifndef wasm_analysis_bits_bounds_lattice_h
 #define wasm_analysis_bits_bounds_lattice_h
 
 #include <optional>
+#include <variant>
 
 #include "lattice.h"
 #include "wasm.h"
@@ -9,67 +26,129 @@
 namespace wasm::analysis {
 
 struct MaxBitsLattice {
-  enum LatticeState { BOTTOM, REGULAR_VALUE, TOP };
+  enum LatticeState { BOTTOM, TOP };
 
-  struct Element {
-    Index upperBound = 0;
-    std::optional<Literal> constVal;
-    LatticeState state = BOTTOM;
+  class Element {
+    // If this holds a LatticeState, it indicates if the element is a top or
+    // bottom element. An Index will represent the maximum number of bits a
+    // value has. A Literal contains the actual value if we know what it is.
+    std::variant<Index, Literal, LatticeState> value;
 
-    bool isTop() const { return state == TOP; }
+  public:
+    bool isTop() const {
+      if (std::holds_alternative<LatticeState>(value)) {
+        return std::get<LatticeState>(value) == LatticeState::TOP;
+      }
+      return false;
+    }
 
-    bool isBottom() const { return state == BOTTOM; }
+    void setTop() { value = LatticeState::TOP; }
 
-    void setUpperBound(Index upperBound) {
-      if (state != TOP) {
-        state = REGULAR_VALUE;
-        this->upperBound = upperBound;
-        constVal.reset();
+    bool isBottom() const {
+      if (std::holds_alternative<LatticeState>(value)) {
+        return std::get<LatticeState>(value) == LatticeState::BOTTOM;
+      }
+      return false;
+    }
+
+    void setBottom() { value = LatticeState::BOTTOM; }
+
+    // Returns an optional which contains a literal (i.e. constant)
+    // value if it exists.
+    std::optional<Literal> getLiteral() const {
+      std::optional<Literal> result;
+      if (std::holds_alternative<Literal>(value)) {
+        result = std::get<Literal>(value);
+      }
+      return result;
+    }
+
+    // Returns the maximum number of bits. The optional is empty
+    // for the top and bottom elements, or if the literal happens
+    // not to be i32 or i64.
+    std::optional<Index> getUpperBound() const {
+      std::optional<Index> result;
+      if (std::holds_alternative<Index>(value)) {
+        result = std::get<Index>(value);
+      } else if (std::holds_alternative<Literal>(value)) {
+        Literal val = std::get<Literal>(value);
+        if (val.type == Type::i32) {
+          result = 32 - val.countLeadingZeroes().geti32();
+        } else if (val.type == Type::i64) {
+          result = 64 - val.countLeadingZeroes().geti64();
+        }
+      }
+      return result;
+    }
+
+    // Returns an upper bound approximated for i32s. If Top, returns 32. If
+    // bottom, or has no upper bound, returns 0.
+    Index geti32ApproxUpperBound() const {
+      std::optional<Index> result = getUpperBound();
+      if (result.has_value()) {
+        return result.value();
+      } else if (isTop()) {
+        return Index(32);
+      } else {
+        return Index(0);
       }
     }
 
-    void setUpperBound(Index upperBound, Literal constVal) {
-      if (state != TOP) {
-        state = REGULAR_VALUE;
-        this->upperBound = upperBound;
-        this->constVal = constVal;
+    // Returns an upper bound approximated of i64s. If Top, returns 32. If
+    // bottom, or has no upper bound, returns 0.
+    Index geti64ApproxUpperBound() const {
+      std::optional<Index> result = getUpperBound();
+      if (result.has_value()) {
+        return result.value();
+      } else if (isTop()) {
+        return Index(64);
+      } else {
+        return Index(0);
       }
     }
 
-    void setTop() {
-      state = TOP;
-      upperBound = UINT32_MAX;
-      constVal.reset();
+    void setUpperBound(Index upperBound) { value = upperBound; }
+
+    // If we don't have a desired literal type, we switch to
+    // bottom (i.e. the literal isn't supposed to have a max bits value).
+    void setLiteralValue(Literal val) {
+      if (val.type == Type::i32 || val.type == Type::i64) {
+        value = val;
+      } else {
+        value = LatticeState::BOTTOM;
+      }
     }
 
     bool makeLeastUpperBound(const Element& other) {
       if (other.isBottom() || isTop()) {
         return false;
       } else if (other.isTop()) {
-        setTop();
+        value = LatticeState::TOP;
         return true;
       } else if (isBottom()) {
-        if (other.constVal.has_value()) {
-          setUpperBound(other.upperBound, other.constVal.value());
+        if (std::holds_alternative<Literal>(other.value)) {
+          value = std::get<Literal>(other.value);
         } else {
-          setUpperBound(other.upperBound);
+          value = std::get<Index>(other.value);
         }
         return true;
       }
 
-      if (upperBound < other.upperBound) {
-        if (other.constVal.has_value()) {
-          setUpperBound(other.upperBound, other.constVal.value());
+      Index currMaxBits = getUpperBound().value();
+      Index otherMaxBits = other.getUpperBound().value();
+
+      if (currMaxBits < otherMaxBits) {
+        if (std::holds_alternative<Literal>(other.value)) {
+          value = std::get<Literal>(other.value);
         } else {
-          setUpperBound(other.upperBound);
+          value = otherMaxBits;
         }
         return true;
-      } else if (upperBound == other.upperBound && other.constVal.has_value()) {
-        if (constVal.has_value()) {
-          constVal.reset();
-        } else {
-          constVal = other.constVal.value();
-        }
+      } else if (currMaxBits == otherMaxBits &&
+                 std::holds_alternative<Literal>(value) &&
+                 (!std::holds_alternative<Literal>(other.value) ||
+                  std::get<Literal>(value) != std::get<Literal>(other.value))) {
+        value = currMaxBits;
         return true;
       }
 
@@ -77,15 +156,16 @@
     }
 
     void print(std::ostream& os) {
-      if (state == TOP) {
-        os << "TOP";
-      } else if (state == BOTTOM) {
-        os << "BOTTOM";
-      } else {
-        os << upperBound;
-        if (constVal.has_value()) {
-          os << " " << constVal.value();
+      if (LatticeState* val = std::get_if<LatticeState>(&value)) {
+        if (*val == LatticeState::BOTTOM) {
+          os << "BOTTOM";
+        } else {
+          os << "TOP";
         }
+      } else if (Index* val = std::get_if<Index>(&value)) {
+        os << *val;
+      } else {
+        os << "Literal " << std::get<Literal>(value);
       }
     }
 
@@ -93,32 +173,38 @@
   };
 
   LatticeComparison compare(const Element& left, const Element& right) {
-    if (left.isTop()) {
-      if (right.isTop()) {
-        return LatticeComparison::EQUAL;
+    if (std::holds_alternative<LatticeState>(left.value)) {
+      LatticeState leftVal = std::get<LatticeState>(left.value);
+      if (std::holds_alternative<LatticeState>(right.value)) {
+        LatticeState rightVal = std::get<LatticeState>(right.value);
+        if (leftVal < rightVal) {
+          return LatticeComparison::LESS;
+        } else if (leftVal > rightVal) {
+          return LatticeComparison::GREATER;
+        } else {
+          return LatticeComparison::EQUAL;
+        }
       } else {
-        return LatticeComparison::GREATER;
+        return leftVal == LatticeState::TOP ? LatticeComparison::GREATER
+                                            : LatticeComparison::LESS;
       }
-    } else if (right.isTop()) {
-      return LatticeComparison::LESS;
-    } else if (left.isBottom()) {
-      if (right.isBottom()) {
-        return LatticeComparison::EQUAL;
-      } else {
-        return LatticeComparison::LESS;
-      }
-    } else if (right.isBottom()) {
-      return LatticeComparison::GREATER;
+    } else if (std::holds_alternative<LatticeState>(right.value)) {
+      return std::get<LatticeState>(right.value) == LatticeState::TOP
+               ? LatticeComparison::LESS
+               : LatticeComparison::GREATER;
     }
 
-    if (left.upperBound > right.upperBound) {
-      return LatticeComparison::GREATER;
-    } else if (left.upperBound < right.upperBound) {
+    Index leftMaxBits = left.getUpperBound().value();
+    Index rightMaxBits = right.getUpperBound().value();
+
+    if (leftMaxBits < rightMaxBits) {
       return LatticeComparison::LESS;
+    } else if (leftMaxBits > rightMaxBits) {
+      return LatticeComparison::GREATER;
     } else {
-      if (left.constVal.has_value()) {
-        if (right.constVal.has_value()) {
-          if (left.constVal.value() == right.constVal.value()) {
+      if (std::holds_alternative<Literal>(left.value)) {
+        if (std::holds_alternative<Literal>(right.value)) {
+          if (std::get<Literal>(left.value) == std::get<Literal>(right.value)) {
             return LatticeComparison::EQUAL;
           } else {
             return LatticeComparison::NO_RELATION;
@@ -127,7 +213,7 @@
           return LatticeComparison::LESS;
         }
       } else {
-        if (right.constVal.has_value()) {
+        if (std::holds_alternative<Literal>(right.value)) {
           return LatticeComparison::GREATER;
         } else {
           return LatticeComparison::EQUAL;
diff --git a/src/analysis/bits-bounds-transfer-function.h b/src/analysis/bits-bounds-transfer-function.h
index c5e53cf..49a5d0d 100644
--- a/src/analysis/bits-bounds-transfer-function.h
+++ b/src/analysis/bits-bounds-transfer-function.h
@@ -1,3 +1,19 @@
+/*
+ * Copyright 2023 WebAssembly Community Group participants
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
 #ifndef wasm_analysis_bits_bounds_transfer_function_h
 #define wasm_analysis_bits_bounds_transfer_function_h
 
@@ -26,25 +42,22 @@
   void visitConst(Const* curr) {
     MaxBitsLattice::Element currElement = bitsLattice.getBottom();
 
-    bool addInformation = true;
     switch (curr->type.getBasic()) {
       case Type::i32:
-        currElement.setUpperBound(
-          32 - curr->value.countLeadingZeroes().geti32(), curr->value);
+        currElement.setLiteralValue(curr->value);
+        if (collectingResults) {
+          exprMaxBounds[curr] = currElement.getUpperBound().value();
+        }
         break;
       case Type::i64:
-        currElement.setUpperBound(
-          64 - curr->value.countLeadingZeroes().geti64(), curr->value);
+        currElement.setLiteralValue(curr->value);
+        if (collectingResults) {
+          exprMaxBounds[curr] = currElement.getUpperBound().value();
+        }
         break;
       default: {
-        addInformation = false;
       }
     }
-
-    if (collectingResults && addInformation) {
-      exprMaxBounds[curr] = currElement.upperBound;
-    }
-
     currState->push(std::move(currElement));
   }
 
@@ -59,22 +72,27 @@
       case RotLInt32:
       case RotRInt32:
       case SubInt32: {
+        // TODO: Use a more precise estimate for these cases.
         currElement.setUpperBound(32);
         break;
       }
       case AddInt32: {
         currElement.setUpperBound(
-          std::min(Index(32), std::max(left.upperBound, right.upperBound) + 1));
+          std::min(Index(32),
+                   std::max(left.geti32ApproxUpperBound(),
+                            right.geti32ApproxUpperBound()) +
+                     1));
         break;
       }
       case MulInt32: {
-        currElement.setUpperBound(
-          std::min(Index(32), left.upperBound + right.upperBound));
+        currElement.setUpperBound(std::min(Index(32),
+                                           left.geti32ApproxUpperBound() +
+                                             right.geti32ApproxUpperBound()));
         break;
       }
       case DivSInt32: {
-        int32_t maxBitsLeft = left.upperBound;
-        int32_t maxBitsRight = right.upperBound;
+        int32_t maxBitsLeft = left.geti32ApproxUpperBound();
+        int32_t maxBitsRight = right.geti32ApproxUpperBound();
         if (maxBitsLeft == 32 || maxBitsRight == 32) {
           currElement.setUpperBound(32);
         } else {
@@ -84,19 +102,21 @@
         break;
       }
       case DivUInt32: {
-        int32_t maxBitsLeft = left.upperBound;
-        int32_t maxBitsRight = right.upperBound;
+        int32_t maxBitsLeft = left.geti32ApproxUpperBound();
+        int32_t maxBitsRight = right.geti32ApproxUpperBound();
         currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
         break;
       }
       case RemSInt32: {
-        if (right.constVal.has_value()) {
-          if (left.upperBound == 32) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
+          Index leftUpperBound = left.geti32ApproxUpperBound();
+          if (leftUpperBound == 32) {
             currElement.setUpperBound(32);
           } else {
             auto bitsRight =
-              Index(wasm::Bits::ceilLog2(right.constVal.value().geti32()));
-            currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+              Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
+            currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
           }
         } else {
           currElement.setUpperBound(32);
@@ -104,56 +124,67 @@
         break;
       }
       case RemUInt32: {
-        if (right.constVal.has_value()) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
           auto bitsRight =
-            Index(wasm::Bits::ceilLog2(right.constVal.value().geti32()));
-          currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+            Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
+          currElement.setUpperBound(
+            std::min(left.geti32ApproxUpperBound(), bitsRight));
         } else {
           currElement.setUpperBound(32);
         }
         break;
       }
       case AndInt32: {
-        currElement.setUpperBound(std::min(left.upperBound, right.upperBound));
+        currElement.setUpperBound(std::min(left.geti32ApproxUpperBound(),
+                                           right.geti32ApproxUpperBound()));
         break;
       }
       case OrInt32:
       case XorInt32: {
-        currElement.setUpperBound(std::max(left.upperBound, right.upperBound));
+        currElement.setUpperBound(std::max(left.geti32ApproxUpperBound(),
+                                           right.geti32ApproxUpperBound()));
         break;
       }
       case ShlInt32: {
-        if (right.constVal.has_value()) {
-          currElement.setUpperBound(std::min(
-            Index(32),
-            left.upperBound + Bits::getEffectiveShifts(
-                                right.constVal.value().geti32(), Type::i32)));
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
+          currElement.setUpperBound(
+            std::min(Index(32),
+                     left.geti32ApproxUpperBound() +
+                       Bits::getEffectiveShifts(
+                         constRightValue.value().geti32(), Type::i32)));
         }
         break;
       }
       case ShrUInt32: {
-        if (right.constVal.has_value()) {
-          auto shifts = std::min(Index(Bits::getEffectiveShifts(
-                                   right.constVal.value().geti32(), Type::i32)),
-                                 left.upperBound);
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
+          Index leftUpperBound = left.geti32ApproxUpperBound();
+          auto shifts =
+            std::min(Index(Bits::getEffectiveShifts(
+                       constRightValue.value().geti32(), Type::i32)),
+                     leftUpperBound);
           currElement.setUpperBound(
-            std::max(Index(0), left.upperBound - shifts));
+            std::max(Index(0), leftUpperBound - shifts));
         } else {
           currElement.setUpperBound(32);
         }
         break;
       }
       case ShrSInt32: {
-        if (right.constVal.has_value()) {
-          if (left.upperBound == 32) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        Index leftUpperBound = left.geti32ApproxUpperBound();
+        if (constRightValue.has_value()) {
+          if (leftUpperBound == 32) {
             currElement.setUpperBound(32);
           } else {
             auto shifts =
               std::min(Index(Bits::getEffectiveShifts(
-                         right.constVal.value().geti32(), Type::i32)),
-                       left.upperBound);
+                         constRightValue.value().geti32(), Type::i32)),
+                       leftUpperBound);
             currElement.setUpperBound(
-              std::max(Index(0), left.upperBound - shifts));
+              std::max(Index(0), leftUpperBound - shifts));
           }
         } else {
           currElement.setUpperBound(32);
@@ -168,17 +199,20 @@
       }
       case AddInt64: {
         currElement.setUpperBound(
-          std::min(Index(64), std::max(left.upperBound, right.upperBound)));
+          std::min(Index(64),
+                   std::max(left.geti64ApproxUpperBound(),
+                            right.geti64ApproxUpperBound())));
         break;
       }
       case MulInt64: {
-        currElement.setUpperBound(
-          std::min(Index(64), left.upperBound + right.upperBound));
+        currElement.setUpperBound(std::min(Index(64),
+                                           left.geti64ApproxUpperBound() +
+                                             right.geti64ApproxUpperBound()));
         break;
       }
       case DivSInt64: {
-        int32_t maxBitsLeft = left.upperBound;
-        int32_t maxBitsRight = right.upperBound;
+        int32_t maxBitsLeft = left.geti64ApproxUpperBound();
+        int32_t maxBitsRight = right.geti64ApproxUpperBound();
         if (maxBitsLeft == 64 || maxBitsRight == 64) {
           currElement.setUpperBound(64);
         } else {
@@ -188,19 +222,21 @@
         break;
       }
       case DivUInt64: {
-        int32_t maxBitsLeft = left.upperBound;
-        int32_t maxBitsRight = right.upperBound;
+        int32_t maxBitsLeft = left.geti64ApproxUpperBound();
+        int32_t maxBitsRight = right.geti64ApproxUpperBound();
         currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
         break;
       }
       case RemSInt64: {
-        if (right.constVal.has_value()) {
-          if (left.upperBound == 64) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        Index leftUpperBound = left.geti64ApproxUpperBound();
+        if (constRightValue.has_value()) {
+          if (leftUpperBound == 64) {
             currElement.setUpperBound(64);
           } else {
             auto bitsRight =
-              Index(wasm::Bits::ceilLog2(right.constVal.value().geti64()));
-            currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+              Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
+            currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
           }
         } else {
           currElement.setUpperBound(64);
@@ -208,59 +244,69 @@
         break;
       }
       case RemUInt64: {
-        if (right.constVal.has_value()) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
           auto bitsRight =
-            Index(wasm::Bits::ceilLog2(right.constVal.value().geti64()));
-          currElement.setUpperBound(std::min(left.upperBound, bitsRight));
+            Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
+          currElement.setUpperBound(
+            std::min(left.geti64ApproxUpperBound(), bitsRight));
         } else {
           currElement.setUpperBound(64);
         }
         break;
       }
       case AndInt64: {
-        currElement.setUpperBound(std::min(left.upperBound, right.upperBound));
+        currElement.setUpperBound(std::min(left.geti64ApproxUpperBound(),
+                                           right.geti64ApproxUpperBound()));
         break;
       }
       case OrInt64:
       case XorInt64: {
-        currElement.setUpperBound(std::max(left.upperBound, right.upperBound));
+        currElement.setUpperBound(std::max(left.geti64ApproxUpperBound(),
+                                           right.geti64ApproxUpperBound()));
         break;
       }
       case ShlInt64: {
-        if (right.constVal.has_value()) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
           currElement.setUpperBound(
             std::min(Index(64),
-                     Bits::getEffectiveShifts(right.constVal.value().geti64(),
+                     Bits::getEffectiveShifts(constRightValue.value().geti64(),
                                               Type::i64) +
-                       left.upperBound));
+                       left.geti64ApproxUpperBound()));
         } else {
           currElement.setUpperBound(64);
         }
         break;
       }
       case ShrUInt64: {
-        if (right.constVal.has_value()) {
-          auto shifts = std::min(Index(Bits::getEffectiveShifts(
-                                   right.constVal.value().geti64(), Type::i64)),
-                                 left.upperBound);
+        std::optional<Literal> constRightValue = right.getLiteral();
+        if (constRightValue.has_value()) {
+          Index leftUpperBound = left.geti64ApproxUpperBound();
+          auto shifts =
+            std::min(Index(Bits::getEffectiveShifts(
+                       constRightValue.value().geti64(), Type::i64)),
+                     leftUpperBound);
           currElement.setUpperBound(
-            std::max(Index(0), left.upperBound - shifts));
+            std::max(Index(0), leftUpperBound - shifts));
         } else {
           currElement.setUpperBound(64);
         }
         break;
       }
       case ShrSInt64: {
-        if (right.constVal.has_value()) {
-          if (left.upperBound == 64) {
+        std::optional<Literal> constRightValue = right.getLiteral();
+        Index leftUpperBound = left.geti64ApproxUpperBound();
+        if (constRightValue.has_value()) {
+          if (leftUpperBound == 64) {
             currElement.setUpperBound(64);
           } else {
             auto shifts =
               std::min(Index(Bits::getEffectiveShifts(
-                         right.constVal.value().geti64(), Type::i64)),
-                       left.upperBound);
+                         constRightValue.value().geti64(), Type::i64)),
+                       leftUpperBound);
             currElement.setUpperBound(
-              std::max(Index(0), left.upperBound - shifts));
+              std::max(Index(0), leftUpperBound - shifts));
           }
         } else {
           currElement.setUpperBound(64);
@@ -273,14 +319,14 @@
     }
 
     if (collectingResults && addInformation) {
-      exprMaxBounds[curr] = currElement.upperBound;
+      exprMaxBounds[curr] = currElement.getUpperBound().value();
     }
 
     currState->push(std::move(currElement));
   }
 
   void visitUnary(Unary* curr) {
-    MaxBitsLattice::Element unaryVal = currState->pop();
+    MaxBitsLattice::Element val = currState->pop();
     MaxBitsLattice::Element currElement = bitsLattice.getBottom();
 
     bool addInformation = true;
@@ -297,35 +343,43 @@
         currElement.setUpperBound(7);
         break;
       }
-      case WrapInt64:
+      case WrapInt64: {
+        currElement.setUpperBound(val.geti64ApproxUpperBound());
+        break;
+      }
       case ExtendUInt32: {
-        currElement.setUpperBound(unaryVal.upperBound);
+        currElement.setUpperBound(val.geti32ApproxUpperBound());
         break;
       }
       case ExtendS8Int32: {
-        currElement.setUpperBound(
-          unaryVal.upperBound >= 8 ? Index(32) : unaryVal.upperBound);
+        Index upperBound = val.geti32ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 8 ? Index(32) : upperBound);
         break;
       }
       case ExtendS16Int32: {
-        currElement.setUpperBound(
-          unaryVal.upperBound >= 16 ? Index(32) : unaryVal.upperBound);
+        Index upperBound = val.geti32ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 16 ? Index(32) : upperBound);
         break;
       }
       case ExtendS8Int64: {
-        currElement.setUpperBound(
-          unaryVal.upperBound >= 8 ? Index(64) : unaryVal.upperBound);
+        Index upperBound = val.geti64ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 8 ? Index(64) : upperBound);
         break;
       }
       case ExtendS16Int64: {
-        currElement.setUpperBound(
-          unaryVal.upperBound >= 16 ? Index(64) : unaryVal.upperBound);
+        Index upperBound = val.geti64ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 16 ? Index(64) : upperBound);
         break;
       }
-      case ExtendS32Int64:
+      case ExtendS32Int64: {
+        Index upperBound = val.geti64ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
+        break;
+      }
+      // TODO: What's the difference of this with the above?
       case ExtendSInt32: {
-        currElement.setUpperBound(
-          unaryVal.upperBound >= 32 ? Index(64) : unaryVal.upperBound);
+        Index upperBound = val.geti32ApproxUpperBound();
+        currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
         break;
       }
       default: {
@@ -334,7 +388,7 @@
     }
 
     if (collectingResults && addInformation) {
-      exprMaxBounds[curr] = currElement.upperBound;
+      exprMaxBounds[curr] = currElement.getUpperBound().value();
     }
 
     currState->push(std::move(currElement));
@@ -343,8 +397,11 @@
   void visitLocalSet(LocalSet* curr) {
     MaxBitsLattice::Element val = currState->pop();
 
-    if (collectingResults && !val.isTop()) {
-      exprMaxBounds[curr] = val.upperBound;
+    if (collectingResults && curr->isTee()) {
+      std::optional<Index> upperBound = val.getUpperBound();
+      if (upperBound.has_value()) {
+        exprMaxBounds[curr] = upperBound.value();
+      }
     }
   }
 
diff --git a/src/ir/bits.h b/src/ir/bits.h
index 25d80fb..ec7a1d3 100644
--- a/src/ir/bits.h
+++ b/src/ir/bits.h
@@ -163,6 +163,9 @@
             return 32;
           }
           int32_t bitsRight = getMaxBits(c);
+          // TODO: If the bitsRight is equal to maxBitsLeft, because we are
+          // making an estimate, it is possible that the right could be 1,
+          // meaning that the actual division result is maxBitsLeft.
           return std::max(0, maxBitsLeft - bitsRight + 1);
         }
         return 32;
diff --git a/test/gtest/cfg.cpp b/test/gtest/cfg.cpp
index 1e85834..a5a699c 100644
--- a/test/gtest/cfg.cpp
+++ b/test/gtest/cfg.cpp
@@ -680,6 +680,7 @@
             LatticeComparison::EQUAL);
 }
 
+// TODO: Add more thorough test cases for max bits analysis.
 TEST_F(CFGTest, MaxBitsAnalysis) {
   auto moduleText = R"wasm(
     (module