Skip to content

Commit 477f3cf

Browse files
seanx92facebook-github-bot
authored andcommitted
move module attribute inplace update to leaf function in ManagedCollisionModule (#2913)
Summary: Pull Request resolved: #2913 inplace update will cause unexpected module attribute mutation described in pytorch/pytorch#70449 by moving it to leaf function we guaranteed no side effect during fx tracing. Differential Revision: D73448087
1 parent a28ac22 commit 477f3cf

File tree

1 file changed

+32
-6
lines changed

1 file changed

+32
-6
lines changed

torchrec/modules/mc_modules.py

+32-6
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
5454
return torch.cat([jt.values() for jt in jd.values()])
5555

5656

57+
# TODO: keep the old implementation for backward compatibility and will remove it later
5758
@torch.fx.wrap
5859
def _mcc_lazy_init(
5960
features: KeyedJaggedTensor,
@@ -78,6 +79,34 @@ def _mcc_lazy_init(
7879
return (features, created_feature_order, features_order)
7980

8081

82+
@torch.fx.wrap
83+
def _mcc_lazy_init_inplace(
84+
features: KeyedJaggedTensor,
85+
feature_names: List[str],
86+
features_order: List[int],
87+
created_feature_order: List[bool],
88+
) -> KeyedJaggedTensor:
89+
input_feature_names: List[str] = features.keys()
90+
if not created_feature_order or not created_feature_order[0]:
91+
for f in feature_names:
92+
features_order.append(input_feature_names.index(f))
93+
94+
if features_order == list(range(len(input_feature_names))):
95+
features_order.clear()
96+
97+
if len(created_feature_order) > 0:
98+
created_feature_order[0] = True
99+
else:
100+
created_feature_order.append(True)
101+
102+
if len(features_order) > 0:
103+
features = features.permute(
104+
features_order,
105+
)
106+
107+
return features
108+
109+
81110
@torch.fx.wrap
82111
def _get_length_per_key(kjt: KeyedJaggedTensor) -> torch.Tensor:
83112
return torch.tensor(kjt.length_per_key())
@@ -298,6 +327,7 @@ class ManagedCollisionCollection(nn.Module):
298327

299328
_table_to_features: Dict[str, List[str]]
300329
_features_order: List[int]
330+
_created_feature_order: List[bool] # use list for inplace update in leaf function
301331

302332
def __init__(
303333
self,
@@ -338,7 +368,7 @@ def __init__(
338368
self._feature_names: List[str] = [
339369
feature for config in embedding_configs for feature in config.feature_names
340370
]
341-
self._created_feature_order = False
371+
self._created_feature_order = [False]
342372
self._features_order = []
343373

344374
def _create_feature_order(
@@ -360,11 +390,7 @@ def forward(
360390
self,
361391
features: KeyedJaggedTensor,
362392
) -> KeyedJaggedTensor:
363-
(
364-
features,
365-
self._created_feature_order,
366-
self._features_order,
367-
) = _mcc_lazy_init(
393+
features = _mcc_lazy_init_inplace(
368394
features,
369395
self._feature_names,
370396
self._features_order,

0 commit comments

Comments
 (0)