Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xdsl_smt/dialects/smt_tensor_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def get_element_type(self) -> AttributeCovT:


AnySMTTensorType: TypeAlias = SMTTensorType[Attribute]
INDEX_WIDTH = 64
IndexType = BitVectorType(INDEX_WIDTH)


def to_integer_array_attr(
Expand Down
68 changes: 68 additions & 0 deletions xdsl_smt/passes/lower_smt_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from xdsl_smt.dialects import smt_array_dialect as smt_array
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this file doesn't use the SMTLowerer and use the OperationSemantics class?


from xdsl_smt.dialects.smt_dialect import (
DeclareConstOp,
)
from xdsl_smt.dialects.smt_tensor_dialect import (
IndexType,
SMTTensorType,
TensorExtractOp,
)
from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination
from xdsl.dialects.builtin import ModuleOp
from xdsl.ir import Attribute
from xdsl.utils.hints import isa
from xdsl.context import Context
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.passes import ModulePass


def lower_tensor_type(typ: Attribute) -> Attribute:
if isa(typ, SMTTensorType):
result = typ.element_type
index_type = IndexType
for _ in typ.shape:
result = smt_array.ArrayType(index_type, result)
return result
return typ


class DeclareConstOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter):
if isa(op.res.type, SMTTensorType):
new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type))
rewriter.replace_matched_op(new_constant_op)


class TensorExtractOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter):
source = op.tensor
assert isinstance(source.type, smt_array.ArrayType)
select_ops: list[smt_array.SelectOp] = []
for idx in op.indices:
select_ops.append(smt_array.SelectOp(source, idx))
source = select_ops[-1].res
rewriter.replace_matched_op(select_ops)


class LowerSMTTensor(ModulePass):
name = "lower-smt-tensor"

def apply(self, ctx: Context, op: ModuleOp):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier(
[DeclareConstOpPattern(), TensorExtractOpPattern()]
)
)
walker.rewrite_module(op)

# Apply DCE pass
DeadCodeElimination().apply(ctx, op)
83 changes: 83 additions & 0 deletions xdsl_smt/passes/rewrite_smt_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from abc import ABC


from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv

from xdsl.ir import SSAValue
from xdsl_smt.dialects.smt_tensor_dialect import (
TensorTransposeOp,
INDEX_WIDTH,
TensorExtractOp,
)
from xdsl.dialects.builtin import ModuleOp
from xdsl.context import Context
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriteWalker,
PatternRewriter,
RewritePattern,
op_type_rewrite_pattern,
)
from xdsl.passes import ModulePass


bv_constants: dict[int, smt_bv.ConstantOp] = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would either pass this dictionary to the RewritePattern, or I would use the CSE pass after the pattern rewriter is done.
Having a global variable will mess things up if you run the pass multiple time

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I would just use CSE after the pass, that should be fast enough.



def getBVConstant(x: int) -> smt_bv.ConstantOp:
global bv_constants
if x not in bv_constants:
bv_constants[x] = smt_bv.ConstantOp.from_int_value(x, INDEX_WIDTH)
return bv_constants[x]


class TensorRewritePattern(RewritePattern, ABC):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not use that class, I would instead use @dataclass on all the other subclasses and just have the variable there

extract_op: TensorExtractOp

def __init__(self, extract_op: TensorExtractOp):
self.extract_op = extract_op
super().__init__()


class RewriteTransposeOpPattern(TensorRewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter):
extract_op = self.extract_op
permutations = op.get_permutation()
new_indices: list[SSAValue] = []
for i in permutations:
new_indices.append(extract_op.indices[i])
new_extract_op = TensorExtractOp(op.operand, new_indices)
rewriter.replace_op(extract_op, new_extract_op)
rewriter.erase_matched_op()


class TensorExtractOpPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter):
source = op.tensor
source_parent_op = source.owner
if isinstance(source_parent_op, TensorTransposeOp):
RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter)
Comment on lines +42 to +61
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would change this logic here to just have a single rewrite pattern that matches a TensorExtractOp of TensorTransposeOp.
I would not try to add abstractions that early, that complexifies the code.



def insertConstantsBeforeModule(op: ModuleOp):
global bv_constants

block = op.body.block
first_op = block.first_op
assert first_op is not None
for val in bv_constants.values():
block.insert_op_before(val, first_op)


class RewriteSMTTensor(ModulePass):
name = "rewrite-smt-tensor"

def apply(self, ctx: Context, op: ModuleOp):
walker = PatternRewriteWalker(
GreedyRewritePatternApplier([TensorExtractOpPattern()]), walk_reverse=True
)
walker.rewrite_module(op)

insertConstantsBeforeModule(op)