Skip to content

Commit

Permalink
[torch] Fix attention on linalg for dynamic shapes
Browse files Browse the repository at this point in the history
Current version does not work for a mixture of dynamic and static shaped
batch dimensions. Rework to grab the correct dynamic shapes.
  • Loading branch information
rsuderman committed Sep 16, 2024
1 parent 14ef05a commit 087d8ac
Showing 1 changed file with 14 additions and 26 deletions.
40 changes: 14 additions & 26 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,35 +1607,23 @@ class ConvertAtenScaledDotProductAttentionOp
op.getLoc(), "expected no attention mask when isCausal is true");
}

SmallVector<OpFoldResult> maskSizes;

if (queryTy.hasStaticShape() && keyTy.hasStaticShape()) {
auto seqLenQ =
rewriter.getIndexAttr(queryTy.getDimSize(queryTy.getRank() - 2));
auto seqLenK =
rewriter.getIndexAttr(keyTy.getDimSize(keyTy.getRank() - 2));
maskSizes = {seqLenQ, seqLenK};
for (int i = queryTy.getRank() - 3; i >= 0; --i) {
auto batchSize = rewriter.getIndexAttr(queryTy.getDimSize(i));
maskSizes.insert(maskSizes.begin(), batchSize);
}
} else { // Dynamic shape case: <?x?x...x?xf32> for example
for (int i = 0; i < queryTy.getRank() - 2; ++i) {
Value batchSize =
rewriter.create<tensor::DimOp>(op.getLoc(), query, i);
maskSizes.push_back(batchSize);
}
Value seqLenQ = rewriter.create<tensor::DimOp>(op.getLoc(), query,
queryTy.getRank() - 2);
Value seqLenK = rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2);
maskSizes.push_back(seqLenQ);
maskSizes.push_back(seqLenK);
SmallVector<int64_t> maskStatic;
SmallVector<Value> maskDyn;
for (int i = 0, s = queryTy.getRank() - 1; i < s; ++i) {
maskStatic.push_back(queryTy.getDimSize(i));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(
rewriter.create<tensor::DimOp>(op.getLoc(), query, i));
}

maskStatic.push_back(keyTy.getDimSize(keyTy.getRank() - 2));
if (maskStatic.back() == ShapedType::kDynamic)
maskDyn.push_back(rewriter.create<tensor::DimOp>(op.getLoc(), key,
keyTy.getRank() - 2));

Type maskType = getElementTypeOrSelf(queryTy);
Value emptyMask =
rewriter.create<tensor::EmptyOp>(op.getLoc(), maskSizes, maskType);
Value emptyMask = rewriter.create<tensor::EmptyOp>(
op.getLoc(), maskStatic, maskType, maskDyn);

Value zero = rewriter.create<arith::ConstantOp>(
op.getLoc(),
Expand Down

0 comments on commit 087d8ac

Please sign in to comment.