From f0c02acbc4a6c3ce566d3131ad9b1bd2f3638910 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 28 Jan 2025 23:15:04 -0500 Subject: [PATCH] add some more helpers --- .../eudsl-llvmpy/eudsl-llvmpy-generate.py | 17 +++--- .../eudsl-tblgen/src/eudsl_tblgen/__init__.py | 52 ++++++++++++++++++- .../eudsl-tblgen/src/eudsl_tblgen_ext.cpp | 15 ++++-- .../tests/td/CommonTypeConstraints.td | 4 ++ projects/eudsl-tblgen/tests/test_bindings.py | 24 +++++++-- 5 files changed, 96 insertions(+), 16 deletions(-) diff --git a/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py b/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py index 8921c234..f99b3312 100644 --- a/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py +++ b/projects/eudsl-llvmpy/eudsl-llvmpy-generate.py @@ -287,14 +287,15 @@ class LLVMMatchType(Generic[_T]): ): arg_types = [] ret_types = [] - for p in intr.get_values().ParamTypes.value: - p_s = p.as_string + for p in intr.get_values().ParamTypes.get_value(): + p_s = p.get_as_string() if p_s.startswith("anon"): - p_s = p.type.as_string + p_s = p.get_type().get_as_string() + pdv = p.get_def().get_values() if p_s == "LLVMMatchType": - p_s += f"[{p.def_.values.Number.value.value}]" + p_s += f"[{pdv.Number.get_value()}]" elif p_s == "LLVMQualPointerType": - _, addr_space = p.def_.values.Sig.value.values + kind, addr_space = pdv.Sig.get_value() p_s += f"[{addr_space}]" else: raise NotImplemented(f"unsupported {p_s=}") @@ -307,8 +308,8 @@ class LLVMMatchType(Generic[_T]): p_s = "pointer" arg_types.append(p_s) - for p in intr.get_values().RetTypes.value: - ret_types.append(p.as_string) + for p in intr.get_values().RetTypes.get_value(): + ret_types.append(p.get_as_string()) ret_str = "" if len(ret_types): @@ -383,5 +384,5 @@ def generate_nb_bindings(header_root: Path, output_root: Path): parser.add_argument("llvmpy_module_dir", type=Path) args = parser.parse_args() - generate_nb_bindings(args.llvm_include_root / "llvm-c", args.output_root) + # generate_nb_bindings(args.llvm_include_root / "llvm-c", args.output_root) generate_amdgcn_intrinsics(args.llvm_include_root, args.llvmpy_module_dir) diff --git a/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py b/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py index bc7580fe..c8aa205c 100644 --- a/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py +++ b/projects/eudsl-tblgen/src/eudsl_tblgen/__init__.py @@ -2,6 +2,7 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Copyright (c) 2024. +from typing import List, Optional from .eudsl_tblgen_ext import * @@ -39,7 +40,56 @@ def get_requested_op_definitions(records, op_inc_filter=None, op_exc_filter=None # Unless there is an exclude filter and it matches. if op_exc_filter and exclude_regex.match(get_operation_name(def_record)): continue - def_record.dump() defs.append(def_record) return defs + + +def collect_all_defs( + record_keeper: RecordKeeper, + selected_dialect: Optional[str] = None, +) -> List[AttrOrTypeDef]: + records = record_keeper.get_defs() + records = [records[d] for d in records] + # Nothing to do if no defs were found. + if not records: + return [] + + defs = [ + AttrOrTypeDef(rec) + for rec in records + if rec.get_value("builders") and rec.get_value("parameters") + ] + result_defs = [] + + if not selected_dialect: + # If a dialect was not specified, ensure that all found defs belong to the same dialect. + dialects = {definition.get_dialect().get_name() for definition in defs} + if len(dialects) > 1: + raise RuntimeError( + "Defs belong to more than one dialect. Must select one via '--(attr|type)defs-dialect'" + ) + result_defs.extend(defs) + else: + # Otherwise, generate the defs that belong to the selected dialect. + dialect_defs = [ + definition + for definition in defs + if definition.get_dialect().get_name() == selected_dialect + ] + result_defs.extend(dialect_defs) + + return result_defs + + +def get_all_type_constraints(records: RecordKeeper) -> List[Constraint]: + result = [] + for record in records.get_all_derived_definitions_if_defined("TypeConstraint"): + # Ignore constraints defined outside of the top-level file. + constr = Constraint(record) + # Generate C++ function only if "cppFunctionName" is set. + if not constr.get_cpp_function_name(): + continue + result.append(constr) + return result + diff --git a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp index 06cd88b8..9e5354d9 100644 --- a/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp +++ b/projects/eudsl-tblgen/src/eudsl_tblgen_ext.cpp @@ -918,7 +918,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { nb::rv_policy::reference_internal) .def("get_name_init_as_string", &llvm::Record::getNameInitAsString) .def("set_name", &llvm::Record::setName, "name"_a) - .def("get_loc", &llvm::Record::getLoc) + .def("get_loc", eudsl::coerceReturn>( + &llvm::Record::getLoc, nb::const_)) .def("append_loc", &llvm::Record::appendLoc, "loc"_a) .def("get_forward_declaration_locs", &llvm::Record::getForwardDeclarationLocs) @@ -1088,8 +1089,8 @@ NB_MODULE(eudsl_tblgen_ext, m) { const std::vector ¯oNames, bool noWarnOnUnusedTemplateArgs) { llvm::ErrorOr> fileOrErr = - llvm::MemoryBuffer::getFileOrSTDIN(inputFilename, - /*IsText=*/true); + llvm::MemoryBuffer::getFile(inputFilename, + /*IsText=*/true); if (std::error_code EC = fileOrErr.getError()) throw std::runtime_error("Could not open input file '" + inputFilename + "': " + EC.message() + @@ -1145,6 +1146,13 @@ NB_MODULE(eudsl_tblgen_ext, m) { -> std::vector { return self.getAllDerivedDefinitions(className); }, + "class_name"_a, nb::rv_policy::reference_internal) + .def( + "get_all_derived_definitions_if_defined", + [](llvm::RecordKeeper &self, const std::string &className) + -> std::vector { + return self.getAllDerivedDefinitionsIfDefined(className); + }, "class_name"_a, nb::rv_policy::reference_internal); nb::class_(m, "raw_ostream"); @@ -1239,6 +1247,7 @@ NB_MODULE(eudsl_tblgen_ext, m) { .def("get_kind", &mlir::tblgen::Constraint::getKind) .def("get_def", &mlir::tblgen::Constraint::getDef, nb::rv_policy::reference_internal); + nb::enum_(mlir_tblgen_Constraint, "Kind") .value("CK_Attr", mlir::tblgen::Constraint::CK_Attr) .value("CK_Region", mlir::tblgen::Constraint::CK_Region) diff --git a/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td b/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td index cd90e377..64546088 100644 --- a/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td +++ b/projects/eudsl-tblgen/tests/td/CommonTypeConstraints.td @@ -918,4 +918,8 @@ def SignlessIntegerOrFloatLike : TypeConstraint, "signless-integer-like or floating-point-like">; +def DummyConstraint : AnyTypeOf<[AnyInteger, Index, AnyFloat]> { + let cppFunctionName = "isValidDummy"; +} + #endif // COMMON_TYPE_CONSTRAINTS_TD diff --git a/projects/eudsl-tblgen/tests/test_bindings.py b/projects/eudsl-tblgen/tests/test_bindings.py index ace54101..7f3c688a 100644 --- a/projects/eudsl-tblgen/tests/test_bindings.py +++ b/projects/eudsl-tblgen/tests/test_bindings.py @@ -6,7 +6,12 @@ from pathlib import Path import pytest -from eudsl_tblgen import RecordKeeper, get_requested_op_definitions +from eudsl_tblgen import ( + RecordKeeper, + get_requested_op_definitions, + get_all_type_constraints, + collect_all_defs, +) @pytest.fixture(scope="function") @@ -174,7 +179,7 @@ def test_init_complex(record_keeper_test_dialect): assert ( repr(op.get_values()) - == "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_347), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)" + == "RecordValues(opDialect=Test_Dialect, opName=types, cppNamespace=test, summary=, description=, opDocGroup=?, arguments=(ins I32:$a, SI64:$b, UI8:$c, Index:$d, F32:$e, NoneType:$f, anonymous_348), results=(outs), regions=(region), successors=(successor), builders=?, skipDefaultBuilders=0, assemblyFormat=?, hasCustomAssemblyFormat=0, hasVerifier=0, hasRegionVerifier=0, hasCanonicalizer=0, hasCanonicalizeMethod=0, hasFolder=0, useCustomPropertiesEncoding=0, traits=[], extraClassDeclaration=?, extraClassDefinition=?)" ) arguments = op.get_values().arguments @@ -193,7 +198,7 @@ def test_init_complex(record_keeper_test_dialect): assert str(arguments.get_value()[5]) == "NoneType" attr = record_keeper_test_dialect.get_defs()["Test_TestAttr"] - assert str(attr.get_values().predicate) == "anonymous_334" + assert str(attr.get_values().predicate) == "anonymous_335" assert str(attr.get_values().storageType) == "test::TestAttr" assert str(attr.get_values().returnType) == "test::TestAttr" assert ( @@ -228,4 +233,15 @@ def test_init_complex(record_keeper_test_dialect): def test_mlir_tblgen(record_keeper_test_dialect): for op in get_requested_op_definitions(record_keeper_test_dialect): - op.dump() + print(op.get_name()) + for constraint in get_all_type_constraints(record_keeper_test_dialect): + print(constraint.get_def_name()) + print(constraint.get_summary()) + + all_defs = collect_all_defs(record_keeper_test_dialect) + for d in all_defs: + print(d.get_name()) + + all_defs = collect_all_defs(record_keeper_test_dialect, "test") + for d in all_defs: + print(d.get_name())