|
| 1 | +include("./Utils.jl") |
| 2 | + |
| 3 | +using MLIR |
| 4 | +using MLIR: IR, API, API.mlir_c |
| 5 | +using .Utils |
| 6 | + |
| 7 | +f(x, y) = 2*(x+y) |
| 8 | +ir, ret = @code_ircode f(2, 3) |
| 9 | + |
| 10 | + |
| 11 | +function registerAllDialects!(ctx) |
| 12 | + registry = MLIR.API.mlirDialectRegistryCreate() |
| 13 | + MLIR.API.mlirRegisterAllDialects(registry) |
| 14 | + handle = MLIR.API.mlirGetDialectHandle__jlir__() |
| 15 | + API.mlirDialectHandleInsertDialect(handle, registry) |
| 16 | + MLIR.API.mlirContextAppendDialectRegistry(ctx, registry) |
| 17 | + # MLIR.API.mlirDialectRegistryDestroy(registry) |
| 18 | + |
| 19 | + MLIR.API.mlirContextLoadAllAvailableDialects(ctx) |
| 20 | + return registry |
| 21 | +end |
| 22 | + |
| 23 | +ctx = API.mlirContextCreate() |
| 24 | +registry = registerAllDialects!(ctx) |
| 25 | + |
| 26 | +state = Ref(API.mlirOperationStateGet("func.func", API.mlirLocationUnknownGet(ctx))) |
| 27 | + |
| 28 | +function API.MlirType(ctx::API.MlirContext, t) |
| 29 | + return @ccall mlir_c.brutus_get_jlirtype(ctx::API.MlirContext, t::Any)::API.MlirType |
| 30 | +end |
| 31 | + |
| 32 | +argtypes = let |
| 33 | + argtypes = getfield(ir, :argtypes) |
| 34 | + API.MlirType.(Ref(ctx), argtypes) |
| 35 | +end |
| 36 | + |
| 37 | +reg = API.mlirRegionCreate() |
| 38 | +entry_block = API.mlirBlockCreate(length(argtypes), argtypes, [API.mlirLocationUnknownGet(ctx) for _ in enumerate(argtypes)]) |
| 39 | + |
| 40 | +API.mlirRegionAppendOwnedBlock(reg, entry_block) |
| 41 | +API.mlirOperationStateAddOwnedRegions(state, 1, [reg]) |
| 42 | + |
| 43 | +push!(block::API.MlirBlock, type::API.MlirType, loc::API.MlirLocation) = |
| 44 | + API.mlirBlockAddArgument(block, type, loc) |
| 45 | + |
| 46 | +input_types = API.mlirBlockGetArgument.(Ref(entry_block), eachindex(argtypes) .- 1) |
| 47 | + |
| 48 | +API.mlirBlockGetNumArguments(entry_block) |
| 49 | + |
| 50 | +push!(block::API.MlirBlock, op::API.MlirOperation) = |
| 51 | + API.mlirBlockAppendOwnedOperation(block, op) |
| 52 | + |
| 53 | +add_op = IR.create_operation("jlir.add_int", API.mlirLocationUnknownGet(ctx); operands=[API.mlirBlockGetArgument(entry_block, 1), API.mlirBlockGetArgument(entry_block, 2)]) # "jlir.add_int"(<<UNKNOWN SSA VALUE>>, <<UNKNOWN SSA VALUE>>) : (!jlir.Int64, !jlir.Int64) -> !jlir.Int64 |
| 54 | + |
| 55 | +named_val_attr = let |
| 56 | + val_attr = @ccall mlir_c.brutus_get_jlirattr(ctx::API.MlirContext, 2::Any)::API.MlirAttribute |
| 57 | + API.mlirNamedAttributeGet(API.mlirIdentifierGet(ctx, "value"), val_attr) |
| 58 | +end |
| 59 | + |
| 60 | +constant_op = IR.create_operation("jlir.constant", API.mlirLocationUnknownGet(ctx); attributes=[named_val_attr], results=[API.MlirType(ctx, Int)]) # "jlir.constant"() {value = #jlir<2>} : () -> !jlir.Int64 |
| 61 | + |
| 62 | +mul_op = IR.create_operation("jlir.mul_int", API.mlirLocationUnknownGet(ctx); operands=[API.mlirOperationGetResult(constant_op, 0), API.mlirOperationGetResult(add_op, 0)]) |
| 63 | + |
| 64 | +ret_op = IR.create_operation("func.return", API.mlirLocationUnknownGet(ctx); operands=[API.mlirBlockGetArgument(entry_block, 2)], result_inference = false) |
| 65 | + |
| 66 | +push!(entry_block, add_op) |
| 67 | +push!(entry_block, constant_op) |
| 68 | +push!(entry_block, mul_op) |
| 69 | +push!(entry_block, ret_op) |
| 70 | + |
| 71 | +named_type_attr = let |
| 72 | + function_type = API.mlirFunctionTypeGet( |
| 73 | + ctx, |
| 74 | + length(input_types), API.mlirValueGetType.(input_types), |
| 75 | + 1, [API.MlirType(ctx, ret)]) |
| 76 | + |
| 77 | + type_attr = API.mlirTypeAttrGet(function_type) |
| 78 | + |
| 79 | + API.mlirNamedAttributeGet(API.mlirIdentifierGet(ctx, "function_type"), type_attr) |
| 80 | +end |
| 81 | + |
| 82 | +named_symbol_name_attr = let |
| 83 | + name = "f" |
| 84 | + |
| 85 | + symbol_name_attr = API.mlirStringAttrGet(ctx, name) |
| 86 | + |
| 87 | + API.mlirNamedAttributeGet(API.mlirIdentifierGet(ctx, "sym_name"), symbol_name_attr) |
| 88 | +end |
| 89 | + |
| 90 | +named_viz_attr = let |
| 91 | + viz_attr = API.mlirStringAttrGet(ctx, "nested") |
| 92 | + |
| 93 | + API.mlirNamedAttributeGet(API.mlirIdentifierGet(ctx, "sym_visibility"), viz_attr) |
| 94 | +end |
| 95 | + |
| 96 | +named_unit_attr = let |
| 97 | + unit_attr = API.mlirUnitAttrGet(ctx) |
| 98 | + |
| 99 | + API.mlirNamedAttributeGet(API.mlirIdentifierGet(ctx, "llvm.emit_c_interface"), unit_attr) |
| 100 | +end |
| 101 | + |
| 102 | +function push!(state::Base.RefValue{MLIR.API.MlirOperationState}, attr::IR.MlirNamedAttribute) |
| 103 | + API.mlirOperationStateAddAttributes(state, 1, Ref(attr)) |
| 104 | +end |
| 105 | + |
| 106 | +push!(state, named_type_attr) |
| 107 | +push!(state, named_symbol_name_attr) |
| 108 | +push!(state, named_viz_attr) |
| 109 | +push!(state, named_unit_attr) |
| 110 | + |
| 111 | +op = API.mlirOperationCreate(state) |
| 112 | + |
| 113 | +API.mlirOperationVerify(op) |
| 114 | + |
| 115 | +function print_callback(str::API.MlirStringRef, userdata) |
| 116 | + data = unsafe_wrap(Array, Base.convert(Ptr{Cchar}, str.data), str.length; own=false) |
| 117 | + write(userdata isa Base.RefValue ? userdata[] : userdata, data) |
| 118 | + return Cvoid() |
| 119 | +end |
| 120 | + |
| 121 | +function Base.show(io::IO, operation::API.MlirOperation) |
| 122 | + c_print_callback = @cfunction(print_callback, Cvoid, (API.MlirStringRef, Any)) |
| 123 | + ref = Ref(io) |
| 124 | + flags = API.mlirOpPrintingFlagsCreate() |
| 125 | + get(io, :debug, false) && API.mlirOpPrintingFlagsEnableDebugInfo(flags, true, true) |
| 126 | + API.mlirOperationPrintWithFlags(operation, flags, c_print_callback, ref) |
| 127 | + println(io) |
| 128 | +end |
| 129 | + |
| 130 | +@show op |
0 commit comments