|
5 | 5 | from typing import ClassVar, TypeAlias |
6 | 6 |
|
7 | 7 | from xdsl.dialects import builtin |
| 8 | +from xdsl.dialects.builtin import ( |
| 9 | + IntAttr, |
| 10 | + StringAttr, |
| 11 | + SymbolRefAttr, |
| 12 | + TupleType, |
| 13 | +) |
8 | 14 | from xdsl.dialects.utils import parse_func_op_like, print_func_op_like |
9 | 15 | from xdsl.ir import ( |
10 | 16 | Attribute, |
|
18 | 24 | VerifyException, |
19 | 25 | ) |
20 | 26 | from xdsl.irdl import ( |
| 27 | + AnyAttr, |
21 | 28 | BaseAttr, |
| 29 | + GenericAttrConstraint, |
22 | 30 | IRDLOperation, |
| 31 | + ParamAttrConstraint, |
23 | 32 | ParameterDef, |
24 | 33 | VarConstraint, |
25 | 34 | irdl_attr_definition, |
|
37 | 46 | from xdsl.printer import Printer |
38 | 47 | from xdsl.traits import ( |
39 | 48 | CallableOpInterface, |
| 49 | + ConstantLike, |
40 | 50 | HasParent, |
41 | 51 | IsolatedFromAbove, |
42 | 52 | IsTerminator, |
| 53 | + Pure, |
43 | 54 | SymbolOpInterface, |
44 | 55 | SymbolTable, |
45 | 56 | SymbolUserOpInterface, |
@@ -264,6 +275,24 @@ def print_parameters(self, printer: Printer) -> None: |
264 | 275 | printer.print_string("x") |
265 | 276 | printer.print_attribute(self.element_type) |
266 | 277 |
|
| 278 | + @classmethod |
| 279 | + def constr( |
| 280 | + cls, |
| 281 | + element_type: IRDLAttrConstraint | None = None, |
| 282 | + *, |
| 283 | + shape: IRDLGenericAttrConstraint[builtin.ArrayAttr[IntAttr]] | None = None, |
| 284 | + ) -> GenericAttrConstraint[ArrayType]: |
| 285 | + if element_type is None and shape is None: |
| 286 | + return BaseAttr[ArrayType](ArrayType) |
| 287 | + shape_constr = AnyAttr() if shape is None else shape |
| 288 | + return ParamAttrConstraint[ArrayType]( |
| 289 | + ArrayType, |
| 290 | + ( |
| 291 | + shape_constr, |
| 292 | + element_type, |
| 293 | + ), |
| 294 | + ) |
| 295 | + |
267 | 296 |
|
268 | 297 | @irdl_attr_definition |
269 | 298 | class BitVectorAttr(ParametrizedAttribute): |
@@ -1558,6 +1587,163 @@ def __init__( |
1558 | 1587 | ) |
1559 | 1588 |
|
1560 | 1589 |
|
| 1590 | +@irdl_attr_definition |
| 1591 | +class ReferenceType(ParametrizedAttribute, TypeAttribute): |
| 1592 | + """ |
| 1593 | + The type of a reference to an object |
| 1594 | + """ |
| 1595 | + |
| 1596 | + name = "asl.ref" |
| 1597 | + type: ParameterDef[Attribute] |
| 1598 | + |
| 1599 | + def print_parameters(self, printer: Printer) -> None: |
| 1600 | + # We need this to pretty print a tuple and its members if |
| 1601 | + # this is referencing one, otherwise just let the type |
| 1602 | + # handle its own printing |
| 1603 | + printer.print("<") |
| 1604 | + printer.print(self.type) |
| 1605 | + printer.print(">") |
| 1606 | + |
| 1607 | + @classmethod |
| 1608 | + def parse_parameters(cls, parser: AttrParser) -> list[Attribute]: |
| 1609 | + # This is complicated by the fact we need to parse tuple |
| 1610 | + # here also as the buildin dialect does not support this |
| 1611 | + # yet |
| 1612 | + parser.parse_characters("<") |
| 1613 | + has_tuple = parser.parse_optional_keyword("tuple") |
| 1614 | + if has_tuple is None: |
| 1615 | + param_type = parser.parse_type() |
| 1616 | + parser.parse_characters(">") |
| 1617 | + return [param_type] |
| 1618 | + else: |
| 1619 | + # If its a tuple then there are any number of types |
| 1620 | + def parse_types(): |
| 1621 | + return parser.parse_type() |
| 1622 | + |
| 1623 | + param_types = parser.parse_comma_separated_list( |
| 1624 | + parser.Delimiter.ANGLE, parse_types |
| 1625 | + ) |
| 1626 | + parser.parse_characters(">") |
| 1627 | + return [TupleType(param_types)] |
| 1628 | + |
| 1629 | + @classmethod |
| 1630 | + def constr( |
| 1631 | + cls, |
| 1632 | + type: IRDLAttrConstraint | None = None, |
| 1633 | + ) -> GenericAttrConstraint[ReferenceType]: |
| 1634 | + if type is None: |
| 1635 | + return BaseAttr[ReferenceType](ReferenceType) |
| 1636 | + return ParamAttrConstraint[ReferenceType]( |
| 1637 | + ReferenceType, |
| 1638 | + (type,), |
| 1639 | + ) |
| 1640 | + |
| 1641 | + |
| 1642 | +@irdl_op_definition |
| 1643 | +class GlobalOp(IRDLOperation): |
| 1644 | + name = "asl.global" |
| 1645 | + |
| 1646 | + assembly_format = "$sym_name `:` $global_type attr-dict" |
| 1647 | + |
| 1648 | + global_type = prop_def(Attribute) |
| 1649 | + sym_name = prop_def(StringAttr) |
| 1650 | + |
| 1651 | + traits = traits_def(SymbolOpInterface()) |
| 1652 | + |
| 1653 | + def __init__( |
| 1654 | + self, |
| 1655 | + global_type: Attribute, |
| 1656 | + sym_name: str | StringAttr, |
| 1657 | + ): |
| 1658 | + if isinstance(sym_name, str): |
| 1659 | + sym_name = StringAttr(sym_name) |
| 1660 | + |
| 1661 | + props: dict[str, Attribute] = { |
| 1662 | + "global_type": global_type, |
| 1663 | + "sym_name": sym_name, |
| 1664 | + } |
| 1665 | + |
| 1666 | + super().__init__(properties=props) |
| 1667 | + |
| 1668 | + |
| 1669 | +@irdl_op_definition |
| 1670 | +class AddressOfOp(IRDLOperation): |
| 1671 | + """ |
| 1672 | + Convert a global reference to an SSA-value to be |
| 1673 | + used in other operations. |
| 1674 | +
|
| 1675 | + %p = asl.address_of @symbol : !asl.ref<i1> |
| 1676 | + """ |
| 1677 | + |
| 1678 | + name = "asl.address_of" |
| 1679 | + |
| 1680 | + traits = traits_def( |
| 1681 | + ConstantLike(), |
| 1682 | + Pure(), |
| 1683 | + ) |
| 1684 | + |
| 1685 | + symbol = prop_def(SymbolRefAttr) |
| 1686 | + res = result_def() |
| 1687 | + |
| 1688 | + assembly_format = "`(` $symbol `)` `:` type($res) attr-dict" |
| 1689 | + |
| 1690 | + assembly_format = "$symbol `:` type($res) attr-dict" |
| 1691 | + |
| 1692 | + |
| 1693 | +@irdl_op_definition |
| 1694 | +class ArrayRefOp(IRDLOperation): |
| 1695 | + """ |
| 1696 | + Create a ref for an array element. |
| 1697 | +
|
| 1698 | + %element_ref = asl.array_ref %array_ref [ %index ] |
| 1699 | + : !asl.ref<!asl.array<16 x i8>> -> !asl.ref<i8> |
| 1700 | + """ |
| 1701 | + |
| 1702 | + name = "asl.array_ref" |
| 1703 | + |
| 1704 | + traits = traits_def( |
| 1705 | + Pure(), |
| 1706 | + ) |
| 1707 | + |
| 1708 | + T: ClassVar = VarConstraint("T", AnyAttr()) |
| 1709 | + A: ClassVar = VarConstraint("A", ArrayType.constr(T)) |
| 1710 | + ref = operand_def(ReferenceType.constr(A)) |
| 1711 | + index = operand_def(IntegerType()) |
| 1712 | + res = result_def(ReferenceType.constr(T)) |
| 1713 | + |
| 1714 | + assembly_format = "$ref `[` $index `]` `:` type($ref) `->` type($res) attr-dict" |
| 1715 | + |
| 1716 | + |
| 1717 | +@irdl_op_definition |
| 1718 | +class LoadOp(IRDLOperation): |
| 1719 | + """ |
| 1720 | + Load from a reference. |
| 1721 | + """ |
| 1722 | + |
| 1723 | + name = "asl.load" |
| 1724 | + |
| 1725 | + T: ClassVar = VarConstraint("T", AnyAttr()) |
| 1726 | + ref = operand_def(ReferenceType.constr(T)) |
| 1727 | + res = result_def(T) |
| 1728 | + |
| 1729 | + assembly_format = "`from` $ref `:` type($res) attr-dict" |
| 1730 | + |
| 1731 | + |
| 1732 | +@irdl_op_definition |
| 1733 | +class StoreOp(IRDLOperation): |
| 1734 | + """ |
| 1735 | + Store from a reference. |
| 1736 | + """ |
| 1737 | + |
| 1738 | + name = "asl.store" |
| 1739 | + |
| 1740 | + T: ClassVar = VarConstraint("T", AnyAttr()) |
| 1741 | + ref = operand_def(ReferenceType.constr(T)) |
| 1742 | + value = operand_def(T) |
| 1743 | + |
| 1744 | + assembly_format = "$value `to` $ref `:` type($value) attr-dict" |
| 1745 | + |
| 1746 | + |
1561 | 1747 | ASLDialect = Dialect( |
1562 | 1748 | "asl", |
1563 | 1749 | [ |
@@ -1622,12 +1808,19 @@ def __init__( |
1622 | 1808 | # Slices |
1623 | 1809 | GetSliceOp, |
1624 | 1810 | SetSliceOp, |
| 1811 | + # References |
| 1812 | + GlobalOp, |
| 1813 | + AddressOfOp, |
| 1814 | + ArrayRefOp, |
| 1815 | + LoadOp, |
| 1816 | + StoreOp, |
1625 | 1817 | ], |
1626 | 1818 | [ |
1627 | 1819 | IntegerType, |
1628 | 1820 | BitVectorType, |
1629 | 1821 | BitVectorAttr, |
1630 | 1822 | StringType, |
1631 | 1823 | ArrayType, |
| 1824 | + ReferenceType, |
1632 | 1825 | ], |
1633 | 1826 | ) |
0 commit comments