diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 4cb1d62c2..d7ae684e3 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -55,6 +55,7 @@ from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding from torchrec.distributed.sharding.dynamic_sharding import ( shards_all_to_all, + update_module_sharding_plan, update_state_dict_post_resharding, ) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding @@ -1232,11 +1233,19 @@ def _update_output_dist(self) -> None: # TODO: Optimize to only go through embedding shardings with new ranks self._output_dists: List[nn.Module] = [] self._embedding_names: List[str] = [] + self._embedding_dims: List[int] = [] + self._uncombined_embedding_names: List[str] = [] + self._uncombined_embedding_dims: List[int] = [] for sharding in self._embedding_shardings: # TODO: if sharding type of table completely changes, need to regenerate everything self._embedding_names.extend(sharding.embedding_names()) self._output_dists.append(sharding.create_output_dist(device=self._device)) embedding_shard_metadata.extend(sharding.embedding_shard_metadata()) + self._embedding_dims.extend(sharding.embedding_dims()) + self._uncombined_embedding_names.extend( + sharding.uncombined_embedding_names() + ) + self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims()) embedding_shard_offsets: List[int] = [ meta.shard_offsets[1] if meta is not None else 0 @@ -1585,6 +1594,26 @@ def update_shards( self._initialize_torch_state(skip_registering=True) self.load_state_dict(current_state) + + # update optimizer + optims = [] + for lookup in self._lookups: + for _, tbe_module in lookup.named_modules(): + if isinstance(tbe_module, FusedOptimizerModule): + # modify param keys to match EmbeddingBagCollection + params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {} + for ( + param_key, + weight, + ) in tbe_module.fused_optimizer.params.items(): + # pyre-fixme[16]: `Mapping` has no attribute `__setitem__` + params["embedding_bags." + param_key] = weight + tbe_module.fused_optimizer.params = params + optims.append(("", tbe_module.fused_optimizer)) + + self._optim: CombinedOptimizer = CombinedOptimizer(optims) + + update_module_sharding_plan(self, changed_sharding_params) return @property diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 05ca485f2..caa937db2 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -221,3 +221,17 @@ def update_state_dict_post_resharding( sharded_t._local_shards = [] return state_dict + + +def update_module_sharding_plan( + module: ShardedModule[Any, Any, Any, Any], # pyre-ignore + changed_sharding_params: Dict[str, ParameterSharding], +) -> None: + if not hasattr(module, "module_sharding_plan"): + return + + # pyre-ignore + current_plan: Dict[str, ParameterSharding] = module.module_sharding_plan + for table_name, param_sharding in changed_sharding_params.items(): + current_plan[table_name] = param_sharding + return diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 63da24ba5..f9a07fc50 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -141,13 +141,10 @@ def create_test_initial_state_dict( return initial_state_dict -def are_modules_identical( - module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], - module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection], +def are_sharded_ebc_modules_identical( + module1: ShardedEmbeddingBagCollection, + module2: ShardedEmbeddingBagCollection, ) -> None: - # Check if both modules have the same type - assert type(module1) is type(module2) - # Check if both modules have the same parameters params1 = list(module1.named_parameters()) params2 = list(module2.named_parameters()) @@ -170,6 +167,52 @@ def are_modules_identical( assert buffer1[0] == buffer2[0] # Check buffer names assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values + # Hard-coded attributes for EmbeddingBagCollection + attribute_list = [ + "_module_fqn", + "_table_names", + "_pooling_type_to_rs_features", + "_output_dtensor", + "_sharding_types", + "_is_weighted", + "_embedding_names", + "_embedding_dims", + "_feature_splits", + "_features_order", + "_uncombined_embedding_names", + "_uncombined_embedding_dims", + "_has_mean_pooling_callback", + "_kjt_key_indices", + "_has_uninitialized_input_dist", + "_has_features_permute", + "_dim_per_key", # Tensor + "_inverse_indices_permute_indices", # Tensor + "_kjt_inverse_order", # Tensor + "_kt_key_ordering", # Tensor + # Non-primitive types which can be compared + "module_sharding_plan", + "_table_name_to_config", + # Excluding the non-primitive types that cannot be compared + # "sharding_type_to_sharding_infos", + # "_embedding_shardings" + # "_input_dists", + # "_lookups", + # "_output_dists", + # "_optim", + ] + + for attr in attribute_list: + assert hasattr(module1, attr) and hasattr(module2, attr) + + val1 = getattr(module1, attr) + val2 = getattr(module2, attr) + + assert type(val1) is type(val2) + if type(val1) is torch.Tensor: + torch.testing.assert_close(val1, val2) + else: + assert val1 == val2 + def output_sharding_plan_delta( old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan @@ -274,7 +317,7 @@ def _test_ebc_resharding( device=ctx.device, ) - are_modules_identical(sharded_m1, resharded_m2) + are_sharded_ebc_modules_identical(sharded_m1, resharded_m2) feature_keys = [] for table in tables: