-
Notifications
You must be signed in to change notification settings - Fork 6
Add lowerings from smt tensor to low-level smt #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,68 @@ | ||
| from xdsl_smt.dialects import smt_array_dialect as smt_array | ||
|
|
||
| from xdsl_smt.dialects.smt_dialect import ( | ||
| DeclareConstOp, | ||
| ) | ||
| from xdsl_smt.dialects.smt_tensor_dialect import ( | ||
| IndexType, | ||
| SMTTensorType, | ||
| TensorExtractOp, | ||
| ) | ||
| from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination | ||
| from xdsl.dialects.builtin import ModuleOp | ||
| from xdsl.ir import Attribute | ||
| from xdsl.utils.hints import isa | ||
| from xdsl.context import Context | ||
| from xdsl.pattern_rewriter import ( | ||
| GreedyRewritePatternApplier, | ||
| PatternRewriteWalker, | ||
| PatternRewriter, | ||
| RewritePattern, | ||
| op_type_rewrite_pattern, | ||
| ) | ||
| from xdsl.passes import ModulePass | ||
|
|
||
|
|
||
| def lower_tensor_type(typ: Attribute) -> Attribute: | ||
| if isa(typ, SMTTensorType): | ||
| result = typ.element_type | ||
| index_type = IndexType | ||
| for _ in typ.shape: | ||
| result = smt_array.ArrayType(index_type, result) | ||
| return result | ||
| return typ | ||
|
|
||
|
|
||
| class DeclareConstOpPattern(RewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter): | ||
| if isa(op.res.type, SMTTensorType): | ||
| new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type)) | ||
| rewriter.replace_matched_op(new_constant_op) | ||
|
|
||
|
|
||
| class TensorExtractOpPattern(RewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): | ||
| source = op.tensor | ||
| assert isinstance(source.type, smt_array.ArrayType) | ||
| select_ops: list[smt_array.SelectOp] = [] | ||
| for idx in op.indices: | ||
| select_ops.append(smt_array.SelectOp(source, idx)) | ||
| source = select_ops[-1].res | ||
| rewriter.replace_matched_op(select_ops) | ||
|
|
||
|
|
||
| class LowerSMTTensor(ModulePass): | ||
| name = "lower-smt-tensor" | ||
|
|
||
| def apply(self, ctx: Context, op: ModuleOp): | ||
| walker = PatternRewriteWalker( | ||
| GreedyRewritePatternApplier( | ||
| [DeclareConstOpPattern(), TensorExtractOpPattern()] | ||
| ) | ||
| ) | ||
| walker.rewrite_module(op) | ||
|
|
||
| # Apply DCE pass | ||
| DeadCodeElimination().apply(ctx, op) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,83 @@ | ||
| from abc import ABC | ||
|
|
||
|
|
||
| from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv | ||
|
|
||
| from xdsl.ir import SSAValue | ||
| from xdsl_smt.dialects.smt_tensor_dialect import ( | ||
| TensorTransposeOp, | ||
| INDEX_WIDTH, | ||
| TensorExtractOp, | ||
| ) | ||
| from xdsl.dialects.builtin import ModuleOp | ||
| from xdsl.context import Context | ||
| from xdsl.pattern_rewriter import ( | ||
| GreedyRewritePatternApplier, | ||
| PatternRewriteWalker, | ||
| PatternRewriter, | ||
| RewritePattern, | ||
| op_type_rewrite_pattern, | ||
| ) | ||
| from xdsl.passes import ModulePass | ||
|
|
||
|
|
||
| bv_constants: dict[int, smt_bv.ConstantOp] = {} | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would either pass this dictionary to the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, I would just use CSE after the pass, that should be fast enough. |
||
|
|
||
|
|
||
| def getBVConstant(x: int) -> smt_bv.ConstantOp: | ||
| global bv_constants | ||
| if x not in bv_constants: | ||
| bv_constants[x] = smt_bv.ConstantOp.from_int_value(x, INDEX_WIDTH) | ||
| return bv_constants[x] | ||
|
|
||
|
|
||
| class TensorRewritePattern(RewritePattern, ABC): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would not use that class, I would instead use |
||
| extract_op: TensorExtractOp | ||
|
|
||
| def __init__(self, extract_op: TensorExtractOp): | ||
| self.extract_op = extract_op | ||
| super().__init__() | ||
|
|
||
|
|
||
| class RewriteTransposeOpPattern(TensorRewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): | ||
| extract_op = self.extract_op | ||
| permutations = op.get_permutation() | ||
| new_indices: list[SSAValue] = [] | ||
| for i in permutations: | ||
| new_indices.append(extract_op.indices[i]) | ||
| new_extract_op = TensorExtractOp(op.operand, new_indices) | ||
| rewriter.replace_op(extract_op, new_extract_op) | ||
| rewriter.erase_matched_op() | ||
|
|
||
|
|
||
| class TensorExtractOpPattern(RewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): | ||
| source = op.tensor | ||
| source_parent_op = source.owner | ||
| if isinstance(source_parent_op, TensorTransposeOp): | ||
| RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter) | ||
|
Comment on lines
+42
to
+61
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would change this logic here to just have a single rewrite pattern that matches a |
||
|
|
||
|
|
||
| def insertConstantsBeforeModule(op: ModuleOp): | ||
| global bv_constants | ||
|
|
||
| block = op.body.block | ||
| first_op = block.first_op | ||
| assert first_op is not None | ||
| for val in bv_constants.values(): | ||
| block.insert_op_before(val, first_op) | ||
|
|
||
|
|
||
| class RewriteSMTTensor(ModulePass): | ||
| name = "rewrite-smt-tensor" | ||
|
|
||
| def apply(self, ctx: Context, op: ModuleOp): | ||
| walker = PatternRewriteWalker( | ||
| GreedyRewritePatternApplier([TensorExtractOpPattern()]), walk_reverse=True | ||
| ) | ||
| walker.rewrite_module(op) | ||
|
|
||
| insertConstantsBeforeModule(op) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this file doesn't use the
SMTLowererand use theOperationSemanticsclass?