Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,20 @@ struct AttentionOpConversion
loc, rewriter.getF32Type(), dimInt);
Value scale = rewriter.createOrFold<math::RsqrtOp>(loc, dimFloat);

int64_t numBatches = op.getQueryType().getRank() - 2;

// When the TMTensor op is marked causal, fuse the mask into the
// attention region body using iree_linalg_ext.index ops and drop the
// materialized mask operand.
bool causal = op.getIsCausal().value_or(false);
if (causal) {
optionalMask = std::nullopt;
}

// Add batches to standard attention indexing maps.
SmallVector<AffineMap> indexingMaps =
getStandardAttentionIndexingMaps(ctx, optionalMask.has_value());

int64_t numBatches = op.getQueryType().getRank() - 2;
for (AffineMap &map : indexingMaps) {
map = map.shiftDims(numBatches);
if (map.getNumResults() == 0) {
Expand All @@ -196,7 +205,29 @@ struct AttentionOpConversion
block->addArgument(rewriter.getF32Type(), loc);
rewriter.setInsertionPoint(block, block->begin());

IREE::LinalgExt::YieldOp::create(rewriter, loc, block->getArgument(0));
if (causal) {
// In the standard layout after shiftDims(numBatches):
// m = numBatches, k2 = numBatches + 3.
int64_t mDim = numBatches;
int64_t k2Dim = numBatches + 3;

Value mIdx = IREE::LinalgExt::IndexOp::create(
rewriter, loc, rewriter.getIndexType(), mDim);
Value k2Idx = IREE::LinalgExt::IndexOp::create(
rewriter, loc, rewriter.getIndexType(), k2Dim);
Value cmp = arith::CmpIOp::create(
rewriter, loc, arith::CmpIPredicate::ugt, k2Idx, mIdx);
// Use the element type of the score (f32).
Value negInf = arith::ConstantOp::create(
rewriter, loc,
rewriter.getFloatAttr(rewriter.getF32Type(), -INFINITY));
Value score = block->getArgument(0);
Value masked =
arith::SelectOp::create(rewriter, loc, cmp, negInf, score);
IREE::LinalgExt::YieldOp::create(rewriter, loc, masked);
} else {
IREE::LinalgExt::YieldOp::create(rewriter, loc, block->getArgument(0));
}
}

rewriter.replaceOp(op, attention.getResult(0));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ static Value computeMatmul(OpBuilder &builder, Location loc, AffineMap lhsMap,
}

static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc,
Region &region, Value value) {
Region &region, AffineMap sMap,
Value value) {
auto rank = cast<RankedTensorType>(value.getType()).getRank();
AffineMap identityMap =
AffineMap::getMultiDimIdentityMap(rank, builder.getContext());
Expand All @@ -199,6 +200,34 @@ static Value applyPostQKMatmulElementwise(OpBuilder &builder, Location loc,
value, indexingMaps, iteratorTypes);
auto &dstRegion = genericOp.getRegion();
builder.cloneRegionBefore(region, dstRegion, dstRegion.end());

// Build a mapping from attention iteration domain dim -> S tensor dim.
// The linalg.generic uses an identity map over S, so linalg iteration
// dim i == S tensor dim i.
DenseMap<int64_t, int64_t> attentionDimToSDim;
for (auto [sIdx, expr] : llvm::enumerate(sMap.getResults())) {
attentionDimToSDim[cast<AffineDimExpr>(expr).getPosition()] = sIdx;
}

// Replace iree_linalg_ext.index ops with linalg.index ops.
SmallVector<IREE::LinalgExt::IndexOp> indexOps;
for (auto indexOp : dstRegion.back().getOps<IREE::LinalgExt::IndexOp>()) {
indexOps.push_back(indexOp);
}
{
OpBuilder::InsertionGuard guard(builder);
for (auto indexOp : indexOps) {
auto it = attentionDimToSDim.find(indexOp.getDim());
assert(it != attentionDimToSDim.end() &&
"index op dim not found in S map");
builder.setInsertionPoint(indexOp);
Value linalgIdx =
linalg::IndexOp::create(builder, loc, it->second)->getResult(0);
indexOp.replaceAllUsesWith(linalgIdx);
indexOp.erase();
}
}

{
OpBuilder::InsertionGuard withinRegion(builder);
builder.setInsertionPoint(dstRegion.back().getTerminator());
Expand Down Expand Up @@ -350,7 +379,7 @@ Value computeQKAndElementwise(Location loc, OpBuilder &b, Value query,
s.getDefiningOp()->setAttrs(qkAttrs);
}

s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, s);
s = applyPostQKMatmulElementwise(b, loc, elementwiseRegion, sMap, s);

if (lowPrecision) {
// For low bit-depth types we perform post Q @ K scaling. This is to avoid
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3097,19 +3097,25 @@ CustomOp::reifyResultShapes(OpBuilder &builder,

LogicalResult IREE::LinalgExt::IndexOp::verify() {
auto parentOp = getOperation()->getParentOp();
if (!isa<CustomOp, AttentionOp>(parentOp)) {
if (!isa<CustomOp, AttentionOp, OnlineAttentionOp>(parentOp)) {
return emitOpError(
"expected parent op to be one of `iree_linalg_ext.custom_op`, "
"`iree_linalg_ext.attention`");
"`iree_linalg_ext.attention`, `iree_linalg_ext.online_attention`");
}
auto customOp = dyn_cast<CustomOp>(parentOp);
auto attentionOp = dyn_cast<AttentionOp>(parentOp);
int64_t numLoops =
customOp ? customOp.getNumLoops() : attentionOp.getNumLoops();
TypeSwitch<Operation *, int64_t>(parentOp)
.Case<CustomOp>(
[](CustomOp op) -> int64_t { return op.getNumLoops(); })
.Case<AttentionOp>([](AttentionOp op) -> int64_t {
return op.getIterationDomainRank();
})
.Case<OnlineAttentionOp>([](OnlineAttentionOp op) -> int64_t {
return op.getIterationDomainRank();
});
if (numLoops <= getDim()) {
return emitOpError("expected dim (")
<< getDim() << ") to be lower than the number of loops (" << numLoops
<< ") of the enclosing CustomOp/AttentionOp";
<< ") of the enclosing operation";
}
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def IREELinalgExt_IndexOp : IREELinalgExt_PureOp<"index", [Pure]>,
This operation is a mirror of `linalg.index` operation and has the same
semantics, except that `linalg.index` enforces that the parent op is a
`LinalgOp`, and the `iree_linalg_ext.index` operation enforces that the
parent op is one of `IREE::LinalgExt::CustomOp` or `IREE::LinalgExt::AttentionOp`.
parent op is one of `IREE::LinalgExt::CustomOp`,
`IREE::LinalgExt::AttentionOp`, or `IREE::LinalgExt::OnlineAttentionOp`.
}];

let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2770,19 +2770,35 @@ SmallVector<utils::IteratorType> AttentionOp::getLoopIteratorTypes() {
getKeyMap(), getValueMap(), getOutputMap());
}

static void offsetAttentionIndices(OpBuilder &b, Region &body,
ArrayRef<OpFoldResult> offsets) {
IRRewriter rewriter(b);
for (auto indexOp : body.getOps<IREE::LinalgExt::IndexOp>()) {
if (indexOp.getDim() >= offsets.size() || !offsets[indexOp.getDim()]) {
continue;
}
OpBuilder::InsertionGuard guard(b);
rewriter.setInsertionPointAfter(indexOp);
AffineExpr index, offset;
bindDims(b.getContext(), index, offset);
OpFoldResult applied = affine::makeComposedFoldedAffineApply(
rewriter, indexOp.getLoc(), index + offset,
{getAsOpFoldResult(indexOp.getResult()), offsets[indexOp.getDim()]});
Value materialized =
getValueOrCreateConstantIndexOp(b, indexOp.getLoc(), applied);
rewriter.replaceUsesWithIf(indexOp, materialized, [&](OpOperand &use) {
return use.getOwner() != materialized.getDefiningOp();
});
}
}

FailureOr<TilingResult>
AttentionOp::getTiledImplementation(OpBuilder &builder,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes) {
assert(offsets.size() == getIterationDomainRank());
assert(sizes.size() == getIterationDomainRank());

// TODO: Add support for linalg_ext.index operations in the region.
// Currently, tiling will break if index operations are present.
if (!getBody()->getOps<IREE::LinalgExt::IndexOp>().empty()) {
return failure();
}

Location loc = getLoc();

SmallVector<Range> querySlice =
Expand Down Expand Up @@ -2847,6 +2863,7 @@ AttentionOp::getTiledImplementation(OpBuilder &builder,

Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
offsetAttentionIndices(builder, tiledOp->getRegion(0), offsets);

return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
Expand Down Expand Up @@ -3006,6 +3023,7 @@ OnlineAttentionOp::getTiledImplementation(OpBuilder &builder,

Operation *tiledOp =
mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
offsetAttentionIndices(builder, tiledOp->getRegion(0), offsets);

return TilingResult{
{tiledOp}, SmallVector<Value>(tiledOp->getResults()), slices};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1781,20 +1781,48 @@ func.func @custom_op_yield_type_mismatch(%arg0 : tensor<?xf32>, %arg1 : tensor<1
// -----

func.func @index_op_outside_custom_op() -> index {
// expected-error @+1 {{expected parent op to be one of `iree_linalg_ext.custom_op`, `iree_linalg_ext.attention`}}
// expected-error @+1 {{expected parent op to be one of `iree_linalg_ext.custom_op`, `iree_linalg_ext.attention`, `iree_linalg_ext.online_attention`}}
%0 = iree_linalg_ext.index 0 : index
return %0 : index
}

// -----

func.func @index_op_invalid_dim_online_attention(
%query: tensor<192x1024x64xf16>,
%key: tensor<192x1024x64xf16>,
%value: tensor<192x1024x64xf16>,
%output: tensor<192x1024x64xf32>,
%max: tensor<192x1024xf32>,
%sum: tensor<192x1024xf32>) -> (tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
%scale = arith.constant 1.0 : f32
%out:3 = iree_linalg_ext.online_attention
{indexing_maps = [affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1)>]}
ins(%query, %key, %value, %scale : tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, tensor<192x1024x64xf16>, f32)
outs(%output, %max, %sum : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>) {
^bb0(%score: f32):
// expected-error @+1 {{expected dim (5) to be lower than the number of loops (5) of the enclosing operation}}
%idx = iree_linalg_ext.index 5 : index
iree_linalg_ext.yield %score : f32
} -> tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
return %out#0, %out#1, %out#2 : tensor<192x1024x64xf32>, tensor<192x1024xf32>, tensor<192x1024xf32>
}

// -----

func.func @index_op_invalid_dim(%arg0 : tensor<?xindex>) -> tensor<?xindex> {
%0 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0) -> (d0)>],
iterator_types = [#iree_linalg_ext.iterator_type<parallel>]}
outs(%arg0: tensor<?xindex>) {
^bb0(%b0 : tensor<?xindex>):
// expected-error @+1 {{expected dim (1) to be lower than the number of loops (1) of the enclosing CustomOp}}
// expected-error @+1 {{expected dim (1) to be lower than the number of loops (1) of the enclosing operation}}
%1 = iree_linalg_ext.index 1 : index
%2 = linalg.generic {
indexing_maps = [affine_map<(d0) -> (d0)>],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,40 @@ module {

// -----

func.func @attention_causal(%arg0: tensor<192x1024x64xf32>, %arg1: tensor<192x1024x64xf32>, %arg2: tensor<192x1024x64xf32>) -> tensor<192x1024x64xf32> {
%cst = arith.constant dense<0.000000e+00> : tensor<192x1024x64xf32>
%scale = arith.constant 1.000000e+00 : f32
%0 = iree_linalg_ext.attention {indexing_maps = [
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d2)>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d3, d4)>,
affine_map<(d0, d1, d2, d3, d4) -> ()>,
affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d4)>
]
} ins(%arg0, %arg1, %arg2, %scale : tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, tensor<192x1024x64xf32>, f32) outs(%cst : tensor<192x1024x64xf32>) {
^bb0(%score: f32):
%m = iree_linalg_ext.index 1 : index
%k2 = iree_linalg_ext.index 3 : index
%cmp = arith.cmpi ugt, %k2, %m : index
%neg_inf = arith.constant 0xFF800000 : f32
%masked = arith.select %cmp, %neg_inf, %score : f32
iree_linalg_ext.yield %masked : f32
} -> tensor<192x1024x64xf32>
return %0 : tensor<192x1024x64xf32>
}

// CHECK-LABEL: func.func @attention_causal(
// CHECK: iree_linalg_ext.attention
// CHECK: ^bb0(%[[SCORE:.+]]: f32):
// CHECK: %[[M:.+]] = iree_linalg_ext.index 1 : index
// CHECK: %[[K2:.+]] = iree_linalg_ext.index 3 : index
// CHECK: %[[CMP:.+]] = arith.cmpi ugt, %[[K2]], %[[M]] : index
// CHECK: %[[NEG_INF:.+]] = arith.constant 0xFF800000 : f32
// CHECK: %[[MASKED:.+]] = arith.select %[[CMP]], %[[NEG_INF]], %[[SCORE]] : f32
// CHECK: iree_linalg_ext.yield %[[MASKED]] : f32

// -----

func.func @custom_op_default(%arg0 : tensor<?xf32>, %arg1 : tensor<?xf32>) -> tensor<?xf32> {
%0 = iree_linalg_ext.custom_op {
indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,40 @@ func.func @attention(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16
// CHECK: arith.mulf
// CHECK: arith.truncf
// CHECK: linalg.yield

// -----

#map = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d4)>
#map1 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d4)>
#map2 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d5, d3)>
#map3 = affine_map<(d0, d1, d2, d3, d4, d5) -> ()>
#map4 = affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3)>

func.func @attention_causal(%q: tensor<2x10x4096x128xf16>, %k: tensor<2x10x4096x128xf16>, %v: tensor<2x10x4096x128xf16>)
-> tensor<2x10x4096x128xf16> {
%scale = arith.constant 0.125 : f16
%acc = tensor.empty() : tensor<2x10x4096x128xf16>
%out = iree_linalg_ext.attention
{indexing_maps = [#map, #map1, #map2, #map3, #map4]}
ins(%q, %k, %v, %scale : tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, tensor<2x10x4096x128xf16>, f16)
outs(%acc : tensor<2x10x4096x128xf16>) {
^bb0(%score: f32):
%m = iree_linalg_ext.index 2 : index
%k2 = iree_linalg_ext.index 5 : index
%cmp = arith.cmpi ugt, %k2, %m : index
%neg_inf = arith.constant 0xFF800000 : f32
%masked = arith.select %cmp, %neg_inf, %score : f32
iree_linalg_ext.yield %masked : f32
} -> tensor<2x10x4096x128xf16>
func.return %out : tensor<2x10x4096x128xf16>
}

// CHECK-LABEL: func.func @attention_causal
// CHECK: iree_linalg_ext.online_attention
// CHECK-NEXT: ^{{.+}}(%[[SCORE:.+]]: f32):
// CHECK-NEXT: %[[M:.+]] = iree_linalg_ext.index 2 : index
// CHECK-NEXT: %[[K2:.+]] = iree_linalg_ext.index 5 : index
// CHECK-NEXT: %[[CMP:.+]] = arith.cmpi ugt, %[[K2]], %[[M]] : index
// CHECK-NEXT: %[[NEG_INF:.+]] = arith.constant 0xFF800000 : f32
// CHECK-NEXT: %[[MASKED:.+]] = arith.select %[[CMP]], %[[NEG_INF]], %[[SCORE]] : f32
// CHECK-NEXT: iree_linalg_ext.yield %[[MASKED]] : f32
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,59 @@ func.func @online_attention_f8(%query: tensor<192x1024x64xf8E4M3FNUZ>,
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield

// -----

// Test that iree_linalg_ext.index ops in the attention region are remapped
// to linalg.index ops in the decomposed output (causal masking pattern).

#mapQ = affine_map<(batch, m, k1, k2, n) -> (batch, m, k1)>
#mapK = affine_map<(batch, m, k1, k2, n) -> (batch, k2, k1)>
#mapV = affine_map<(batch, m, k1, k2, n) -> (batch, k2, n)>
#mapS = affine_map<(batch, m, k1, k2, n) -> ()>
#mapO = affine_map<(batch, m, k1, k2, n) -> (batch, m, n)>
#mapR = affine_map<(batch, m, k1, k2, n) -> (batch, m)>

func.func @online_attention_causal(
%query: tensor<4x1024x64xf16>,
%key: tensor<4x1024x64xf16>,
%value: tensor<4x1024x64xf16>,
%output: tensor<4x1024x64xf32>,
%max: tensor<4x1024xf32>,
%sum: tensor<4x1024xf32>)
-> (tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>) {
%scale = arith.constant 1.0 : f32
%out:3 = iree_linalg_ext.online_attention
{indexing_maps = [#mapQ, #mapK, #mapV, #mapS, #mapO, #mapR, #mapR]}
ins(%query, %key, %value, %scale : tensor<4x1024x64xf16>, tensor<4x1024x64xf16>, tensor<4x1024x64xf16>, f32)
outs(%output, %max, %sum : tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>) {
^bb0(%score: f32):
%m = iree_linalg_ext.index 1 : index
%k2 = iree_linalg_ext.index 3 : index
%cmp = arith.cmpi ugt, %k2, %m : index
%neg_inf = arith.constant 0xFF800000 : f32
%masked = arith.select %cmp, %neg_inf, %score : f32
iree_linalg_ext.yield %masked : f32
} -> tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>
return %out#0, %out#1, %out#2 : tensor<4x1024x64xf32>, tensor<4x1024xf32>, tensor<4x1024xf32>
}

// CHECK-LABEL: @online_attention_causal
// S = Q @ K
// CHECK: linalg.generic
// CHECK: arith.extf
// CHECK: arith.extf
// CHECK: arith.mulf
// CHECK: arith.addf
// CHECK: linalg.yield
// S = S * scale (pre-applied to Q)
// Post QK matmul elementwise (the causal masking region):
// iree_linalg_ext.index ops should be remapped to linalg.index ops.
// sMap = (batch, m, k1, k2, n) -> (batch, m, k2)
// So attention dim 1 (m) -> S dim 1, attention dim 3 (k2) -> S dim 2
// CHECK: linalg.generic
// CHECK: %[[M_IDX:.+]] = linalg.index 1
// CHECK: %[[K2_IDX:.+]] = linalg.index 2
// CHECK: arith.cmpi ugt, %[[K2_IDX]], %[[M_IDX]]
// CHECK: arith.select
// CHECK: linalg.yield
Loading
Loading