diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index fc1ce453e7..5c0f0a85b3 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -2404,3 +2404,7 @@ extern "C" void dump_operation(Operation *op, const char *filename) { extern "C" bool pjrt_device_is_addressable(PjRtDevice *device) { return device->IsAddressable(); } + +extern "C" mlir::Operation *mlirGetParentOfTypeFunctionOp(mlir::Operation *op) { + return op->getParentOfType(); +} diff --git a/src/Ops.jl b/src/Ops.jl index 7c4ebbd2c6..0f13fa097a 100644 --- a/src/Ops.jl +++ b/src/Ops.jl @@ -126,6 +126,11 @@ end result_inference=false, ) + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cstop) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + res = MLIR.IR.result(cstop) tres = TracedRArray{T,N}((), res, size(x)) constants[value] = tres @@ -201,6 +206,12 @@ for (T, mlir_func) in ( splatattr = MLIR.API.$mlir_func(tt, number) cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + cst = MLIR.IR.result(cst_op) ta = TracedRArray{$T,length(shape)}((), cst, shape) return ta @@ -221,6 +232,12 @@ end tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) + + parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op) + if parent_func_op == C_NULL + error("Constant must be created inside a Function Op.") + end + cst = MLIR.IR.result(cst_op) ta = TracedRArray{T,length(shape)}((), cst, shape) return ta diff --git a/src/mlir/IR/Operation.jl b/src/mlir/IR/Operation.jl index 32f42b6838..c1768b44a7 100644 --- a/src/mlir/IR/Operation.jl +++ b/src/mlir/IR/Operation.jl @@ -68,6 +68,12 @@ Gets the operation that owns this operation, returning null if the operation is parent_op(operation::Operation) = Operation(API.mlirOperationGetParentOperation(operation), false) +""" + parent_region(op) +Gets the region that owns this operation. +""" +parent_region(operation::Operation) = parent_region(block(operation)) + """ rmfromparent!(op) @@ -331,8 +337,21 @@ function create_operation_common( end end +function create_operation_common_with_checks(args...; operands=nothing, kwargs...) + op = create_operation_common(args...; operands, kwargs...) + if !isnothing(operands) + parent_function_op = get_parent_of_type_function_op(op) + if parent_function_op != C_NULL + function_op_region = parent_region(parent_function_op) + operand_region = parent_region.(operands) + # TODO: add the checks + end + end + return op +end + function create_operation(args...; kwargs...) - res = create_operation_common(args...; kwargs...) + res = create_operation_common_with_checks(args...; kwargs...) if _has_block() push!(block(), res) end @@ -340,7 +359,17 @@ function create_operation(args...; kwargs...) end function create_operation_at_front(args...; kwargs...) - res = create_operation_common(args...; kwargs...) + res = create_operation_common_with_checks(args...; kwargs...) Base.pushfirst!(block(), res) return res end + +function get_parent_of_type_function_op(op::Operation) + GC.@preserve op begin + funcop = @ccall API.mlir_c.mlirGetParentOfTypeFunctionOp( + op::API.MlirOperation + )::API.MlirOperation + end + funcop.ptr == C_NULL && return C_NULL + return Operation(funcop, false) +end diff --git a/src/mlir/IR/Value.jl b/src/mlir/IR/Value.jl index a24632d934..38c877f763 100644 --- a/src/mlir/IR/Value.jl +++ b/src/mlir/IR/Value.jl @@ -121,3 +121,5 @@ function Base.show(io::IO, value::Value) API.mlirValuePrint(value, c_print_callback, ref) end end + +parent_region(value::Value) = parent_region(owner(value))