blob: 49a5d0d0c3ed4be7fd339216c760cfb73fecd4d9 [file] [log] [blame] [edit]
/*
* Copyright 2023 WebAssembly Community Group participants
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef wasm_analysis_bits_bounds_transfer_function_h
#define wasm_analysis_bits_bounds_transfer_function_h
#include <unordered_map>
#include "bits-bounds-lattice.h"
#include "ir/bits.h"
#include "ir/stack-utils.h"
#include "stack-lattice.h"
#include "visitor-transfer-function.h"
namespace wasm::analysis {
class MaxBitsTransferFunction
: public VisitorTransferFunc<MaxBitsTransferFunction,
StackLattice<MaxBitsLattice>,
AnalysisDirection::Forward> {
MaxBitsLattice& bitsLattice;
public:
std::unordered_map<Expression*, Index> exprMaxBounds;
MaxBitsTransferFunction(MaxBitsLattice& bitsLattice)
: bitsLattice(bitsLattice) {}
void visitConst(Const* curr) {
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
switch (curr->type.getBasic()) {
case Type::i32:
currElement.setLiteralValue(curr->value);
if (collectingResults) {
exprMaxBounds[curr] = currElement.getUpperBound().value();
}
break;
case Type::i64:
currElement.setLiteralValue(curr->value);
if (collectingResults) {
exprMaxBounds[curr] = currElement.getUpperBound().value();
}
break;
default: {
}
}
currState->push(std::move(currElement));
}
void visitBinary(Binary* curr) {
MaxBitsLattice::Element right = currState->pop();
MaxBitsLattice::Element left = currState->pop();
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
bool addInformation = true;
switch (curr->op) {
case RotLInt32:
case RotRInt32:
case SubInt32: {
// TODO: Use a more precise estimate for these cases.
currElement.setUpperBound(32);
break;
}
case AddInt32: {
currElement.setUpperBound(
std::min(Index(32),
std::max(left.geti32ApproxUpperBound(),
right.geti32ApproxUpperBound()) +
1));
break;
}
case MulInt32: {
currElement.setUpperBound(std::min(Index(32),
left.geti32ApproxUpperBound() +
right.geti32ApproxUpperBound()));
break;
}
case DivSInt32: {
int32_t maxBitsLeft = left.geti32ApproxUpperBound();
int32_t maxBitsRight = right.geti32ApproxUpperBound();
if (maxBitsLeft == 32 || maxBitsRight == 32) {
currElement.setUpperBound(32);
} else {
currElement.setUpperBound(
std::max(0, maxBitsLeft - maxBitsRight + 1));
}
break;
}
case DivUInt32: {
int32_t maxBitsLeft = left.geti32ApproxUpperBound();
int32_t maxBitsRight = right.geti32ApproxUpperBound();
currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
break;
}
case RemSInt32: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
Index leftUpperBound = left.geti32ApproxUpperBound();
if (leftUpperBound == 32) {
currElement.setUpperBound(32);
} else {
auto bitsRight =
Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
}
} else {
currElement.setUpperBound(32);
}
break;
}
case RemUInt32: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
auto bitsRight =
Index(wasm::Bits::ceilLog2(constRightValue.value().geti32()));
currElement.setUpperBound(
std::min(left.geti32ApproxUpperBound(), bitsRight));
} else {
currElement.setUpperBound(32);
}
break;
}
case AndInt32: {
currElement.setUpperBound(std::min(left.geti32ApproxUpperBound(),
right.geti32ApproxUpperBound()));
break;
}
case OrInt32:
case XorInt32: {
currElement.setUpperBound(std::max(left.geti32ApproxUpperBound(),
right.geti32ApproxUpperBound()));
break;
}
case ShlInt32: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
currElement.setUpperBound(
std::min(Index(32),
left.geti32ApproxUpperBound() +
Bits::getEffectiveShifts(
constRightValue.value().geti32(), Type::i32)));
}
break;
}
case ShrUInt32: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
Index leftUpperBound = left.geti32ApproxUpperBound();
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
constRightValue.value().geti32(), Type::i32)),
leftUpperBound);
currElement.setUpperBound(
std::max(Index(0), leftUpperBound - shifts));
} else {
currElement.setUpperBound(32);
}
break;
}
case ShrSInt32: {
std::optional<Literal> constRightValue = right.getLiteral();
Index leftUpperBound = left.geti32ApproxUpperBound();
if (constRightValue.has_value()) {
if (leftUpperBound == 32) {
currElement.setUpperBound(32);
} else {
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
constRightValue.value().geti32(), Type::i32)),
leftUpperBound);
currElement.setUpperBound(
std::max(Index(0), leftUpperBound - shifts));
}
} else {
currElement.setUpperBound(32);
}
break;
}
case RotLInt64:
case RotRInt64:
case SubInt64: {
currElement.setUpperBound(64);
break;
}
case AddInt64: {
currElement.setUpperBound(
std::min(Index(64),
std::max(left.geti64ApproxUpperBound(),
right.geti64ApproxUpperBound())));
break;
}
case MulInt64: {
currElement.setUpperBound(std::min(Index(64),
left.geti64ApproxUpperBound() +
right.geti64ApproxUpperBound()));
break;
}
case DivSInt64: {
int32_t maxBitsLeft = left.geti64ApproxUpperBound();
int32_t maxBitsRight = right.geti64ApproxUpperBound();
if (maxBitsLeft == 64 || maxBitsRight == 64) {
currElement.setUpperBound(64);
} else {
currElement.setUpperBound(
std::max(0, maxBitsLeft - maxBitsRight + 1));
}
break;
}
case DivUInt64: {
int32_t maxBitsLeft = left.geti64ApproxUpperBound();
int32_t maxBitsRight = right.geti64ApproxUpperBound();
currElement.setUpperBound(std::max(0, maxBitsLeft - maxBitsRight + 1));
break;
}
case RemSInt64: {
std::optional<Literal> constRightValue = right.getLiteral();
Index leftUpperBound = left.geti64ApproxUpperBound();
if (constRightValue.has_value()) {
if (leftUpperBound == 64) {
currElement.setUpperBound(64);
} else {
auto bitsRight =
Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
currElement.setUpperBound(std::min(leftUpperBound, bitsRight));
}
} else {
currElement.setUpperBound(64);
}
break;
}
case RemUInt64: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
auto bitsRight =
Index(wasm::Bits::ceilLog2(constRightValue.value().geti64()));
currElement.setUpperBound(
std::min(left.geti64ApproxUpperBound(), bitsRight));
} else {
currElement.setUpperBound(64);
}
break;
}
case AndInt64: {
currElement.setUpperBound(std::min(left.geti64ApproxUpperBound(),
right.geti64ApproxUpperBound()));
break;
}
case OrInt64:
case XorInt64: {
currElement.setUpperBound(std::max(left.geti64ApproxUpperBound(),
right.geti64ApproxUpperBound()));
break;
}
case ShlInt64: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
currElement.setUpperBound(
std::min(Index(64),
Bits::getEffectiveShifts(constRightValue.value().geti64(),
Type::i64) +
left.geti64ApproxUpperBound()));
} else {
currElement.setUpperBound(64);
}
break;
}
case ShrUInt64: {
std::optional<Literal> constRightValue = right.getLiteral();
if (constRightValue.has_value()) {
Index leftUpperBound = left.geti64ApproxUpperBound();
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
constRightValue.value().geti64(), Type::i64)),
leftUpperBound);
currElement.setUpperBound(
std::max(Index(0), leftUpperBound - shifts));
} else {
currElement.setUpperBound(64);
}
break;
}
case ShrSInt64: {
std::optional<Literal> constRightValue = right.getLiteral();
Index leftUpperBound = left.geti64ApproxUpperBound();
if (constRightValue.has_value()) {
if (leftUpperBound == 64) {
currElement.setUpperBound(64);
} else {
auto shifts =
std::min(Index(Bits::getEffectiveShifts(
constRightValue.value().geti64(), Type::i64)),
leftUpperBound);
currElement.setUpperBound(
std::max(Index(0), leftUpperBound - shifts));
}
} else {
currElement.setUpperBound(64);
}
break;
}
default: {
addInformation = false;
}
}
if (collectingResults && addInformation) {
exprMaxBounds[curr] = currElement.getUpperBound().value();
}
currState->push(std::move(currElement));
}
void visitUnary(Unary* curr) {
MaxBitsLattice::Element val = currState->pop();
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
bool addInformation = true;
switch (curr->op) {
case ClzInt32:
case CtzInt32:
case PopcntInt32: {
currElement.setUpperBound(6);
break;
}
case ClzInt64:
case CtzInt64:
case PopcntInt64: {
currElement.setUpperBound(7);
break;
}
case WrapInt64: {
currElement.setUpperBound(val.geti64ApproxUpperBound());
break;
}
case ExtendUInt32: {
currElement.setUpperBound(val.geti32ApproxUpperBound());
break;
}
case ExtendS8Int32: {
Index upperBound = val.geti32ApproxUpperBound();
currElement.setUpperBound(upperBound >= 8 ? Index(32) : upperBound);
break;
}
case ExtendS16Int32: {
Index upperBound = val.geti32ApproxUpperBound();
currElement.setUpperBound(upperBound >= 16 ? Index(32) : upperBound);
break;
}
case ExtendS8Int64: {
Index upperBound = val.geti64ApproxUpperBound();
currElement.setUpperBound(upperBound >= 8 ? Index(64) : upperBound);
break;
}
case ExtendS16Int64: {
Index upperBound = val.geti64ApproxUpperBound();
currElement.setUpperBound(upperBound >= 16 ? Index(64) : upperBound);
break;
}
case ExtendS32Int64: {
Index upperBound = val.geti64ApproxUpperBound();
currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
break;
}
// TODO: What's the difference of this with the above?
case ExtendSInt32: {
Index upperBound = val.geti32ApproxUpperBound();
currElement.setUpperBound(upperBound >= 32 ? Index(64) : upperBound);
break;
}
default: {
addInformation = false;
}
}
if (collectingResults && addInformation) {
exprMaxBounds[curr] = currElement.getUpperBound().value();
}
currState->push(std::move(currElement));
}
void visitLocalSet(LocalSet* curr) {
MaxBitsLattice::Element val = currState->pop();
if (collectingResults && curr->isTee()) {
std::optional<Index> upperBound = val.getUpperBound();
if (upperBound.has_value()) {
exprMaxBounds[curr] = upperBound.value();
}
}
}
void visitExpression(Expression* curr) {
StackSignature exprSignature(curr);
for (size_t i = 0; i < exprSignature.params.size(); ++i) {
MaxBitsLattice::Element currElement = bitsLattice.getBottom();
currElement.setTop();
currState->push(std::move(currElement));
}
for (size_t i = 0; i < exprSignature.results.size(); ++i) {
currState->pop();
}
}
};
} // namespace wasm::analysis
#endif // wasm_analysis_bits_bounds_transfer_function_h