Skip to content

Commit 3053bb4

Browse files
committed
Add global, address_of, array_ref, load, store, ref<T>
1 parent 1bd199c commit 3053bb4

File tree

5 files changed

+337
-128
lines changed

5 files changed

+337
-128
lines changed

asl_xdsl/dialects/asl.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
from typing import ClassVar, TypeAlias
66

77
from xdsl.dialects import builtin
8+
from xdsl.dialects.builtin import (
9+
IntAttr,
10+
StringAttr,
11+
SymbolRefAttr,
12+
TupleType,
13+
)
814
from xdsl.dialects.utils import parse_func_op_like, print_func_op_like
915
from xdsl.ir import (
1016
Attribute,
@@ -18,8 +24,11 @@
1824
VerifyException,
1925
)
2026
from xdsl.irdl import (
27+
AnyAttr,
2128
BaseAttr,
29+
GenericAttrConstraint,
2230
IRDLOperation,
31+
ParamAttrConstraint,
2332
ParameterDef,
2433
VarConstraint,
2534
irdl_attr_definition,
@@ -37,9 +46,11 @@
3746
from xdsl.printer import Printer
3847
from xdsl.traits import (
3948
CallableOpInterface,
49+
ConstantLike,
4050
HasParent,
4151
IsolatedFromAbove,
4252
IsTerminator,
53+
Pure,
4354
SymbolOpInterface,
4455
SymbolTable,
4556
SymbolUserOpInterface,
@@ -264,6 +275,24 @@ def print_parameters(self, printer: Printer) -> None:
264275
printer.print_string("x")
265276
printer.print_attribute(self.element_type)
266277

278+
@classmethod
279+
def constr(
280+
cls,
281+
element_type: IRDLAttrConstraint | None = None,
282+
*,
283+
shape: IRDLGenericAttrConstraint[builtin.ArrayAttr[IntAttr]] | None = None,
284+
) -> GenericAttrConstraint[ArrayType]:
285+
if element_type is None and shape is None:
286+
return BaseAttr[ArrayType](ArrayType)
287+
shape_constr = AnyAttr() if shape is None else shape
288+
return ParamAttrConstraint[ArrayType](
289+
ArrayType,
290+
(
291+
shape_constr,
292+
element_type,
293+
),
294+
)
295+
267296

268297
@irdl_attr_definition
269298
class BitVectorAttr(ParametrizedAttribute):
@@ -1558,6 +1587,163 @@ def __init__(
15581587
)
15591588

15601589

1590+
@irdl_attr_definition
1591+
class ReferenceType(ParametrizedAttribute, TypeAttribute):
1592+
"""
1593+
The type of a reference to an object
1594+
"""
1595+
1596+
name = "asl.ref"
1597+
type: ParameterDef[Attribute]
1598+
1599+
def print_parameters(self, printer: Printer) -> None:
1600+
# We need this to pretty print a tuple and its members if
1601+
# this is referencing one, otherwise just let the type
1602+
# handle its own printing
1603+
printer.print("<")
1604+
printer.print(self.type)
1605+
printer.print(">")
1606+
1607+
@classmethod
1608+
def parse_parameters(cls, parser: AttrParser) -> list[Attribute]:
1609+
# This is complicated by the fact we need to parse tuple
1610+
# here also as the buildin dialect does not support this
1611+
# yet
1612+
parser.parse_characters("<")
1613+
has_tuple = parser.parse_optional_keyword("tuple")
1614+
if has_tuple is None:
1615+
param_type = parser.parse_type()
1616+
parser.parse_characters(">")
1617+
return [param_type]
1618+
else:
1619+
# If its a tuple then there are any number of types
1620+
def parse_types():
1621+
return parser.parse_type()
1622+
1623+
param_types = parser.parse_comma_separated_list(
1624+
parser.Delimiter.ANGLE, parse_types
1625+
)
1626+
parser.parse_characters(">")
1627+
return [TupleType(param_types)]
1628+
1629+
@classmethod
1630+
def constr(
1631+
cls,
1632+
type: IRDLAttrConstraint | None = None,
1633+
) -> GenericAttrConstraint[ReferenceType]:
1634+
if type is None:
1635+
return BaseAttr[ReferenceType](ReferenceType)
1636+
return ParamAttrConstraint[ReferenceType](
1637+
ReferenceType,
1638+
(type,),
1639+
)
1640+
1641+
1642+
@irdl_op_definition
1643+
class GlobalOp(IRDLOperation):
1644+
name = "asl.global"
1645+
1646+
assembly_format = "$sym_name `:` $global_type attr-dict"
1647+
1648+
global_type = prop_def(Attribute)
1649+
sym_name = prop_def(StringAttr)
1650+
1651+
traits = traits_def(SymbolOpInterface())
1652+
1653+
def __init__(
1654+
self,
1655+
global_type: Attribute,
1656+
sym_name: str | StringAttr,
1657+
):
1658+
if isinstance(sym_name, str):
1659+
sym_name = StringAttr(sym_name)
1660+
1661+
props: dict[str, Attribute] = {
1662+
"global_type": global_type,
1663+
"sym_name": sym_name,
1664+
}
1665+
1666+
super().__init__(properties=props)
1667+
1668+
1669+
@irdl_op_definition
1670+
class AddressOfOp(IRDLOperation):
1671+
"""
1672+
Convert a global reference to an SSA-value to be
1673+
used in other operations.
1674+
1675+
%p = asl.address_of @symbol : !asl.ref<i1>
1676+
"""
1677+
1678+
name = "asl.address_of"
1679+
1680+
traits = traits_def(
1681+
ConstantLike(),
1682+
Pure(),
1683+
)
1684+
1685+
symbol = prop_def(SymbolRefAttr)
1686+
res = result_def()
1687+
1688+
assembly_format = "`(` $symbol `)` `:` type($res) attr-dict"
1689+
1690+
assembly_format = "$symbol `:` type($res) attr-dict"
1691+
1692+
1693+
@irdl_op_definition
1694+
class ArrayRefOp(IRDLOperation):
1695+
"""
1696+
Create a ref for an array element.
1697+
1698+
%element_ref = asl.array_ref %array_ref [ %index ]
1699+
: !asl.ref<!asl.array<16 x i8>> -> !asl.ref<i8>
1700+
"""
1701+
1702+
name = "asl.array_ref"
1703+
1704+
traits = traits_def(
1705+
Pure(),
1706+
)
1707+
1708+
T: ClassVar = VarConstraint("T", AnyAttr())
1709+
A: ClassVar = VarConstraint("A", ArrayType.constr(T))
1710+
ref = operand_def(ReferenceType.constr(A))
1711+
index = operand_def(IntegerType())
1712+
res = result_def(ReferenceType.constr(T))
1713+
1714+
assembly_format = "$ref `[` $index `]` `:` type($ref) `->` type($res) attr-dict"
1715+
1716+
1717+
@irdl_op_definition
1718+
class LoadOp(IRDLOperation):
1719+
"""
1720+
Load from a reference.
1721+
"""
1722+
1723+
name = "asl.load"
1724+
1725+
T: ClassVar = VarConstraint("T", AnyAttr())
1726+
ref = operand_def(ReferenceType.constr(T))
1727+
res = result_def(T)
1728+
1729+
assembly_format = "`from` $ref `:` type($res) attr-dict"
1730+
1731+
1732+
@irdl_op_definition
1733+
class StoreOp(IRDLOperation):
1734+
"""
1735+
Store from a reference.
1736+
"""
1737+
1738+
name = "asl.store"
1739+
1740+
T: ClassVar = VarConstraint("T", AnyAttr())
1741+
ref = operand_def(ReferenceType.constr(T))
1742+
value = operand_def(T)
1743+
1744+
assembly_format = "$value `to` $ref `:` type($value) attr-dict"
1745+
1746+
15611747
ASLDialect = Dialect(
15621748
"asl",
15631749
[
@@ -1622,12 +1808,19 @@ def __init__(
16221808
# Slices
16231809
GetSliceOp,
16241810
SetSliceOp,
1811+
# References
1812+
GlobalOp,
1813+
AddressOfOp,
1814+
ArrayRefOp,
1815+
LoadOp,
1816+
StoreOp,
16251817
],
16261818
[
16271819
IntegerType,
16281820
BitVectorType,
16291821
BitVectorAttr,
16301822
StringType,
16311823
ArrayType,
1824+
ReferenceType,
16321825
],
16331826
)

asl_xdsl/tools/asl_opt.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ def register_all_targets(self):
3434

3535
def interpret_target(module: ModuleOp, output: IO[str]):
3636
from xdsl.interpreter import Interpreter
37-
from xdsl.interpreters import arith, scf
37+
from xdsl.interpreters import arith, cf, func
3838

3939
from asl_xdsl.interpreters.asl import ASLFunctions
4040

4141
interpreter = Interpreter(module, file=output)
4242
interpreter.register_implementations(ASLFunctions())
4343
interpreter.register_implementations(arith.ArithFunctions())
44-
interpreter.register_implementations(scf.ScfFunctions())
44+
interpreter.register_implementations(cf.CfFunctions())
45+
interpreter.register_implementations(func.FuncFunctions())
4546
op = interpreter.get_op_for_symbol("main.0")
4647
trait = op.get_trait(CallableOpInterface)
4748
assert trait is not None

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "asl-xdsl"
33
version = "0.0.0"
4-
dependencies = ["xdsl==0.40.0"]
4+
dependencies = ["xdsl==0.43.0"]
55
requires-python = ">=3.10"
66

77
[project.optional-dependencies]

tests/filecheck/dialects/asl/primitives.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,21 @@ builtin.module {
115115
asl.print_sintN_dec %sint1 : i8
116116
// CHECK: asl.print_sintN_hex %sint1 : i8 -> ()
117117
// CHECK-NEXT: asl.print_sintN_dec %sint1 : i8 -> ()
118+
119+
asl.global "G" : !asl.bits<32>
120+
asl.global "A" : !asl.array<16 x !asl.bits<32>>
121+
122+
%gref = asl.address_of @G : !asl.ref<!asl.bits<32>>
123+
%aref = asl.address_of @A : !asl.ref<!asl.array<16 x !asl.bits<32>>>
124+
%eref = asl.array_ref %aref[%int1] : !asl.ref<!asl.array<16 x !asl.bits<32>>> -> !asl.ref<!asl.bits<32>>
125+
asl.store %bits1 to %gref : !asl.bits<32>
126+
%load = asl.load from %eref : !asl.bits<32>
127+
// CHECK: asl.global "G" : !asl.bits<32>
128+
// CHECK-NEXT: asl.global "A" : !asl.array<16x!asl.bits<32>>
129+
// CHECK-NEXT: %gref = asl.address_of @G : !asl.ref<!asl.bits<32>>
130+
// CHECK-NEXT: %aref = asl.address_of @A : !asl.ref<!asl.array<16x!asl.bits<32>>>
131+
// CHECK-NEXT: %eref = asl.array_ref %aref[%int1] : !asl.ref<!asl.array<16x!asl.bits<32>>> -> !asl.ref<!asl.bits<32>>
132+
// CHECK-NEXT: asl.store %bits1 to %gref : !asl.bits<32>
133+
// CHECK-NEXT: %load = asl.load from %eref : !asl.bits<32>
134+
118135
}

0 commit comments

Comments
 (0)