From c47d9dfdd39c625d7fa050348a17ad95a412ee22 Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Fri, 26 Sep 2025 11:05:13 -0600 Subject: [PATCH 1/7] add some fixes for lowering to C++ --- pyproject.toml | 1 - xdsl_smt/dialects/transfer.py | 11 + xdsl_smt/passes/transfer_lower.py | 77 +-- xdsl_smt/utils/lower_utils.py | 903 ++++++++++++++++++------------ 4 files changed, 591 insertions(+), 401 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ddf69724..e476f7c4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,6 @@ include = ["xdsl_smt", "tests"] ignore = [ "xdsl_smt/utils/z3_to_dialect.py", "xdsl_smt/utils/integer_to_z3.py", - "xdsl_smt/utils/lower_utils.py", "xdsl_smt/passes/calculate_smt.py", "xdsl_smt/passes/transfer_lower.py", "xdsl_smt/cli/xdsl_translate.py", diff --git a/xdsl_smt/dialects/transfer.py b/xdsl_smt/dialects/transfer.py index 28c89f68..a46915cb 100644 --- a/xdsl_smt/dialects/transfer.py +++ b/xdsl_smt/dialects/transfer.py @@ -286,6 +286,15 @@ class UAddOverflowOp(PredicateOp): class SAddOverflowOp(PredicateOp): name = "transfer.sadd_overflow" +@irdl_op_definition +class USubOverflowOp(PredicateOp): + name = "transfer.usub_overflow" + + +@irdl_op_definition +class SSubOverflowOp(PredicateOp): + name = "transfer.ssub_overflow" + @irdl_op_definition class AndOp(BinOp): @@ -829,6 +838,8 @@ class GetSignedMinValueOp(UnaryOp): SAddOverflowOp, UShlOverflowOp, SShlOverflowOp, + USubOverflowOp, + SSubOverflowOp, SelectOp, IsPowerOf2Op, IsAllOnesOp, diff --git a/xdsl_smt/passes/transfer_lower.py b/xdsl_smt/passes/transfer_lower.py index 7a4416d3..874ab889 100644 --- a/xdsl_smt/passes/transfer_lower.py +++ b/xdsl_smt/passes/transfer_lower.py @@ -1,76 +1,79 @@ -from typing import TextIO -from xdsl.dialects.func import * -from xdsl.pattern_rewriter import * -from functools import singledispatch from dataclasses import dataclass -from xdsl.passes import ModulePass +from functools import singledispatch +from typing import TextIO -from xdsl.ir import Operation from xdsl.context import Context +from xdsl.dialects.builtin import ModuleOp +from xdsl.dialects.func import FuncOp +from xdsl.ir import Operation +from xdsl.passes import ModulePass +from xdsl.pattern_rewriter import ( + GreedyRewritePatternApplier, + PatternRewriter, + PatternRewriteWalker, + RewritePattern, + op_type_rewrite_pattern, +) + from ..utils.lower_utils import ( - lowerOperation, CPP_CLASS_KEY, - lowerDispatcher, INDUCTION_KEY, + lowerDispatcher, lowerInductionOps, + lowerOperation, + set_int_to_apint, ) -from xdsl.pattern_rewriter import ( - RewritePattern, - PatternRewriter, - op_type_rewrite_pattern, - PatternRewriteWalker, - GreedyRewritePatternApplier, -) -from xdsl.dialects import builtin - autogen = 0 @singledispatch -def transferFunction(op, fout): +def transferFunction(op: Operation, fout: TextIO): pass -funcStr = "" indent = "\t" +funcPrefix = 'extern "C" ' +funcStr = funcPrefix needDispatch: list[FuncOp] = [] inductionOp: list[FuncOp] = [] @transferFunction.register -def _(op: Operation, fout): +def _(op: Operation, fout: TextIO): global needDispatch global inductionOp if isinstance(op, ModuleOp): return - # print(lowerDispatcher(needDispatch)) - # fout.write(lowerDispatcher(needDispatch)) if len(op.results) > 0 and op.results[0].name_hint is None: global autogen op.results[0].name_hint = "autogen" + str(autogen) autogen += 1 + if isinstance(op, FuncOp): + for arg in op.args: + if arg.name_hint is None: + arg.name_hint = "autogen" + str(autogen) + autogen += 1 + if CPP_CLASS_KEY in op.attributes: + needDispatch.append(op) + if INDUCTION_KEY in op.attributes: + inductionOp.append(op) global funcStr funcStr += lowerOperation(op) parentOp = op.parent_op() if isinstance(parentOp, FuncOp) and parentOp.body.block.last_op == op: funcStr += "}\n" fout.write(funcStr) - funcStr = "" - if isinstance(op, FuncOp): - if CPP_CLASS_KEY in op.attributes: - needDispatch.append(op) - if INDUCTION_KEY in op.attributes: - inductionOp.append(op) + funcStr = funcPrefix @dataclass class LowerOperation(RewritePattern): - def __init__(self, fout): + def __init__(self, fout: TextIO): self.fout = fout @op_type_rewrite_pattern - def match_and_rewrite(self, op: Operation, rewriter: PatternRewriter): + def match_and_rewrite(self, op: Operation, _: PatternRewriter): transferFunction(op, self.fout) @@ -83,20 +86,26 @@ def addInductionOps(fout: TextIO): def addDispatcher(fout: TextIO, is_forward: bool): global needDispatch if len(needDispatch) != 0: - # print(lowerDispatcher(needDispatch)) fout.write(lowerDispatcher(needDispatch, is_forward)) @dataclass(frozen=True) class LowerToCpp(ModulePass): name = "trans_lower" - fout: TextIO = None + fout: TextIO + int_to_apint: bool = False - def apply(self, ctx: Context, op: builtin.ModuleOp) -> None: + def apply(self, ctx: Context, op: ModuleOp) -> None: + global autogen + autogen = 0 + set_int_to_apint(self.int_to_apint) + # We found PatternRewriteWalker skipped the op itself during iteration + # Do it manually on op + transferFunction(op, self.fout) walker = PatternRewriteWalker( GreedyRewritePatternApplier([LowerOperation(self.fout)]), walk_regions_first=False, - apply_recursively=True, + apply_recursively=False, walk_reverse=False, ) walker.rewrite_module(op) diff --git a/xdsl_smt/utils/lower_utils.py b/xdsl_smt/utils/lower_utils.py index 86e1c4e5..2b0aaef4 100644 --- a/xdsl_smt/utils/lower_utils.py +++ b/xdsl_smt/utils/lower_utils.py @@ -1,52 +1,63 @@ +from functools import singledispatch +from typing import Callable + +import xdsl.dialects.arith as arith +from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType +from xdsl.dialects.func import CallOp, FuncOp, ReturnOp +from xdsl.ir import Attribute, Block, BlockArgument, Operation, SSAValue + from ..dialects.transfer import ( AbstractValueType, - GetOp, - MakeOp, - NegOp, - Constant, + AddPoisonOp, + AShrOp, + ClearHighBitsOp, + ClearLowBitsOp, + ClearSignBitOp, CmpOp, - AndOp, - OrOp, - XorOp, - AddOp, - SubOp, + ConcatOp, + Constant, + ConstRangeForOp, CountLOneOp, CountLZeroOp, CountROneOp, CountRZeroOp, - SetHighBitsOp, - SetLowBitsOp, - GetLowBitsOp, - GetBitWidthOp, - UMulOverflowOp, - SMinOp, - SMaxOp, - UMinOp, - UMaxOp, - TransIntegerType, - ShlOp, - AShrOp, - LShrOp, ExtractOp, - ConcatOp, GetAllOnesOp, - SelectOp, + GetBitWidthOp, + GetHighBitsOp, + GetLowBitsOp, + GetOp, + GetSignedMaxValueOp, + GetSignedMinValueOp, + IntersectsOp, + LShrOp, + MakeOp, + NegOp, NextLoopOp, - ConstRangeForOp, + RemovePoisonOp, RepeatOp, - IntersectsOp, - # FromArithOp, + SAddOverflowOp, + SDivOp, + SelectOp, + SetHighBitsOp, + SetLowBitsOp, + SetSignBitOp, + ShlOp, + SMaxOp, + SMinOp, + SMulOverflowOp, + SRemOp, + SShlOverflowOp, + TransIntegerType, TupleType, - AddPoisonOp, - RemovePoisonOp, + UAddOverflowOp, + UDivOp, + UMaxOp, + UMinOp, + UMulOverflowOp, + URemOp, + UShlOverflowOp, ) -from xdsl.dialects.func import FuncOp, Return, Call -from xdsl.pattern_rewriter import * -from functools import singledispatch -from typing import TypeVar, cast -from xdsl.dialects.builtin import Signedness, IntegerType, IndexType, IntegerAttr -from xdsl.ir import Operation -import xdsl.dialects.arith as arith operNameToCpp = { "transfer.and": "&", @@ -62,15 +73,29 @@ "arith.subi": "-", "transfer.neg": "~", "transfer.mul": "*", + "transfer.udiv": ".udiv", + "transfer.sdiv": ".sdiv", + "transfer.urem": ".urem", + "transfer.srem": ".srem", "transfer.umul_overflow": ".umul_ov", + "transfer.smul_overflow": ".smul_ov", + "transfer.uadd_overflow": ".uadd_ov", + "transfer.sadd_overflow": ".sadd_ov", + "transfer.ushl_overflow": ".ushl_ov", + "transfer.sshl_overflow": ".sshl_ov", "transfer.get_bit_width": ".getBitWidth", "transfer.countl_zero": ".countl_zero", "transfer.countr_zero": ".countr_zero", "transfer.countl_one": ".countl_one", "transfer.countr_one": ".countr_one", + "transfer.get_high_bits": ".getHiBits", "transfer.get_low_bits": ".getLoBits", "transfer.set_high_bits": ".setHighBits", "transfer.set_low_bits": ".setLowBits", + "transfer.clear_high_bits": ".clearHighBits", + "transfer.clear_low_bits": ".clearLowBits", + "transfer.set_sign_bit": ".setSignBit", + "transfer.clear_sign_bit": ".clearSignBit", "transfer.intersects": ".intersects", "transfer.cmp": [ ".eq", @@ -84,7 +109,6 @@ ".ugt", ".uge", ], - # "transfer.fromArith": "APInt", "transfer.make": "{{{0}}}", "transfer.get": "[{0}]", "transfer.shl": ".shl", @@ -92,22 +116,68 @@ "transfer.lshr": ".lshr", "transfer.concat": ".concat", "transfer.extract": ".extractBits", - "transfer.umin": [".ule", "?", ":"], - "transfer.smin": [".sle", "?", ":"], - "transfer.umax": [".ugt", "?", ":"], - "transfer.smax": [".sgt", "?", ":"], + "transfer.umin": "A::APIntOps::umin", + "transfer.smin": "A::APIntOps::smin", + "transfer.umax": "A::APIntOps::umax", + "transfer.smax": "A::APIntOps::smax", "func.return": "return", "transfer.constant": "APInt", "arith.select": ["?", ":"], "arith.cmpi": ["==", "!=", "<", "<=", ">", ">="], "transfer.get_all_ones": "APInt::getAllOnes", + "transfer.get_signed_max_value": "APInt::getSignedMaxValue", + "transfer.get_signed_min_value": "APInt::getSignedMinValue", "transfer.select": ["?", ":"], "transfer.reverse_bits": ".reverseBits", "transfer.add_poison": " ", "transfer.remove_poison": " ", + "comb.add": "+", + "comb.sub": "-", + "comb.mul": "*", + "comb.and": "&", + "comb.or": "|", + "comb.xor": "^", + "comb.divs": ".sdiv", + "comb.divu": ".udiv", + "comb.mods": ".srem", + "comb.modu": ".urem", + "comb.mux": ["?", ":"], + "comb.shrs": ".ashr", + "comb.shru": ".lshr", + "comb.shl": ".shl", + "comb.extract": ".extractBits", + "comb.concat": ".concat", } # transfer.constRangeLoop and NextLoop are controller operations, should be handle specially + +VAL_EXCEEDS_BW = "{1}.uge({1}.getBitWidth())" +RHS_IS_ZERO = "{1} == 0" +RET_ZERO = "{0} = APInt({1}.getBitWidth(), 0)" +RET_ONE = "{0} = APInt({1}.getBitWidth(), 1)" +RET_ONES = "{0} = APInt({1}.getBitWidth(), -1)" +RET_SIGN_MIN_VAL = "{0} = APInt::getSignedMinValue({1}.getBitWidth())" +RET_LHS = "{0} = {1}" + +SHIFT_ACTION = (VAL_EXCEEDS_BW, RET_ZERO) +ASHR_ACTION0 = VAL_EXCEEDS_BW + " && {0}.isSignBitSet()", RET_ONES +ASHR_ACTION1 = VAL_EXCEEDS_BW + " && {0}.isSignBitClear()", RET_ZERO +REM_ACTION = RHS_IS_ZERO, RET_LHS +DIV_ACTION = RHS_IS_ZERO, RET_ONES +SDIV_ACTION0 = ("{0}.isMinSignedValue() && {1} == -1", RET_SIGN_MIN_VAL) +SDIV_ACTION1 = (RHS_IS_ZERO + " && {0}.isNonNegative()", RET_ONES) +SDIV_ACTION2 = (RHS_IS_ZERO + " && {0}.isNegative()", RET_ONE) + +op_to_cons: dict[type[Operation], list[tuple[str, str]]] = { + ShlOp: [SHIFT_ACTION], + LShrOp: [SHIFT_ACTION], + UDivOp: [DIV_ACTION], + URemOp: [REM_ACTION], + SRemOp: [REM_ACTION], + AShrOp: [ASHR_ACTION0, ASHR_ACTION1], + SDivOp: [SDIV_ACTION0, SDIV_ACTION1, SDIV_ACTION2], +} + unsignedReturnedType = { CountLOneOp, CountLZeroOp, @@ -116,11 +186,49 @@ GetBitWidthOp, } -ends = ";\n" -indent = "\t" +int_to_apint = False +use_custom_vec = True +EQ = " = " +END = ";\n" +IDNT = "\t" +CPP_CLASS_KEY = "CPPCLASS" +INDUCTION_KEY = "induction" +OPERATION_NO = "operationNo" + + +def set_int_to_apint(to_apint: bool) -> None: + global int_to_apint + int_to_apint = to_apint + + +def set_use_custom_vec(custom_vec: bool) -> None: + global use_custom_vec + use_custom_vec = custom_vec + + +def get_ret_val(op: Operation) -> str: + ret_val = op.results[0].name_hint + assert ret_val + return ret_val + + +def get_op_names(op: Operation) -> list[str]: + return [oper.name_hint for oper in op.operands if oper.name_hint] + +def get_operand(op: Operation, idx: int) -> str: + name = op.operands[idx].name_hint + assert name + return name -def lowerType(typ, specialOp=None): + +def get_op_str(op: Operation) -> str: + op_name = operNameToCpp[op.name] + assert isinstance(op_name, str) + return op_name + + +def lowerType(typ: Attribute, specialOp: Operation | Block | None = None) -> str: if specialOp is not None: for op in unsignedReturnedType: if isinstance(specialOp, op): @@ -132,20 +240,17 @@ def lowerType(typ, specialOp=None): typeName = lowerType(fields[0]) for i in range(1, len(fields)): assert lowerType(fields[i]) == typeName + if use_custom_vec: + return "Vec<" + str(len(fields)) + ">" return "std::vector<" + typeName + ">" elif isinstance(typ, IntegerType): - return "int" + return "int" if not int_to_apint else "APInt" elif isinstance(typ, IndexType): return "int" assert False and "unsupported type" -CPP_CLASS_KEY = "CPPCLASS" -INDUCTION_KEY = "induction" -OPERATION_NO = "operationNo" - - -def lowerInductionOps(inductionOp: list[FuncOp]): +def lowerInductionOps(inductionOp: list[FuncOp]) -> str: if len(inductionOp) > 0: functionSignature = """ {returnedType} {funcName}(ArrayRef<{returnedType}> operands){{ @@ -159,16 +264,16 @@ def lowerInductionOps(inductionOp: list[FuncOp]): """ result = "" for func in inductionOp: - returnedType = func.function_type.outputs.data[0] funcName = func.sym_name.data - returnedType = lowerType(returnedType) - result += functionSignature.format( - returnedType=returnedType, funcName=funcName - ) + ret_ty = lowerType(func.function_type.outputs.data[0]) + result += functionSignature.format(returnedType=ret_ty, funcName=funcName) + return result + return "" + -def lowerDispatcher(needDispatch: list[FuncOp], is_forward: bool): +def lowerDispatcher(needDispatch: list[FuncOp], is_forward: bool) -> str: if len(needDispatch) > 0: returnedType = needDispatch[0].function_type.outputs.data[0] for func in needDispatch: @@ -190,18 +295,18 @@ def lowerDispatcher(needDispatch: list[FuncOp], is_forward: bool): functionSignature = ( "std::optional<" + returnedType + "> " + funcName + expr + "{{\n{0}}}\n\n" ) - indent = "\t" + dyn_cast = ( - indent + IDNT + "if(auto castedOp=dyn_cast<{0}>(op);castedOp&&{1}){{\n{2}" - + indent + + IDNT + "}}\n" ) - return_inst = indent + indent + "return {0}({1});\n" + return_inst = IDNT + IDNT + "return {0}({1});\n" def handleOneTransferFunction(func: FuncOp, operationNo: int) -> str: blockStr = "" - for cppClass in func.attributes[CPP_CLASS_KEY]: + for cppClass in func.attributes[CPP_CLASS_KEY]: # type: ignore argStr = "" if INDUCTION_KEY in func.attributes: argStr = "operands" @@ -215,95 +320,117 @@ def handleOneTransferFunction(func: FuncOp, operationNo: int) -> str: operationNoStr = "true" else: operationNoStr = "operationNo == " + str(operationNo) - blockStr += dyn_cast.format(cppClass.data, operationNoStr, ifBody) + blockStr += dyn_cast.format(cppClass.data, operationNoStr, ifBody) # type: ignore return blockStr funcBody = "" for func in needDispatch: if is_forward: - funcBody += handleOneTransferFunction(func) + funcBody += handleOneTransferFunction(func, -1) else: operationNo = func.attributes[OPERATION_NO] assert isinstance(operationNo, IntegerAttr) funcBody += handleOneTransferFunction(func, operationNo.value.data) - funcBody += indent + "return {};\n" + funcBody += IDNT + "return {};\n" + return functionSignature.format(funcBody) + return "" -def isFunctionCall(opName): + +def isFunctionCall(opName: str) -> bool: return opName[0] == "." -def lowerToNonClassMethod(op: Operation): - returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" +def lowerToNonClassMethod(op: Operation) -> str: + ret_type = lowerType(op.results[0].type, op) + ret_val = get_ret_val(op) expr = "(" if len(op.operands) > 0: - expr += op.operands[0].name_hint + expr += get_operand(op, 0) for i in range(1, len(op.operands)): - expr += "," + op.operands[i].name_hint + expr += "," + get_operand(op, i) expr += ")" - return ( - indent - + returnedType - + " " - + returnedValue - + equals - + operNameToCpp[op.name] - + expr - + ends - ) + return IDNT + ret_type + " " + ret_val + EQ + get_op_str(op) + expr + END + + +def lowerToClassMethod( + op: Operation, + castOperand: Callable[[SSAValue | str], str] | None = None, + castResult: Callable[[Operation], str] | None = None, +) -> str: + ret_ty = lowerType(op.results[0].type, op) + ret_val = get_ret_val(op) -def lowerToClassMethod(op: Operation, castOperand=None, castResult=None): - returnedType = lowerType(op.results[0].type, op) if castResult is not None: - returnedValue = op.results[0].name_hint + "_autocast" - else: - returnedValue = op.results[0].name_hint - equals = "=" - expr = op.operands[0].name_hint + operNameToCpp[op.name] + "(" + ret_val += "_autocast" + expr = get_operand(op, 0) + get_op_str(op) + "(" + if castOperand is not None: operands = [castOperand(operand) for operand in op.operands] else: - operands = [operand.name_hint for operand in op.operands] + operands = get_op_names(op) + if len(operands) > 1: expr += operands[1] for i in range(2, len(operands)): expr += "," + operands[i] + expr += ")" - result = indent + returnedType + " " + returnedValue + equals + expr + ends + + if type(op) in op_to_cons: + conds, actions = zip(*op_to_cons[type(op)]) # type: ignore + + og_op_names = get_op_names(op) + conds: list[str] = [cond.format(*og_op_names) for cond in conds] + actions: list[str] = [act.format(ret_val, *og_op_names) for act in actions] + + if_fmt = "if ({cond}) {{\n" + IDNT + IDNT + "{act}" + END + IDNT + "}}" + + ifs = " else ".join( + [if_fmt.format(cond=c, act=a) for c, a in zip(conds, actions)] + ) + + final_else_br = IDNT + IDNT + ret_val + EQ + expr + END + + result = IDNT + ret_ty + " " + ret_val + END + result += IDNT + ifs + " else {\n" + final_else_br + IDNT + "}\n" + + else: + result = IDNT + ret_ty + " " + ret_val + EQ + expr + END + if castResult is not None: return result + castResult(op) + return result @singledispatch -def lowerOperation(op): +def lowerOperation(op: Operation) -> str: returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] - if isFunctionCall(operNameToCpp[op.name]): - expr = operandsName[0] + operNameToCpp[op.name] + "(" + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) + op_str = get_op_str(op) + + if isFunctionCall(op_str): + expr = operandsName[0] + op_str + "(" if len(operandsName) > 1: expr += operandsName[1] for i in range(2, len(operandsName)): expr += "," + operandsName[i] expr += ")" else: - expr = operandsName[0] + operNameToCpp[op.name] + operandsName[1] - result = indent + returnedType + " " + returnedValue + equals + expr + ends - return result + expr = operandsName[0] + op_str + operandsName[1] + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register def _(op: CmpOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) predicate = op.predicate.value.data operName = operNameToCpp[op.name][predicate] expr = operandsName[0] + operName + "(" @@ -312,219 +439,291 @@ def _(op: CmpOp): for i in range(2, len(operandsName)): expr += "," + operandsName[i] expr += ")" - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: arith.Cmpi): +def _(op: arith.CmpiOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) assert len(operandsName) == 2 predicate = op.predicate.value.data operName = operNameToCpp[op.name][predicate] - expr = "(" + operandsName[0] + operName + operandsName[1] - expr += ")" - return indent + returnedType + " " + returnedValue + equals + expr + ends + expr = "(" + operandsName[0] + operName + operandsName[1] + ")" + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: arith.Select): +def _(op: arith.SelectOp): returnedType = lowerType(op.operands[1].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) operator = operNameToCpp[op.name] expr = "" for i in range(len(operandsName)): expr += operandsName[i] + " " if i < len(operator): expr += operator[i] + " " - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register def _(op: SelectOp): returnedType = lowerType(op.operands[1].type, op) - returnedValue = op.results[0].name_hint - equals = "=" - operandsName = [oper.name_hint for oper in op.operands] + returnedValue = get_ret_val(op) + operandsName = get_op_names(op) operator = operNameToCpp[op.name] expr = "" for i in range(len(operandsName)): expr += operandsName[i] + " " if i < len(operator): expr += operator[i] + " " - return indent + returnedType + " " + returnedValue + equals + expr + ends + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register -def _(op: GetOp): +def _(op: GetOp) -> str: returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - index = op.attributes["index"].value.data + returnedValue = get_ret_val(op) + index = op.attributes["index"].value.data # type: ignore + return ( - indent + IDNT + returnedType + " " + returnedValue - + equals - + op.operands[0].name_hint - + operNameToCpp[op.name].format(index) - + ends + + EQ + + get_operand(op, 0) + + get_op_str(op).format(index) # type: ignore + + END ) @lowerOperation.register -def _(op: MakeOp): +def _(op: MakeOp) -> str: returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" + returnedValue = get_ret_val(op) expr = "" if len(op.operands) > 0: - expr += op.operands[0].name_hint + expr += get_operand(op, 0) for i in range(1, len(op.operands)): - expr += "," + op.operands[i].name_hint + expr += "," + get_operand(op, i) + return ( - indent + IDNT + returnedType + " " + returnedValue - + equals + + EQ + returnedType - + operNameToCpp[op.name].format(expr) - + ends + + get_op_str(op).format(expr) + + END ) +def trivial_overflow_predicate(op: Operation) -> str: + returnedValue = get_ret_val(op) + varDecls = "bool " + returnedValue + END + expr = get_operand(op, 0) + get_op_str(op) + "(" + expr += get_operand(op, 1) + "," + returnedValue + ")" + result = varDecls + IDNT + expr + END + return IDNT + result + + @lowerOperation.register def _(op: UMulOverflowOp): - varDecls = "bool " + op.results[0].name_hint + ends - expr = op.operands[0].name_hint + operNameToCpp[op.name] + "(" - expr += op.operands[1].name_hint + "," + op.results[0].name_hint - expr += ")" - result = varDecls + "\t" + expr + ends - return indent + result + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: NegOp): - returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - equals = "=" - return ( - indent - + returnedType - + " " - + returnedValue - + equals - + operNameToCpp[op.name] - + op.operands[0].name_hint - + ends - ) +def _(op: SMulOverflowOp): + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: Return): - opName = operNameToCpp[op.name] + " " - operand = op.arguments[0].name_hint - return indent + opName + operand + ends +def _(op: UAddOverflowOp): + return trivial_overflow_predicate(op) -""" @lowerOperation.register -def _(op: FromArithOp): - opTy = op.op.type - assert isinstance(opTy, IntegerType) - size = opTy.width.data - returnedType = "APInt" - returnedValue = op.results[0].name_hint - return ( - indent - + returnedType - + " " - + returnedValue - + "(" - + str(size) - + ", " - + op.op.name_hint - + ")" - + ends - ) -""" +def _(op: SAddOverflowOp): + return trivial_overflow_predicate(op) @lowerOperation.register -def _(op: arith.Constant): - value = op.value.value.data +def _(op: SShlOverflowOp): + return trivial_overflow_predicate(op) + + +@lowerOperation.register +def _(op: UShlOverflowOp): + return trivial_overflow_predicate(op) + + +@lowerOperation.register +def _(op: NegOp) -> str: + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_str = get_op_str(op) + operand = get_operand(op, 0) + + return IDNT + ret_type + " " + ret_val + EQ + op_str + operand + END + + +@lowerOperation.register +def _(op: ReturnOp) -> str: + opName = get_op_str(op) + " " + operand = op.arguments[0].name_hint + assert operand + + return IDNT + opName + operand + END + + +@lowerOperation.register +def _(op: arith.ConstantOp): + value = op.value.value.data # type: ignore + assert isinstance(value, int) or isinstance(value, float) assert isinstance(op.results[0].type, IntegerType) size = op.results[0].type.width.data + max_val_plus_one = 1 << size returnedType = "int" - if value > ((1 << 31) - 1): + if value >= (1 << 31): assert False and "arith constant overflow maximal int" - returnedValue = op.results[0].name_hint - return indent + returnedType + " " + returnedValue + " = " + str(value) + ends + returnedValue = get_ret_val(op) + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + str((value + max_val_plus_one) % max_val_plus_one) + + END + ) @lowerOperation.register def _(op: Constant): value = op.value.value.data returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint + returnedValue = get_ret_val(op) return ( - indent + IDNT + returnedType + " " + returnedValue + "(" - + op.operands[0].name_hint + + get_operand(op, 0) + ".getBitWidth()," + str(value) + ")" - + ends + + END ) @lowerOperation.register def _(op: GetAllOnesOp): + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_name = get_op_str(op) + + return ( + IDNT + + ret_type + + " " + + ret_val + + EQ + + op_name + + "(" + + get_operand(op, 0) + + ".getBitWidth()" + + ")" + + END + ) + + +@lowerOperation.register +def _(op: GetSignedMaxValueOp): + ret_type = lowerType(op.results[0].type) + ret_val = get_ret_val(op) + op_name = get_op_str(op) + + return ( + IDNT + + ret_type + + " " + + ret_val + + EQ + + op_name + + "(" + + get_operand(op, 0) + + ".getBitWidth()" + + ")" + + END + ) + + +@lowerOperation.register +def _(op: GetSignedMinValueOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] + returnedValue = get_ret_val(op) + op_name = get_op_str(op) + return ( - indent + IDNT + returnedType + " " + returnedValue - + " = " - + opName + + EQ + + op_name + "(" - + op.operands[0].name_hint + + get_operand(op, 0) + ".getBitWidth()" + ")" - + ends + + END ) @lowerOperation.register -def _(op: Call): +def _(op: CallOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint + returnedValue = get_ret_val(op) callee = op.callee.string_value() + "(" - operandsName = [oper.name_hint for oper in op.operands] + operandsName = get_op_names(op) expr = "" if len(operandsName) > 0: expr += operandsName[0] for i in range(1, len(operandsName)): expr += "," + operandsName[i] expr += ")" - return indent + returnedType + " " + returnedValue + "=" + callee + expr + ends + return IDNT + returnedType + " " + returnedValue + EQ + callee + expr + END + + +def set_clear_bits( + op: SetHighBitsOp | SetLowBitsOp | ClearHighBitsOp | ClearLowBitsOp, +) -> str: + ret_ty = lowerType(op.results[0].type, op) + ret_val = get_ret_val(op) + arg = get_operand(op, 0) + count = get_operand(op, 1) + op_str = get_op_str(op) + + set_val = f"{IDNT}{ret_ty} {ret_val} = {arg};\n" + cond = f"{count}.ule({count}.getBitWidth())" + if_br = f"{IDNT}{IDNT}{ret_val}{op_str}({count}.getZExtValue());\n" + el_br = f"{IDNT}{IDNT}{ret_val}{op_str}({count}.getBitWidth());\n" + + return f"{set_val}{IDNT}if ({cond})\n{if_br}{IDNT}else\n{el_br}" @lowerOperation.register def _(op: FuncOp): - def lowerArgs(arg): + def lowerArgs(arg: BlockArgument) -> str: + assert arg.name_hint return lowerType(arg.type) + " " + arg.name_hint returnedType = lowerType(op.function_type.outputs.data[0]) @@ -535,21 +734,23 @@ def lowerArgs(arg): for i in range(1, len(op.args)): expr += "," + lowerArgs(op.args[i]) expr += ")" - # return returnedType + " " + funcName + expr + "{{\n{0}}}\n\n" - return returnedType + " " + funcName + expr + "{\n" + return returnedType + " " + funcName + expr + "{\n" # } -def castToAPIntFromUnsigned(op: Operation): - lastReturn = op.results[0].name_hint + "_autocast" + +def castToAPIntFromUnsigned(op: Operation) -> str: + returnedValue = get_ret_val(op) + lastReturn = returnedValue + "_autocast" apInt = None for operand in op.operands: if isinstance(operand.type, TransIntegerType): apInt = operand.name_hint break returnedType = "APInt" - returnedValue = op.results[0].name_hint + assert apInt + return ( - indent + IDNT + returnedType + " " + returnedValue @@ -558,10 +759,30 @@ def castToAPIntFromUnsigned(op: Operation): + ".getBitWidth()," + lastReturn + ")" - + ends + + END ) +@lowerOperation.register +def _(op: SDivOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: UDivOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: SRemOp): + return lowerToClassMethod(op, None, None) + + +@lowerOperation.register +def _(op: URemOp): + return lowerToClassMethod(op, None, None) + + @lowerOperation.register def _(op: IntersectsOp): return lowerToClassMethod(op, None, None) @@ -587,34 +808,57 @@ def _(op: CountRZeroOp): return lowerToClassMethod(op, None, castToAPIntFromUnsigned) -def castToUnisgnedFromAPInt(operand): - if isinstance(operand.type, TransIntegerType): - return operand.name_hint + ".getZExtValue()" - return operand.name_hint +def castToUnisgnedFromAPInt(operand: SSAValue | str) -> str: + if isinstance(operand, str): + return "(" + operand + ").getZExtValue()" + elif isinstance(operand.type, TransIntegerType): + return f"{operand.name_hint}.getZExtValue()" + + return str(operand.name_hint) @lowerOperation.register def _(op: SetHighBitsOp): + return set_clear_bits(op) + + +@lowerOperation.register +def _(op: SetLowBitsOp): + return set_clear_bits(op) + + +@lowerOperation.register +def _(op: ClearHighBitsOp): + return set_clear_bits(op) + + +@lowerOperation.register +def _(op: ClearLowBitsOp): + return set_clear_bits(op) + + +@lowerOperation.register +def _(op: SetSignBitOp): returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" + op.operands[0].name_hint + ends + "\t" - expr = op.results[0].name_hint + operNameToCpp[op.name] + "(" - operands = op.operands[1].name_hint + ".getZExtValue()" + returnedValue = get_ret_val(op) + equals = EQ + get_operand(op, 0) + END + IDNT + expr = returnedValue + get_op_str(op) + "(" + operands = "" expr = expr + operands + ")" - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + + return IDNT + returnedType + " " + returnedValue + equals + expr + END @lowerOperation.register -def _(op: SetLowBitsOp): +def _(op: ClearSignBitOp): returnedType = lowerType(op.results[0].type, op) - returnedValue = op.results[0].name_hint - equals = "=" + op.operands[0].name_hint + ends + "\t" - expr = op.results[0].name_hint + operNameToCpp[op.name] + "(" - operands = op.operands[1].name_hint + ".getZExtValue()" + returnedValue = get_ret_val(op) + equals = EQ + get_operand(op, 0) + END + IDNT + expr = returnedValue + get_op_str(op) + "(" + operands = "" expr = expr + operands + ")" - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + + return IDNT + returnedType + " " + returnedValue + equals + expr + END @lowerOperation.register @@ -622,98 +866,45 @@ def _(op: GetLowBitsOp): return lowerToClassMethod(op, castToUnisgnedFromAPInt) +@lowerOperation.register +def _(op: GetHighBitsOp): + return lowerToClassMethod(op, castToUnisgnedFromAPInt) + + @lowerOperation.register def _(op: GetBitWidthOp): return lowerToClassMethod(op, None, castToAPIntFromUnsigned) -# op1 < op2? op1: op2 @lowerOperation.register def _(op: SMaxOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + return lower_min_max(op) @lowerOperation.register def _(op: SMinOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + return lower_min_max(op) @lowerOperation.register def _(op: UMaxOp): - returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + return lower_min_max(op) @lowerOperation.register def _(op: UMinOp): + return lower_min_max(op) + + +def lower_min_max(op: UMinOp | UMaxOp | SMinOp | SMaxOp) -> str: returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - operands = [operand.name_hint for operand in op.operands] - operator = operNameToCpp[op.name] - equals = "=" - expr = ( - operands[0] - + operator[0] - + "(" - + operands[1] - + ")" - + operator[1] - + operands[0] - + operator[2] - + operands[1] - ) - result = returnedType + " " + returnedValue + equals + expr + ends - return indent + result + returnedValue = get_ret_val(op) + operands = get_op_names(op) + operator = get_op_str(op) + + expr = operator + "(" + operands[0] + "," + operands[1] + ")" + + return IDNT + returnedType + " " + returnedValue + EQ + expr + END @lowerOperation.register @@ -751,111 +942,91 @@ def _(op: ConstRangeForOp): indvar, *block_iter_args = loopBody.args iter_args = op.iter_args - global indent loopBefore = "" for i, blk_arg in enumerate(block_iter_args): iter_type = lowerType(iter_args[i].type, iter_args[i].owner) iter_name = blk_arg.name_hint - loopBefore += ( - indent + iter_type + " " + iter_name + " = " + iter_args[i].name_hint + ends - ) + iter_arg = iter_args[i].name_hint + assert iter_name + assert iter_arg + + loopBefore += IDNT + iter_type + " " + iter_name + EQ + iter_arg + END - loopFor = indent + "for(APInt {0} = {1}; {0}.ule({2}); {0}+={3}){{\n".format( + loopFor = IDNT + "for(APInt {0} = {1}; {0}.ule({2}); {0}+={3}){{\n".format( indvar.name_hint, lowerBound, upperBound, step ) - indent += "\t" - """ - mainLoop="" - for loopOp in loopBody.ops: - mainLoop+=(indent + indent+ lowerOperation(loopOp)) - endLoopFor=indent+"}\n" - """ + return loopBefore + loopFor @lowerOperation.register -def _(op: NextLoopOp): +def _(op: NextLoopOp) -> str: loopBlock = op.parent_block() - indvar, *block_iter_args = loopBlock.args - global indent + assert loopBlock + _, *block_iter_args = loopBlock.args assignments = "" for i, arg in enumerate(op.operands): - assignments += ( - indent + block_iter_args[i].name_hint + " = " + arg.name_hint + ends - ) - indent = indent[:-1] - endLoopFor = indent + "}\n" + block_arg = block_iter_args[i].name_hint + arg_name = arg.name_hint + assert block_arg + assert arg_name + + assignments += IDNT + block_arg + EQ + arg_name + END + + endLoopFor = IDNT + "}\n" loopOp = loopBlock.parent_op() + assert loopOp + for i, res in enumerate(loopOp.results): - endLoopFor += ( - indent - + lowerType(res.type, loopOp) - + " " - + res.name_hint - + " = " - + block_iter_args[i].name_hint - + ends - ) + ty = lowerType(res.type, loopOp) + res_name = res.name_hint + block_arg = block_iter_args[i].name_hint + assert res_name + assert block_arg + + endLoopFor += IDNT + ty + " " + res_name + EQ + block_arg + END + return assignments + endLoopFor @lowerOperation.register def _(op: RepeatOp): returnedType = lowerType(op.operands[0].type, op) - returnedValue = op.results[0].name_hint - arg0_name = op.operands[0].name_hint - count = op.operands[1].name_hint - initExpr = indent + returnedType + " " + returnedValue + " = " + arg0_name + ends + returnedValue = get_ret_val(op) + arg0_name = get_operand(op, 0) + count = get_operand(op, 1) + initExpr = IDNT + returnedType + " " + returnedValue + EQ + arg0_name + END forHead = ( - indent - + "for(APInt i(" - + count - + ".getBitWidth(),1);i.ult(" - + count - + ");++i){\n" + IDNT + "for(APInt i(" + count + ".getBitWidth(),1);i.ult(" + count + ");++i){\n" ) forBody = ( - indent - + "\t" + IDNT + + IDNT + returnedValue - + " = " + + EQ + returnedValue + ".concat(" + arg0_name + ")" - + ends + + END ) - forEnd = indent + "}\n" + forEnd = IDNT + "}\n" return initExpr + forHead + forBody + forEnd @lowerOperation.register def _(op: AddPoisonOp): returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] - return ( - indent - + returnedType - + " " - + returnedValue - + " = " - + op.operands[0].name_hint - + ends - ) + returnedValue = get_ret_val(op) + operand = get_operand(op, 0) + + return IDNT + returnedType + " " + returnedValue + EQ + operand + END @lowerOperation.register -def _(op: RemovePoisonOp): +def _(op: RemovePoisonOp) -> str: returnedType = lowerType(op.results[0].type) - returnedValue = op.results[0].name_hint - opName = operNameToCpp[op.name] - return ( - indent - + returnedType - + " " - + returnedValue - + " = " - + op.operands[0].name_hint - + ends - ) + returnedValue = get_ret_val(op) + operand = get_operand(op, 0) + + return IDNT + returnedType + " " + returnedValue + EQ + operand + END From 6f386fd0379dda53f3eb7589b620329ec61533cb Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Fri, 26 Sep 2025 19:28:49 -0600 Subject: [PATCH 2/7] remove project specific features from cpp-translate --- pyproject.toml | 1 - xdsl_smt/cli/cpp_translate.py | 150 +++++++++--------------------- xdsl_smt/dialects/transfer.py | 17 +++- xdsl_smt/passes/transfer_lower.py | 38 ++------ xdsl_smt/utils/lower_utils.py | 60 +----------- 5 files changed, 70 insertions(+), 196 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e476f7c4..b18d6367 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ ignore = [ "xdsl_smt/utils/z3_to_dialect.py", "xdsl_smt/utils/integer_to_z3.py", "xdsl_smt/passes/calculate_smt.py", - "xdsl_smt/passes/transfer_lower.py", "xdsl_smt/cli/xdsl_translate.py", "xdsl_smt/cli/transfer_smt_verifier.py", ] diff --git a/xdsl_smt/cli/cpp_translate.py b/xdsl_smt/cli/cpp_translate.py index 29b13b26..f51913ed 100644 --- a/xdsl_smt/cli/cpp_translate.py +++ b/xdsl_smt/cli/cpp_translate.py @@ -1,128 +1,68 @@ -#!/usr/bin/env python3 - import argparse -from typing import cast import sys +from pathlib import Path from xdsl.context import Context -from xdsl.ir import Operation +from xdsl.dialects.arith import Arith +from xdsl.dialects.builtin import Builtin, ModuleOp +from xdsl.dialects.func import Func, FuncOp from xdsl.parser import Parser -from xdsl.dialects.arith import Arith -from xdsl.dialects.func import Func -from xdsl_smt.dialects.transfer import Transfer from xdsl_smt.dialects.llvm_dialect import LLVM -from xdsl_smt.passes.transfer_lower import LowerToCpp, addDispatcher, addInductionOps -from xdsl.dialects.func import FuncOp, ReturnOp -from xdsl.dialects.builtin import ( - Builtin, - ModuleOp, - IntegerAttr, - StringAttr, -) - - -def register_all_arguments(arg_parser: argparse.ArgumentParser): - arg_parser.add_argument( - "transfer_functions", type=str, nargs="?", help="path to the transfer functions" - ) - - -def parse_file(ctx: Context, file: str | None) -> Operation: - if file is None: - f = sys.stdin - file = "" - else: - f = open(file) - - parser = Parser(ctx, f.read(), file) - module = parser.parse_op() - return module - +from xdsl_smt.dialects.transfer import Transfer +from xdsl_smt.passes.transfer_lower import LowerToCpp -def is_transfer_function(func: FuncOp) -> bool: - return "applied_to" in func.attributes +def _register_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Translate MLIR code to C++") -def is_forward(func: FuncOp) -> bool: - if "is_forward" in func.attributes: - forward = func.attributes["is_forward"] - assert isinstance(forward, IntegerAttr) - return forward.value.data == 1 - return False + parser.add_argument( + "-i", + "--input", + type=Path, + default=None, + help="Path to the input MLIR file (defaults to stdin if omitted).", + ) + parser.add_argument( + "-o", + "--output", + type=Path, + default=None, + help="Path to the output MLIR file (defaults to stdout if omitted).", + ) + return parser.parse_args() -def getCounterexampleFunc(func: FuncOp) -> str | None: - if "soundness_counterexample" not in func.attributes: - return None - attr = func.attributes["soundness_counterexample"] - assert isinstance(attr, StringAttr) - return attr.data +def _parse_mlir_module(p: Path | None, ctx: Context) -> ModuleOp: + code = p.read_text() if p else sys.stdin.read() + fname = p.name if p else "" + mod = Parser(ctx, code, fname).parse_op() -def checkFunctionValidity(func: FuncOp) -> bool: - if len(func.function_type.inputs) != len(func.args): - return False - for func_type_arg, arg in zip(func.function_type.inputs, func.args): - if func_type_arg != arg.type: - return False - return_op = func.body.block.last_op - if not (return_op is not None and isinstance(return_op, ReturnOp)): - return False - return return_op.operands[0].type == func.function_type.outputs.data[0] + if isinstance(mod, ModuleOp): + return mod + elif isinstance(mod, FuncOp): + return ModuleOp([mod]) + else: + raise ValueError(f"mlir in '{fname}' is neither a ModuleOp, nor a FuncOp") -def main() -> None: +def _get_ctx() -> Context: ctx = Context() - arg_parser = argparse.ArgumentParser() - register_all_arguments(arg_parser) - args = arg_parser.parse_args() - - # Register all dialects ctx.load_dialect(Arith) ctx.load_dialect(Builtin) ctx.load_dialect(Func) ctx.load_dialect(Transfer) ctx.load_dialect(LLVM) - # Parse the files - module = parse_file(ctx, args.transfer_functions) - assert isinstance(module, ModuleOp) - - allFuncMapping: dict[str, FuncOp] = {} - forward = False - counterexampleFuncs: set[str] = set() - with open("tmp.cpp", "w") as fout: - LowerToCpp.fout = fout - for func in module.ops: - if isinstance(func, FuncOp): - if is_transfer_function(func): - forward |= is_transfer_function(func) and is_forward(func) - counterexampleFunc = getCounterexampleFunc(func) - if counterexampleFunc is not None: - counterexampleFuncs.add(counterexampleFunc) - allFuncMapping[func.sym_name.data] = func - - # check function validity - if not checkFunctionValidity(func): - print(func.sym_name) - # check function validity - - for counterexample in counterexampleFuncs: - assert counterexample in allFuncMapping - allFuncMapping[counterexample].detach() - del allFuncMapping[counterexample] - for func in module.ops: - if isinstance(func, FuncOp): - allFuncMapping[func.sym_name.data] = func - # HACK: we know the pass won't check that the operation is a module - LowerToCpp(fout).apply(ctx, cast(ModuleOp, func)) - addInductionOps(fout) - addDispatcher(fout, forward) - - # printer = Printer(target=Printer.Target.MLIR) - # printer.print_op(module) - - -if __name__ == "__main__": - main() + return ctx + + +def main() -> None: + args = _register_args() + + ctx = _get_ctx() + funcs = _parse_mlir_module(args.input, ctx) + output = args.output.open("w", encoding="utf-8") if args.output else sys.stdout + + LowerToCpp(output).apply(ctx, funcs) diff --git a/xdsl_smt/dialects/transfer.py b/xdsl_smt/dialects/transfer.py index a46915cb..5cb6144e 100644 --- a/xdsl_smt/dialects/transfer.py +++ b/xdsl_smt/dialects/transfer.py @@ -31,7 +31,7 @@ VarOperand, irdl_attr_definition, irdl_op_definition, - param_def, + ParameterDef, IRDLOperation, traits_def, lazy_traits_def, @@ -286,6 +286,7 @@ class UAddOverflowOp(PredicateOp): class SAddOverflowOp(PredicateOp): name = "transfer.sadd_overflow" + @irdl_op_definition class USubOverflowOp(PredicateOp): name = "transfer.usub_overflow" @@ -545,7 +546,7 @@ def __init__( @irdl_attr_definition class AbstractValueType(ParametrizedAttribute, TypeAttribute): name = "transfer.abs_value" - fields: ArrayAttr[Attribute] = param_def() + fields: ParameterDef[ArrayAttr[Attribute]] def get_num_fields(self) -> int: return len(self.fields.data) @@ -556,13 +557,13 @@ def get_fields(self): def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None: if isinstance(shape, list): shape = ArrayAttr(shape) - super().__init__(shape) + super().__init__([shape]) @irdl_attr_definition class TupleType(ParametrizedAttribute, TypeAttribute): name = "transfer.tuple" - fields: ArrayAttr[Attribute] = param_def() + fields: ParameterDef[ArrayAttr[Attribute]] def get_num_fields(self) -> int: return len(self.fields.data) @@ -573,7 +574,7 @@ def get_fields(self): def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None: if isinstance(shape, list): shape = ArrayAttr(shape) - super().__init__(shape) + super().__init__([shape]) @irdl_op_definition @@ -794,6 +795,11 @@ class GetSignedMinValueOp(UnaryOp): name = "transfer.get_signed_min_value" +@irdl_op_definition +class GetLimitedValueOp(BinOp): + name = "transfer.get_limited_value" + + Transfer = Dialect( "transfer", [ @@ -856,6 +862,7 @@ class GetSignedMinValueOp(UnaryOp): AddPoisonOp, RemovePoisonOp, ReverseBitsOp, + GetLimitedValueOp, ], [TransIntegerType, AbstractValueType, TupleType], ) diff --git a/xdsl_smt/passes/transfer_lower.py b/xdsl_smt/passes/transfer_lower.py index 874ab889..bf6864cc 100644 --- a/xdsl_smt/passes/transfer_lower.py +++ b/xdsl_smt/passes/transfer_lower.py @@ -1,5 +1,4 @@ from dataclasses import dataclass -from functools import singledispatch from typing import TextIO from xdsl.context import Context @@ -18,29 +17,18 @@ from ..utils.lower_utils import ( CPP_CLASS_KEY, INDUCTION_KEY, - lowerDispatcher, - lowerInductionOps, lowerOperation, set_int_to_apint, + set_use_custom_vec, ) autogen = 0 - - -@singledispatch -def transferFunction(op: Operation, fout: TextIO): - pass - - -indent = "\t" -funcPrefix = 'extern "C" ' -funcStr = funcPrefix +funcStr = "" needDispatch: list[FuncOp] = [] inductionOp: list[FuncOp] = [] -@transferFunction.register -def _(op: Operation, fout: TextIO): +def transfer_func(op: Operation, fout: TextIO): global needDispatch global inductionOp if isinstance(op, ModuleOp): @@ -64,7 +52,7 @@ def _(op: Operation, fout: TextIO): if isinstance(parentOp, FuncOp) and parentOp.body.block.last_op == op: funcStr += "}\n" fout.write(funcStr) - funcStr = funcPrefix + funcStr = "" @dataclass @@ -74,19 +62,7 @@ def __init__(self, fout: TextIO): @op_type_rewrite_pattern def match_and_rewrite(self, op: Operation, _: PatternRewriter): - transferFunction(op, self.fout) - - -def addInductionOps(fout: TextIO): - global inductionOp - if len(inductionOp) != 0: - fout.write(lowerInductionOps(inductionOp)) - - -def addDispatcher(fout: TextIO, is_forward: bool): - global needDispatch - if len(needDispatch) != 0: - fout.write(lowerDispatcher(needDispatch, is_forward)) + transfer_func(op, self.fout) @dataclass(frozen=True) @@ -94,14 +70,16 @@ class LowerToCpp(ModulePass): name = "trans_lower" fout: TextIO int_to_apint: bool = False + use_custom_vec: bool = False def apply(self, ctx: Context, op: ModuleOp) -> None: global autogen autogen = 0 set_int_to_apint(self.int_to_apint) + set_use_custom_vec(self.use_custom_vec) # We found PatternRewriteWalker skipped the op itself during iteration # Do it manually on op - transferFunction(op, self.fout) + transfer_func(op, self.fout) walker = PatternRewriteWalker( GreedyRewritePatternApplier([LowerOperation(self.fout)]), walk_regions_first=False, diff --git a/xdsl_smt/utils/lower_utils.py b/xdsl_smt/utils/lower_utils.py index 2b0aaef4..c1017458 100644 --- a/xdsl_smt/utils/lower_utils.py +++ b/xdsl_smt/utils/lower_utils.py @@ -187,7 +187,7 @@ } int_to_apint = False -use_custom_vec = True +use_custom_vec = False EQ = " = " END = ";\n" IDNT = "\t" @@ -818,22 +818,7 @@ def castToUnisgnedFromAPInt(operand: SSAValue | str) -> str: @lowerOperation.register -def _(op: SetHighBitsOp): - return set_clear_bits(op) - - -@lowerOperation.register -def _(op: SetLowBitsOp): - return set_clear_bits(op) - - -@lowerOperation.register -def _(op: ClearHighBitsOp): - return set_clear_bits(op) - - -@lowerOperation.register -def _(op: ClearLowBitsOp): +def _(op: SetHighBitsOp | SetLowBitsOp | ClearHighBitsOp | ClearLowBitsOp): return set_clear_bits(op) @@ -862,12 +847,7 @@ def _(op: ClearSignBitOp): @lowerOperation.register -def _(op: GetLowBitsOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: GetHighBitsOp): +def _(op: GetLowBitsOp | GetHighBitsOp): return lowerToClassMethod(op, castToUnisgnedFromAPInt) @@ -877,22 +857,7 @@ def _(op: GetBitWidthOp): @lowerOperation.register -def _(op: SMaxOp): - return lower_min_max(op) - - -@lowerOperation.register -def _(op: SMinOp): - return lower_min_max(op) - - -@lowerOperation.register -def _(op: UMaxOp): - return lower_min_max(op) - - -@lowerOperation.register -def _(op: UMinOp): +def _(op: SMaxOp | SMinOp | UMaxOp | UMinOp): return lower_min_max(op) @@ -908,22 +873,7 @@ def lower_min_max(op: UMinOp | UMaxOp | SMinOp | SMaxOp) -> str: @lowerOperation.register -def _(op: ShlOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: AShrOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: LShrOp): - return lowerToClassMethod(op, castToUnisgnedFromAPInt) - - -@lowerOperation.register -def _(op: ExtractOp): +def _(op: ShlOp | AShrOp | LShrOp | ExtractOp): return lowerToClassMethod(op, castToUnisgnedFromAPInt) From 9639b6abec6f10d2822e444974dfcdc85e8f4535 Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Fri, 26 Sep 2025 19:46:14 -0600 Subject: [PATCH 3/7] whoopsie --- xdsl_smt/dialects/transfer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/xdsl_smt/dialects/transfer.py b/xdsl_smt/dialects/transfer.py index 5cb6144e..45a0a8a1 100644 --- a/xdsl_smt/dialects/transfer.py +++ b/xdsl_smt/dialects/transfer.py @@ -31,7 +31,7 @@ VarOperand, irdl_attr_definition, irdl_op_definition, - ParameterDef, + param_def, IRDLOperation, traits_def, lazy_traits_def, @@ -546,7 +546,7 @@ def __init__( @irdl_attr_definition class AbstractValueType(ParametrizedAttribute, TypeAttribute): name = "transfer.abs_value" - fields: ParameterDef[ArrayAttr[Attribute]] + fields: ArrayAttr[Attribute] = param_def() def get_num_fields(self) -> int: return len(self.fields.data) @@ -557,13 +557,13 @@ def get_fields(self): def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None: if isinstance(shape, list): shape = ArrayAttr(shape) - super().__init__([shape]) + super().__init__(shape) @irdl_attr_definition class TupleType(ParametrizedAttribute, TypeAttribute): name = "transfer.tuple" - fields: ParameterDef[ArrayAttr[Attribute]] + fields: ArrayAttr[Attribute] = param_def() def get_num_fields(self) -> int: return len(self.fields.data) @@ -574,7 +574,7 @@ def get_fields(self): def __init__(self, shape: list[Attribute] | ArrayAttr[Attribute]) -> None: if isinstance(shape, list): shape = ArrayAttr(shape) - super().__init__([shape]) + super().__init__(shape) @irdl_op_definition From 4ee5cba93ccfa89151ec391cfca05192032ceafc Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Sat, 27 Sep 2025 10:32:57 -0600 Subject: [PATCH 4/7] add filecheck tests to cpp translate --- tests/filecheck/lower-to-cpp/arith.mlir | 131 ++++++++ tests/filecheck/lower-to-cpp/special-ops.mlir | 44 +++ .../lower-to-cpp/transfer-bin-ops.mlir | 291 ++++++++++++++++++ .../lower-to-cpp/transfer-pred-ops.mlir | 180 +++++++++++ .../lower-to-cpp/transfer-unary-ops.mlir | 131 ++++++++ xdsl_smt/cli/cpp_translate.py | 12 +- 6 files changed, 788 insertions(+), 1 deletion(-) create mode 100644 tests/filecheck/lower-to-cpp/arith.mlir create mode 100644 tests/filecheck/lower-to-cpp/special-ops.mlir create mode 100644 tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir create mode 100644 tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir create mode 100644 tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir diff --git a/tests/filecheck/lower-to-cpp/arith.mlir b/tests/filecheck/lower-to-cpp/arith.mlir new file mode 100644 index 00000000..e419a384 --- /dev/null +++ b/tests/filecheck/lower-to-cpp/arith.mlir @@ -0,0 +1,131 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.addi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "add_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.subi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "sub_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.andi"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "and_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.ori"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "or_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.xori"(%x, %y) : (i32, i32) -> i32 + "func.return"(%r) : (i32) -> () + }) {"sym_name" = "xor_test", "function_type" = (i32, i32) -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 0 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "eq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 1 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "neq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 2 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "lt_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 3 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "leq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 4 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "gt_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : i32, %y : i32): + %r = "arith.cmpi"(%x, %y) {"predicate" = 5 : i64} : (i32, i32) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "geq_test", "function_type" = (i32, i32) -> i1} : () -> () + + "func.func"() ({ + %x = "arith.constant"() {value = 3 : i32} : () -> i32 + "func.return"(%x) : (i32) -> () + }) {"sym_name" = "const_test", "function_type" = () -> i32} : () -> () + + "func.func"() ({ + ^0(%x : i32): + "func.return"(%x) : (i32) -> () + }) {"sym_name" = "empty_func_test", "function_type" = (i32) -> i32} : () -> () +}) : () -> () + +// CHECK: int add_test(int x,int y){ +// CHECK-NEXT: int r = x+y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sub_test(int x,int y){ +// CHECK-NEXT: int r = x-y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int and_test(int x,int y){ +// CHECK-NEXT: int r = x&y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int or_test(int x,int y){ +// CHECK-NEXT: int r = x|y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int xor_test(int x,int y){ +// CHECK-NEXT: int r = x^y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int eq_test(int x,int y){ +// CHECK-NEXT: int r = (x==y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int neq_test(int x,int y){ +// CHECK-NEXT: int r = (x!=y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int lt_test(int x,int y){ +// CHECK-NEXT: int r = (xy); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int geq_test(int x,int y){ +// CHECK-NEXT: int r = (x>=y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int const_test(){ +// CHECK-NEXT: int x = 3; +// CHECK-NEXT: return x; +// CHECK-NEXT: } +// CHECK-NEXT: int empty_func_test(int x){ +// CHECK-NEXT: return x; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/special-ops.mlir b/tests/filecheck/lower-to-cpp/special-ops.mlir new file mode 100644 index 00000000..ad8b16dd --- /dev/null +++ b/tests/filecheck/lower-to-cpp/special-ops.mlir @@ -0,0 +1,44 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%c : i1, %x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.select"(%c, %x, %y) : (i1, !transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "select_test", "function_type" = (i1, !transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.make"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> + "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () + }) {"sym_name" = "make_2_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer, %z : !transfer.integer): + %r = "transfer.make"(%x, %y, %z) : (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]> + "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>) -> () + }) {"sym_name" = "make_3_test", "function_type" = (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): + %r = "transfer.get"(%x) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_test", "function_type" = (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer} : () -> () +}) : () -> () + +// CHECK: APInt select_test(int c,APInt x,APInt y){ +// CHECK-NEXT: APInt r = c ? x : y ; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector make_2_test(APInt x,APInt y){ +// CHECK-NEXT: std::vector r = std::vector{x,y}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector make_3_test(APInt x,APInt y,APInt z){ +// CHECK-NEXT: std::vector r = std::vector{x,y,z}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_test(std::vector x){ +// CHECK-NEXT: APInt r = x[0]; +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir new file mode 100644 index 00000000..45d71c5b --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir @@ -0,0 +1,291 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.add"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "add_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sub"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "sub_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.mul"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "mul_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.and"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "and_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.or"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "or_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.xor"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "xor_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.udiv"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "udiv_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sdiv"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "sdiv_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.urem"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "urem_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.srem"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "srem_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.shl"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "shl_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.ashr"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "ashr_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.lshr"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "lshr_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umin"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "umin_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smin"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "smin_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umax"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "umax_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smax"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "smax_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.get_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.get_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.set_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.set_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.clear_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_high_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.clear_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () +}) : () -> () + +// CHECK: APInt add_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x+y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt sub_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x-y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt mul_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x*y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt and_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x&y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt or_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x|y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt xor_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x^y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt udiv_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.udiv(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt sdiv_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (x.isMinSignedValue() && y == -1) { +// CHECK-NEXT: r = APInt::getSignedMinValue(x.getBitWidth()); +// CHECK-NEXT: } else if (y == 0 && x.isNonNegative()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else if (y == 0 && x.isNegative()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 1); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.sdiv(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt urem_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = x; +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.urem(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt srem_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y == 0) { +// CHECK-NEXT: r = x; +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.srem(y); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt shl_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth())) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.shl(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt ashr_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth()) && x.isSignBitSet()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), -1); +// CHECK-NEXT: } else if (y.uge(y.getBitWidth()) && x.isSignBitClear()) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.ashr(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt lshr_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r; +// CHECK-NEXT: if (y.uge(y.getBitWidth())) { +// CHECK-NEXT: r = APInt(x.getBitWidth(), 0); +// CHECK-NEXT: } else { +// CHECK-NEXT: r = x.lshr(y.getZExtValue()); +// CHECK-NEXT: } +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt umin_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = A::APIntOps::umin(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt smin_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = A::APIntOps::smin(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt umax_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = A::APIntOps::umax(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt smax_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = A::APIntOps::smax(x,y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_high_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x.getHiBits(y.getZExtValue()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_low_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x.getLoBits(y.getZExtValue()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt set_high_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.setHighBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.setHighBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt set_low_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.setLowBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.setLowBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt clear_high_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.clearHighBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.clearHighBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt clear_low_bits_test(APInt x,APInt y){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: if (y.ule(y.getBitWidth())) +// CHECK-NEXT: r.clearLowBits(y.getZExtValue()); +// CHECK-NEXT: else +// CHECK-NEXT: r.clearLowBits(y.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir new file mode 100644 index 00000000..5a363339 --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir @@ -0,0 +1,180 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.umul_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "umul_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.smul_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "smul_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.uadd_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "uadd_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sadd_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sadd_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.ushl_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ushl_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.sshl_overflow"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sshl_ov_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.intersects"(%x, %y) : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "intersects_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 0} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "eq_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 1} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "neq_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 2} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "slt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 3} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sle_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 4} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sgt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 5} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "sge_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 6} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ult_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 7} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ule_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 8} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "ugt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.cmp"(%x, %y) {predicate = 9} : (!transfer.integer, !transfer.integer) -> i1 + "func.return"(%r) : (i1) -> () + }) {"sym_name" = "uge_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () +}) : () -> () + +// CHECK: int umul_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.umul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int smul_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.smul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uadd_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.uadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sadd_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ushl_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.ushl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sshl_ov_test(APInt x,APInt y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sshl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int intersects_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.intersects(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int eq_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.eq(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int neq_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.ne(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int slt_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.slt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sle_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.sle(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sgt_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.sgt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sge_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.sge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ult_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.ult(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ule_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.ule(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ugt_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.ugt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uge_test(APInt x,APInt y){ +// CHECK-NEXT: int r = x.uge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir new file mode 100644 index 00000000..6b0665e4 --- /dev/null +++ b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir @@ -0,0 +1,131 @@ +// RUN: cpp-translate -i %s | filecheck %s + +"builtin.module"() ({ + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_bit_width"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_bw_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countl_zero"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countl_zero_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countr_zero"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countr_zero_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countl_one"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countl_one_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.countr_one"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "countr_one_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.neg"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "neg_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.clear_sign_bit"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "clear_sign_bit_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.set_sign_bit"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "set_sign_bit_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_all_ones"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_all_ones_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_signed_max_value"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_signed_max_value_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.get_signed_min_value"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "get_signed_min_value_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer): + %r = "transfer.reverse_bits"(%x) : (!transfer.integer) -> !transfer.integer + "func.return"(%r) : (!transfer.integer) -> () + }) {"sym_name" = "reverse_bits_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () +}) : () -> () + +// CHECK: APInt get_bw_test(APInt x){ +// CHECK-NEXT: unsigned r_autocast = x.getBitWidth(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt countl_zero_test(APInt x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt countr_zero_test(APInt x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt countl_one_test(APInt x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt countr_one_test(APInt x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt neg_test(APInt x){ +// CHECK-NEXT: APInt r = ~x; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt clear_sign_bit_test(APInt x){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: r.clearSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt set_sign_bit_test(APInt x){ +// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: r.setSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_all_ones_test(APInt x){ +// CHECK-NEXT: APInt r = APInt::getAllOnes(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_signed_max_value_test(APInt x){ +// CHECK-NEXT: APInt r = APInt::getSignedMaxValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt get_signed_min_value_test(APInt x){ +// CHECK-NEXT: APInt r = APInt::getSignedMinValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: APInt reverse_bits_test(APInt x){ +// CHECK-NEXT: APInt r = x.reverseBits(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/xdsl_smt/cli/cpp_translate.py b/xdsl_smt/cli/cpp_translate.py index f51913ed..1f935c19 100644 --- a/xdsl_smt/cli/cpp_translate.py +++ b/xdsl_smt/cli/cpp_translate.py @@ -30,6 +30,14 @@ def _register_args() -> argparse.Namespace: default=None, help="Path to the output MLIR file (defaults to stdout if omitted).", ) + parser.add_argument( + "--apint", action="store_true", help="Use apints for bitvector type lowering" + ) + parser.add_argument( + "--custom_vec", + action="store_true", + help="Use custom vec class for transfer value lowering", + ) return parser.parse_args() @@ -65,4 +73,6 @@ def main() -> None: funcs = _parse_mlir_module(args.input, ctx) output = args.output.open("w", encoding="utf-8") if args.output else sys.stdout - LowerToCpp(output).apply(ctx, funcs) + LowerToCpp( + output, int_to_apint=args.apint, use_custom_vec=args.custom_vec + ).apply(ctx, funcs) From c33c6a25c4cdc7f0e4080ca5696ee1e3817e8a90 Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Sat, 27 Sep 2025 10:37:51 -0600 Subject: [PATCH 5/7] format --- tests/filecheck/lower-to-cpp/arith.mlir | 22 +++++++++---------- tests/filecheck/lower-to-cpp/special-ops.mlir | 2 +- .../lower-to-cpp/transfer-bin-ops.mlir | 4 ++-- .../lower-to-cpp/transfer-pred-ops.mlir | 2 +- xdsl_smt/cli/cpp_translate.py | 6 ++--- 5 files changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/filecheck/lower-to-cpp/arith.mlir b/tests/filecheck/lower-to-cpp/arith.mlir index e419a384..3d4fa0eb 100644 --- a/tests/filecheck/lower-to-cpp/arith.mlir +++ b/tests/filecheck/lower-to-cpp/arith.mlir @@ -6,25 +6,25 @@ %r = "arith.addi"(%x, %y) : (i32, i32) -> i32 "func.return"(%r) : (i32) -> () }) {"sym_name" = "add_test", "function_type" = (i32, i32) -> i32} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.subi"(%x, %y) : (i32, i32) -> i32 "func.return"(%r) : (i32) -> () }) {"sym_name" = "sub_test", "function_type" = (i32, i32) -> i32} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.andi"(%x, %y) : (i32, i32) -> i32 "func.return"(%r) : (i32) -> () }) {"sym_name" = "and_test", "function_type" = (i32, i32) -> i32} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.ori"(%x, %y) : (i32, i32) -> i32 "func.return"(%r) : (i32) -> () }) {"sym_name" = "or_test", "function_type" = (i32, i32) -> i32} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.xori"(%x, %y) : (i32, i32) -> i32 @@ -36,42 +36,42 @@ %r = "arith.cmpi"(%x, %y) {"predicate" = 0 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "eq_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.cmpi"(%x, %y) {"predicate" = 1 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "neq_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.cmpi"(%x, %y) {"predicate" = 2 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "lt_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.cmpi"(%x, %y) {"predicate" = 3 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "leq_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.cmpi"(%x, %y) {"predicate" = 4 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "gt_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ ^0(%x : i32, %y : i32): %r = "arith.cmpi"(%x, %y) {"predicate" = 5 : i64} : (i32, i32) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "geq_test", "function_type" = (i32, i32) -> i1} : () -> () - + "func.func"() ({ %x = "arith.constant"() {value = 3 : i32} : () -> i32 "func.return"(%x) : (i32) -> () }) {"sym_name" = "const_test", "function_type" = () -> i32} : () -> () - + "func.func"() ({ ^0(%x : i32): "func.return"(%x) : (i32) -> () diff --git a/tests/filecheck/lower-to-cpp/special-ops.mlir b/tests/filecheck/lower-to-cpp/special-ops.mlir index ad8b16dd..bdbba35a 100644 --- a/tests/filecheck/lower-to-cpp/special-ops.mlir +++ b/tests/filecheck/lower-to-cpp/special-ops.mlir @@ -12,7 +12,7 @@ %r = "transfer.make"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () }) {"sym_name" = "make_2_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () - + "func.func"() ({ ^0(%x : !transfer.integer, %y : !transfer.integer, %z : !transfer.integer): %r = "transfer.make"(%x, %y, %z) : (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]> diff --git a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir index 45d71c5b..06b0dc1e 100644 --- a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir +++ b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir @@ -102,7 +102,7 @@ %r = "transfer.smax"(%x, %y) : (!transfer.integer,!transfer.integer) -> !transfer.integer "func.return"(%r) : (!transfer.integer) -> () }) {"sym_name" = "smax_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () - + "func.func"() ({ ^0(%x : !transfer.integer, %y : !transfer.integer): %r = "transfer.get_high_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer @@ -113,7 +113,7 @@ ^0(%x : !transfer.integer, %y : !transfer.integer): %r = "transfer.get_low_bits"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.integer "func.return"(%r) : (!transfer.integer) -> () - }) {"sym_name" = "get_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () + }) {"sym_name" = "get_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () "func.func"() ({ ^0(%x : !transfer.integer, %y : !transfer.integer): diff --git a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir index 5a363339..d3d85556 100644 --- a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir +++ b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir @@ -96,7 +96,7 @@ %r = "transfer.cmp"(%x, %y) {predicate = 8} : (!transfer.integer, !transfer.integer) -> i1 "func.return"(%r) : (i1) -> () }) {"sym_name" = "ugt_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () - + "func.func"() ({ ^0(%x : !transfer.integer, %y : !transfer.integer): %r = "transfer.cmp"(%x, %y) {predicate = 9} : (!transfer.integer, !transfer.integer) -> i1 diff --git a/xdsl_smt/cli/cpp_translate.py b/xdsl_smt/cli/cpp_translate.py index 1f935c19..6b98b0e4 100644 --- a/xdsl_smt/cli/cpp_translate.py +++ b/xdsl_smt/cli/cpp_translate.py @@ -73,6 +73,6 @@ def main() -> None: funcs = _parse_mlir_module(args.input, ctx) output = args.output.open("w", encoding="utf-8") if args.output else sys.stdout - LowerToCpp( - output, int_to_apint=args.apint, use_custom_vec=args.custom_vec - ).apply(ctx, funcs) + LowerToCpp(output, int_to_apint=args.apint, use_custom_vec=args.custom_vec).apply( + ctx, funcs + ) From 381187aedf2dcb1e5af516af499402acaa9139b9 Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Tue, 14 Oct 2025 12:45:36 -0600 Subject: [PATCH 6/7] Add flag for knownbits, pass APInt's as const ref --- tests/filecheck/lower-to-cpp/special-ops.mlir | 70 ++++-- .../lower-to-cpp/transfer-bin-ops.mlir | 106 ++++----- .../lower-to-cpp/transfer-pred-ops.mlir | 148 ++++++------ .../lower-to-cpp/transfer-unary-ops.mlir | 110 ++++----- xdsl_smt/cli/cpp_translate.py | 37 +-- xdsl_smt/passes/transfer_lower.py | 57 ++--- xdsl_smt/utils/lower_utils.py | 224 +++++++----------- 7 files changed, 355 insertions(+), 397 deletions(-) diff --git a/tests/filecheck/lower-to-cpp/special-ops.mlir b/tests/filecheck/lower-to-cpp/special-ops.mlir index bdbba35a..9c983a48 100644 --- a/tests/filecheck/lower-to-cpp/special-ops.mlir +++ b/tests/filecheck/lower-to-cpp/special-ops.mlir @@ -6,39 +6,57 @@ %r = "transfer.select"(%c, %x, %y) : (i1, !transfer.integer, !transfer.integer) -> !transfer.integer "func.return"(%r) : (!transfer.integer) -> () }) {"sym_name" = "select_test", "function_type" = (i1, !transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () - + "func.func"() ({ - ^0(%x : !transfer.integer, %y : !transfer.integer): - %r = "transfer.make"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> + ^0(%lhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>, %rhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): + %lhs0 = "transfer.get"(%lhs) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %lhs1 = "transfer.get"(%lhs) {index = 1} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %rhs0 = "transfer.get"(%rhs) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %rhs1 = "transfer.get"(%rhs) {index = 1} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer + %res0 = "transfer.or"(%lhs0, %rhs0) : (!transfer.integer, !transfer.integer) -> !transfer.integer + %res1 = "transfer.and"(%lhs1, %rhs1) : (!transfer.integer, !transfer.integer) -> !transfer.integer + %r = "transfer.make"(%res0, %res1) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () - }) {"sym_name" = "make_2_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () - - "func.func"() ({ - ^0(%x : !transfer.integer, %y : !transfer.integer, %z : !transfer.integer): - %r = "transfer.make"(%x, %y, %z) : (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]> - "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>) -> () - }) {"sym_name" = "make_3_test", "function_type" = (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>} : () -> () + }) {"sym_name" = "kb_and_test", "function_type" = (!transfer.abs_value<[!transfer.integer, !transfer.integer]>, !transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () "func.func"() ({ ^0(%x : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): %r = "transfer.get"(%x) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer "func.return"(%r) : (!transfer.integer) -> () }) {"sym_name" = "get_test", "function_type" = (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer} : () -> () + + "func.func"() ({ + ^0(%x : !transfer.integer, %y : !transfer.integer): + %r = "transfer.make"(%x, %y) : (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]> + "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> () + }) {"sym_name" = "make_2_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer]>} : () -> () + + // "func.func"() ({ + // ^0(%x : !transfer.integer, %y : !transfer.integer, %z : !transfer.integer): + // %r = "transfer.make"(%x, %y, %z) : (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]> + // "func.return"(%r) : (!transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>) -> () + // }) {"sym_name" = "make_3_test", "function_type" = (!transfer.integer, !transfer.integer, !transfer.integer) -> !transfer.abs_value<[!transfer.integer, !transfer.integer, !transfer.integer]>} : () -> () }) : () -> () -// CHECK: APInt select_test(int c,APInt x,APInt y){ -// CHECK-NEXT: APInt r = c ? x : y ; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: std::vector make_2_test(APInt x,APInt y){ -// CHECK-NEXT: std::vector r = std::vector{x,y}; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: std::vector make_3_test(APInt x,APInt y,APInt z){ -// CHECK-NEXT: std::vector r = std::vector{x,y,z}; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt get_test(std::vector x){ -// CHECK-NEXT: APInt r = x[0]; -// CHECK-NEXT: return r; -// CHECK-NEXT: } +// CHECK: const APInt select_test(int c,const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = c ? x : y ; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector kb_and_test(std::vector &lhs,std::vector &rhs){ +// CHECK-NEXT: const APInt lhs0 = lhs[0]; +// CHECK-NEXT: const APInt lhs1 = lhs[1]; +// CHECK-NEXT: const APInt rhs0 = rhs[0]; +// CHECK-NEXT: const APInt rhs1 = rhs[1]; +// CHECK-NEXT: const APInt res0 = lhs0|rhs0; +// CHECK-NEXT: const APInt res1 = lhs1&rhs1; +// CHECK-NEXT: std::vector r = std::vector{res0,res1}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_test(std::vector &x){ +// CHECK-NEXT: const APInt r = x[0]; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: std::vector make_2_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: std::vector r = std::vector{x,y}; +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir index 06b0dc1e..793d643e 100644 --- a/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir +++ b/tests/filecheck/lower-to-cpp/transfer-bin-ops.mlir @@ -140,32 +140,32 @@ }) {"sym_name" = "clear_low_bits_test", "function_type" = (!transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () }) : () -> () -// CHECK: APInt add_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x+y; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt sub_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x-y; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt mul_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x*y; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt and_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x&y; -// CHECK-NEXT: return r; +// CHECK: const APInt add_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x+y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt sub_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x-y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt mul_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x*y; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt and_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x&y; +// CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt or_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x|y; +// CHECK-NEXT: const APInt or_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x|y; // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt xor_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x^y; +// CHECK-NEXT: const APInt xor_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x^y; // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt udiv_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt udiv_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y == 0) { // CHECK-NEXT: r = APInt(x.getBitWidth(), -1); // CHECK-NEXT: } else { @@ -173,8 +173,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt sdiv_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt sdiv_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (x.isMinSignedValue() && y == -1) { // CHECK-NEXT: r = APInt::getSignedMinValue(x.getBitWidth()); // CHECK-NEXT: } else if (y == 0 && x.isNonNegative()) { @@ -186,8 +186,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt urem_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt urem_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y == 0) { // CHECK-NEXT: r = x; // CHECK-NEXT: } else { @@ -195,8 +195,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt srem_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt srem_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y == 0) { // CHECK-NEXT: r = x; // CHECK-NEXT: } else { @@ -204,8 +204,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt shl_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt shl_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y.uge(y.getBitWidth())) { // CHECK-NEXT: r = APInt(x.getBitWidth(), 0); // CHECK-NEXT: } else { @@ -213,8 +213,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt ashr_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt ashr_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y.uge(y.getBitWidth()) && x.isSignBitSet()) { // CHECK-NEXT: r = APInt(x.getBitWidth(), -1); // CHECK-NEXT: } else if (y.uge(y.getBitWidth()) && x.isSignBitClear()) { @@ -224,8 +224,8 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt lshr_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r; +// CHECK-NEXT: const APInt lshr_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r; // CHECK-NEXT: if (y.uge(y.getBitWidth())) { // CHECK-NEXT: r = APInt(x.getBitWidth(), 0); // CHECK-NEXT: } else { @@ -233,56 +233,56 @@ // CHECK-NEXT: } // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt umin_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = A::APIntOps::umin(x,y); +// CHECK-NEXT: const APInt umin_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::umin(x,y); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt smin_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = A::APIntOps::smin(x,y); +// CHECK-NEXT: const APInt smin_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::smin(x,y); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt umax_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = A::APIntOps::umax(x,y); +// CHECK-NEXT: const APInt umax_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::umax(x,y); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt smax_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = A::APIntOps::smax(x,y); +// CHECK-NEXT: const APInt smax_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = A::APIntOps::smax(x,y); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt get_high_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x.getHiBits(y.getZExtValue()); +// CHECK-NEXT: const APInt get_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x.getHiBits(y.getZExtValue()); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt get_low_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x.getLoBits(y.getZExtValue()); +// CHECK-NEXT: const APInt get_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x.getLoBits(y.getZExtValue()); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt set_high_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: const APInt set_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; // CHECK-NEXT: if (y.ule(y.getBitWidth())) // CHECK-NEXT: r.setHighBits(y.getZExtValue()); // CHECK-NEXT: else // CHECK-NEXT: r.setHighBits(y.getBitWidth()); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt set_low_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: const APInt set_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; // CHECK-NEXT: if (y.ule(y.getBitWidth())) // CHECK-NEXT: r.setLowBits(y.getZExtValue()); // CHECK-NEXT: else // CHECK-NEXT: r.setLowBits(y.getBitWidth()); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt clear_high_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: const APInt clear_high_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; // CHECK-NEXT: if (y.ule(y.getBitWidth())) // CHECK-NEXT: r.clearHighBits(y.getZExtValue()); // CHECK-NEXT: else // CHECK-NEXT: r.clearHighBits(y.getBitWidth()); // CHECK-NEXT: return r; // CHECK-NEXT: } -// CHECK-NEXT: APInt clear_low_bits_test(APInt x,APInt y){ -// CHECK-NEXT: APInt r = x; +// CHECK-NEXT: const APInt clear_low_bits_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: const APInt r = x; // CHECK-NEXT: if (y.ule(y.getBitWidth())) // CHECK-NEXT: r.clearLowBits(y.getZExtValue()); // CHECK-NEXT: else diff --git a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir index d3d85556..293e2232 100644 --- a/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir +++ b/tests/filecheck/lower-to-cpp/transfer-pred-ops.mlir @@ -104,77 +104,77 @@ }) {"sym_name" = "uge_test", "function_type" = (!transfer.integer, !transfer.integer) -> i1} : () -> () }) : () -> () -// CHECK: int umul_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.umul_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int smul_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.smul_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int uadd_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.uadd_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int sadd_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.sadd_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int ushl_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.ushl_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int sshl_ov_test(APInt x,APInt y){ -// CHECK-NEXT: bool r; -// CHECK-NEXT: x.sshl_ov(y,r); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int intersects_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.intersects(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int eq_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.eq(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int neq_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.ne(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int slt_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.slt(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int sle_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.sle(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int sgt_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.sgt(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int sge_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.sge(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int ult_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.ult(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int ule_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.ule(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int ugt_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.ugt(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: int uge_test(APInt x,APInt y){ -// CHECK-NEXT: int r = x.uge(y); -// CHECK-NEXT: return r; -// CHECK-NEXT: } +// CHECK: int umul_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.umul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int smul_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.smul_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uadd_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.uadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sadd_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sadd_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ushl_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.ushl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sshl_ov_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: bool r; +// CHECK-NEXT: x.sshl_ov(y,r); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int intersects_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.intersects(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int eq_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.eq(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int neq_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ne(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int slt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.slt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sle_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sle(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sgt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sgt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int sge_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.sge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ult_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ult(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ule_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ule(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int ugt_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.ugt(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: int uge_test(const APInt &x,const APInt &y){ +// CHECK-NEXT: int r = x.uge(y); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir index 6b0665e4..7c5a0e8e 100644 --- a/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir +++ b/tests/filecheck/lower-to-cpp/transfer-unary-ops.mlir @@ -74,58 +74,58 @@ }) {"sym_name" = "reverse_bits_test", "function_type" = (!transfer.integer) -> !transfer.integer} : () -> () }) : () -> () -// CHECK: APInt get_bw_test(APInt x){ -// CHECK-NEXT: unsigned r_autocast = x.getBitWidth(); -// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt countl_zero_test(APInt x){ -// CHECK-NEXT: unsigned r_autocast = x.countl_zero(); -// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt countr_zero_test(APInt x){ -// CHECK-NEXT: unsigned r_autocast = x.countr_zero(); -// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt countl_one_test(APInt x){ -// CHECK-NEXT: unsigned r_autocast = x.countl_one(); -// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt countr_one_test(APInt x){ -// CHECK-NEXT: unsigned r_autocast = x.countr_one(); -// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt neg_test(APInt x){ -// CHECK-NEXT: APInt r = ~x; -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt clear_sign_bit_test(APInt x){ -// CHECK-NEXT: APInt r = x; -// CHECK-NEXT: r.clearSignBit(); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt set_sign_bit_test(APInt x){ -// CHECK-NEXT: APInt r = x; -// CHECK-NEXT: r.setSignBit(); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt get_all_ones_test(APInt x){ -// CHECK-NEXT: APInt r = APInt::getAllOnes(x.getBitWidth()); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt get_signed_max_value_test(APInt x){ -// CHECK-NEXT: APInt r = APInt::getSignedMaxValue(x.getBitWidth()); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt get_signed_min_value_test(APInt x){ -// CHECK-NEXT: APInt r = APInt::getSignedMinValue(x.getBitWidth()); -// CHECK-NEXT: return r; -// CHECK-NEXT: } -// CHECK-NEXT: APInt reverse_bits_test(APInt x){ -// CHECK-NEXT: APInt r = x.reverseBits(); -// CHECK-NEXT: return r; -// CHECK-NEXT: } +// CHECK: const APInt get_bw_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.getBitWidth(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countl_zero_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countr_zero_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_zero(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countl_one_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countl_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt countr_one_test(const APInt &x){ +// CHECK-NEXT: unsigned r_autocast = x.countr_one(); +// CHECK-NEXT: APInt r(x.getBitWidth(),r_autocast); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt neg_test(const APInt &x){ +// CHECK-NEXT: const APInt r = ~x; +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt clear_sign_bit_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: r.clearSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt set_sign_bit_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x; +// CHECK-NEXT: r.setSignBit(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_all_ones_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getAllOnes(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_signed_max_value_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getSignedMaxValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt get_signed_min_value_test(const APInt &x){ +// CHECK-NEXT: const APInt r = APInt::getSignedMinValue(x.getBitWidth()); +// CHECK-NEXT: return r; +// CHECK-NEXT: } +// CHECK-NEXT: const APInt reverse_bits_test(const APInt &x){ +// CHECK-NEXT: const APInt r = x.reverseBits(); +// CHECK-NEXT: return r; +// CHECK-NEXT: } diff --git a/xdsl_smt/cli/cpp_translate.py b/xdsl_smt/cli/cpp_translate.py index 6b98b0e4..7f63f69f 100644 --- a/xdsl_smt/cli/cpp_translate.py +++ b/xdsl_smt/cli/cpp_translate.py @@ -10,7 +10,7 @@ from xdsl_smt.dialects.llvm_dialect import LLVM from xdsl_smt.dialects.transfer import Transfer -from xdsl_smt.passes.transfer_lower import LowerToCpp +from xdsl_smt.passes.transfer_lower import lower_to_cpp def _register_args() -> argparse.Namespace: @@ -31,12 +31,19 @@ def _register_args() -> argparse.Namespace: help="Path to the output MLIR file (defaults to stdout if omitted).", ) parser.add_argument( - "--apint", action="store_true", help="Use apints for bitvector type lowering" + "--apint", + action="store_true", + help="Use LLVM APInts for bitvector type lowering", ) parser.add_argument( "--custom_vec", action="store_true", - help="Use custom vec class for transfer value lowering", + help="Use custom vec class for abstract value lowering", + ) + parser.add_argument( + "--llvm_kb", + action="store_true", + help="Use LLVM KnownBits for abstract value lowering", ) return parser.parse_args() @@ -55,7 +62,9 @@ def _parse_mlir_module(p: Path | None, ctx: Context) -> ModuleOp: raise ValueError(f"mlir in '{fname}' is neither a ModuleOp, nor a FuncOp") -def _get_ctx() -> Context: +def main() -> None: + args = _register_args() + ctx = Context() ctx.load_dialect(Arith) ctx.load_dialect(Builtin) @@ -63,16 +72,16 @@ def _get_ctx() -> Context: ctx.load_dialect(Transfer) ctx.load_dialect(LLVM) - return ctx - - -def main() -> None: - args = _register_args() - - ctx = _get_ctx() - funcs = _parse_mlir_module(args.input, ctx) + module = _parse_mlir_module(args.input, ctx) output = args.output.open("w", encoding="utf-8") if args.output else sys.stdout - LowerToCpp(output, int_to_apint=args.apint, use_custom_vec=args.custom_vec).apply( - ctx, funcs + if args.custom_vec and args.llvm_kb: + raise ValueError("Cannot lower with both custom vectors and LLVM KnownBits") + + lower_to_cpp( + module, + output, + use_apint=args.apint, + use_custom_vec=args.custom_vec, + use_llvm_kb=args.llvm_kb, ) diff --git a/xdsl_smt/passes/transfer_lower.py b/xdsl_smt/passes/transfer_lower.py index bf6864cc..b7862647 100644 --- a/xdsl_smt/passes/transfer_lower.py +++ b/xdsl_smt/passes/transfer_lower.py @@ -1,11 +1,10 @@ from dataclasses import dataclass from typing import TextIO +import sys -from xdsl.context import Context from xdsl.dialects.builtin import ModuleOp from xdsl.dialects.func import FuncOp from xdsl.ir import Operation -from xdsl.passes import ModulePass from xdsl.pattern_rewriter import ( GreedyRewritePatternApplier, PatternRewriter, @@ -15,22 +14,17 @@ ) from ..utils.lower_utils import ( - CPP_CLASS_KEY, - INDUCTION_KEY, lowerOperation, - set_int_to_apint, + set_use_apint, set_use_custom_vec, + set_use_llvm_kb, ) autogen = 0 funcStr = "" -needDispatch: list[FuncOp] = [] -inductionOp: list[FuncOp] = [] def transfer_func(op: Operation, fout: TextIO): - global needDispatch - global inductionOp if isinstance(op, ModuleOp): return if len(op.results) > 0 and op.results[0].name_hint is None: @@ -42,10 +36,6 @@ def transfer_func(op: Operation, fout: TextIO): if arg.name_hint is None: arg.name_hint = "autogen" + str(autogen) autogen += 1 - if CPP_CLASS_KEY in op.attributes: - needDispatch.append(op) - if INDUCTION_KEY in op.attributes: - inductionOp.append(op) global funcStr funcStr += lowerOperation(op) parentOp = op.parent_op() @@ -65,25 +55,24 @@ def match_and_rewrite(self, op: Operation, _: PatternRewriter): transfer_func(op, self.fout) -@dataclass(frozen=True) -class LowerToCpp(ModulePass): - name = "trans_lower" - fout: TextIO - int_to_apint: bool = False - use_custom_vec: bool = False +def lower_to_cpp( + op: ModuleOp, + fout: TextIO = sys.stdout, + use_apint: bool = False, + use_custom_vec: bool = False, + use_llvm_kb: bool = False, +) -> None: + global autogen + autogen = 0 - def apply(self, ctx: Context, op: ModuleOp) -> None: - global autogen - autogen = 0 - set_int_to_apint(self.int_to_apint) - set_use_custom_vec(self.use_custom_vec) - # We found PatternRewriteWalker skipped the op itself during iteration - # Do it manually on op - transfer_func(op, self.fout) - walker = PatternRewriteWalker( - GreedyRewritePatternApplier([LowerOperation(self.fout)]), - walk_regions_first=False, - apply_recursively=False, - walk_reverse=False, - ) - walker.rewrite_module(op) + # set options + set_use_apint(use_apint) + set_use_custom_vec(use_custom_vec) + set_use_llvm_kb(use_llvm_kb) + + PatternRewriteWalker( + GreedyRewritePatternApplier([LowerOperation(fout)]), + walk_regions_first=False, + apply_recursively=False, + walk_reverse=False, + ).rewrite_module(op) diff --git a/xdsl_smt/utils/lower_utils.py b/xdsl_smt/utils/lower_utils.py index c1017458..7eebef6b 100644 --- a/xdsl_smt/utils/lower_utils.py +++ b/xdsl_smt/utils/lower_utils.py @@ -2,7 +2,7 @@ from typing import Callable import xdsl.dialects.arith as arith -from xdsl.dialects.builtin import IndexType, IntegerAttr, IntegerType +from xdsl.dialects.builtin import IndexType, IntegerType from xdsl.dialects.func import CallOp, FuncOp, ReturnOp from xdsl.ir import Attribute, Block, BlockArgument, Operation, SSAValue @@ -150,6 +150,10 @@ } # transfer.constRangeLoop and NextLoop are controller operations, should be handle specially +# consts +EQ = " = " +END = ";\n" +IDNT = "\t" VAL_EXCEEDS_BW = "{1}.uge({1}.getBitWidth())" RHS_IS_ZERO = "{1} == 0" @@ -178,34 +182,28 @@ SDivOp: [SDIV_ACTION0, SDIV_ACTION1, SDIV_ACTION2], } -unsignedReturnedType = { - CountLOneOp, - CountLZeroOp, - CountROneOp, - CountRZeroOp, - GetBitWidthOp, -} - -int_to_apint = False +# lowering config +use_apint = False use_custom_vec = False -EQ = " = " -END = ";\n" -IDNT = "\t" -CPP_CLASS_KEY = "CPPCLASS" -INDUCTION_KEY = "induction" -OPERATION_NO = "operationNo" +use_llvm_kb = False -def set_int_to_apint(to_apint: bool) -> None: - global int_to_apint - int_to_apint = to_apint +def set_use_apint(f: bool) -> None: + global use_apint + use_apint = f -def set_use_custom_vec(custom_vec: bool) -> None: +def set_use_custom_vec(f: bool) -> None: global use_custom_vec - use_custom_vec = custom_vec + use_custom_vec = f +def set_use_llvm_kb(f: bool) -> None: + global use_llvm_kb + use_llvm_kb = f + + +# helpers def get_ret_val(op: Operation) -> str: ret_val = op.results[0].name_hint assert ret_val @@ -229,132 +227,46 @@ def get_op_str(op: Operation) -> str: def lowerType(typ: Attribute, specialOp: Operation | Block | None = None) -> str: + unsigned_ret_type = { + CountLOneOp, + CountLZeroOp, + CountROneOp, + CountRZeroOp, + GetBitWidthOp, + } + if specialOp is not None: - for op in unsignedReturnedType: + for op in unsigned_ret_type: if isinstance(specialOp, op): return "unsigned" - if isinstance(typ, TransIntegerType): - return "APInt" + + if isinstance(typ, TransIntegerType) or ( + isinstance(typ, IntegerType) and use_apint + ): + return "const APInt" elif isinstance(typ, AbstractValueType) or isinstance(typ, TupleType): fields = typ.get_fields() typeName = lowerType(fields[0]) for i in range(1, len(fields)): assert lowerType(fields[i]) == typeName + if use_custom_vec: return "Vec<" + str(len(fields)) + ">" - return "std::vector<" + typeName + ">" - elif isinstance(typ, IntegerType): - return "int" if not int_to_apint else "APInt" - elif isinstance(typ, IndexType): - return "int" - assert False and "unsupported type" - - -def lowerInductionOps(inductionOp: list[FuncOp]) -> str: - if len(inductionOp) > 0: - functionSignature = """ -{returnedType} {funcName}(ArrayRef<{returnedType}> operands){{ - {returnedType} result={funcName}(operands[0], operands[1]); - for(int i=2;i str: - if len(needDispatch) > 0: - returnedType = needDispatch[0].function_type.outputs.data[0] - for func in needDispatch: - if func.function_type.outputs.data[0] != returnedType: - print(func) - print(func.function_type.outputs.data[0]) - assert ( - "we assume all transfer functions have the same returned type" - and False - ) - returnedType = lowerType(returnedType) - funcName = "naiveDispatcher" - # we assume all operands have the same type as expr - # User should tell the generator all operands - if is_forward: - expr = "(Operation* op, std::vector> operands)" + elif use_llvm_kb: + assert len(fields) == 2 + return "const llvm::KnownBits" else: - expr = "(Operation* op, std::vector> operands, unsigned operationNo)" - functionSignature = ( - "std::optional<" + returnedType + "> " + funcName + expr + "{{\n{0}}}\n\n" - ) + return "std::vector<" + typeName + ">" + elif isinstance(typ, IndexType) or isinstance(typ, IntegerType): + return "int" - dyn_cast = ( - IDNT - + "if(auto castedOp=dyn_cast<{0}>(op);castedOp&&{1}){{\n{2}" - + IDNT - + "}}\n" - ) - return_inst = IDNT + IDNT + "return {0}({1});\n" - - def handleOneTransferFunction(func: FuncOp, operationNo: int) -> str: - blockStr = "" - for cppClass in func.attributes[CPP_CLASS_KEY]: # type: ignore - argStr = "" - if INDUCTION_KEY in func.attributes: - argStr = "operands" - else: - if len(func.args) > 0: - argStr = "operands[0]" - for i in range(1, len(func.args)): - argStr += ", operands[" + str(i) + "]" - ifBody = return_inst.format(func.sym_name.data, argStr) - if operationNo == -1: - operationNoStr = "true" - else: - operationNoStr = "operationNo == " + str(operationNo) - blockStr += dyn_cast.format(cppClass.data, operationNoStr, ifBody) # type: ignore - return blockStr - - funcBody = "" - for func in needDispatch: - if is_forward: - funcBody += handleOneTransferFunction(func, -1) - else: - operationNo = func.attributes[OPERATION_NO] - assert isinstance(operationNo, IntegerAttr) - funcBody += handleOneTransferFunction(func, operationNo.value.data) - funcBody += IDNT + "return {};\n" - - return functionSignature.format(funcBody) - - return "" + raise ValueError(f"unsupported type: {type(typ)}") def isFunctionCall(opName: str) -> bool: return opName[0] == "." -def lowerToNonClassMethod(op: Operation) -> str: - ret_type = lowerType(op.results[0].type, op) - ret_val = get_ret_val(op) - expr = "(" - if len(op.operands) > 0: - expr += get_operand(op, 0) - for i in range(1, len(op.operands)): - expr += "," + get_operand(op, i) - expr += ")" - - return IDNT + ret_type + " " + ret_val + EQ + get_op_str(op) + expr + END - - def lowerToClassMethod( op: Operation, castOperand: Callable[[SSAValue | str], str] | None = None, @@ -492,22 +404,42 @@ def _(op: GetOp) -> str: returnedValue = get_ret_val(op) index = op.attributes["index"].value.data # type: ignore - return ( - IDNT - + returnedType - + " " - + returnedValue - + EQ - + get_operand(op, 0) - + get_op_str(op).format(index) # type: ignore - + END - ) + if use_llvm_kb: + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + get_operand(op, 0) + + (".Zero" if index == 0 else ".One") + + END + ) + + else: + return ( + IDNT + + returnedType + + " " + + returnedValue + + EQ + + get_operand(op, 0) + + get_op_str(op).format(index) # type: ignore + + END + ) @lowerOperation.register def _(op: MakeOp) -> str: - returnedType = lowerType(op.results[0].type, op) returnedValue = get_ret_val(op) + + if use_llvm_kb and isinstance(op.results[0].type, AbstractValueType): + s = f"{IDNT}llvm::KnownBits {returnedValue}{END}" + s += f"{IDNT}{returnedValue}.Zero = {get_operand(op, 0)}{END}" + s += f"{IDNT}{returnedValue}.One = {get_operand(op, 1)}{END}" + return s + + returnedType = lowerType(op.results[0].type, op) expr = "" if len(op.operands) > 0: expr += get_operand(op, 0) @@ -723,8 +655,18 @@ def set_clear_bits( @lowerOperation.register def _(op: FuncOp): def lowerArgs(arg: BlockArgument) -> str: + global use_apint assert arg.name_hint - return lowerType(arg.type) + " " + arg.name_hint + s = f"{lowerType(arg.type)} {arg.name_hint}" + if ( + isinstance(arg.type, AbstractValueType) + or isinstance(arg.type, TupleType) + or isinstance(arg.type, TransIntegerType) + or (isinstance(arg.type, IntegerType) and use_apint) + ): + s = f"{lowerType(arg.type)} &{arg.name_hint}" + + return s returnedType = lowerType(op.function_type.outputs.data[0]) funcName = op.sym_name.data From d426d3dfe892b01169239fb1159edcf982903028 Mon Sep 17 00:00:00 2001 From: Dominic Kennedy Date: Tue, 14 Oct 2025 12:47:09 -0600 Subject: [PATCH 7/7] fmt --- tests/filecheck/lower-to-cpp/special-ops.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/filecheck/lower-to-cpp/special-ops.mlir b/tests/filecheck/lower-to-cpp/special-ops.mlir index 9c983a48..a0680584 100644 --- a/tests/filecheck/lower-to-cpp/special-ops.mlir +++ b/tests/filecheck/lower-to-cpp/special-ops.mlir @@ -6,7 +6,7 @@ %r = "transfer.select"(%c, %x, %y) : (i1, !transfer.integer, !transfer.integer) -> !transfer.integer "func.return"(%r) : (!transfer.integer) -> () }) {"sym_name" = "select_test", "function_type" = (i1, !transfer.integer, !transfer.integer) -> !transfer.integer} : () -> () - + "func.func"() ({ ^0(%lhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>, %rhs : !transfer.abs_value<[!transfer.integer, !transfer.integer]>): %lhs0 = "transfer.get"(%lhs) {index = 0} : (!transfer.abs_value<[!transfer.integer, !transfer.integer]>) -> !transfer.integer