|
| 1 | +""" |
| 2 | + @leaf type [make_tracer = true] |
| 3 | +
|
| 4 | +This marks a type as a leaf type for the purposes of tracing in reactant. This means that |
| 5 | +we won't recurse into the type and it will be left untouched. |
| 6 | +""" |
| 7 | +macro leaf(args...) |
| 8 | + @assert length(args) ≥ 1 |
| 9 | + orig_type, args = args[1], args[2:end] |
| 10 | + |
| 11 | + options = Dict{Symbol,Any}() |
| 12 | + while length(args) ≥ 1 |
| 13 | + if !Meta.isexpr(args[1], :(=)) |
| 14 | + error("Invalid argument $(args[1])") |
| 15 | + end |
| 16 | + options[args[1].args[1]] = args[1].args[2] |
| 17 | + args = args[2:end] |
| 18 | + end |
| 19 | + |
| 20 | + subtype = Meta.isexpr(orig_type, :(<:)) |
| 21 | + type = subtype ? orig_type.args[1] : orig_type |
| 22 | + |
| 23 | + traced_type_inner_expr = quote |
| 24 | + Base.@nospecializeinfer function Reactant.traced_type_inner( |
| 25 | + @nospecialize(T::Type{$(orig_type)}), |
| 26 | + seen, |
| 27 | + @nospecialize(mode::$(TraceMode)), |
| 28 | + @nospecialize(track_numbers::Type), |
| 29 | + @nospecialize(sharding), |
| 30 | + ) |
| 31 | + return T |
| 32 | + end |
| 33 | + end |
| 34 | + |
| 35 | + make_tracer_expr = if get(options, :make_tracer, true) |
| 36 | + quote |
| 37 | + function Reactant.make_tracer( |
| 38 | + seen, |
| 39 | + @nospecialize(prev::$(type)), |
| 40 | + @nospecialize(path), |
| 41 | + mode::$(TraceMode); |
| 42 | + kwargs..., |
| 43 | + ) |
| 44 | + return prev |
| 45 | + end |
| 46 | + end |
| 47 | + else |
| 48 | + :() |
| 49 | + end |
| 50 | + |
| 51 | + return esc( |
| 52 | + quote |
| 53 | + $traced_type_inner_expr |
| 54 | + $make_tracer_expr |
| 55 | + end, |
| 56 | + ) |
| 57 | +end |
| 58 | + |
1 | 59 | @enum TraceMode begin |
2 | 60 | ConcreteToTraced = 1 |
3 | 61 | TracedTrack = 2 |
|
14 | 72 |
|
15 | 73 | function traced_type_inner end |
16 | 74 |
|
17 | | -Base.@nospecializeinfer function traced_type_inner( |
18 | | - @nospecialize(T::Type{Union{}}), |
19 | | - seen, |
20 | | - mode::TraceMode, |
21 | | - @nospecialize(track_numbers::Type), |
22 | | - @nospecialize(sharding) |
23 | | -) |
24 | | - return T |
| 75 | +for T in (Symbol, Union{}) |
| 76 | + @eval @leaf $T make_tracer = false |
25 | 77 | end |
26 | 78 |
|
27 | 79 | for T in ( |
28 | 80 | DataType, |
29 | 81 | Module, |
30 | 82 | Nothing, |
31 | | - Symbol, |
32 | 83 | AbstractChar, |
33 | 84 | AbstractString, |
34 | 85 | AbstractFloat, |
35 | 86 | Integer, |
36 | 87 | RNumber, |
37 | 88 | Val, |
38 | 89 | VersionNumber, |
| 90 | + Base.ExceptionStack, |
| 91 | + Core.MethodInstance, |
39 | 92 | ) |
40 | | - @eval Base.@nospecializeinfer function traced_type_inner( |
41 | | - @nospecialize(T::Type{<:$T}), |
42 | | - seen, |
43 | | - mode::TraceMode, |
44 | | - @nospecialize(track_numbers::Type), |
45 | | - @nospecialize(sharding) |
46 | | - ) |
47 | | - return T |
48 | | - end |
| 93 | + @eval @leaf <:$T |
49 | 94 | end |
50 | 95 |
|
51 | 96 | Base.@nospecializeinfer function traced_type_inner( |
@@ -754,15 +799,6 @@ function Base.showerror(io::IO, err::NoFieldMatchError) |
754 | 799 | end |
755 | 800 | end |
756 | 801 |
|
757 | | -function make_tracer( |
758 | | - seen, |
759 | | - @nospecialize(prev::Union{Base.ExceptionStack,Core.MethodInstance}), |
760 | | - @nospecialize(path), |
761 | | - mode; |
762 | | - kwargs..., |
763 | | -) |
764 | | - return prev |
765 | | -end |
766 | 802 | append_path(@nospecialize(path), i) = (path..., i) |
767 | 803 |
|
768 | 804 | function make_tracer( |
|
0 commit comments