Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 259 additions & 86 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
}
};

// Collapse tensor<1xiN> into tensor<iN>
// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
Location loc) {
SmallVector<ReassociationExprs, 1> reassociation;
// Create the collapsed type
auto inputType = cast<RankedTensorType>(input.getType());
auto elemType = inputType.getElementType();
auto collapsedType = RankedTensorType::get({}, elemType);
// Emit the collapse op
return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
reassociation);
}

// The multiplier may be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
// by:
// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
// - If the multiplier is constant, set 'multiplierConstant' instead.
static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
int64_t &multiplierArg) {

auto loc = op.getLoc();
auto inputTy = cast<ShapedType>(op.getInput().getType());
unsigned rank = inputTy.getRank();
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};

if (isConstant) {
// If we are rescaling per-channel then we need to store the multiplier
// values in a buffer.
if (multiplierValues.size() == 1) {
multiplierConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc,
DenseIntElementsAttr::get(multiplierType, multiplierValues)));

indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));
}
} else {
// If we are not rescaling per-channel then we need to collapse 1xN to N
// and push broadcastMap.
auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
if (tensorType && tensorType.hasStaticShape() &&
tensorType.getShape()[0] == 1) {
// broadcastMap = affine_map<(d0, d1) -> ()>
// It would affect as broadcast for scalar values in linalg::GenericOp.
AffineMap broadcastMap =
AffineMap::get(rank, 0, {}, rewriter.getContext());
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
indexingMaps.push_back(broadcastMap);
} else {
genericInputs.push_back(op.getMultiplier());
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));
}
}
multiplierArg = indexingMaps.size() - 1;
}

// The shift may be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
// - If the shift is constant, set 'shiftConstant' instead.
static void setupLinalgGenericOpInputAndIndexingMapForShift(
PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
int64_t &shiftArg) {

auto loc = op.getLoc();
auto inputTy = cast<ShapedType>(op.getInput().getType());
unsigned rank = inputTy.getRank();
SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};

if (isConstant) {
// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
if (shiftValues.size() == 1) {
shiftConstant = rewriter.create<arith::ConstantOp>(
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
}
} else {
// If we are not rescaling per-channel then we need to collapse 1xN to N
// and push broadcastMap.
auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
if (tensorType && tensorType.hasStaticShape() &&
tensorType.getShape()[0] == 1) {
// broadcastMap = affine_map<(d0, d1) -> ()>
// It would affect as broadcast for scalar values in linalg::GenericOp.
AffineMap broadcastMap =
AffineMap::get(rank, 0, {}, rewriter.getContext());
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op.getShift(), loc));
indexingMaps.push_back(broadcastMap);
} else {
genericInputs.push_back(op.getShift());
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
}
}
shiftArg = indexingMaps.size() - 1;
}

// Return the extended Zp to be used in subsequent arithmetic operations.
static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
FailureOr<int64_t> maybeZp, Location loc,
ValueRange blockArgs, int64_t iZpArg) {
Value result;
// The Zp value can be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
// be passed as an input to linalg::GenericOp.
if (failed(maybeZp)) {
result = blockArgs[iZpArg];
auto zpTy = result.getType();
if (zpTy.getIntOrFloatBitWidth() < 32) {
if (zpTy.isUnsignedInteger()) {
result =
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
} else {
result =
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
}
}
} else {
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
result = builder.create<arith::ConstantOp>(
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
}
return result;
}

// Return the i32 outputZp to be used in subsequent arithmetic operations.
static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
FailureOr<int64_t> maybeZp, Location loc,
ValueRange blockArgs, int64_t oZpArg) {
Value result;
// The Zp value can be either constant or non-constant, depending on
// whether dynamic extension is enabled.
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
// be passed as an input to linalg::GenericOp.
if (failed(maybeZp)) {
result = blockArgs[oZpArg];
auto zpTy = result.getType();
if (zpTy.getIntOrFloatBitWidth() < 32) {
if (zpTy.isUnsignedInteger()) {
result =
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
} else {
result =
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
}
} else if (zpTy.getIntOrFloatBitWidth() > 32) {
result =
builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
}
} else {
const int32_t attrBitwidth = 32;
result = builder.create<arith::ConstantOp>(
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
}
return result;
}

class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
public:
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
Expand Down Expand Up @@ -1376,40 +1569,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
}
}

// The shift and multiplier values.
DenseElementsAttr shiftElems;
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant shift input values");
bool isShiftConstant = false;
if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
isShiftConstant = true;

DenseElementsAttr multiplierElems;
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
return rewriter.notifyMatchFailure(
op, "tosa.rescale requires constant multiplier input values");

llvm::SmallVector<int8_t> shiftValues =
llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));

// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
if (shiftValues[i] > 63) {
shiftValues[i] = 0;
multiplierValues[i] = 0;
bool isMultiplierConstant = false;
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
isMultiplierConstant = true;

llvm::SmallVector<int8_t> shiftValues;
llvm::SmallVector<int32_t> multiplierValues;
bool doubleRound;

if (isMultiplierConstant && isShiftConstant) {
shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
// explicit cast is required here
multiplierValues = llvm::to_vector(
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
[](IntegerAttr attr) -> int32_t {
return static_cast<int32_t>(attr.getInt());
}));

// If we shift by more than the bitwidth, this just sets to 0.
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
if (shiftValues[i] > 63) {
shiftValues[i] = 0;
multiplierValues[i] = 0;
}
}
}
// Double round only occurs if shift is greater than 31, check that this
// is ever true.
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
} else
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;

// Double round only occurs if shift is greater than 31, check that this
// is ever true.

bool doubleRound =
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
RoundingMode roundingMode =
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;

Expand All @@ -1421,45 +1617,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
// values in a buffer.
Value multiplierConstant;
int64_t multiplierArg = 0;
if (multiplierValues.size() == 1) {
multiplierConstant = arith::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
} else {
SmallVector<AffineExpr, 2> multiplierExprs{
rewriter.getAffineDimExpr(rank - 1)};
auto multiplierType =
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
rewriter.getI32Type());
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc,
DenseIntElementsAttr::get(multiplierType, multiplierValues)));

indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, multiplierExprs,
rewriter.getContext()));

multiplierArg = indexingMaps.size() - 1;
}
setupLinalgGenericOpInputAndIndexingMapForMultiplier(
rewriter, multiplierValues, genericInputs, indexingMaps,
isMultiplierConstant, op, multiplierConstant, multiplierArg);

// If we are rescaling per-channel then we need to store the shift
// values in a buffer.
Value shiftConstant;
int64_t shiftArg = 0;
if (shiftValues.size() == 1) {
shiftConstant = arith::ConstantOp::create(
rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
} else {
SmallVector<AffineExpr, 2> shiftExprs = {
rewriter.getAffineDimExpr(rank - 1)};
auto shiftType =
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
rewriter.getIntegerType(8));
genericInputs.push_back(arith::ConstantOp::create(
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
/*symbolCount=*/0, shiftExprs,
rewriter.getContext()));
shiftArg = indexingMaps.size() - 1;
setupLinalgGenericOpInputAndIndexingMapForShift(
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
shiftConstant, shiftArg);

// broadcastMap = affine_map<(d0, d1) -> ()>
// It would affect as broadcast for scalar values in linalg::GenericOp.
AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
// The inputZp and outputZp may be either constant or non-constant,
// depending on whether dynamic extension is enabled.
// - If the zp is non-constant, add it as an input to linalg::GenericOp by:
// 1. Pushing it into 'genericInputs'.
// 2. Appending a corresponding affine map to 'indexingMaps'.
int64_t iZpArg = 0;
if (failed(maybeIZp)) {
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
indexingMaps.push_back(broadcastMap);
iZpArg = indexingMaps.size() - 1;
}
int64_t oZpArg = 0;
if (failed(maybeOZp)) {
genericInputs.push_back(
collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
indexingMaps.push_back(broadcastMap);
oZpArg = indexingMaps.size() - 1;
}

// Indexing maps for output values.
Expand All @@ -1479,36 +1671,17 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
Type valueTy = value.getType();

FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp)) {
(void)rewriter.notifyMatchFailure(
op, "input zero point cannot be statically determined");
return;
}

const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
// Extend zeropoint for sub-32bits widths.
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
auto inputZp = arith::ConstantOp::create(
nestedBuilder, loc,
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
*maybeIZp));
auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
nestedLoc, blockArgs, iZpArg);

FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
if (failed(maybeOZp)) {
(void)rewriter.notifyMatchFailure(
op, "output zero point cannot be statically determined");
return;
};
auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
nestedLoc, blockArgs, oZpArg);

IntegerType outIntType =
cast<IntegerType>(blockArgs.back().getType());
unsigned outBitWidth = outIntType.getWidth();
const int32_t outAttrBitwidth = 32;
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
auto outputZp = arith::ConstantOp::create(
nestedBuilder, loc,
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
*maybeOZp));

Value multiplier = multiplierConstant ? multiplierConstant
: blockArgs[multiplierArg];
Expand Down
Loading