Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions xdsl_smt/semantics/transfer_semantics.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
from dataclasses import dataclass
from xdsl.pattern_rewriter import (
PatternRewriter,
)
from typing import Mapping, Sequence

from xdsl.pattern_rewriter import PatternRewriter
from xdsl_smt.dialects import smt_bitvector_dialect as smt_bv
from xdsl_smt.dialects import smt_dialect as smt
from xdsl_smt.dialects import transfer
from xdsl_smt.passes.lower_to_smt.smt_lowerer import (
SMTLowerer,
)
from xdsl_smt.passes.lower_to_smt.smt_lowerer import SMTLowerer
from xdsl_smt.dialects.smt_utils_dialect import (
AnyPairType,
PairType,
Expand All @@ -19,13 +16,10 @@
from xdsl_smt.dialects.smt_dialect import BoolType
from xdsl_smt.semantics.semantics import OperationSemantics, TypeSemantics
from xdsl.ir import Operation, SSAValue, Attribute
from typing import Mapping, Sequence
from xdsl.utils.hints import isa
from xdsl.dialects.builtin import IntegerAttr, IntegerType
from xdsl_smt.utils.transfer_to_smt_util import (
get_low_bits,
set_high_bits,
set_low_bits,
count_lzeros,
count_rzeros,
count_lones,
Expand All @@ -34,15 +28,21 @@
is_non_negative,
is_negative,
get_high_bits,
clear_high_bits,
clear_low_bits,
)


class AbstractValueTypeSemantics(TypeSemantics):
"""Lower all types in an abstract value to SMT types
But the last element is useless, this makes GetOp easier"""

def lower_type(self, ty: Attribute) -> Attribute:
"""
If the input type is already a smt type, skip lowering
"""
if ty.name.startswith("smt"):
return ty
return SMTLowerer.lower_type(ty)

def get_semantics(self, type: Attribute) -> Attribute:
assert isinstance(type, transfer.AbstractValueType) or isinstance(
type, transfer.TupleType
Expand Down Expand Up @@ -266,6 +266,7 @@ def get_semantics(
bv_res, ops = smt_bool_to_bv1(umul_overflow.res)

poison_op = smt.ConstantBoolOp(False)

res = PairOp(bv_res, poison_op.result)
rewriter.insert_op_before_matched_op([umul_overflow] + ops + [poison_op, res])
return ((res.res,), effect_state)
Expand Down Expand Up @@ -639,9 +640,27 @@ def get_semantics(
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
result = set_high_bits(operands[0], operands[1])
rewriter.insert_op_before_matched_op(result)
return ((result[-1].results[0],), effect_state)
arg = operands[0]
count = operands[1]
assert isinstance(bv_type := arg.type, smt_bv.BitVectorType)

const_bw = smt_bv.ConstantOp(bv_type.width, bv_type.width)
const_one = smt_bv.ConstantOp(1, bv_type.width)

umin = smt_bv.UltOp(count, const_bw.res)
clamped_count = smt.IteOp(umin.res, count, const_bw.res)

sub = smt_bv.SubOp(const_bw.res, clamped_count.res)
shl = smt_bv.ShlOp(const_one.res, clamped_count.res)
sub2 = smt_bv.SubOp(shl.res, const_one.res)
shl2 = smt_bv.ShlOp(sub2.res, sub.res)
or_op = smt_bv.OrOp(arg, shl2.res)

rewriter.insert_op_before_matched_op(
[const_bw, const_one, umin, clamped_count, sub, shl, sub2, shl2, or_op]
)

return ((or_op.res,), effect_state)


class SetLowBitsOpSemantics(OperationSemantics):
Expand All @@ -653,9 +672,19 @@ def get_semantics(
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
result = set_low_bits(operands[0], operands[1])
rewriter.insert_op_before_matched_op(result)
return ((result[-1].results[0],), effect_state)
arg = operands[0]
count = operands[1]
assert isinstance(bv_type := arg.type, smt_bv.BitVectorType)

const_one = smt_bv.ConstantOp(1, bv_type.width)

shl = smt_bv.ShlOp(const_one.res, count)
sub = smt_bv.SubOp(shl.res, const_one.res)
or_op = smt_bv.OrOp(arg, sub.res)

rewriter.insert_op_before_matched_op([const_one, shl, sub, or_op])

return ((or_op.res,), effect_state)


class SetSignBitOpSemantics(OperationSemantics):
Expand Down Expand Up @@ -692,7 +721,7 @@ def get_semantics(
operand_type = operand.type
assert isinstance(operand_type, smt_bv.BitVectorType)
width = operand_type.width.data
signed_max_value = smt_bv.ConstantOp(1 << (width - 1) - 1, width)
signed_max_value = smt_bv.ConstantOp((1 << (width - 1)) - 1, width)
and_op = smt_bv.AndOp(signed_max_value.res, operand)
result = [signed_max_value, and_op]

Expand Down Expand Up @@ -737,9 +766,27 @@ def get_semantics(
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
result = clear_high_bits(operands[0], operands[1])
rewriter.insert_op_before_matched_op(result)
return ((result[-1].results[0],), effect_state)
arg = operands[0]
count = operands[1]
assert isinstance(bv_type := arg.type, smt_bv.BitVectorType)

const_bw = smt_bv.ConstantOp(bv_type.width, bv_type.width)
one = smt_bv.ConstantOp(1, bv_type.width)

umin = smt_bv.UltOp(count, const_bw.res)
new_count = smt.IteOp(umin.res, count, const_bw.res)

# mask = (1 << (width - count)) - 1
sub = smt_bv.SubOp(const_bw.res, new_count.res)
shl = smt_bv.ShlOp(one.res, sub.res)
mask = smt_bv.SubOp(shl.res, one.res)
masked = smt_bv.AndOp(arg, mask.res)

rewriter.insert_op_before_matched_op(
[const_bw, one, umin, new_count, sub, shl, mask, masked]
)

return ((masked.res,), effect_state)


class ClearLowBitsOpSemantics(OperationSemantics):
Expand All @@ -751,9 +798,21 @@ def get_semantics(
effect_state: SSAValue | None,
rewriter: PatternRewriter,
) -> tuple[Sequence[SSAValue], SSAValue | None]:
result = clear_low_bits(operands[0], operands[1])
rewriter.insert_op_before_matched_op(result)
return ((result[-1].results[0],), effect_state)
arg = operands[0]
count = operands[1]
assert isinstance(bv_type := arg.type, smt_bv.BitVectorType)

const_one = smt_bv.ConstantOp(1, bv_type.width)

# mask = ~((1 << count) - 1)
shl = smt_bv.ShlOp(const_one.res, count)
sub = smt_bv.SubOp(shl.res, const_one.res)
not_mask = smt_bv.NotOp(sub.res)
masked = smt_bv.AndOp(arg, not_mask.res)

rewriter.insert_op_before_matched_op([const_one, shl, sub, not_mask, masked])

return ((masked.res,), effect_state)


class SMinOpSemantics(OperationSemantics):
Expand Down
38 changes: 0 additions & 38 deletions xdsl_smt/utils/transfer_to_smt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,6 @@ def get_high_bits_constant(high_bits: SSAValue) -> list[Operation]:
return result + get_bits_constant(result[-1].results[0], width_constant)


def clear_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]:
"""
clear_low_bits(x, low_bits) -> x & ~(get_low_bits_constant(low_bits))
"""
result = get_low_bits_constant(low_bits)
result.append(smt_bv.NotOp(result[-1].results[0]))
result.append(smt_bv.AndOp(result[-1].results[0], b))
return result


def clear_high_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]:
"""
clear_high_bits(x, high_bits) -> x & ~(get_high_bits_constant(high_bits))
"""
result = get_high_bits_constant(low_bits)
result.append(smt_bv.NotOp(result[-1].results[0]))
result.append(smt_bv.AndOp(result[-1].results[0], b))
return result


def get_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]:
"""
get_low_bits(x, low_bits) -> x & (get_low_bits_constant(low_bits))
Expand All @@ -111,24 +91,6 @@ def get_high_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]:
return result


def set_high_bits(b: SSAValue, high_bits: SSAValue) -> list[Operation]:
"""
set_high_bits(x, high_bits) -> x | (get_high_bits_constant(high_bits))
"""
result = get_high_bits_constant(high_bits)
result.append(smt_bv.OrOp(result[-1].results[0], b))
return result


def set_low_bits(b: SSAValue, low_bits: SSAValue) -> list[Operation]:
"""
set_low_bits(x, low_bits) -> x | (get_low_bits_constant(low_bits))
"""
result = get_low_bits_constant(low_bits)
result.append(smt_bv.OrOp(result[-1].results[0], b))
return result


def count_ones(b: SSAValue) -> list[Operation]:
assert isinstance(b.type, smt_bv.BitVectorType)
n = b.type.width.data
Expand Down
Loading