Skip to content

Conversation

@Hatsunespica
Copy link
Collaborator

This PR adds two new files:
rewrite_smt_tensor.py
This file matches the pattern (extract (op args) indices) to (extract tensor' indicts')

For example: (extract (transpose T) x y) -> (extract T y x)

More patterns will be added in some next PRs. The only pattern included in this PR is for transposition served as an example usage.

lower_smt_tensor.py
This file lowers a smt_tensor type into smt_arrry type which is <X, <Y, elementType>>

For example: smt_tensor<2x2xi8> -> <IndexType, <IndexType, bv<8>> while IndexType is a bv<64> type defined in `smt_tensor_dialect

Next, it lowers TensorExtractOp as a chain of indices applications.


For the next PR, I want to add refine_tensor_semantics in xdsl-smt and provide test cases over there.
Do you think it's a good idea?

Copy link
Contributor

@math-fehr math-fehr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you split this PR in two different PRs (one for the lowering, and one for the optimizations).
Also, can you add tests for these?

@@ -0,0 +1,68 @@
from xdsl_smt.dialects import smt_array_dialect as smt_array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this file doesn't use the SMTLowerer and use the OperationSemantics class?

from xdsl.passes import ModulePass


bv_constants: dict[int, smt_bv.ConstantOp] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either pass this dictionary to the RewritePattern, or I would use the CSE pass after the pattern rewriter is done.
Having a global variable will mess things up if you run the pass multiple time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I would just use CSE after the pass, that should be fast enough.

return bv_constants[x]


class TensorRewritePattern(RewritePattern, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use that class, I would instead use @dataclass on all the other subclasses and just have the variable there

Comment on lines +42 to +61
class RewriteTransposeOpPattern(TensorRewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter):
extract_op = self.extract_op
permutations = op.get_permutation()
new_indices: list[SSAValue] = []
for i in permutations:
new_indices.append(extract_op.indices[i])
new_extract_op = TensorExtractOp(op.operand, new_indices)
rewriter.replace_op(extract_op, new_extract_op)
rewriter.erase_matched_op()


class TensorExtractOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter):
source = op.tensor
source_parent_op = source.owner
if isinstance(source_parent_op, TensorTransposeOp):
RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change this logic here to just have a single rewrite pattern that matches a TensorExtractOp of TensorTransposeOp.
I would not try to add abstractions that early, that complexifies the code.

@Hatsunespica
Copy link
Collaborator Author

Can you split this PR in two different PRs (one for the lowering, and one for the optimizations). Also, can you add tests for these?

Hi Mathieu, this PR is only for lowerings without any optimizations.
It provides lowerings from tensor to smt_array.
And the next PR is about tensor_refinemnet in xdsl-smt

@math-fehr
Copy link
Contributor

But xdsl_smt/passes/rewrite_smt_tensor.py is a file only containing optimizations right?

@Hatsunespica
Copy link
Collaborator Author

But xdsl_smt/passes/rewrite_smt_tensor.py is a file only containing optimizations right?

Nope, it's not for optimization, but a part of verification.

Here is an example:
To verify

T(T(x)) == x

let's find indices where 
T(T(x))[i, j] != x[i, j]
    |
  ( Rewrite Tensor)
   |
  V
T(x)[j, i] != x[i, j]
    |
  ( Rewrite Tensor)
   |
  V
x[i, j] != x[i, j]
    |
  ( Lower Tensor)
   |
  V
  apply(smt_array(x), i, j) != apply(smt_array(x), i, j))
unsat

tensosr_refinement (that is not in this PR) instantiates indices [i, j] and adds element accessor on both sides.
rewrite tensor keeps rewriting operation with an element accessor with another accessor until reach the simplest form
lower tensor lowers element accessor and tensor to smt_array and apply operations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants