Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: early fail if not correct region #1008

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2404,3 +2404,7 @@ extern "C" void dump_operation(Operation *op, const char *filename) {
extern "C" bool pjrt_device_is_addressable(PjRtDevice *device) {
return device->IsAddressable();
}

extern "C" mlir::Operation *mlirGetParentOfTypeFunctionOp(mlir::Operation *op) {
return op->getParentOfType<mlir::FunctionOpInterface>();
}
17 changes: 17 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ end
result_inference=false,
)

parent_func_op = MLIR.IR.get_parent_of_type_function_op(cstop)
if parent_func_op == C_NULL
error("Constant must be created inside a Function Op.")
end

res = MLIR.IR.result(cstop)
tres = TracedRArray{T,N}((), res, size(x))
constants[value] = tres
Expand Down Expand Up @@ -201,6 +206,12 @@ for (T, mlir_func) in (

splatattr = MLIR.API.$mlir_func(tt, number)
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)

parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
if parent_func_op == C_NULL
error("Constant must be created inside a Function Op.")
end

cst = MLIR.IR.result(cst_op)
ta = TracedRArray{$T,length(shape)}((), cst, shape)
return ta
Expand All @@ -221,6 +232,12 @@ end
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)

parent_func_op = MLIR.IR.get_parent_of_type_function_op(cst_op)
if parent_func_op == C_NULL
error("Constant must be created inside a Function Op.")
end

cst = MLIR.IR.result(cst_op)
ta = TracedRArray{T,length(shape)}((), cst, shape)
return ta
Expand Down
33 changes: 31 additions & 2 deletions src/mlir/IR/Operation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,12 @@ Gets the operation that owns this operation, returning null if the operation is
parent_op(operation::Operation) =
Operation(API.mlirOperationGetParentOperation(operation), false)

"""
parent_region(op)
Gets the region that owns this operation.
"""
parent_region(operation::Operation) = parent_region(block(operation))

"""
rmfromparent!(op)

Expand Down Expand Up @@ -331,16 +337,39 @@ function create_operation_common(
end
end

function create_operation_common_with_checks(args...; operands=nothing, kwargs...)
op = create_operation_common(args...; operands, kwargs...)
if !isnothing(operands)
parent_function_op = get_parent_of_type_function_op(op)
if parent_function_op != C_NULL
function_op_region = parent_region(parent_function_op)
operand_region = parent_region.(operands)
# TODO: add the checks
end
end
return op
end

function create_operation(args...; kwargs...)
res = create_operation_common(args...; kwargs...)
res = create_operation_common_with_checks(args...; kwargs...)
if _has_block()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

technically speaking the actual issue here is actually on like 349, where we're accidentally creating a stablehlo.constant operation, and pushing it into a module. It's probably better to figure out we're creating it at the wrong place, rather than figure out during the use stage that we're using an operation which is itself in the wrong place

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had an older issue where a concreteRNG state wasn't being reset, so the region was completely different. We can't capture that without adding checks here

push!(block(), res)
end
return res
end

function create_operation_at_front(args...; kwargs...)
res = create_operation_common(args...; kwargs...)
res = create_operation_common_with_checks(args...; kwargs...)
Base.pushfirst!(block(), res)
return res
end

function get_parent_of_type_function_op(op::Operation)
GC.@preserve op begin
funcop = @ccall API.mlir_c.mlirGetParentOfTypeFunctionOp(
op::API.MlirOperation
)::API.MlirOperation
end
funcop.ptr == C_NULL && return C_NULL
return Operation(funcop, false)
end
2 changes: 2 additions & 0 deletions src/mlir/IR/Value.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,3 +121,5 @@ function Base.show(io::IO, value::Value)
API.mlirValuePrint(value, c_print_callback, ref)
end
end

parent_region(value::Value) = parent_region(owner(value))
Loading