Skip to content

Fine-grained number tracing #1051

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 22 additions & 14 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,19 +129,23 @@ end
```
"""
macro trace(args...)
track_numbers = true
expr = first(args)
if length(args) > 1 && Meta.isexpr(args[1], :(=))
tn_expr = args[1]
tn_expr.args[1] == :track_numbers ||
error("@trace supports setting track_numbers, but got $(tn_expr)")

track_numbers = tn_expr.args[2]
expr = only(args[2:end])
else
expr = only(args)
options = Dict([:track_numbers => false, :include_paths => :([])])
while length(args) > 1
kwarg, args = first(args), args[2:end]
if !Meta.isexpr(kwarg, :(=))
error("Expected keyword argument but got $(kwarg)")
end
option, value = kwarg.args
if !haskey(options, option)
error("Unknown keyword argument $(option), expected one of $(keys(options))")
else
options[option] = value
end
end
track_numbers = track_numbers ? Number : Union{}
expr = only(args)
track_numbers = options[:track_numbers] ? Number : Union{}
include_paths_expr = options[:include_paths]
expr = macroexpand(__module__, expr)

if Meta.isexpr(expr, :(=))
Expand All @@ -157,11 +161,12 @@ macro trace(args...)
return esc(trace_call(__module__, call))
end
Meta.isexpr(expr, :if) && return esc(trace_if(__module__, expr; track_numbers))
Meta.isexpr(expr, :for) && return (esc(trace_for(__module__, expr; track_numbers)))
Meta.isexpr(expr, :for) &&
return (esc(trace_for(__module__, expr; track_numbers, include_paths_expr)))
return error("Only `if-elseif-else` blocks are currently supported by `@trace`")
end

function trace_for(mod, expr; track_numbers)
function trace_for(mod, expr; track_numbers, include_paths_expr)
Meta.isexpr(expr, :for, 2) || error("expected for expr")
assign, body = expr.args

Expand Down Expand Up @@ -216,6 +221,8 @@ function trace_for(mod, expr; track_numbers)
) for (s, ref) in zip(external_syms, ref_syms)
]

include_paths = gensym(:include_paths)

reactant_code_block = quote
let args = $(args_init)
cond_fn =
Expand All @@ -238,13 +245,14 @@ function trace_for(mod, expr; track_numbers)
$counter[].mlir_data = ($counter[] + 1).mlir_data
nothing
end

$(include_paths) = $(include_paths_expr)
$(ReactantCore).traced_while(
cond_fn,
body_fn,
args;
track_numbers=$(track_numbers),
verify_arg_names=$(QuoteNode(args_names)),
include_paths=$(include_paths),
)
end
end
Expand Down
11 changes: 9 additions & 2 deletions src/ControlFlow.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ function ReactantCore.traced_call(f::Function, args...)
end

function ReactantCore.traced_while(
cond_fn::CFn, body_fn::BFn, args; track_numbers=Number, verify_arg_names=nothing
cond_fn::CFn,
body_fn::BFn,
args;
track_numbers=Number,
verify_arg_names=nothing,
include_paths=[],
) where {CFn,BFn}
return Ops.while_loop(cond_fn, body_fn, args...; track_numbers, verify_arg_names)
return Ops.while_loop(
cond_fn, body_fn, args...; track_numbers, verify_arg_names, include_paths
)
end
9 changes: 7 additions & 2 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1722,7 +1722,12 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
end

@noinline function while_loop(
cond_fn::CFn, body_fn::BFn, args...; track_numbers, verify_arg_names=nothing
cond_fn::CFn,
body_fn::BFn,
args...;
track_numbers,
verify_arg_names=nothing,
include_paths=[],
) where {CFn,BFn}
# TODO: detect and prevent mutation within the condition

Expand All @@ -1733,7 +1738,7 @@ end

for (i, prev) in enumerate(args)
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers
seen_args, prev, (), Reactant.NoStopTracedTrack; track_numbers, include_paths
)
end

Expand Down
4 changes: 3 additions & 1 deletion src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ function make_mlir_fn(
argprefix::Symbol=:args,
resprefix::Symbol=:result,
resargprefix::Symbol=:resargs,
include_paths=[],
num_replicas=1,
)
if sizeof(typeof(f)) != 0 || f isa Base.BroadcastFunction
Expand Down Expand Up @@ -238,7 +239,7 @@ function make_mlir_fn(
end
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args, args[i], (argprefix, i), inmode; toscalar, runtime
seen_args, args[i], (argprefix, i), inmode; toscalar, runtime, include_paths
)
end

Expand Down Expand Up @@ -376,6 +377,7 @@ function make_mlir_fn(
(resargprefix, i),
Reactant.NoStopTracedTrack;
runtime,
include_paths=[],
)
end
traced_result
Expand Down
Loading
Loading