diff --git a/xdsl_smt/dialects/smt_tensor_dialect.py b/xdsl_smt/dialects/smt_tensor_dialect.py index eced1e75..3e1a7eb9 100644 --- a/xdsl_smt/dialects/smt_tensor_dialect.py +++ b/xdsl_smt/dialects/smt_tensor_dialect.py @@ -69,6 +69,8 @@ def get_element_type(self) -> AttributeCovT: AnySMTTensorType: TypeAlias = SMTTensorType[Attribute] +INDEX_WIDTH = 64 +IndexType = BitVectorType(INDEX_WIDTH) def to_integer_array_attr( diff --git a/xdsl_smt/passes/lower_smt_tensor.py b/xdsl_smt/passes/lower_smt_tensor.py new file mode 100644 index 00000000..288e7090 --- /dev/null +++ b/xdsl_smt/passes/lower_smt_tensor.py @@ -0,0 +1,68 @@ +from xdsl_smt.dialects import smt_array_dialect as smt_array + +from xdsl_smt.dialects.smt_dialect import ( + DeclareConstOp, +) +from xdsl_smt.dialects.smt_tensor_dialect import ( + IndexType, + SMTTensorType, + TensorExtractOp, +) +from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import Attribute +from xdsl.utils.hints import isa +from xdsl.context import Context +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.passes import ModulePass + + +def lower_tensor_type(typ: Attribute) -> Attribute: + if isa(typ, SMTTensorType): + result = typ.element_type + index_type = IndexType + for _ in typ.shape: + result = smt_array.ArrayType(index_type, result) + return result + return typ + + +class DeclareConstOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter): + if isa(op.res.type, SMTTensorType): + new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type)) + rewriter.replace_matched_op(new_constant_op) + + +class TensorExtractOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): + source = op.tensor + assert isinstance(source.type, smt_array.ArrayType) + select_ops: list[smt_array.SelectOp] = [] + for idx in op.indices: + select_ops.append(smt_array.SelectOp(source, idx)) + source = select_ops[-1].res + rewriter.replace_matched_op(select_ops) + + +class LowerSMTTensor(ModulePass): + name = "lower-smt-tensor" + + def apply(self, ctx: Context, op: ModuleOp): + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [DeclareConstOpPattern(), TensorExtractOpPattern()] + ) + ) + walker.rewrite_module(op) + + # Apply DCE pass + DeadCodeElimination().apply(ctx, op) diff --git a/xdsl_smt/passes/rewrite_smt_tensor.py b/xdsl_smt/passes/rewrite_smt_tensor.py new file mode 100644 index 00000000..d761fac8 --- /dev/null +++ b/xdsl_smt/passes/rewrite_smt_tensor.py @@ -0,0 +1,83 @@ +from abc import ABC + + +from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv + +from xdsl.ir import SSAValue +from xdsl_smt.dialects.smt_tensor_dialect import ( + TensorTransposeOp, + INDEX_WIDTH, + TensorExtractOp, +) +from xdsl.dialects.builtin import ModuleOp +from xdsl.context import Context +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.passes import ModulePass + + +bv_constants: dict[int, smt_bv.ConstantOp] = {} + + +def getBVConstant(x: int) -> smt_bv.ConstantOp: + global bv_constants + if x not in bv_constants: + bv_constants[x] = smt_bv.ConstantOp.from_int_value(x, INDEX_WIDTH) + return bv_constants[x] + + +class TensorRewritePattern(RewritePattern, ABC): + extract_op: TensorExtractOp + + def __init__(self, extract_op: TensorExtractOp): + self.extract_op = extract_op + super().__init__() + + +class RewriteTransposeOpPattern(TensorRewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): + extract_op = self.extract_op + permutations = op.get_permutation() + new_indices: list[SSAValue] = [] + for i in permutations: + new_indices.append(extract_op.indices[i]) + new_extract_op = TensorExtractOp(op.operand, new_indices) + rewriter.replace_op(extract_op, new_extract_op) + rewriter.erase_matched_op() + + +class TensorExtractOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): + source = op.tensor + source_parent_op = source.owner + if isinstance(source_parent_op, TensorTransposeOp): + RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter) + + +def insertConstantsBeforeModule(op: ModuleOp): + global bv_constants + + block = op.body.block + first_op = block.first_op + assert first_op is not None + for val in bv_constants.values(): + block.insert_op_before(val, first_op) + + +class RewriteSMTTensor(ModulePass): + name = "rewrite-smt-tensor" + + def apply(self, ctx: Context, op: ModuleOp): + walker = PatternRewriteWalker( + GreedyRewritePatternApplier([TensorExtractOpPattern()]), walk_reverse=True + ) + walker.rewrite_module(op) + + insertConstantsBeforeModule(op)