Skip to content

Commit a11cae2

Browse files
vam-googlecopybara-github
authored andcommitted
fix GPU pywrap implementation
PiperOrigin-RevId: 724563620
1 parent 41a11a7 commit a11cae2

File tree

1 file changed

+103
-9
lines changed

1 file changed

+103
-9
lines changed

third_party/py/rules_pywrap/pywrap.impl.bzl

+103-9
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ PywrapInfo = provider(
88
"py_stub": "Pybind Python stub used to resolve cross-package references",
99
"cc_only": "True if this PywrapInfo represents cc-only library (no PyIni_)",
1010
"starlark_only": "",
11+
"default_runfiles": "",
1112
},
1213
)
1314

@@ -21,6 +22,7 @@ PywrapFilters = provider(
2122
fields = {
2223
"pywrap_lib_filter": "",
2324
"common_lib_filters": "",
25+
"dynamic_lib_filter": "",
2426
},
2527
)
2628

@@ -68,12 +70,16 @@ def pywrap_library(
6870
if starlark_only_pywrap_count > 0:
6971
starlark_only_filter_full_name = "%s%s__starlark_only_common" % (cur_pkg, name)
7072

73+
inverse_common_lib_filters = _construct_inverse_common_lib_filters(
74+
common_lib_filters,
75+
)
76+
7177
_linker_input_filters(
7278
name = linker_input_filters_name,
7379
dep = ":%s" % info_collector_name,
7480
pywrap_lib_filter = pywrap_lib_filter,
7581
pywrap_lib_exclusion_filter = pywrap_lib_exclusion_filter,
76-
common_lib_filters = {v: k for k, v in common_lib_filters.items()},
82+
common_lib_filters = inverse_common_lib_filters,
7783
starlark_only_filter_name = starlark_only_filter_full_name,
7884
)
7985

@@ -108,7 +114,7 @@ def pywrap_library(
108114
common_cc_binary_name = "%s" % common_lib_name
109115
common_import_name = _construct_common_binary(
110116
common_cc_binary_name,
111-
[":%s" % common_split_name] + common_deps,
117+
common_deps + [":%s" % common_split_name],
112118
linkopts,
113119
testonly,
114120
compatible_with,
@@ -117,6 +123,7 @@ def pywrap_library(
117123
binaries_data.values(),
118124
common_lib_pkg,
119125
ver_script,
126+
data = [":%s" % common_split_name],
120127
)
121128
actual_binaries_data = binaries_data
122129
actual_common_deps = common_deps
@@ -223,7 +230,8 @@ def _construct_common_binary(
223230
local_defines,
224231
dependency_common_lib_packages,
225232
dependent_common_lib_package,
226-
version_script):
233+
version_script,
234+
data):
227235
actual_linkopts = _construct_linkopt_soname(name) + _construct_linkopt_rpaths(
228236
dependency_common_lib_packages,
229237
dependent_common_lib_package,
@@ -242,6 +250,7 @@ def _construct_common_binary(
242250
compatible_with = compatible_with,
243251
win_def_file = win_def_file,
244252
local_defines = local_defines,
253+
# data = data,
245254
)
246255

247256
if_lib_name = "%s_if_lib" % name
@@ -263,6 +272,14 @@ def _construct_common_binary(
263272
compatible_with = compatible_with,
264273
)
265274

275+
cc_lib_name = "%s_cc_library" % name
276+
native.cc_library(
277+
name = cc_lib_name,
278+
deps = [":%s" % import_name],
279+
testonly = testonly,
280+
data = data,
281+
)
282+
266283
return import_name
267284

268285
def _pywrap_split_library_impl(ctx):
@@ -275,6 +292,7 @@ def _pywrap_split_library_impl(ctx):
275292

276293
split_linker_inputs = []
277294
private_linker_inputs = []
295+
default_runfiles = None
278296
if not pw.cc_only:
279297
split_linker_inputs.append(li)
280298
pywrap_lib_filter = ctx.attr.linker_input_filters[PywrapFilters].pywrap_lib_filter
@@ -286,11 +304,14 @@ def _pywrap_split_library_impl(ctx):
286304
depset(direct = private_lis),
287305
]
288306

307+
# default_runfiles = pw.default_runfiles
308+
289309
return _construct_split_library_cc_info(
290310
ctx,
291311
split_linker_inputs,
292312
user_link_flags,
293313
private_linker_inputs,
314+
default_runfiles,
294315
)
295316

296317
_pywrap_split_library = rule(
@@ -332,15 +353,30 @@ def _pywrap_common_split_library_impl(ctx):
332353
else:
333354
libs_to_include = filters.common_lib_filters[ctx.attr.common_lib_full_name]
334355

356+
user_link_flags = {}
357+
dynamic_lib_filter = filters.dynamic_lib_filter
358+
default_runfiles = ctx.runfiles()
335359
for pw in pywrap_infos:
336360
pw_lis = pw.cc_info.linking_context.linker_inputs.to_list()[1:]
361+
pw_runfiles_merged = False
337362
for li in pw_lis:
338363
if li in libs_to_exclude:
339364
continue
340-
if include_all_not_excluded or (li in libs_to_include):
365+
if include_all_not_excluded or (li in libs_to_include) or li in dynamic_lib_filter:
341366
split_linker_inputs.append(li)
367+
for user_link_flag in li.user_link_flags:
368+
user_link_flags[user_link_flag] = True
369+
if not pw_runfiles_merged:
370+
default_runfiles = default_runfiles.merge(pw.default_runfiles)
371+
pw_runfiles_merged = True
342372

343-
return _construct_split_library_cc_info(ctx, split_linker_inputs, [], [])
373+
return _construct_split_library_cc_info(
374+
ctx,
375+
split_linker_inputs,
376+
list(user_link_flags.keys()),
377+
[],
378+
default_runfiles,
379+
)
344380

345381
_pywrap_common_split_library = rule(
346382
attrs = {
@@ -367,7 +403,8 @@ def _construct_split_library_cc_info(
367403
ctx,
368404
split_linker_inputs,
369405
user_link_flags,
370-
private_linker_inputs):
406+
private_linker_inputs,
407+
default_runfiles):
371408
dependency_libraries = _construct_dependency_libraries(
372409
ctx,
373410
split_linker_inputs,
@@ -386,7 +423,11 @@ def _construct_split_library_cc_info(
386423
),
387424
)
388425

389-
return [CcInfo(linking_context = linking_context)]
426+
return [
427+
CcInfo(linking_context = linking_context),
428+
# DefaultInfo(files = default_runfiles.files)
429+
DefaultInfo(runfiles = default_runfiles),
430+
]
390431

391432
def _construct_dependency_libraries(ctx, split_linker_inputs):
392433
cc_toolchain = find_cpp_toolchain(ctx)
@@ -400,7 +441,7 @@ def _construct_dependency_libraries(ctx, split_linker_inputs):
400441
for split_linker_input in split_linker_inputs:
401442
for lib in split_linker_input.libraries:
402443
lib_copy = lib
403-
if not lib.alwayslink:
444+
if not lib.alwayslink and (lib.static_library or lib.pic_static_library):
404445
lib_copy = cc_common.create_library_to_link(
405446
actions = ctx.actions,
406447
cc_toolchain = cc_toolchain,
@@ -418,6 +459,10 @@ def _linker_input_filters_impl(ctx):
418459
pywrap_lib_exclusion_filter = {}
419460
pywrap_lib_filter = {}
420461
visited_filters = {}
462+
463+
#
464+
# pywrap private filter
465+
#
421466
if ctx.attr.pywrap_lib_exclusion_filter:
422467
for li in ctx.attr.pywrap_lib_exclusion_filter[CcInfo].linking_context.linker_inputs.to_list():
423468
pywrap_lib_exclusion_filter[li] = li.owner
@@ -429,13 +474,19 @@ def _linker_input_filters_impl(ctx):
429474

430475
common_lib_filters = {k: {} for k in ctx.attr.common_lib_filters.values()}
431476

477+
#
478+
# common lib filters
479+
#
432480
for filter, name in ctx.attr.common_lib_filters.items():
433481
filter_li = filter[CcInfo].linking_context.linker_inputs.to_list()
434482
for li in filter_li:
435483
if li not in visited_filters:
436484
common_lib_filters[name][li] = li.owner
437485
visited_filters[li] = li.owner
438486

487+
#
488+
# starlark -only filter
489+
#
439490
pywrap_infos = ctx.attr.dep[CollectedPywrapInfo].pywrap_infos.to_list()
440491
starlark_only_filter = {}
441492

@@ -451,10 +502,29 @@ def _linker_input_filters_impl(ctx):
451502
starlark_only_filter.pop(li, None)
452503

453504
common_lib_filters[ctx.attr.starlark_only_filter_name] = starlark_only_filter
505+
506+
#
507+
# dynamic libs filter
508+
#
509+
dynamic_lib_filter = {}
510+
empty_lib_filter = {}
511+
for pw in pywrap_infos:
512+
for li in pw.cc_info.linking_context.linker_inputs.to_list()[1:]:
513+
all_dynamic = None
514+
for lib in li.libraries:
515+
if lib.static_library or lib.pic_static_library or not lib.dynamic_library:
516+
all_dynamic = False
517+
break
518+
elif all_dynamic == None:
519+
all_dynamic = True
520+
if all_dynamic:
521+
dynamic_lib_filter[li] = li.owner
522+
454523
return [
455524
PywrapFilters(
456525
pywrap_lib_filter = pywrap_lib_filter,
457526
common_lib_filters = common_lib_filters,
527+
dynamic_lib_filter = dynamic_lib_filter,
458528
),
459529
]
460530

@@ -488,7 +558,7 @@ _linker_input_filters = rule(
488558
def pywrap_common_library(name, dep, filter_name = None):
489559
native.alias(
490560
name = name,
491-
actual = "%s_import" % (filter_name if filter_name else dep + "_common"),
561+
actual = "%s_cc_library" % (filter_name if filter_name else dep + "_common"),
492562
)
493563

494564
def pywrap_binaries(name, dep, **kwargs):
@@ -621,10 +691,15 @@ def _pywrap_info_wrapper_impl(ctx):
621691
substitutions = substitutions,
622692
)
623693

694+
default_runfiles = ctx.runfiles().merge(
695+
ctx.attr.deps[0][DefaultInfo].default_runfiles,
696+
)
697+
624698
return [
625699
PyInfo(transitive_sources = depset()),
626700
PywrapInfo(
627701
cc_info = ctx.attr.deps[0][CcInfo],
702+
default_runfiles = default_runfiles,
628703
owner = ctx.label,
629704
common_lib_packages = ctx.attr.common_lib_packages,
630705
py_stub = py_stub,
@@ -652,11 +727,16 @@ _pywrap_info_wrapper = rule(
652727

653728
def _cc_only_pywrap_info_wrapper_impl(ctx):
654729
wrapped_dep = ctx.attr.deps[0]
730+
default_runfiles = ctx.runfiles().merge(
731+
ctx.attr.deps[0][DefaultInfo].default_runfiles,
732+
)
733+
655734
return [
656735
PyInfo(transitive_sources = depset()),
657736
PywrapInfo(
658737
cc_info = wrapped_dep[CcInfo],
659738
owner = ctx.label,
739+
default_runfiles = default_runfiles,
660740
common_lib_packages = ctx.attr.common_lib_packages,
661741
py_stub = None,
662742
cc_only = True,
@@ -923,6 +1003,20 @@ def _get_common_lib_package_and_name(common_lib_full_name):
9231003
return common_lib_full_name.rsplit("/", 1)
9241004
return "", common_lib_full_name
9251005

1006+
def _construct_inverse_common_lib_filters(common_lib_filters):
1007+
inverse_common_lib_filters = {}
1008+
for common_lib_k, common_lib_v in common_lib_filters.items():
1009+
new_common_lib_k = common_lib_v
1010+
if type(common_lib_v) == type([]):
1011+
new_common_lib_k = "_%s_common_lib_filter" % common_lib_k.rsplit("/", 1)[-1]
1012+
native.cc_library(
1013+
name = new_common_lib_k,
1014+
deps = common_lib_v,
1015+
)
1016+
1017+
inverse_common_lib_filters[new_common_lib_k] = common_lib_k
1018+
return inverse_common_lib_filters
1019+
9261020
def _construct_linkopt_soname(name):
9271021
soname = name.rsplit("/", 1)[1] if "/" in name else name
9281022
soname = soname if name.startswith("lib") else ("lib%s" % soname)

0 commit comments

Comments
 (0)