Skip to content

Commit 0baeef5

Browse files
committed
[mlir][tosa] Support RescaleOp with dynamic extension in TosaToLinalg
The shift, multiplier, inputZp, and outputZp can be either constant or non-constant, depending on whether dynamic extension is enabled. When these values are non-constant, they are added as inputs to linalg::GenericOp, and corresponding affine maps are appended to the indexingMaps.
1 parent 81035c3 commit 0baeef5

File tree

2 files changed

+317
-88
lines changed

2 files changed

+317
-88
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 261 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
13451345
}
13461346
};
13471347

1348+
// Collapse tensor<1xiN> into tensor<iN>
1349+
// E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1350+
static Value collapse1xNTensorToN(PatternRewriter &rewriter, Value input,
1351+
Location loc) {
1352+
SmallVector<ReassociationExprs, 1> reassociation;
1353+
// Create the collapsed type
1354+
auto inputType = cast<RankedTensorType>(input.getType());
1355+
auto elemType = inputType.getElementType();
1356+
auto collapsedType = RankedTensorType::get({}, elemType);
1357+
// Emit the collapse op
1358+
return rewriter.create<tensor::CollapseShapeOp>(loc, collapsedType, input,
1359+
reassociation);
1360+
}
1361+
1362+
// The multiplier may be either constant or non-constant, depending on
1363+
// whether dynamic extension is enabled.
1364+
// - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1365+
// by:
1366+
// 1. Pushing it into 'genericInputs'.
1367+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1368+
// - If the multiplier is constant, set 'multiplierConstant' instead.
1369+
static void setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1370+
PatternRewriter &rewriter, llvm::SmallVector<int32_t> &multiplierValues,
1371+
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1372+
bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1373+
int64_t &multiplierArg) {
1374+
1375+
auto loc = op.getLoc();
1376+
auto inputTy = cast<ShapedType>(op.getInput().getType());
1377+
unsigned rank = inputTy.getRank();
1378+
SmallVector<AffineExpr, 2> multiplierExprs{
1379+
rewriter.getAffineDimExpr(rank - 1)};
1380+
1381+
if (isConstant) {
1382+
// If we are rescaling per-channel then we need to store the multiplier
1383+
// values in a buffer.
1384+
if (multiplierValues.size() == 1) {
1385+
multiplierConstant = rewriter.create<arith::ConstantOp>(
1386+
loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1387+
} else {
1388+
auto multiplierType =
1389+
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1390+
rewriter.getI32Type());
1391+
genericInputs.push_back(arith::ConstantOp::create(
1392+
rewriter, loc,
1393+
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1394+
1395+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1396+
/*symbolCount=*/0, multiplierExprs,
1397+
rewriter.getContext()));
1398+
}
1399+
} else {
1400+
// If we are not rescaling per-channel then we need to collapse 1xN to N
1401+
// and push broadcastMap.
1402+
auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier().getType());
1403+
if (tensorType && tensorType.hasStaticShape() &&
1404+
tensorType.getShape()[0] == 1) {
1405+
// broadcastMap = affine_map<(d0, d1) -> ()>
1406+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1407+
AffineMap broadcastMap =
1408+
AffineMap::get(rank, 0, {}, rewriter.getContext());
1409+
genericInputs.push_back(
1410+
collapse1xNTensorToN(rewriter, op.getMultiplier(), loc));
1411+
indexingMaps.push_back(broadcastMap);
1412+
} else {
1413+
genericInputs.push_back(op.getMultiplier());
1414+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1415+
/*symbolCount=*/0, multiplierExprs,
1416+
rewriter.getContext()));
1417+
}
1418+
}
1419+
multiplierArg = indexingMaps.size() - 1;
1420+
}
1421+
1422+
// The shift may be either constant or non-constant, depending on
1423+
// whether dynamic extension is enabled.
1424+
// - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1425+
// 1. Pushing it into 'genericInputs'.
1426+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1427+
// - If the shift is constant, set 'shiftConstant' instead.
1428+
static void setupLinalgGenericOpInputAndIndexingMapForShift(
1429+
PatternRewriter &rewriter, llvm::SmallVector<int8_t> &shiftValues,
1430+
SmallVector<Value, 4> &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431+
bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
1432+
int64_t &shiftArg) {
1433+
1434+
auto loc = op.getLoc();
1435+
auto inputTy = cast<ShapedType>(op.getInput().getType());
1436+
unsigned rank = inputTy.getRank();
1437+
SmallVector<AffineExpr, 2> shiftExprs = {rewriter.getAffineDimExpr(rank - 1)};
1438+
1439+
if (isConstant) {
1440+
// If we are rescaling per-channel then we need to store the shift
1441+
// values in a buffer.
1442+
if (shiftValues.size() == 1) {
1443+
shiftConstant = rewriter.create<arith::ConstantOp>(
1444+
loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1445+
} else {
1446+
auto shiftType =
1447+
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1448+
rewriter.getIntegerType(8));
1449+
genericInputs.push_back(arith::ConstantOp::create(
1450+
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1451+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1452+
/*symbolCount=*/0, shiftExprs,
1453+
rewriter.getContext()));
1454+
}
1455+
} else {
1456+
// If we are not rescaling per-channel then we need to collapse 1xN to N
1457+
// and push broadcastMap.
1458+
auto tensorType = dyn_cast<RankedTensorType>(op.getShift().getType());
1459+
if (tensorType && tensorType.hasStaticShape() &&
1460+
tensorType.getShape()[0] == 1) {
1461+
// broadcastMap = affine_map<(d0, d1) -> ()>
1462+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1463+
AffineMap broadcastMap =
1464+
AffineMap::get(rank, 0, {}, rewriter.getContext());
1465+
genericInputs.push_back(
1466+
collapse1xNTensorToN(rewriter, op.getShift(), loc));
1467+
indexingMaps.push_back(broadcastMap);
1468+
} else {
1469+
genericInputs.push_back(op.getShift());
1470+
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1471+
/*symbolCount=*/0, shiftExprs,
1472+
rewriter.getContext()));
1473+
}
1474+
}
1475+
shiftArg = indexingMaps.size() - 1;
1476+
}
1477+
1478+
// Return the extended Zp to be used in subsequent arithmetic operations.
1479+
static Value getExtendInputZp(OpBuilder &builder, Type valueTy,
1480+
FailureOr<int64_t> maybeZp, Location loc,
1481+
ValueRange blockArgs, int64_t iZpArg) {
1482+
Value result;
1483+
// The Zp value can be either constant or non-constant, depending on
1484+
// whether dynamic extension is enabled.
1485+
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1486+
// be passed as an input to linalg::GenericOp.
1487+
if (failed(maybeZp)) {
1488+
result = blockArgs[iZpArg];
1489+
auto zpTy = result.getType();
1490+
if (zpTy.getIntOrFloatBitWidth() < 32) {
1491+
if (zpTy.isUnsignedInteger()) {
1492+
result =
1493+
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
1494+
} else {
1495+
result =
1496+
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
1497+
}
1498+
}
1499+
} else {
1500+
const int32_t bitwidth = valueTy.getIntOrFloatBitWidth();
1501+
// Extend zeropoint for sub-32bits widths.
1502+
const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32;
1503+
result = builder.create<arith::ConstantOp>(
1504+
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
1505+
}
1506+
return result;
1507+
}
1508+
1509+
// Return the i32 outputZp to be used in subsequent arithmetic operations.
1510+
static Value getI32OutputZp(OpBuilder &builder, Type valueTy,
1511+
FailureOr<int64_t> maybeZp, Location loc,
1512+
ValueRange blockArgs, int64_t oZpArg) {
1513+
Value result;
1514+
// The Zp value can be either constant or non-constant, depending on
1515+
// whether dynamic extension is enabled.
1516+
// If 'maybeZp' fails, it indicates that Zp is non-constant and will
1517+
// be passed as an input to linalg::GenericOp.
1518+
if (failed(maybeZp)) {
1519+
result = blockArgs[oZpArg];
1520+
auto zpTy = result.getType();
1521+
if (zpTy.getIntOrFloatBitWidth() < 32) {
1522+
if (zpTy.isUnsignedInteger()) {
1523+
result =
1524+
builder.create<arith::ExtUIOp>(loc, builder.getI32Type(), result);
1525+
} else {
1526+
result =
1527+
builder.create<arith::ExtSIOp>(loc, builder.getI32Type(), result);
1528+
}
1529+
} else if (zpTy.getIntOrFloatBitWidth() > 32) {
1530+
result =
1531+
builder.create<arith::TruncIOp>(loc, builder.getI32Type(), result);
1532+
}
1533+
} else {
1534+
const int32_t attrBitwidth = 32;
1535+
result = builder.create<arith::ConstantOp>(
1536+
loc, IntegerAttr::get(builder.getIntegerType(attrBitwidth), *maybeZp));
1537+
}
1538+
return result;
1539+
}
1540+
13481541
class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13491542
public:
13501543
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1376,40 +1569,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
13761569
}
13771570
}
13781571

1379-
// The shift and multiplier values.
13801572
DenseElementsAttr shiftElems;
1381-
if (!matchPattern(op.getShift(), m_Constant(&shiftElems)))
1382-
return rewriter.notifyMatchFailure(
1383-
op, "tosa.rescale requires constant shift input values");
1573+
bool isShiftConstant = false;
1574+
if (matchPattern(op.getShift(), m_Constant(&shiftElems)))
1575+
isShiftConstant = true;
13841576

13851577
DenseElementsAttr multiplierElems;
1386-
if (!matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1387-
return rewriter.notifyMatchFailure(
1388-
op, "tosa.rescale requires constant multiplier input values");
1389-
1390-
llvm::SmallVector<int8_t> shiftValues =
1391-
llvm::to_vector(shiftElems.getValues<int8_t>());
1392-
// explicit cast is required here
1393-
llvm::SmallVector<int32_t> multiplierValues = llvm::to_vector(
1394-
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1395-
[](IntegerAttr attr) -> int32_t {
1396-
return static_cast<int32_t>(attr.getInt());
1397-
}));
1398-
1399-
// If we shift by more than the bitwidth, this just sets to 0.
1400-
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1401-
if (shiftValues[i] > 63) {
1402-
shiftValues[i] = 0;
1403-
multiplierValues[i] = 0;
1578+
bool isMultiplierConstant = false;
1579+
if (matchPattern(op.getMultiplier(), m_Constant(&multiplierElems)))
1580+
isMultiplierConstant = true;
1581+
1582+
llvm::SmallVector<int8_t> shiftValues;
1583+
llvm::SmallVector<int32_t> multiplierValues;
1584+
bool doubleRound;
1585+
1586+
if (isMultiplierConstant && isShiftConstant) {
1587+
shiftValues = llvm::to_vector(shiftElems.getValues<int8_t>());
1588+
// explicit cast is required here
1589+
multiplierValues = llvm::to_vector(
1590+
llvm::map_range(multiplierElems.getValues<IntegerAttr>(),
1591+
[](IntegerAttr attr) -> int32_t {
1592+
return static_cast<int32_t>(attr.getInt());
1593+
}));
1594+
1595+
// If we shift by more than the bitwidth, this just sets to 0.
1596+
for (int i = 0, s = multiplierValues.size(); i < s; i++) {
1597+
if (shiftValues[i] > 63) {
1598+
shiftValues[i] = 0;
1599+
multiplierValues[i] = 0;
1600+
}
14041601
}
1405-
}
1602+
// Double round only occurs if shift is greater than 31, check that this
1603+
// is ever true.
1604+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1605+
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
1606+
} else
1607+
doubleRound = op.getRoundingMode() == RoundingMode::DOUBLE_ROUND;
14061608

1407-
// Double round only occurs if shift is greater than 31, check that this
1408-
// is ever true.
1409-
1410-
bool doubleRound =
1411-
op.getRoundingMode() == RoundingMode::DOUBLE_ROUND &&
1412-
llvm::any_of(shiftValues, [](int32_t v) { return v > 31; });
14131609
RoundingMode roundingMode =
14141610
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
14151611

@@ -1421,45 +1617,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14211617
// values in a buffer.
14221618
Value multiplierConstant;
14231619
int64_t multiplierArg = 0;
1424-
if (multiplierValues.size() == 1) {
1425-
multiplierConstant = arith::ConstantOp::create(
1426-
rewriter, loc, rewriter.getI32IntegerAttr(multiplierValues.front()));
1427-
} else {
1428-
SmallVector<AffineExpr, 2> multiplierExprs{
1429-
rewriter.getAffineDimExpr(rank - 1)};
1430-
auto multiplierType =
1431-
RankedTensorType::get({static_cast<int64_t>(multiplierValues.size())},
1432-
rewriter.getI32Type());
1433-
genericInputs.push_back(arith::ConstantOp::create(
1434-
rewriter, loc,
1435-
DenseIntElementsAttr::get(multiplierType, multiplierValues)));
1436-
1437-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1438-
/*symbolCount=*/0, multiplierExprs,
1439-
rewriter.getContext()));
1440-
1441-
multiplierArg = indexingMaps.size() - 1;
1442-
}
1620+
setupLinalgGenericOpInputAndIndexingMapForMultiplier(
1621+
rewriter, multiplierValues, genericInputs, indexingMaps,
1622+
isMultiplierConstant, op, multiplierConstant, multiplierArg);
14431623

14441624
// If we are rescaling per-channel then we need to store the shift
14451625
// values in a buffer.
14461626
Value shiftConstant;
14471627
int64_t shiftArg = 0;
1448-
if (shiftValues.size() == 1) {
1449-
shiftConstant = arith::ConstantOp::create(
1450-
rewriter, loc, rewriter.getI8IntegerAttr(shiftValues.front()));
1451-
} else {
1452-
SmallVector<AffineExpr, 2> shiftExprs = {
1453-
rewriter.getAffineDimExpr(rank - 1)};
1454-
auto shiftType =
1455-
RankedTensorType::get({static_cast<int64_t>(shiftValues.size())},
1456-
rewriter.getIntegerType(8));
1457-
genericInputs.push_back(arith::ConstantOp::create(
1458-
rewriter, loc, DenseIntElementsAttr::get(shiftType, shiftValues)));
1459-
indexingMaps.push_back(AffineMap::get(/*dimCount=*/rank,
1460-
/*symbolCount=*/0, shiftExprs,
1461-
rewriter.getContext()));
1462-
shiftArg = indexingMaps.size() - 1;
1628+
setupLinalgGenericOpInputAndIndexingMapForShift(
1629+
rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1630+
shiftConstant, shiftArg);
1631+
1632+
// broadcastMap = affine_map<(d0, d1) -> ()>
1633+
// It would affect as broadcast for scalar values in linalg::GenericOp.
1634+
AffineMap broadcastMap = AffineMap::get(rank, 0, {}, rewriter.getContext());
1635+
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1636+
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1637+
// The inputZp and outputZp may be either constant or non-constant,
1638+
// depending on whether dynamic extension is enabled.
1639+
// - If the zp is non-constant, add it as an input to linalg::GenericOp by:
1640+
// 1. Pushing it into 'genericInputs'.
1641+
// 2. Appending a corresponding affine map to 'indexingMaps'.
1642+
int64_t iZpArg = 0;
1643+
if (failed(maybeIZp)) {
1644+
genericInputs.push_back(
1645+
collapse1xNTensorToN(rewriter, op->getOperand(3), loc));
1646+
indexingMaps.push_back(broadcastMap);
1647+
iZpArg = indexingMaps.size() - 1;
1648+
}
1649+
int64_t oZpArg = 0;
1650+
if (failed(maybeOZp)) {
1651+
genericInputs.push_back(
1652+
collapse1xNTensorToN(rewriter, op->getOperand(4), loc));
1653+
indexingMaps.push_back(broadcastMap);
1654+
oZpArg = indexingMaps.size() - 1;
14631655
}
14641656

14651657
// Indexing maps for output values.
@@ -1479,39 +1671,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
14791671
Type valueTy = value.getType();
14801672

14811673
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
1482-
if (failed(maybeIZp)) {
1483-
(void)rewriter.notifyMatchFailure(
1484-
op, "input zero point cannot be statically determined");
1485-
return;
1486-
}
1487-
1488-
const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth();
1489-
// Extend zeropoint for sub-32bits widths.
1490-
const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32;
1491-
auto inputZp = arith::ConstantOp::create(
1492-
nestedBuilder, loc,
1493-
IntegerAttr::get(rewriter.getIntegerType(inAttrBitwidth),
1494-
*maybeIZp));
1674+
auto inputZp = getExtendInputZp(nestedBuilder, valueTy, maybeIZp,
1675+
nestedLoc, blockArgs, iZpArg);
14951676

14961677
FailureOr<int64_t> maybeOZp = op.getOutputZeroPoint();
1497-
if (failed(maybeOZp)) {
1498-
(void)rewriter.notifyMatchFailure(
1499-
op, "output zero point cannot be statically determined");
1500-
return;
1501-
};
1678+
auto outputZp = getI32OutputZp(nestedBuilder, valueTy, maybeOZp,
1679+
nestedLoc, blockArgs, oZpArg);
15021680

15031681
IntegerType outIntType =
15041682
cast<IntegerType>(blockArgs.back().getType());
15051683
unsigned outBitWidth = outIntType.getWidth();
1506-
const int32_t outAttrBitwidth = 32;
15071684
assert(outBitWidth <= 32 && "Unexpected output zeropoint bitwidth");
1508-
auto outputZp = arith::ConstantOp::create(
1509-
nestedBuilder, loc,
1510-
IntegerAttr::get(rewriter.getIntegerType(outAttrBitwidth),
1511-
*maybeOZp));
15121685

1513-
Value multiplier = multiplierConstant ? multiplierConstant
1514-
: blockArgs[multiplierArg];
1686+
Value multiplier =
1687+
multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
15151688
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
15161689

15171690
if (valueTy.isUnsignedInteger()) {

0 commit comments

Comments
 (0)