Skip to content

Commit 3475e51

Browse files
committed
Add support for automatically calling unsafe_load() in getproperty()
Copying the description from the code: > By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped > types will return pointers (Ptr{T}) to the struct fields. That behaviour is > useful for accessing nested struct fields but it does require explicitly > calling unsafe_load() every time. When enabled this option will automatically > call unsafe_load() for you *except on nested struct fields and arrays*, which > should make explicitly calling unsafe_load() unnecessary in most cases.
1 parent 129555d commit 3475e51

File tree

7 files changed

+256
-34
lines changed

7 files changed

+256
-34
lines changed

docs/src/changelog.md

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ Changelog](https://keepachangelog.com).
1111
([5a1cc29](https://github.com/JuliaInterop/Clang.jl/commit/5a1cc29c154ed925f01e59dfd705cbf8042158e4)).
1212
- Added bindings for Clang 17, which should allow compatibility with Julia 1.12
1313
([#494]).
14+
- Experimental support for automatically dereferencing struct fields in
15+
`Base.getproperty()` with the `auto_field_dereference` option ([#502]).
1416

1517
### Fixed
1618

gen/generator.toml

+13
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,19 @@ wrap_variadic_function = false
181181
# generate getproperty/setproperty! methods for the types in the following list
182182
field_access_method_list = []
183183

184+
# EXPERIMENTAL:
185+
# By default the getproperty!(x::Ptr, ::Symbol) methods created for wrapped
186+
# types will return pointers (Ptr{T}) to the struct fields. That behaviour is
187+
# useful for accessing nested struct fields but it does require explicitly
188+
# calling unsafe_load() every time. When enabled this option will automatically
189+
# call unsafe_load() for you *except on nested struct fields and arrays*, which
190+
# should make explicitly calling unsafe_load() unnecessary in most cases. A @ptr
191+
# macro will be defined for cases where you really do want a pointer to a field
192+
# (e.g. for writing), which supports syntax like `@ptr(foo.bar)`.
193+
#
194+
# This should be used with `field_access_method_list`.
195+
auto_field_dereference = false
196+
184197
# the generator will prefix the function argument names in the following list with a "_" to
185198
# prevent the generated symbols from conflicting with the symbols defined and exported in Base.
186199
function_argument_conflict_symbols = []

src/generator/codegen.jl

+63-7
Original file line numberDiff line numberDiff line change
@@ -296,19 +296,20 @@ end
296296

297297
############################### Struct ###############################
298298

299-
function _emit_getproperty_ptr!(body, root_cursor, cursor, options)
299+
function _emit_pointer_access!(body, root_cursor, cursor, options)
300300
field_cursors = fields(getCursorType(cursor))
301301
field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors
302302
for field_cursor in field_cursors
303303
n = name(field_cursor)
304304
if isempty(n)
305-
_emit_getproperty_ptr!(body, root_cursor, field_cursor, options)
305+
_emit_pointer_access!(body, root_cursor, field_cursor, options)
306306
continue
307307
end
308308
fsym = make_symbol_safe(n)
309309
fty = getCursorType(field_cursor)
310310
ty = translate(tojulia(fty), options)
311311
offset = getOffsetOf(getCursorType(root_cursor), n)
312+
312313
if isBitField(field_cursor)
313314
w = getFieldDeclBitWidth(field_cursor)
314315
@assert w <= 32 # Bit fields should not be larger than int(32 bits)
@@ -322,12 +323,63 @@ function _emit_getproperty_ptr!(body, root_cursor, cursor, options)
322323
end
323324
end
324325

325-
# Base.getproperty(x::Ptr, f::Symbol) -> Ptr
326+
# getptr(x::Ptr, f::Symbol) -> Ptr
327+
function emit_getptr!(dag, node, options)
328+
sym = make_symbol_safe(node.id)
329+
signature = Expr(:call, :getptr, :(x::Ptr{$sym}), :(f::Symbol))
330+
body = Expr(:block)
331+
_emit_pointer_access!(body, node.cursor, node.cursor, options)
332+
333+
push!(body.args, :(error($("Unrecognized field of type `$sym`") * ": $f")))
334+
push!(node.exprs, Expr(:function, signature, body))
335+
return dag
336+
end
337+
338+
function emit_deref_getproperty!(body, root_cursor, cursor, options)
339+
field_cursors = fields(getCursorType(cursor))
340+
field_cursors = isempty(field_cursors) ? children(cursor) : field_cursors
341+
for field_cursor in field_cursors
342+
n = name(field_cursor)
343+
if isempty(n)
344+
emit_deref_getproperty!(body, root_cursor, field_cursor, options)
345+
continue
346+
end
347+
fsym = make_symbol_safe(n)
348+
fty = getCursorType(field_cursor)
349+
canonical_type = getCanonicalType(fty)
350+
351+
return_expr = :(getptr(x, f))
352+
353+
# Automatically dereference all field types except for nested structs
354+
# and arrays.
355+
if !(canonical_type isa Union{CLRecord, CLConstantArray}) && !isBitField(field_cursor)
356+
return_expr = :(unsafe_load($return_expr))
357+
elseif isBitField(field_cursor)
358+
return_expr = :(getbitfieldproperty(x, $return_expr))
359+
end
360+
361+
ex = :(f === $(QuoteNode(fsym)) && return $return_expr)
362+
push!(body.args, ex)
363+
end
364+
end
365+
366+
# Base.getproperty(x::Ptr, f::Symbol)
326367
function emit_getproperty_ptr!(dag, node, options)
368+
auto_deref = get(options, "auto_field_dereference", false)
327369
sym = make_symbol_safe(node.id)
370+
371+
# If automatically dereferencing, we first need to emit getptr!()
372+
if auto_deref
373+
emit_getptr!(dag, node, options)
374+
end
375+
328376
signature = Expr(:call, :(Base.getproperty), :(x::Ptr{$sym}), :(f::Symbol))
329377
body = Expr(:block)
330-
_emit_getproperty_ptr!(body, node.cursor, node.cursor, options)
378+
if auto_deref
379+
emit_deref_getproperty!(body, node.cursor, node.cursor, options)
380+
else
381+
_emit_pointer_access!(body, node.cursor, node.cursor, options)
382+
end
331383
push!(body.args, :(return getfield(x, f)))
332384
getproperty_expr = Expr(:function, signature, body)
333385
push!(node.exprs, getproperty_expr)
@@ -370,10 +422,14 @@ end
370422
function emit_setproperty!(dag, node, options)
371423
sym = make_symbol_safe(node.id)
372424
signature = Expr(:call, :(Base.setproperty!), :(x::Ptr{$sym}), :(f::Symbol), :v)
373-
store_expr = :(unsafe_store!(getproperty(x, f), v))
425+
426+
auto_deref = get(options, "auto_field_dereference", false)
427+
pointer_getter = auto_deref ? :getptr : :getproperty
428+
store_expr = :(unsafe_store!($pointer_getter(x, f), v))
429+
374430
if is_bitfield_type(node.type)
375431
body = quote
376-
fptr = getproperty(x, f)
432+
fptr = $pointer_getter(x, f)
377433
if fptr isa Ptr
378434
$store_expr
379435
else
@@ -398,7 +454,7 @@ function get_names_types(root_cursor, cursor, options)
398454
for field_cursor in field_cursors
399455
n = name(field_cursor)
400456
if isempty(n)
401-
_emit_getproperty_ptr!(root_cursor, field_cursor, options)
457+
_emit_pointer_access!(root_cursor, field_cursor, options)
402458
continue
403459
end
404460
fsym = make_symbol_safe(n)

src/generator/passes.jl

+16
Original file line numberDiff line numberDiff line change
@@ -1094,6 +1094,7 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict)
10941094
use_native_enum = get(general_options, "use_julia_native_enum_type", false)
10951095
print_CEnum = get(general_options, "print_using_CEnum", true)
10961096
wrap_variadic_function = get(codegen_options, "wrap_variadic_function", false)
1097+
auto_deref = get(codegen_options, "auto_field_dereference", false)
10971098

10981099
show_info && @info "[ProloguePrinter]: print to $(x.file)"
10991100
open(x.file, "w") do io
@@ -1186,6 +1187,21 @@ function (x::ProloguePrinter)(dag::ExprDAG, options::Dict)
11861187
println(io, string(set_expr), "\n")
11871188
end
11881189

1190+
if auto_deref
1191+
println(io, raw"""
1192+
macro ptr(expr)
1193+
if !Meta.isexpr(expr, :.)
1194+
error("Expression is not a property access, cannot use @ptr on it.")
1195+
end
1196+
1197+
quote
1198+
local penultimate_obj = $(esc(expr.args[1]))
1199+
getptr(penultimate_obj, $(esc(expr.args[2])))
1200+
end
1201+
end
1202+
""")
1203+
end
1204+
11891205
# print prelogue patches
11901206
if !isempty(prologue_file_path)
11911207
println(io, read(prologue_file_path, String))

test/generators.jl

+101
Original file line numberDiff line numberDiff line change
@@ -249,3 +249,104 @@ end
249249
@test docstring_has("callback")
250250
end
251251
end
252+
253+
@testset "Struct getproperty()/setproperty!()" begin
254+
options = Dict("general" => Dict{String, Any}("auto_mutability" => true,
255+
"auto_mutability_with_new" => false,
256+
"auto_mutability_includelist" => ["WithFields"]),
257+
"codegen" => Dict{String, Any}("field_access_method_list" => ["WithFields", "Other"]))
258+
259+
# Test the default getproperty()/setproperty!() behaviour
260+
mktemp() do path, io
261+
options["general"]["output_file_path"] = path
262+
ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options)
263+
build!(ctx)
264+
265+
println(read(path, String))
266+
267+
m = Module()
268+
Base.include(m, path)
269+
270+
# We now have to run in the latest world to use the new definitions
271+
Base.invokelatest() do
272+
obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1))
273+
274+
GC.@preserve obj begin
275+
obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj))
276+
277+
# The default getproperty() should basically always return a
278+
# pointer to the field (except for bitfields, which are tested
279+
# elsewhere).
280+
@test obj_ptr.int_value isa Ptr{Cint}
281+
@test obj_ptr.int_ptr isa Ptr{Ptr{Cint}}
282+
@test obj_ptr.struct_value isa Ptr{m.Other}
283+
@test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct}
284+
@test obj_ptr.array isa Ptr{NTuple{2, Cint}}
285+
286+
# Sanity test
287+
int_value = unsafe_load(obj_ptr.int_value)
288+
@test int_value == obj.int_value
289+
290+
# Test setproperty!()
291+
obj_ptr.int_value = int_value + 1
292+
@test unsafe_load(obj_ptr.int_value) == int_value + 1
293+
end
294+
end
295+
end
296+
297+
# Test the auto_field_dereference option
298+
mktemp() do path, io
299+
options["general"]["output_file_path"] = path
300+
options["codegen"]["auto_field_dereference"] = true
301+
ctx = create_context([joinpath(@__DIR__, "include/struct-properties.h")], get_default_args(), options)
302+
build!(ctx)
303+
304+
println(read(path, String))
305+
306+
m = Module()
307+
Base.include(m, path)
308+
309+
# We now have to run in the latest world to use the new definitions
310+
Base.invokelatest() do
311+
obj = m.WithFields(1, C_NULL, m.Other(42), C_NULL, m.TypedefStruct(1), (1, 1))
312+
313+
GC.@preserve obj begin
314+
obj_ptr = Ptr{m.WithFields}(pointer_from_objref(obj))
315+
316+
# Test getproperty()
317+
@test obj_ptr.int_value isa Cint
318+
@test obj_ptr.int_value == obj.int_value
319+
@test obj_ptr.int_ptr isa Ptr{Cint}
320+
321+
@test obj_ptr.struct_value isa Ptr{m.Other}
322+
@test obj_ptr.struct_value.i == obj.struct_value.i
323+
@test obj_ptr.struct_ptr isa Ptr{m.Other}
324+
@test obj_ptr.typedef_struct_value isa Ptr{m.TypedefStruct}
325+
326+
@test obj_ptr.array isa Ptr{NTuple{2, Cint}}
327+
328+
@test_throws ErrorException obj_ptr.foo
329+
330+
# Test @ptr
331+
val_ptr = @eval m @ptr $obj_ptr.int_value
332+
@test val_ptr isa Ptr{Cint}
333+
int_ptr = @eval m @ptr $obj_ptr.int_ptr
334+
@test int_ptr isa Ptr{Ptr{Cint}}
335+
336+
@test_throws LoadError (@eval m @ptr $obj_ptr)
337+
@test_throws ErrorException (@eval m @ptr $obj_ptr.foo)
338+
339+
# Test setproperty!()
340+
new_value = obj.int_value * 2
341+
obj_ptr.int_value = new_value
342+
@test obj.int_value == new_value
343+
344+
new_value = obj.struct_value.i * 2
345+
obj_ptr.struct_value.i = new_value
346+
@test obj.struct_value.i == new_value
347+
348+
@test_throws ErrorException obj_ptr.foo = 1
349+
end
350+
end
351+
end
352+
end

test/include/struct-properties.h

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
typedef struct {
2+
int i;
3+
} TypedefStruct;
4+
5+
struct Other {
6+
int i;
7+
};
8+
9+
struct WithFields {
10+
int int_value;
11+
int* int_ptr;
12+
13+
struct Other struct_value;
14+
struct Other* struct_ptr;
15+
TypedefStruct typedef_struct_value;
16+
17+
int array[2];
18+
};

test/test_bitfield.jl

+43-27
Original file line numberDiff line numberDiff line change
@@ -61,44 +61,60 @@ function build_libbitfield()
6161
error("Could not build libbitfield binary")
6262
end
6363

64-
# Generate wrappers
65-
@info "Building libbitfield wrapper"
66-
args = get_default_args()
67-
headers = joinpath(@__DIR__, "build", "include", "bitfield.h")
68-
options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml"))
69-
lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield")
70-
options["general"]["library_name"] = "\"$(escape_string(lib_path))\""
71-
options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl")
72-
ctx = create_context(headers, args, options)
73-
build!(ctx)
74-
75-
# Call a function to ensure build is successful
76-
include("LibBitField.jl")
77-
m = Base.@invokelatest LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)
78-
Base.@invokelatest LibBitField.toBitfield(Ref(m))
64+
# Test the binary
65+
generate_wrappers(false)
7966
catch e
8067
@warn "Building libbitfield failed: $e"
8168
success = false
8269
end
8370
return success
8471
end
8572

73+
function generate_wrappers(auto_deref::Bool)
74+
@info "Building libbitfield wrapper"
75+
args = get_default_args()
76+
headers = joinpath(@__DIR__, "build", "include", "bitfield.h")
77+
options = load_options(joinpath(@__DIR__, "bitfield", "generate.toml"))
78+
options["codegen"]["auto_field_dereference"] = auto_deref
79+
options["codegen"]["field_access_method_list"] = ["BitField"]
80+
81+
lib_path = joinpath(@__DIR__, "build", "lib", Sys.iswindows() ? "bitfield.dll" : "libbitfield")
82+
options["general"]["library_name"] = "\"$(escape_string(lib_path))\""
83+
options["general"]["output_file_path"] = joinpath(@__DIR__, "LibBitField.jl")
84+
ctx = create_context(headers, args, options)
85+
build!(ctx)
8686

87+
# Call a function to ensure build is successful
88+
anonmod = Module()
89+
Base.include(anonmod, "LibBitField.jl")
90+
m = Base.@invokelatest anonmod.LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3)
91+
Base.@invokelatest anonmod.LibBitField.toBitfield(Ref(m))
92+
93+
return anonmod
94+
end
8795

8896
@testset "Bitfield" begin
8997
if build_libbitfield()
90-
bf = Ref(LibBitField.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3)))
91-
m = Ref(LibBitField.Mirror(10, 1.5, 1e6, -4, 7, 3))
92-
GC.@preserve bf m begin
93-
pbf = Ptr{LibBitField.BitField}(pointer_from_objref(bf))
94-
pm = Ptr{LibBitField.Mirror}(pointer_from_objref(m))
95-
@test LibBitField.toMirror(bf) == m[]
96-
@test LibBitField.toBitfield(m).a == bf[].a
97-
@test LibBitField.toBitfield(m).b == bf[].b
98-
@test LibBitField.toBitfield(m).c == bf[].c
99-
@test LibBitField.toBitfield(m).d == bf[].d
100-
@test LibBitField.toBitfield(m).e == bf[].e
101-
@test LibBitField.toBitfield(m).f == bf[].f
98+
# Test the wrappers with and without auto-dereferencing. In the case of
99+
# bitfields they should have identical behaviour.
100+
for auto_deref in [false, true]
101+
anonmod = generate_wrappers(auto_deref)
102+
lib = anonmod.LibBitField
103+
104+
bf = Ref(lib.BitField(Int8(10), 1.5, Int32(1e6), Int32(-4), Int32(7), UInt32(3)))
105+
m = Ref(lib.Mirror(10, 1.5, 1e6, -4, 7, 3))
106+
107+
GC.@preserve bf m begin
108+
pbf = Ptr{lib.BitField}(pointer_from_objref(bf))
109+
pm = Ptr{lib.Mirror}(pointer_from_objref(m))
110+
@test lib.toMirror(bf) == m[]
111+
@test lib.toBitfield(m).a == bf[].a
112+
@test lib.toBitfield(m).b == bf[].b
113+
@test lib.toBitfield(m).c == bf[].c
114+
@test lib.toBitfield(m).d == bf[].d
115+
@test lib.toBitfield(m).e == bf[].e
116+
@test lib.toBitfield(m).f == bf[].f
117+
end
102118
end
103119
end
104120
end

0 commit comments

Comments
 (0)