@@ -1345,6 +1345,199 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
1345
1345
}
1346
1346
};
1347
1347
1348
+ // Collapse tensor<1xiN> into tensor<iN>
1349
+ // E.g. tensor.collapse_shape %arg1 [] : tensor<1xi16> into tensor<i16>
1350
+ static Value collapse1xNTensorToN (PatternRewriter &rewriter, Value input,
1351
+ Location loc) {
1352
+ SmallVector<ReassociationExprs, 1 > reassociation;
1353
+ // Create the collapsed type
1354
+ auto inputType = cast<RankedTensorType>(input.getType ());
1355
+ auto elemType = inputType.getElementType ();
1356
+ auto collapsedType = RankedTensorType::get ({}, elemType);
1357
+ // Emit the collapse op
1358
+ return rewriter.create <tensor::CollapseShapeOp>(loc, collapsedType, input,
1359
+ reassociation);
1360
+ }
1361
+
1362
+ // The multiplier may be either constant or non-constant, depending on
1363
+ // whether dynamic extension is enabled.
1364
+ // - If the multiplier is non-constant, add it as an input to linalg::GenericOp
1365
+ // by:
1366
+ // 1. Pushing it into 'genericInputs'.
1367
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1368
+ // - If the multiplier is constant, set 'multiplierConstant' instead.
1369
+ static void setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1370
+ PatternRewriter &rewriter, llvm::SmallVector<int32_t > &multiplierValues,
1371
+ SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1372
+ bool isConstant, tosa::RescaleOp op, Value &multiplierConstant,
1373
+ int64_t &multiplierArg) {
1374
+
1375
+ auto loc = op.getLoc ();
1376
+ auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1377
+ unsigned rank = inputTy.getRank ();
1378
+ SmallVector<AffineExpr, 2 > multiplierExprs{
1379
+ rewriter.getAffineDimExpr (rank - 1 )};
1380
+
1381
+ if (isConstant) {
1382
+ // If we are rescaling per-channel then we need to store the multiplier
1383
+ // values in a buffer.
1384
+ if (multiplierValues.size () == 1 ) {
1385
+ multiplierConstant = rewriter.create <arith::ConstantOp>(
1386
+ loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1387
+ } else {
1388
+ auto multiplierType =
1389
+ RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1390
+ rewriter.getI32Type ());
1391
+ genericInputs.push_back (arith::ConstantOp::create (
1392
+ rewriter, loc,
1393
+ DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1394
+
1395
+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1396
+ /* symbolCount=*/ 0 , multiplierExprs,
1397
+ rewriter.getContext ()));
1398
+ }
1399
+ } else {
1400
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
1401
+ // and push broadcastMap.
1402
+ auto tensorType = dyn_cast<RankedTensorType>(op.getMultiplier ().getType ());
1403
+ if (tensorType && tensorType.hasStaticShape () &&
1404
+ tensorType.getShape ()[0 ] == 1 ) {
1405
+ // broadcastMap = affine_map<(d0, d1) -> ()>
1406
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1407
+ AffineMap broadcastMap =
1408
+ AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1409
+ genericInputs.push_back (
1410
+ collapse1xNTensorToN (rewriter, op.getMultiplier (), loc));
1411
+ indexingMaps.push_back (broadcastMap);
1412
+ } else {
1413
+ genericInputs.push_back (op.getMultiplier ());
1414
+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1415
+ /* symbolCount=*/ 0 , multiplierExprs,
1416
+ rewriter.getContext ()));
1417
+ }
1418
+ }
1419
+ multiplierArg = indexingMaps.size () - 1 ;
1420
+ }
1421
+
1422
+ // The shift may be either constant or non-constant, depending on
1423
+ // whether dynamic extension is enabled.
1424
+ // - If the shift is non-constant, add it as an input to linalg::GenericOp by:
1425
+ // 1. Pushing it into 'genericInputs'.
1426
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1427
+ // - If the shift is constant, set 'shiftConstant' instead.
1428
+ static void setupLinalgGenericOpInputAndIndexingMapForShift (
1429
+ PatternRewriter &rewriter, llvm::SmallVector<int8_t > &shiftValues,
1430
+ SmallVector<Value, 4 > &genericInputs, SmallVector<AffineMap> &indexingMaps,
1431
+ bool isConstant, tosa::RescaleOp op, Value &shiftConstant,
1432
+ int64_t &shiftArg) {
1433
+
1434
+ auto loc = op.getLoc ();
1435
+ auto inputTy = cast<ShapedType>(op.getInput ().getType ());
1436
+ unsigned rank = inputTy.getRank ();
1437
+ SmallVector<AffineExpr, 2 > shiftExprs = {rewriter.getAffineDimExpr (rank - 1 )};
1438
+
1439
+ if (isConstant) {
1440
+ // If we are rescaling per-channel then we need to store the shift
1441
+ // values in a buffer.
1442
+ if (shiftValues.size () == 1 ) {
1443
+ shiftConstant = rewriter.create <arith::ConstantOp>(
1444
+ loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1445
+ } else {
1446
+ auto shiftType =
1447
+ RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1448
+ rewriter.getIntegerType (8 ));
1449
+ genericInputs.push_back (arith::ConstantOp::create (
1450
+ rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1451
+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1452
+ /* symbolCount=*/ 0 , shiftExprs,
1453
+ rewriter.getContext ()));
1454
+ }
1455
+ } else {
1456
+ // If we are not rescaling per-channel then we need to collapse 1xN to N
1457
+ // and push broadcastMap.
1458
+ auto tensorType = dyn_cast<RankedTensorType>(op.getShift ().getType ());
1459
+ if (tensorType && tensorType.hasStaticShape () &&
1460
+ tensorType.getShape ()[0 ] == 1 ) {
1461
+ // broadcastMap = affine_map<(d0, d1) -> ()>
1462
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1463
+ AffineMap broadcastMap =
1464
+ AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1465
+ genericInputs.push_back (
1466
+ collapse1xNTensorToN (rewriter, op.getShift (), loc));
1467
+ indexingMaps.push_back (broadcastMap);
1468
+ } else {
1469
+ genericInputs.push_back (op.getShift ());
1470
+ indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1471
+ /* symbolCount=*/ 0 , shiftExprs,
1472
+ rewriter.getContext ()));
1473
+ }
1474
+ }
1475
+ shiftArg = indexingMaps.size () - 1 ;
1476
+ }
1477
+
1478
+ // Return the extended Zp to be used in subsequent arithmetic operations.
1479
+ static Value getExtendInputZp (OpBuilder &builder, Type valueTy,
1480
+ FailureOr<int64_t > maybeZp, Location loc,
1481
+ ValueRange blockArgs, int64_t iZpArg) {
1482
+ Value result;
1483
+ // The Zp value can be either constant or non-constant, depending on
1484
+ // whether dynamic extension is enabled.
1485
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1486
+ // be passed as an input to linalg::GenericOp.
1487
+ if (failed (maybeZp)) {
1488
+ result = blockArgs[iZpArg];
1489
+ auto zpTy = result.getType ();
1490
+ if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1491
+ if (zpTy.isUnsignedInteger ()) {
1492
+ result =
1493
+ builder.create <arith::ExtUIOp>(loc, builder.getI32Type (), result);
1494
+ } else {
1495
+ result =
1496
+ builder.create <arith::ExtSIOp>(loc, builder.getI32Type (), result);
1497
+ }
1498
+ }
1499
+ } else {
1500
+ const int32_t bitwidth = valueTy.getIntOrFloatBitWidth ();
1501
+ // Extend zeropoint for sub-32bits widths.
1502
+ const int32_t attrBitwidth = bitwidth > 32 ? bitwidth : 32 ;
1503
+ result = builder.create <arith::ConstantOp>(
1504
+ loc, IntegerAttr::get (builder.getIntegerType (attrBitwidth), *maybeZp));
1505
+ }
1506
+ return result;
1507
+ }
1508
+
1509
+ // Return the i32 outputZp to be used in subsequent arithmetic operations.
1510
+ static Value getI32OutputZp (OpBuilder &builder, Type valueTy,
1511
+ FailureOr<int64_t > maybeZp, Location loc,
1512
+ ValueRange blockArgs, int64_t oZpArg) {
1513
+ Value result;
1514
+ // The Zp value can be either constant or non-constant, depending on
1515
+ // whether dynamic extension is enabled.
1516
+ // If 'maybeZp' fails, it indicates that Zp is non-constant and will
1517
+ // be passed as an input to linalg::GenericOp.
1518
+ if (failed (maybeZp)) {
1519
+ result = blockArgs[oZpArg];
1520
+ auto zpTy = result.getType ();
1521
+ if (zpTy.getIntOrFloatBitWidth () < 32 ) {
1522
+ if (zpTy.isUnsignedInteger ()) {
1523
+ result =
1524
+ builder.create <arith::ExtUIOp>(loc, builder.getI32Type (), result);
1525
+ } else {
1526
+ result =
1527
+ builder.create <arith::ExtSIOp>(loc, builder.getI32Type (), result);
1528
+ }
1529
+ } else if (zpTy.getIntOrFloatBitWidth () > 32 ) {
1530
+ result =
1531
+ builder.create <arith::TruncIOp>(loc, builder.getI32Type (), result);
1532
+ }
1533
+ } else {
1534
+ const int32_t attrBitwidth = 32 ;
1535
+ result = builder.create <arith::ConstantOp>(
1536
+ loc, IntegerAttr::get (builder.getIntegerType (attrBitwidth), *maybeZp));
1537
+ }
1538
+ return result;
1539
+ }
1540
+
1348
1541
class RescaleConverter : public OpRewritePattern <tosa::RescaleOp> {
1349
1542
public:
1350
1543
using OpRewritePattern<tosa::RescaleOp>::OpRewritePattern;
@@ -1376,40 +1569,43 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1376
1569
}
1377
1570
}
1378
1571
1379
- // The shift and multiplier values.
1380
1572
DenseElementsAttr shiftElems;
1381
- if (! matchPattern (op. getShift (), m_Constant (&shiftElems)))
1382
- return rewriter. notifyMatchFailure (
1383
- op, " tosa.rescale requires constant shift input values " ) ;
1573
+ bool isShiftConstant = false ;
1574
+ if ( matchPattern (op. getShift (), m_Constant (&shiftElems)))
1575
+ isShiftConstant = true ;
1384
1576
1385
1577
DenseElementsAttr multiplierElems;
1386
- if (!matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1387
- return rewriter.notifyMatchFailure (
1388
- op, " tosa.rescale requires constant multiplier input values" );
1389
-
1390
- llvm::SmallVector<int8_t > shiftValues =
1391
- llvm::to_vector (shiftElems.getValues <int8_t >());
1392
- // explicit cast is required here
1393
- llvm::SmallVector<int32_t > multiplierValues = llvm::to_vector (
1394
- llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1395
- [](IntegerAttr attr) -> int32_t {
1396
- return static_cast <int32_t >(attr.getInt ());
1397
- }));
1398
-
1399
- // If we shift by more than the bitwidth, this just sets to 0.
1400
- for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1401
- if (shiftValues[i] > 63 ) {
1402
- shiftValues[i] = 0 ;
1403
- multiplierValues[i] = 0 ;
1578
+ bool isMultiplierConstant = false ;
1579
+ if (matchPattern (op.getMultiplier (), m_Constant (&multiplierElems)))
1580
+ isMultiplierConstant = true ;
1581
+
1582
+ llvm::SmallVector<int8_t > shiftValues;
1583
+ llvm::SmallVector<int32_t > multiplierValues;
1584
+ bool doubleRound;
1585
+
1586
+ if (isMultiplierConstant && isShiftConstant) {
1587
+ shiftValues = llvm::to_vector (shiftElems.getValues <int8_t >());
1588
+ // explicit cast is required here
1589
+ multiplierValues = llvm::to_vector (
1590
+ llvm::map_range (multiplierElems.getValues <IntegerAttr>(),
1591
+ [](IntegerAttr attr) -> int32_t {
1592
+ return static_cast <int32_t >(attr.getInt ());
1593
+ }));
1594
+
1595
+ // If we shift by more than the bitwidth, this just sets to 0.
1596
+ for (int i = 0 , s = multiplierValues.size (); i < s; i++) {
1597
+ if (shiftValues[i] > 63 ) {
1598
+ shiftValues[i] = 0 ;
1599
+ multiplierValues[i] = 0 ;
1600
+ }
1404
1601
}
1405
- }
1602
+ // Double round only occurs if shift is greater than 31, check that this
1603
+ // is ever true.
1604
+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1605
+ llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
1606
+ } else
1607
+ doubleRound = op.getRoundingMode () == RoundingMode::DOUBLE_ROUND;
1406
1608
1407
- // Double round only occurs if shift is greater than 31, check that this
1408
- // is ever true.
1409
-
1410
- bool doubleRound =
1411
- op.getRoundingMode () == RoundingMode::DOUBLE_ROUND &&
1412
- llvm::any_of (shiftValues, [](int32_t v) { return v > 31 ; });
1413
1609
RoundingMode roundingMode =
1414
1610
doubleRound ? RoundingMode::DOUBLE_ROUND : RoundingMode::SINGLE_ROUND;
1415
1611
@@ -1421,45 +1617,41 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1421
1617
// values in a buffer.
1422
1618
Value multiplierConstant;
1423
1619
int64_t multiplierArg = 0 ;
1424
- if (multiplierValues.size () == 1 ) {
1425
- multiplierConstant = arith::ConstantOp::create (
1426
- rewriter, loc, rewriter.getI32IntegerAttr (multiplierValues.front ()));
1427
- } else {
1428
- SmallVector<AffineExpr, 2 > multiplierExprs{
1429
- rewriter.getAffineDimExpr (rank - 1 )};
1430
- auto multiplierType =
1431
- RankedTensorType::get ({static_cast <int64_t >(multiplierValues.size ())},
1432
- rewriter.getI32Type ());
1433
- genericInputs.push_back (arith::ConstantOp::create (
1434
- rewriter, loc,
1435
- DenseIntElementsAttr::get (multiplierType, multiplierValues)));
1436
-
1437
- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1438
- /* symbolCount=*/ 0 , multiplierExprs,
1439
- rewriter.getContext ()));
1440
-
1441
- multiplierArg = indexingMaps.size () - 1 ;
1442
- }
1620
+ setupLinalgGenericOpInputAndIndexingMapForMultiplier (
1621
+ rewriter, multiplierValues, genericInputs, indexingMaps,
1622
+ isMultiplierConstant, op, multiplierConstant, multiplierArg);
1443
1623
1444
1624
// If we are rescaling per-channel then we need to store the shift
1445
1625
// values in a buffer.
1446
1626
Value shiftConstant;
1447
1627
int64_t shiftArg = 0 ;
1448
- if (shiftValues.size () == 1 ) {
1449
- shiftConstant = arith::ConstantOp::create (
1450
- rewriter, loc, rewriter.getI8IntegerAttr (shiftValues.front ()));
1451
- } else {
1452
- SmallVector<AffineExpr, 2 > shiftExprs = {
1453
- rewriter.getAffineDimExpr (rank - 1 )};
1454
- auto shiftType =
1455
- RankedTensorType::get ({static_cast <int64_t >(shiftValues.size ())},
1456
- rewriter.getIntegerType (8 ));
1457
- genericInputs.push_back (arith::ConstantOp::create (
1458
- rewriter, loc, DenseIntElementsAttr::get (shiftType, shiftValues)));
1459
- indexingMaps.push_back (AffineMap::get (/* dimCount=*/ rank,
1460
- /* symbolCount=*/ 0 , shiftExprs,
1461
- rewriter.getContext ()));
1462
- shiftArg = indexingMaps.size () - 1 ;
1628
+ setupLinalgGenericOpInputAndIndexingMapForShift (
1629
+ rewriter, shiftValues, genericInputs, indexingMaps, isShiftConstant, op,
1630
+ shiftConstant, shiftArg);
1631
+
1632
+ // broadcastMap = affine_map<(d0, d1) -> ()>
1633
+ // It would affect as broadcast for scalar values in linalg::GenericOp.
1634
+ AffineMap broadcastMap = AffineMap::get (rank, 0 , {}, rewriter.getContext ());
1635
+ FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1636
+ FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1637
+ // The inputZp and outputZp may be either constant or non-constant,
1638
+ // depending on whether dynamic extension is enabled.
1639
+ // - If the zp is non-constant, add it as an input to linalg::GenericOp by:
1640
+ // 1. Pushing it into 'genericInputs'.
1641
+ // 2. Appending a corresponding affine map to 'indexingMaps'.
1642
+ int64_t iZpArg = 0 ;
1643
+ if (failed (maybeIZp)) {
1644
+ genericInputs.push_back (
1645
+ collapse1xNTensorToN (rewriter, op->getOperand (3 ), loc));
1646
+ indexingMaps.push_back (broadcastMap);
1647
+ iZpArg = indexingMaps.size () - 1 ;
1648
+ }
1649
+ int64_t oZpArg = 0 ;
1650
+ if (failed (maybeOZp)) {
1651
+ genericInputs.push_back (
1652
+ collapse1xNTensorToN (rewriter, op->getOperand (4 ), loc));
1653
+ indexingMaps.push_back (broadcastMap);
1654
+ oZpArg = indexingMaps.size () - 1 ;
1463
1655
}
1464
1656
1465
1657
// Indexing maps for output values.
@@ -1479,39 +1671,20 @@ class RescaleConverter : public OpRewritePattern<tosa::RescaleOp> {
1479
1671
Type valueTy = value.getType ();
1480
1672
1481
1673
FailureOr<int64_t > maybeIZp = op.getInputZeroPoint ();
1482
- if (failed (maybeIZp)) {
1483
- (void )rewriter.notifyMatchFailure (
1484
- op, " input zero point cannot be statically determined" );
1485
- return ;
1486
- }
1487
-
1488
- const int32_t inBitwidth = valueTy.getIntOrFloatBitWidth ();
1489
- // Extend zeropoint for sub-32bits widths.
1490
- const int32_t inAttrBitwidth = inBitwidth > 32 ? inBitwidth : 32 ;
1491
- auto inputZp = arith::ConstantOp::create (
1492
- nestedBuilder, loc,
1493
- IntegerAttr::get (rewriter.getIntegerType (inAttrBitwidth),
1494
- *maybeIZp));
1674
+ auto inputZp = getExtendInputZp (nestedBuilder, valueTy, maybeIZp,
1675
+ nestedLoc, blockArgs, iZpArg);
1495
1676
1496
1677
FailureOr<int64_t > maybeOZp = op.getOutputZeroPoint ();
1497
- if (failed (maybeOZp)) {
1498
- (void )rewriter.notifyMatchFailure (
1499
- op, " output zero point cannot be statically determined" );
1500
- return ;
1501
- };
1678
+ auto outputZp = getI32OutputZp (nestedBuilder, valueTy, maybeOZp,
1679
+ nestedLoc, blockArgs, oZpArg);
1502
1680
1503
1681
IntegerType outIntType =
1504
1682
cast<IntegerType>(blockArgs.back ().getType ());
1505
1683
unsigned outBitWidth = outIntType.getWidth ();
1506
- const int32_t outAttrBitwidth = 32 ;
1507
1684
assert (outBitWidth <= 32 && " Unexpected output zeropoint bitwidth" );
1508
- auto outputZp = arith::ConstantOp::create (
1509
- nestedBuilder, loc,
1510
- IntegerAttr::get (rewriter.getIntegerType (outAttrBitwidth),
1511
- *maybeOZp));
1512
1685
1513
- Value multiplier = multiplierConstant ? multiplierConstant
1514
- : blockArgs[multiplierArg];
1686
+ Value multiplier =
1687
+ multiplierConstant ? multiplierConstant : blockArgs[multiplierArg];
1515
1688
Value shift = shiftConstant ? shiftConstant : blockArgs[shiftArg];
1516
1689
1517
1690
if (valueTy.isUnsignedInteger ()) {
0 commit comments