Skip to content

Enable Optimizer Storing & Fix incomplete updates to Sharded EBC attributes in resharding #2911

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.dynamic_sharding import (
shards_all_to_all,
update_module_sharding_plan,
update_state_dict_post_resharding,
)
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
Expand Down Expand Up @@ -1232,11 +1233,19 @@ def _update_output_dist(self) -> None:
# TODO: Optimize to only go through embedding shardings with new ranks
self._output_dists: List[nn.Module] = []
self._embedding_names: List[str] = []
self._embedding_dims: List[int] = []
self._uncombined_embedding_names: List[str] = []
self._uncombined_embedding_dims: List[int] = []
for sharding in self._embedding_shardings:
# TODO: if sharding type of table completely changes, need to regenerate everything
self._embedding_names.extend(sharding.embedding_names())
self._output_dists.append(sharding.create_output_dist(device=self._device))
embedding_shard_metadata.extend(sharding.embedding_shard_metadata())
self._embedding_dims.extend(sharding.embedding_dims())
self._uncombined_embedding_names.extend(
sharding.uncombined_embedding_names()
)
self._uncombined_embedding_dims.extend(sharding.uncombined_embedding_dims())

embedding_shard_offsets: List[int] = [
meta.shard_offsets[1] if meta is not None else 0
Expand Down Expand Up @@ -1585,6 +1594,26 @@ def update_shards(
self._initialize_torch_state(skip_registering=True)

self.load_state_dict(current_state)

# update optimizer
optims = []
for lookup in self._lookups:
for _, tbe_module in lookup.named_modules():
if isinstance(tbe_module, FusedOptimizerModule):
# modify param keys to match EmbeddingBagCollection
params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {}
for (
param_key,
weight,
) in tbe_module.fused_optimizer.params.items():
# pyre-fixme[16]: `Mapping` has no attribute `__setitem__`
params["embedding_bags." + param_key] = weight
tbe_module.fused_optimizer.params = params
optims.append(("", tbe_module.fused_optimizer))

self._optim: CombinedOptimizer = CombinedOptimizer(optims)

update_module_sharding_plan(self, changed_sharding_params)
return

@property
Expand Down
14 changes: 14 additions & 0 deletions torchrec/distributed/sharding/dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,3 +221,17 @@ def update_state_dict_post_resharding(
sharded_t._local_shards = []

return state_dict


def update_module_sharding_plan(
module: ShardedModule[Any, Any, Any, Any], # pyre-ignore
changed_sharding_params: Dict[str, ParameterSharding],
) -> None:
if not hasattr(module, "module_sharding_plan"):
return

# pyre-ignore
current_plan: Dict[str, ParameterSharding] = module.module_sharding_plan
for table_name, param_sharding in changed_sharding_params.items():
current_plan[table_name] = param_sharding
return
57 changes: 50 additions & 7 deletions torchrec/distributed/tests/test_dynamic_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,13 +141,10 @@ def create_test_initial_state_dict(
return initial_state_dict


def are_modules_identical(
module1: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection],
module2: Union[EmbeddingBagCollection, ShardedEmbeddingBagCollection],
def are_sharded_ebc_modules_identical(
module1: ShardedEmbeddingBagCollection,
module2: ShardedEmbeddingBagCollection,
) -> None:
# Check if both modules have the same type
assert type(module1) is type(module2)

# Check if both modules have the same parameters
params1 = list(module1.named_parameters())
params2 = list(module2.named_parameters())
Expand All @@ -170,6 +167,52 @@ def are_modules_identical(
assert buffer1[0] == buffer2[0] # Check buffer names
assert torch.allclose(buffer1[1], buffer2[1]) # Check buffer values

# Hard-coded attributes for EmbeddingBagCollection
attribute_list = [
"_module_fqn",
"_table_names",
"_pooling_type_to_rs_features",
"_output_dtensor",
"_sharding_types",
"_is_weighted",
"_embedding_names",
"_embedding_dims",
"_feature_splits",
"_features_order",
"_uncombined_embedding_names",
"_uncombined_embedding_dims",
"_has_mean_pooling_callback",
"_kjt_key_indices",
"_has_uninitialized_input_dist",
"_has_features_permute",
"_dim_per_key", # Tensor
"_inverse_indices_permute_indices", # Tensor
"_kjt_inverse_order", # Tensor
"_kt_key_ordering", # Tensor
# Non-primitive types which can be compared
"module_sharding_plan",
"_table_name_to_config",
# Excluding the non-primitive types that cannot be compared
# "sharding_type_to_sharding_infos",
# "_embedding_shardings"
# "_input_dists",
# "_lookups",
# "_output_dists",
# "_optim",
]

for attr in attribute_list:
assert hasattr(module1, attr) and hasattr(module2, attr)

val1 = getattr(module1, attr)
val2 = getattr(module2, attr)

assert type(val1) is type(val2)
if type(val1) is torch.Tensor:
torch.testing.assert_close(val1, val2)
else:
assert val1 == val2


def output_sharding_plan_delta(
old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan
Expand Down Expand Up @@ -274,7 +317,7 @@ def _test_ebc_resharding(
device=ctx.device,
)

are_modules_identical(sharded_m1, resharded_m2)
are_sharded_ebc_modules_identical(sharded_m1, resharded_m2)

feature_keys = []
for table in tables:
Expand Down
Loading