Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -262,15 +262,24 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention",
Attention is defined as matmul(softmax(matmul(Q, transpose(K))+M), V) and
has shape BxMxN. Usually, this operator also performs scaling, masking and
dropout, but we leave that out of the current implementation.

When `is_causal` is true, the attention mask operand is a materialized
causal (lower-triangular) mask. Downstream consumers may use this flag
to replace the mask with a fused index computation.
}];

let arguments = (ins Variadic<AnyShaped>:$inputs,
Variadic<AnyShaped>:$outputs
Variadic<AnyShaped>:$outputs,
OptionalAttr<BoolAttr>:$is_causal
);

let builders = [
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs), [{
build($_builder, $_state, TypeRange(outputs), inputs, outputs);
OpBuilder<(ins "ValueRange":$inputs, "ValueRange":$outputs,
CArg<"std::optional<bool>", "std::nullopt">:$isCausal), [{
build($_builder, $_state, TypeRange(outputs), inputs, outputs,
isCausal.has_value()
? $_builder.getBoolAttr(*isCausal)
: BoolAttr());
}]>
];

Expand Down
9 changes: 6 additions & 3 deletions lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2039,9 +2039,12 @@ class ConvertAtenScaledDotProductAttentionOp
}

// Overwrite with tm_tensor::attention
Value attention = AttentionOp::create(rewriter, loc, outType, inputs,
SmallVector<Value>{output})
.getResult()[0];
std::optional<bool> isCausalOpt =
causal ? std::optional<bool>(true) : std::nullopt;
Value attention =
AttentionOp::create(rewriter, loc, inputs, SmallVector<Value>{output},
isCausalOpt)
.getResult()[0];

if (opTy != outType) {
attention = tensor::ExpandShapeOp::create(rewriter, loc, opTy, attention,
Expand Down
Loading