-
Notifications
You must be signed in to change notification settings - Fork 6
Add lowerings from smt tensor to low-level smt #82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
math-fehr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you split this PR in two different PRs (one for the lowering, and one for the optimizations).
Also, can you add tests for these?
| @@ -0,0 +1,68 @@ | |||
| from xdsl_smt.dialects import smt_array_dialect as smt_array | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason this file doesn't use the SMTLowerer and use the OperationSemantics class?
| from xdsl.passes import ModulePass | ||
|
|
||
|
|
||
| bv_constants: dict[int, smt_bv.ConstantOp] = {} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would either pass this dictionary to the RewritePattern, or I would use the CSE pass after the pattern rewriter is done.
Having a global variable will mess things up if you run the pass multiple time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, I would just use CSE after the pass, that should be fast enough.
| return bv_constants[x] | ||
|
|
||
|
|
||
| class TensorRewritePattern(RewritePattern, ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would not use that class, I would instead use @dataclass on all the other subclasses and just have the variable there
| class RewriteTransposeOpPattern(TensorRewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): | ||
| extract_op = self.extract_op | ||
| permutations = op.get_permutation() | ||
| new_indices: list[SSAValue] = [] | ||
| for i in permutations: | ||
| new_indices.append(extract_op.indices[i]) | ||
| new_extract_op = TensorExtractOp(op.operand, new_indices) | ||
| rewriter.replace_op(extract_op, new_extract_op) | ||
| rewriter.erase_matched_op() | ||
|
|
||
|
|
||
| class TensorExtractOpPattern(RewritePattern): | ||
| @op_type_rewrite_pattern | ||
| def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): | ||
| source = op.tensor | ||
| source_parent_op = source.owner | ||
| if isinstance(source_parent_op, TensorTransposeOp): | ||
| RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would change this logic here to just have a single rewrite pattern that matches a TensorExtractOp of TensorTransposeOp.
I would not try to add abstractions that early, that complexifies the code.
Hi Mathieu, this PR is only for lowerings without any optimizations. |
|
But xdsl_smt/passes/rewrite_smt_tensor.py is a file only containing optimizations right? |
Nope, it's not for optimization, but a part of verification. Here is an example:
|
This PR adds two new files:
rewrite_smt_tensor.pyThis file matches the pattern
(extract (op args) indices)to(extract tensor' indicts')For example:
(extract (transpose T) x y)->(extract T y x)More patterns will be added in some next PRs. The only pattern included in this PR is for transposition served as an example usage.
lower_smt_tensor.pyThis file lowers a
smt_tensortype intosmt_arrrytype which is<X, <Y, elementType>>For example:
smt_tensor<2x2xi8>-><IndexType, <IndexType, bv<8>>whileIndexTypeis abv<64>type defined in `smt_tensor_dialectNext, it lowers
TensorExtractOpas a chain of indices applications.For the next PR, I want to add refine_tensor_semantics in xdsl-smt and provide test cases over there.
Do you think it's a good idea?