Skip to content

Commit 0981db6

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Enable Optimizer Storing & Fix incomplete updates to Sharded EBC attributes in resharding (#2911)
Summary: Previously the dynamic sharding unit test was incomplete in truly verifying that a resharded EBC has all the attributes updated correctly. I ran into these issues when trying to enable optimizer state storing and DMP interface in D73049934 Main changes: 1. Add in dynamic sharding unit test's `are_sharded_ebc_modules_identical` the private attributes for ShardedEmbeddingCollection. This method will only compare primitive types or primitive reference types and tensors 1. This helped identify the gaps in current DS implementation - namely `module_sharding_plan`, `_embedding_dims`, `_uncombined_embedding_names`, `_uncombined_embedding_dims` not being updated correctly to reflect the new shard placements & order 2. Add in updates to `module_sharding_plan`, `_embedding_dims`, `_uncombined_embedding_names`, `_uncombined_embedding_dims` in reshard API for Sharded EBC. 3. Add in call to update Optimizer. The diff splits are not ideal, but the full optimizer unit test will be added in D73049934 Differential Revision: D73530909
1 parent a28ac22 commit 0981db6

File tree

3 files changed

+93
-7
lines changed

3 files changed

+93
-7
lines changed

torchrec/distributed/embeddingbag.py

+29
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5656
from torchrec.distributed.sharding.dynamic_sharding import (
5757
shards_all_to_all,
58+
update_module_sharding_plan,
5859
update_state_dict_post_resharding,
5960
)
6061
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
@@ -1232,11 +1233,19 @@ def _update_output_dist(self) -> None:
12321233
# TODO: Optimize to only go through embedding shardings with new ranks
12331234
self._output_dists: List[nn.Module] = []
12341235
self._embedding_names: List[str] = []
1236+
self._embedding_dims: List[int] = []
1237+
self._uncombined_embedding_names: List[str] = []
1238+
self._uncombined_embedding_dims: List[int] = []
12351239
for sharding in self._embedding_shardings:
12361240
# TODO: if sharding type of table completely changes, need to regenerate everything
12371241
self._embedding_names.extend(sharding.embedding_names())
12381242
self._output_dists.append(sharding.create_output_dist(device=self._device))
12391243
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
1244+
self._embedding_dims.extend(sharding.embedding_dims())
1245+
self._uncombined_embedding_names.extend(
1246+
sharding.uncombined_embedding_names()
1247+
)
1248+
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())
12401249

12411250
embedding_shard_offsets: List[int] = [
12421251
meta.shard_offsets[1] if meta is not None else 0
@@ -1585,6 +1594,26 @@ def update_shards(
15851594
self._initialize_torch_state(skip_registering=True)
15861595

15871596
self.load_state_dict(current_state)
1597+
1598+
# update optimizer
1599+
optims = []
1600+
for lookup in self._lookups:
1601+
for _, tbe_module in lookup.named_modules():
1602+
if isinstance(tbe_module, FusedOptimizerModule):
1603+
# modify param keys to match EmbeddingBagCollection
1604+
params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {}
1605+
for (
1606+
param_key,
1607+
weight,
1608+
) in tbe_module.fused_optimizer.params.items():
1609+
# pyre-fixme[16]: `Mapping` has no attribute `__setitem__`
1610+
params["embedding_bags." + param_key] = weight
1611+
tbe_module.fused_optimizer.params = params
1612+
optims.append(("", tbe_module.fused_optimizer))
1613+
1614+
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
1615+
1616+
update_module_sharding_plan(self, changed_sharding_params)
15881617
return
15891618

15901619
@property

torchrec/distributed/sharding/dynamic_sharding.py

+14
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,17 @@ def update_state_dict_post_resharding(
221221
sharded_t._local_shards = []
222222

223223
return state_dict
224+
225+
226+
def update_module_sharding_plan(
227+
module: ShardedModule[Any, Any, Any, Any], # pyre-ignore
228+
changed_sharding_params: Dict[str, ParameterSharding],
229+
) -> None:
230+
if not hasattr(module, "module_sharding_plan"):
231+
return
232+
233+
# pyre-ignore
234+
current_plan: Dict[str, ParameterSharding] = module.module_sharding_plan
235+
for table_name, param_sharding in changed_sharding_params.items():
236+
current_plan[table_name] = param_sharding
237+
return

torchrec/distributed/tests/test_dynamic_sharding.py

+50-7
Original file line numberDiff line numberDiff line change
@@ -141,13 +141,10 @@ def create_test_initial_state_dict(
141141
return initial_state_dict
142142

143143

144-
def are_modules_identical(
145-
module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection],
146-
module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection],
144+
def are_sharded_ebc_modules_identical(
145+
module1: ShardedEmbeddingBagCollection,
146+
module2: ShardedEmbeddingBagCollection,
147147
) -> None:
148-
# Check if both modules have the same type
149-
assert type(module1) is type(module2)
150-
151148
# Check if both modules have the same parameters
152149
params1 = list(module1.named_parameters())
153150
params2 = list(module2.named_parameters())
@@ -170,6 +167,52 @@ def are_modules_identical(
170167
assert buffer1[0] == buffer2[0] # Check buffer names
171168
assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values
172169

170+
# Hard-coded attributes for EmbeddingBagCollection
171+
attribute_list = [
172+
"_module_fqn",
173+
"_table_names",
174+
"_pooling_type_to_rs_features",
175+
"_output_dtensor",
176+
"_sharding_types",
177+
"_is_weighted",
178+
"_embedding_names",
179+
"_embedding_dims",
180+
"_feature_splits",
181+
"_features_order",
182+
"_uncombined_embedding_names",
183+
"_uncombined_embedding_dims",
184+
"_has_mean_pooling_callback",
185+
"_kjt_key_indices",
186+
"_has_uninitialized_input_dist",
187+
"_has_features_permute",
188+
"_dim_per_key", # Tensor
189+
"_inverse_indices_permute_indices", # Tensor
190+
"_kjt_inverse_order", # Tensor
191+
"_kt_key_ordering", # Tensor
192+
# Non-primitive types which can be compared
193+
"module_sharding_plan",
194+
"_table_name_to_config",
195+
# Excluding the non-primitive types that cannot be compared
196+
# "sharding_type_to_sharding_infos",
197+
# "_embedding_shardings"
198+
# "_input_dists",
199+
# "_lookups",
200+
# "_output_dists",
201+
# "_optim",
202+
]
203+
204+
for attr in attribute_list:
205+
assert hasattr(module1, attr) and hasattr(module2, attr)
206+
207+
val1 = getattr(module1, attr)
208+
val2 = getattr(module2, attr)
209+
210+
assert type(val1) is type(val2)
211+
if type(val1) is torch.Tensor:
212+
torch.testing.assert_close(val1, val2)
213+
else:
214+
assert val1 == val2
215+
173216

174217
def output_sharding_plan_delta(
175218
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
@@ -274,7 +317,7 @@ def _test_ebc_resharding(
274317
device=ctx.device,
275318
)
276319

277-
are_modules_identical(sharded_m1, resharded_m2)
320+
are_sharded_ebc_modules_identical(sharded_m1, resharded_m2)
278321

279322
feature_keys = []
280323
for table in tables:

0 commit comments

Comments
 (0)