Skip to content

Commit 1b0bddc

Browse files
aporialiaofacebook-github-bot
authored andcommitted
Add unsharded module property to sharded modules and EBC
Summary: Adding a simple unsharded module reference to sharded modules. This will be used in Dynamic Sharding by DistributedModelParallel to reshard an already-sharded_module. As DMP is created with only one-way relationship in mind, accessing the unsharded module type will help determine which sharder to use in 'resharding'. See comment under `types.py` Differential Revision: D73537260
1 parent 0981db6 commit 1b0bddc

File tree

4 files changed

+48
-2
lines changed

4 files changed

+48
-2
lines changed

torchrec/distributed/embedding_types.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,18 @@
1111
import copy
1212
from dataclasses import dataclass
1313
from enum import Enum, unique
14-
from typing import Any, Dict, Generic, Iterator, List, Optional, Tuple, TypeVar, Union
14+
from typing import (
15+
Any,
16+
Dict,
17+
Generic,
18+
Iterator,
19+
List,
20+
Optional,
21+
Tuple,
22+
Type,
23+
TypeVar,
24+
Union,
25+
)
1526

1627
import torch
1728
from fbgemm_gpu.split_table_batched_embeddings_ops_training import EmbeddingLocation
@@ -399,6 +410,16 @@ def train(self, mode: bool = True): # pyre-ignore[3]
399410

400411
return self
401412

413+
@property
414+
def unsharded_module_type(self) -> Type[nn.Module]:
415+
"""
416+
As this is the generic ShardedEmbeddingModule class, simply
417+
return the generic nn.Module type. In the inherited classes of
418+
ShardedEmbeddingModule, the specific unsharded module type will
419+
be returned.
420+
"""
421+
return nn.Module
422+
402423

403424
M = TypeVar("M", bound=nn.Module)
404425

torchrec/distributed/embeddingbag.py

+8
Original file line numberDiff line numberDiff line change
@@ -1627,6 +1627,10 @@ def create_context(self) -> EmbeddingBagCollectionContext:
16271627
def extend_shard_name(shard_name: str) -> str:
16281628
return f"embedding_bags.{shard_name}.weight"
16291629

1630+
@property
1631+
def unsharded_module_type(self) -> Type[EmbeddingBagCollection]:
1632+
return EmbeddingBagCollection
1633+
16301634

16311635
class EmbeddingBagCollectionSharder(BaseEmbeddingSharder[EmbeddingBagCollection]):
16321636
"""
@@ -1916,6 +1920,10 @@ def fused_optimizer(self) -> KeyedOptimizer:
19161920
def create_context(self) -> NullShardedModuleContext:
19171921
return NullShardedModuleContext()
19181922

1923+
@property
1924+
def unsharded_module_type(self) -> Type[nn.EmbeddingBag]:
1925+
return nn.EmbeddingBag
1926+
19191927

19201928
class EmbeddingBagSharder(BaseEmbeddingSharder[nn.EmbeddingBag]):
19211929
"""

torchrec/distributed/object_pool.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
# pyre-strict
99

1010
from abc import abstractmethod
11-
from typing import Generic
11+
from typing import Generic, Type
1212

1313
import torch
1414
from torch._prims_common import is_integer_dtype
@@ -144,3 +144,7 @@ def compute(self, ctx: ShrdCtx, dist_input: torch.Tensor) -> DistOut:
144144
# `None`.
145145
def output_dist(self, ctx: ShrdCtx, output: DistOut) -> LazyAwaitable[Out]:
146146
pass
147+
148+
@property
149+
def unsharded_module_type(self) -> Type[ObjectPool[Out]]:
150+
return ObjectPool[Out]

torchrec/distributed/types.py

+13
Original file line numberDiff line numberDiff line change
@@ -1034,6 +1034,19 @@ def sharded_parameter_names(self, prefix: str = "") -> Iterator[str]:
10341034
for key, _ in self.named_parameters(prefix):
10351035
yield key
10361036

1037+
@property
1038+
@abc.abstractmethod
1039+
def unsharded_module_type(self) -> Type[nn.Module]:
1040+
"""
1041+
This property is added as part of dynamic sharding implementation.
1042+
1043+
When resharding an already-sharded module wrapped in DMP, the unsharded
1044+
module type is needed to identify the proper sharder to reshard. This is
1045+
due to DistributedModelParellel (DMP) references module Sharders based
1046+
on the unsharded module type.
1047+
"""
1048+
...
1049+
10371050

10381051
def get_tensor_size_bytes(t: torch.Tensor) -> int:
10391052
b: int = t.numel() * t.element_size()

0 commit comments

Comments
 (0)