@@ -54,6 +54,7 @@ def _cat_jagged_values(jd: Dict[str, JaggedTensor]) -> torch.Tensor:
54
54
return torch .cat ([jt .values () for jt in jd .values ()])
55
55
56
56
57
+ # TODO: keep the old implementation for backward compatibility and will remove it later
57
58
@torch .fx .wrap
58
59
def _mcc_lazy_init (
59
60
features : KeyedJaggedTensor ,
@@ -78,6 +79,34 @@ def _mcc_lazy_init(
78
79
return (features , created_feature_order , features_order )
79
80
80
81
82
+ @torch .fx .wrap
83
+ def _mcc_lazy_init_inplace (
84
+ features : KeyedJaggedTensor ,
85
+ feature_names : List [str ],
86
+ features_order : List [int ],
87
+ created_feature_order : List [bool ],
88
+ ) -> KeyedJaggedTensor :
89
+ input_feature_names : List [str ] = features .keys ()
90
+ if not created_feature_order or not created_feature_order [0 ]:
91
+ for f in feature_names :
92
+ features_order .append (input_feature_names .index (f ))
93
+
94
+ if features_order == list (range (len (input_feature_names ))):
95
+ features_order .clear ()
96
+
97
+ if len (created_feature_order ) > 0 :
98
+ created_feature_order [0 ] = True
99
+ else :
100
+ created_feature_order .append (True )
101
+
102
+ if len (features_order ) > 0 :
103
+ features = features .permute (
104
+ features_order ,
105
+ )
106
+
107
+ return features
108
+
109
+
81
110
@torch .fx .wrap
82
111
def _get_length_per_key (kjt : KeyedJaggedTensor ) -> torch .Tensor :
83
112
return torch .tensor (kjt .length_per_key ())
@@ -298,6 +327,7 @@ class ManagedCollisionCollection(nn.Module):
298
327
299
328
_table_to_features : Dict [str , List [str ]]
300
329
_features_order : List [int ]
330
+ _created_feature_order : List [bool ] # use list for inplace update in leaf function
301
331
302
332
def __init__ (
303
333
self ,
@@ -338,7 +368,7 @@ def __init__(
338
368
self ._feature_names : List [str ] = [
339
369
feature for config in embedding_configs for feature in config .feature_names
340
370
]
341
- self ._created_feature_order = False
371
+ self ._created_feature_order = [ False ]
342
372
self ._features_order = []
343
373
344
374
def _create_feature_order (
@@ -360,11 +390,7 @@ def forward(
360
390
self ,
361
391
features : KeyedJaggedTensor ,
362
392
) -> KeyedJaggedTensor :
363
- (
364
- features ,
365
- self ._created_feature_order ,
366
- self ._features_order ,
367
- ) = _mcc_lazy_init (
393
+ features = _mcc_lazy_init_inplace (
368
394
features ,
369
395
self ._feature_names ,
370
396
self ._features_order ,
0 commit comments