From f56e2dc6894ad9fcb5560d66e8ac8da7970efc82 Mon Sep 17 00:00:00 2001 From: Yuyou Fan Date: Sun, 30 Nov 2025 06:33:55 -0700 Subject: [PATCH 1/5] Add lowereings from smt tensor to low-level smt --- xdsl_smt/dialects/smt_tensor_dialect.py | 2 + .../passes/{lower_to_smt => }/lower_to_smt.py | 0 .../passes/lower_to_smt/lower_smt_tensor.py | 67 ++++++ xdsl_smt/passes/rewrite_smt_tensor.py | 223 ++++++++++++++++++ 4 files changed, 292 insertions(+) rename xdsl_smt/passes/{lower_to_smt => }/lower_to_smt.py (100%) create mode 100644 xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py create mode 100644 xdsl_smt/passes/rewrite_smt_tensor.py diff --git a/xdsl_smt/dialects/smt_tensor_dialect.py b/xdsl_smt/dialects/smt_tensor_dialect.py index eced1e75..3e1a7eb9 100644 --- a/xdsl_smt/dialects/smt_tensor_dialect.py +++ b/xdsl_smt/dialects/smt_tensor_dialect.py @@ -69,6 +69,8 @@ def get_element_type(self) -> AttributeCovT: AnySMTTensorType: TypeAlias = SMTTensorType[Attribute] +INDEX_WIDTH = 64 +IndexType = BitVectorType(INDEX_WIDTH) def to_integer_array_attr( diff --git a/xdsl_smt/passes/lower_to_smt/lower_to_smt.py b/xdsl_smt/passes/lower_to_smt.py similarity index 100% rename from xdsl_smt/passes/lower_to_smt/lower_to_smt.py rename to xdsl_smt/passes/lower_to_smt.py diff --git a/xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py b/xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py new file mode 100644 index 00000000..6c745281 --- /dev/null +++ b/xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py @@ -0,0 +1,67 @@ +from xdsl_smt.dialects import smt_array_dialect as smt_array + +from xdsl_smt.dialects.smt_dialect import ( + DeclareConstOp, +) +from xdsl_smt.dialects.smt_tensor_dialect import ( + IndexType, + SMTTensorType, + TensorExtractOp +) +from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination +from xdsl.dialects.builtin import ModuleOp +from xdsl.ir import Attribute +from xdsl.context import Context +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.passes import ModulePass + + +def lower_tensor_type(typ: Attribute) -> Attribute: + if isinstance(typ, SMTTensorType): + result = typ.element_type + index_type = IndexType + for _ in typ.shape: + result = smt_array.ArrayType(index_type, result) + return result + return typ + + +class DeclareConstOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter): + if isinstance(op.res.type, SMTTensorType): + new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type)) + rewriter.replace_matched_op(new_constant_op) + + +class TensorExtractOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): + source = op.tensor + assert isinstance(source.type, smt_array.ArrayType) + select_ops: list[smt_array.SelectOp] = [] + for idx in op.indices: + select_ops.append(smt_array.SelectOp(source, idx)) + source = select_ops[-1].res + rewriter.replace_matched_op(select_ops) + + +class LowerSMTTensor(ModulePass): + name = "lower-smt-tensor" + + def apply(self, ctx: Context, op: ModuleOp): + walker = PatternRewriteWalker( + GreedyRewritePatternApplier( + [DeclareConstOpPattern(), TensorExtractOpPattern()] + ) + ) + walker.rewrite_module(op) + + # Apply DCE pass + DeadCodeElimination().apply(ctx, op) diff --git a/xdsl_smt/passes/rewrite_smt_tensor.py b/xdsl_smt/passes/rewrite_smt_tensor.py new file mode 100644 index 00000000..be809588 --- /dev/null +++ b/xdsl_smt/passes/rewrite_smt_tensor.py @@ -0,0 +1,223 @@ +from abc import ABC +from typing import Callable + + +from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv + +from xdsl_smt.dialects.smt_dialect import DeclareFunOp, IteOp +from xdsl.ir import Operation, SSAValue +from xdsl_smt.dialects.smt_tensor_dialect import ( + ElementwiseBinaryOperation, + TensorTransposeOp, + ElementwiseUnaryOperation, + INDEX_WIDTH, + TensorExtractOp +) +from xdsl.dialects.builtin import FunctionType, ModuleOp +from xdsl.ir import Attribute +from xdsl.context import Context +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriteWalker, + PatternRewriter, + RewritePattern, + op_type_rewrite_pattern, +) +from xdsl.passes import ModulePass + +from xdsl_smt.dialects.smt_dialect import ( + CallOp, +) + + +bv_constants: dict[int, smt_bv.ConstantOp] = {} + + +def getBVConstant(x: int) -> smt_bv.ConstantOp: + global bv_constants + if x not in bv_constants: + bv_constants[x] = smt_bv.ConstantOp.from_int_value(x, INDEX_WIDTH) + return bv_constants[x] + + +class TensorRewritePattern(RewritePattern, ABC): + extract_op: TensorExtractOp + + def __init__(self, extract_op): + self.extract_op = extract_op + super().__init__() + + +class RewriteTransposeOpPattern(TensorRewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): + extract_op = self.extract_op + permutations = op.permutation.get_values() + new_indices: list[SSAValue] = [] + for i in permutations: + new_indices.append(extract_op.indices[i]) + new_extract_op = TensorExtractOp(op.operand, new_indices) + rewriter.replace_op(extract_op, new_extract_op) + rewriter.erase_matched_op() + + +def toFuncName(name: str) -> str: + return name.replace(".", "_") + + +elementwise_unary_function_set: set[DeclareFunOp] = set() +elementwise_unary_functions: dict[str, Callable[[SSAValue], list[Operation]]] = {} + + +elementwise_binary_function_set: set[DeclareFunOp] = set() +elementwise_binary_functions: dict[ + str, Callable[[SSAValue, SSAValue], list[Operation]] +] = {} + + +def initElementwiseIntFunction(): + global elementwise_binary_functions + global elementwise_unary_functions + elementwise_binary_functions["smt_tensor_add"] = lambda x, y: [smt_bv.AddOp(x, y)] + elementwise_binary_functions["smt_tensor_subtract"] = lambda x, y: [ + smt_bv.SubOp(x, y) + ] + elementwise_binary_functions["smt_tensor_multiply"] = lambda x, y: [ + smt_bv.MulOp(x, y) + ] + + def get_maximum_ops(lhs: SSAValue, rhs: SSAValue) -> list[Operation]: + less_than_op = smt_bv.SltOp(lhs, rhs) + ite_op = IteOp(less_than_op.res, rhs, lhs) + return [less_than_op, ite_op] + + elementwise_binary_functions["smt_tensor_maximum"] = get_maximum_ops + + def get_minimum_ops(lhs: SSAValue, rhs: SSAValue) -> list[Operation]: + less_than_op = smt_bv.SltOp(lhs, rhs) + ite_op = IteOp(less_than_op.res, lhs, rhs) + return [less_than_op, ite_op] + + elementwise_binary_functions["smt_tensor_minimum"] = get_minimum_ops + + def get_abs_ops(val: SSAValue) -> list[Operation]: + neg_op = smt_bv.NegOp(val) + less_than_op = smt_bv.SltOp(val, neg_op.res) + ite_op = IteOp(less_than_op.res, neg_op.res, val) + return [neg_op, less_than_op, ite_op] + + elementwise_unary_functions["smt_tensor_abs"] = get_abs_ops + elementwise_unary_functions["smt_tensor_negate"] = lambda x: [smt_bv.NegOp(x)] + + +def getElementwiseBinaryFunction(op_name: str, element_type: Attribute): + global elementwise_binary_function_set + global elementwise_binary_functions + if op_name not in elementwise_binary_functions: + element_uf_type = FunctionType.from_lists( + [element_type, element_type], [element_type] + ) + defun_op = DeclareFunOp(element_uf_type, op_name) + elementwise_binary_function_set.add(defun_op) + elementwise_binary_functions[op_name] = lambda x, y: [ + CallOp(defun_op.ret, [x, y]) + ] + return elementwise_binary_functions[op_name] + + +def getElementwiseUnaryFunction(op_name: str, element_type: Attribute): + global elementwise_unary_function_set + global elementwise_unary_functions + if op_name not in elementwise_unary_functions: + element_uf_type = FunctionType.from_lists([element_type], [element_type]) + defun_op = DeclareFunOp(element_uf_type, op_name) + elementwise_unary_function_set.add(defun_op) + elementwise_unary_functions[op_name] = lambda x: [CallOp(defun_op.ret, [x])] + return elementwise_unary_functions[op_name] + + + +class RewriteElementwiseUnaryOpPattern(TensorRewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: ElementwiseUnaryOperation, rewriter: PatternRewriter + ): + element_type = self.extract_op.result.type + op_name = toFuncName(op.name) + unary_function = getElementwiseUnaryFunction(op_name, element_type) + extract_op_op = TensorExtractOp(op.op, self.extract_op.indices) + call_ops = unary_function(extract_op_op.result) + rewriter.replace_op(self.extract_op, [extract_op_op] + call_ops) + rewriter.erase_matched_op() + + +class RewriteElementwiseBinaryOpPattern(TensorRewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite( + self, op: ElementwiseBinaryOperation, rewriter: PatternRewriter + ): + element_type = self.extract_op.result.type + op_name = toFuncName(op.name) + binary_function = getElementwiseBinaryFunction(op_name, element_type) + extract_lhs_op = TensorExtractOp(op.lhs, self.extract_op.indices) + extract_rhs_op = TensorExtractOp(op.rhs, self.extract_op.indices) + call_ops = binary_function(extract_lhs_op.result, extract_rhs_op.result) + rewriter.replace_op( + self.extract_op, [extract_lhs_op, extract_rhs_op] + call_ops + ) + rewriter.erase_op(op) + + +class TensorExtractOpPattern(RewritePattern): + @op_type_rewrite_pattern + def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): + source = op.tensor + source_parent_op = source.owner + if isinstance(source_parent_op, ElementwiseUnaryOperation): + RewriteElementwiseUnaryOpPattern(op).match_and_rewrite( + source_parent_op, rewriter + ) + elif isinstance(source_parent_op, ElementwiseBinaryOperation): + RewriteElementwiseBinaryOpPattern(op).match_and_rewrite( + source_parent_op, rewriter + ) + elif isinstance(source_parent_op, TensorTransposeOp): + RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter) + + +def insertFunctionBeforeModule(op: ModuleOp): + block = op.body.block + first_op = block.first_op + assert first_op is not None + while len(elementwise_binary_function_set) > 0: + function_op = elementwise_binary_function_set.pop() + block.insert_op_before(function_op, first_op) + + while len(elementwise_unary_function_set) > 0: + function_op = elementwise_unary_function_set.pop() + block.insert_op_before(function_op, first_op) + + +def insertConstantsBeforeModule(op: ModuleOp): + global bv_constants + + block = op.body.block + first_op = block.first_op + assert first_op is not None + for val in bv_constants.values(): + block.insert_op_before(val, first_op) + + +class RewriteSMTTensor(ModulePass): + name = "rewrite-smt-tensor" + + def apply(self, ctx: Context, op: ModuleOp): + initElementwiseIntFunction() + + walker = PatternRewriteWalker( + GreedyRewritePatternApplier([TensorExtractOpPattern()]), walk_reverse=True + ) + walker.rewrite_module(op) + + insertFunctionBeforeModule(op) + insertConstantsBeforeModule(op) From 894c8355488f47a960a85f9f35381fd635059aa4 Mon Sep 17 00:00:00 2001 From: Yuyou Fan Date: Sun, 30 Nov 2025 06:36:15 -0700 Subject: [PATCH 2/5] Fix format --- xdsl_smt/passes/{lower_to_smt => }/lower_smt_tensor.py | 2 +- xdsl_smt/passes/rewrite_smt_tensor.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) rename xdsl_smt/passes/{lower_to_smt => }/lower_smt_tensor.py (99%) diff --git a/xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py b/xdsl_smt/passes/lower_smt_tensor.py similarity index 99% rename from xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py rename to xdsl_smt/passes/lower_smt_tensor.py index 6c745281..cf716e3b 100644 --- a/xdsl_smt/passes/lower_to_smt/lower_smt_tensor.py +++ b/xdsl_smt/passes/lower_smt_tensor.py @@ -6,7 +6,7 @@ from xdsl_smt.dialects.smt_tensor_dialect import ( IndexType, SMTTensorType, - TensorExtractOp + TensorExtractOp, ) from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination from xdsl.dialects.builtin import ModuleOp diff --git a/xdsl_smt/passes/rewrite_smt_tensor.py b/xdsl_smt/passes/rewrite_smt_tensor.py index be809588..28755bea 100644 --- a/xdsl_smt/passes/rewrite_smt_tensor.py +++ b/xdsl_smt/passes/rewrite_smt_tensor.py @@ -4,14 +4,14 @@ from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv -from xdsl_smt.dialects.smt_dialect import DeclareFunOp, IteOp +from xdsl_smt.dialects.smt_dialect import DeclareFunOp, IteOp from xdsl.ir import Operation, SSAValue from xdsl_smt.dialects.smt_tensor_dialect import ( ElementwiseBinaryOperation, TensorTransposeOp, ElementwiseUnaryOperation, INDEX_WIDTH, - TensorExtractOp + TensorExtractOp, ) from xdsl.dialects.builtin import FunctionType, ModuleOp from xdsl.ir import Attribute @@ -136,7 +136,6 @@ def getElementwiseUnaryFunction(op_name: str, element_type: Attribute): return elementwise_unary_functions[op_name] - class RewriteElementwiseUnaryOpPattern(TensorRewritePattern): @op_type_rewrite_pattern def match_and_rewrite( From 3f52dd0653a363344e15af6111f410a3119ee7e2 Mon Sep 17 00:00:00 2001 From: Yuyou Fan Date: Sun, 30 Nov 2025 06:40:17 -0700 Subject: [PATCH 3/5] Fix format --- xdsl_smt/passes/rewrite_smt_tensor.py | 145 +------------------------- 1 file changed, 3 insertions(+), 142 deletions(-) diff --git a/xdsl_smt/passes/rewrite_smt_tensor.py b/xdsl_smt/passes/rewrite_smt_tensor.py index 28755bea..23103259 100644 --- a/xdsl_smt/passes/rewrite_smt_tensor.py +++ b/xdsl_smt/passes/rewrite_smt_tensor.py @@ -1,20 +1,15 @@ from abc import ABC -from typing import Callable from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv -from xdsl_smt.dialects.smt_dialect import DeclareFunOp, IteOp -from xdsl.ir import Operation, SSAValue +from xdsl.ir import SSAValue from xdsl_smt.dialects.smt_tensor_dialect import ( - ElementwiseBinaryOperation, TensorTransposeOp, - ElementwiseUnaryOperation, INDEX_WIDTH, TensorExtractOp, ) -from xdsl.dialects.builtin import FunctionType, ModuleOp -from xdsl.ir import Attribute +from xdsl.dialects.builtin import ModuleOp from xdsl.context import Context from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -25,10 +20,6 @@ ) from xdsl.passes import ModulePass -from xdsl_smt.dialects.smt_dialect import ( - CallOp, -) - bv_constants: dict[int, smt_bv.ConstantOp] = {} @@ -61,142 +52,15 @@ def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): rewriter.erase_matched_op() -def toFuncName(name: str) -> str: - return name.replace(".", "_") - - -elementwise_unary_function_set: set[DeclareFunOp] = set() -elementwise_unary_functions: dict[str, Callable[[SSAValue], list[Operation]]] = {} - - -elementwise_binary_function_set: set[DeclareFunOp] = set() -elementwise_binary_functions: dict[ - str, Callable[[SSAValue, SSAValue], list[Operation]] -] = {} - - -def initElementwiseIntFunction(): - global elementwise_binary_functions - global elementwise_unary_functions - elementwise_binary_functions["smt_tensor_add"] = lambda x, y: [smt_bv.AddOp(x, y)] - elementwise_binary_functions["smt_tensor_subtract"] = lambda x, y: [ - smt_bv.SubOp(x, y) - ] - elementwise_binary_functions["smt_tensor_multiply"] = lambda x, y: [ - smt_bv.MulOp(x, y) - ] - - def get_maximum_ops(lhs: SSAValue, rhs: SSAValue) -> list[Operation]: - less_than_op = smt_bv.SltOp(lhs, rhs) - ite_op = IteOp(less_than_op.res, rhs, lhs) - return [less_than_op, ite_op] - - elementwise_binary_functions["smt_tensor_maximum"] = get_maximum_ops - - def get_minimum_ops(lhs: SSAValue, rhs: SSAValue) -> list[Operation]: - less_than_op = smt_bv.SltOp(lhs, rhs) - ite_op = IteOp(less_than_op.res, lhs, rhs) - return [less_than_op, ite_op] - - elementwise_binary_functions["smt_tensor_minimum"] = get_minimum_ops - - def get_abs_ops(val: SSAValue) -> list[Operation]: - neg_op = smt_bv.NegOp(val) - less_than_op = smt_bv.SltOp(val, neg_op.res) - ite_op = IteOp(less_than_op.res, neg_op.res, val) - return [neg_op, less_than_op, ite_op] - - elementwise_unary_functions["smt_tensor_abs"] = get_abs_ops - elementwise_unary_functions["smt_tensor_negate"] = lambda x: [smt_bv.NegOp(x)] - - -def getElementwiseBinaryFunction(op_name: str, element_type: Attribute): - global elementwise_binary_function_set - global elementwise_binary_functions - if op_name not in elementwise_binary_functions: - element_uf_type = FunctionType.from_lists( - [element_type, element_type], [element_type] - ) - defun_op = DeclareFunOp(element_uf_type, op_name) - elementwise_binary_function_set.add(defun_op) - elementwise_binary_functions[op_name] = lambda x, y: [ - CallOp(defun_op.ret, [x, y]) - ] - return elementwise_binary_functions[op_name] - - -def getElementwiseUnaryFunction(op_name: str, element_type: Attribute): - global elementwise_unary_function_set - global elementwise_unary_functions - if op_name not in elementwise_unary_functions: - element_uf_type = FunctionType.from_lists([element_type], [element_type]) - defun_op = DeclareFunOp(element_uf_type, op_name) - elementwise_unary_function_set.add(defun_op) - elementwise_unary_functions[op_name] = lambda x: [CallOp(defun_op.ret, [x])] - return elementwise_unary_functions[op_name] - - -class RewriteElementwiseUnaryOpPattern(TensorRewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite( - self, op: ElementwiseUnaryOperation, rewriter: PatternRewriter - ): - element_type = self.extract_op.result.type - op_name = toFuncName(op.name) - unary_function = getElementwiseUnaryFunction(op_name, element_type) - extract_op_op = TensorExtractOp(op.op, self.extract_op.indices) - call_ops = unary_function(extract_op_op.result) - rewriter.replace_op(self.extract_op, [extract_op_op] + call_ops) - rewriter.erase_matched_op() - - -class RewriteElementwiseBinaryOpPattern(TensorRewritePattern): - @op_type_rewrite_pattern - def match_and_rewrite( - self, op: ElementwiseBinaryOperation, rewriter: PatternRewriter - ): - element_type = self.extract_op.result.type - op_name = toFuncName(op.name) - binary_function = getElementwiseBinaryFunction(op_name, element_type) - extract_lhs_op = TensorExtractOp(op.lhs, self.extract_op.indices) - extract_rhs_op = TensorExtractOp(op.rhs, self.extract_op.indices) - call_ops = binary_function(extract_lhs_op.result, extract_rhs_op.result) - rewriter.replace_op( - self.extract_op, [extract_lhs_op, extract_rhs_op] + call_ops - ) - rewriter.erase_op(op) - - class TensorExtractOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: TensorExtractOp, rewriter: PatternRewriter): source = op.tensor source_parent_op = source.owner - if isinstance(source_parent_op, ElementwiseUnaryOperation): - RewriteElementwiseUnaryOpPattern(op).match_and_rewrite( - source_parent_op, rewriter - ) - elif isinstance(source_parent_op, ElementwiseBinaryOperation): - RewriteElementwiseBinaryOpPattern(op).match_and_rewrite( - source_parent_op, rewriter - ) - elif isinstance(source_parent_op, TensorTransposeOp): + if isinstance(source_parent_op, TensorTransposeOp): RewriteTransposeOpPattern(op).match_and_rewrite(source_parent_op, rewriter) -def insertFunctionBeforeModule(op: ModuleOp): - block = op.body.block - first_op = block.first_op - assert first_op is not None - while len(elementwise_binary_function_set) > 0: - function_op = elementwise_binary_function_set.pop() - block.insert_op_before(function_op, first_op) - - while len(elementwise_unary_function_set) > 0: - function_op = elementwise_unary_function_set.pop() - block.insert_op_before(function_op, first_op) - - def insertConstantsBeforeModule(op: ModuleOp): global bv_constants @@ -211,12 +75,9 @@ class RewriteSMTTensor(ModulePass): name = "rewrite-smt-tensor" def apply(self, ctx: Context, op: ModuleOp): - initElementwiseIntFunction() - walker = PatternRewriteWalker( GreedyRewritePatternApplier([TensorExtractOpPattern()]), walk_reverse=True ) walker.rewrite_module(op) - insertFunctionBeforeModule(op) insertConstantsBeforeModule(op) From d1c7f5151779e911dfe0cff02cbdf91e59b7ba62 Mon Sep 17 00:00:00 2001 From: Yuyou Fan Date: Sun, 30 Nov 2025 06:57:45 -0700 Subject: [PATCH 4/5] Remove an accident change --- xdsl_smt/passes/{ => lower_to_smt}/lower_to_smt.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename xdsl_smt/passes/{ => lower_to_smt}/lower_to_smt.py (100%) diff --git a/xdsl_smt/passes/lower_to_smt.py b/xdsl_smt/passes/lower_to_smt/lower_to_smt.py similarity index 100% rename from xdsl_smt/passes/lower_to_smt.py rename to xdsl_smt/passes/lower_to_smt/lower_to_smt.py From 9276bf51d0510fbb4bb00cf3883a3f6c8fd9f5cd Mon Sep 17 00:00:00 2001 From: Yuyou Fan Date: Sun, 30 Nov 2025 07:07:39 -0700 Subject: [PATCH 5/5] Fix pyright --- xdsl_smt/passes/lower_smt_tensor.py | 5 +++-- xdsl_smt/passes/rewrite_smt_tensor.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xdsl_smt/passes/lower_smt_tensor.py b/xdsl_smt/passes/lower_smt_tensor.py index cf716e3b..288e7090 100644 --- a/xdsl_smt/passes/lower_smt_tensor.py +++ b/xdsl_smt/passes/lower_smt_tensor.py @@ -11,6 +11,7 @@ from xdsl_smt.passes.dead_code_elimination import DeadCodeElimination from xdsl.dialects.builtin import ModuleOp from xdsl.ir import Attribute +from xdsl.utils.hints import isa from xdsl.context import Context from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, @@ -23,7 +24,7 @@ def lower_tensor_type(typ: Attribute) -> Attribute: - if isinstance(typ, SMTTensorType): + if isa(typ, SMTTensorType): result = typ.element_type index_type = IndexType for _ in typ.shape: @@ -35,7 +36,7 @@ def lower_tensor_type(typ: Attribute) -> Attribute: class DeclareConstOpPattern(RewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: DeclareConstOp, rewriter: PatternRewriter): - if isinstance(op.res.type, SMTTensorType): + if isa(op.res.type, SMTTensorType): new_constant_op = DeclareConstOp(lower_tensor_type(op.res.type)) rewriter.replace_matched_op(new_constant_op) diff --git a/xdsl_smt/passes/rewrite_smt_tensor.py b/xdsl_smt/passes/rewrite_smt_tensor.py index 23103259..d761fac8 100644 --- a/xdsl_smt/passes/rewrite_smt_tensor.py +++ b/xdsl_smt/passes/rewrite_smt_tensor.py @@ -34,7 +34,7 @@ def getBVConstant(x: int) -> smt_bv.ConstantOp: class TensorRewritePattern(RewritePattern, ABC): extract_op: TensorExtractOp - def __init__(self, extract_op): + def __init__(self, extract_op: TensorExtractOp): self.extract_op = extract_op super().__init__() @@ -43,7 +43,7 @@ class RewriteTransposeOpPattern(TensorRewritePattern): @op_type_rewrite_pattern def match_and_rewrite(self, op: TensorTransposeOp, rewriter: PatternRewriter): extract_op = self.extract_op - permutations = op.permutation.get_values() + permutations = op.get_permutation() new_indices: list[SSAValue] = [] for i in permutations: new_indices.append(extract_op.indices[i])