From 2660764546165fbc2348277149a24f032bdc89fb Mon Sep 17 00:00:00 2001
From: Emma Lin <line@fb.com>
Date: Wed, 9 Apr 2025 23:11:41 -0700
Subject: [PATCH 1/2] mock bucket, id/weight tensors, add total num buckets

Differential Revision: D72350620
---
 .../distributed/batched_embedding_kernel.py   | 312 +++++++++++++++++-
 torchrec/distributed/embedding.py             |  26 +-
 torchrec/distributed/embedding_kernel.py      |  13 +-
 torchrec/distributed/embedding_sharding.py    |  15 +
 torchrec/distributed/embedding_types.py       |  15 +
 .../distributed/quant_embedding_kernel.py     |   3 +
 .../sharding/rw_sequence_sharding.py          |  12 +
 torchrec/distributed/sharding/rw_sharding.py  |  26 ++
 torchrec/modules/embedding_configs.py         |   2 +
 9 files changed, 404 insertions(+), 20 deletions(-)

diff --git a/torchrec/distributed/batched_embedding_kernel.py b/torchrec/distributed/batched_embedding_kernel.py
index 1aff0ecf6..15fba983b 100644
--- a/torchrec/distributed/batched_embedding_kernel.py
+++ b/torchrec/distributed/batched_embedding_kernel.py
@@ -13,6 +13,7 @@
 import itertools
 import logging
 import tempfile
+from collections import defaultdict, OrderedDict
 from dataclasses import dataclass
 from typing import (
     Any,
@@ -65,7 +66,7 @@
     ShardMetadata,
     TensorProperties,
 )
-from torchrec.distributed.utils import append_prefix
+from torchrec.distributed.utils import append_prefix, none_throws
 from torchrec.modules.embedding_configs import (
     data_type_to_sparse_type,
     pooling_type_to_pooling_mode,
@@ -88,6 +89,10 @@ def _populate_ssd_tbe_params(config: GroupedEmbeddingConfig) -> Dict[str, Any]:
 
     ssd_tbe_params: Dict[str, Any] = {}
 
+    for table in config.embedding_tables:
+        if table.zero_collision:
+            ssd_tbe_params["enable_zero_collision_tbe"] = True
+            logger.info("Enabling zero collision TBE")
     # drop the non-ssd tbe fused params
     ssd_tbe_signature = inspect.signature(
         SSDTableBatchedEmbeddingBags.__init__
@@ -904,7 +909,7 @@ def __init__(
         embedding_location = compute_kernel_to_embedding_location(compute_kernel)
 
         self._emb_module: SSDTableBatchedEmbeddingBags = SSDTableBatchedEmbeddingBags(
-            embedding_specs=list(zip(self._local_rows, self._local_cols)),
+            embedding_specs=list(zip(self._num_embeddings, self._local_cols)),
             feature_table_map=self._feature_table_map,
             ssd_cache_location=embedding_location,
             pooling_mode=PoolingMode.NONE,
@@ -926,6 +931,18 @@ def __init__(
         )
         self.init_parameters()
 
+        self._enable_zero_collision_tbe: bool = ssd_tbe_params[
+            "enable_zero_collision_tbe"
+        ]
+        self._tracked_ids: Optional[KeyedJaggedTensor] = None
+        self._sharded_local_buckets: Optional[List[Tuple[int, int, int]]] = None
+        if self._enable_zero_collision_tbe:
+            self._sharded_local_buckets = self.get_sharded_local_buckets()
+        # temporary tensors auto generated for checkpointing
+        # once training is resumed and forward is called, these tensors will be reset to None
+        # since the value can be changed by backward pass, we don't want to duplicate memory
+        self._split_weights: Optional[List[Dict[str, ShardedTensor]]] = None
+
     def init_parameters(self) -> None:
         """
         An advantage of SSD TBE is that we don't need to init weights. Hence skipping.
@@ -963,19 +980,96 @@ def state_dict(
         # in the case no_snapshot=False, a flush is required. we rely on the flush operation in
         # ShardedEmbeddingBagCollection._pre_state_dict_hook()
 
-        emb_tables = self.split_embedding_weights(no_snapshot=no_snapshot)
+        emb_tables = self.split_embedding_weights_with_id_buckets(
+            no_snapshot=no_snapshot
+        )
+        weights = [emb_table[0] for emb_table in emb_tables]
         emb_table_config_copy = copy.deepcopy(self._config.embedding_tables)
         for emb_table in emb_table_config_copy:
             emb_table.local_metadata.placement._device = torch.device("cpu")
         ret = get_state_dict(
             emb_table_config_copy,
-            emb_tables,
+            weights,
             self._pg,
             destination,
             prefix,
         )
         return ret
 
+    def get_sharded_split_tensors(
+        self,
+        prefix: str,
+        emb_table_inx: int,
+        weight_tensor: torch.Tensor,
+        bucket_tensor: torch.Tensor,
+        id_tensor: torch.Tensor,
+    ) -> Dict[str, Any]:
+        if not self._enable_zero_collision_tbe:
+            return {}
+
+        table_config = copy.deepcopy(self._config.embedding_tables[emb_table_inx])
+        table_config.local_metadata.placement._device = torch.device("cpu")
+        ret: Dict[str, Any] = {}
+
+        weight_key = append_prefix(prefix, f"{table_config.name}.weight")
+        weight_local_metadata = copy.deepcopy(table_config.local_metadata)
+        weight_local_metadata.shard_sizes = list(weight_tensor.size())
+        weight_local_shards = [Shard(weight_tensor, weight_local_metadata)]
+        weight_global_size = (
+            self._num_embeddings[emb_table_inx],
+            self._local_cols[emb_table_inx],
+        )
+        if self._pg is not None:
+            ret[weight_key] = ShardedTensor._init_from_local_shards_and_reset_offsets(
+                weight_local_shards,
+                weight_global_size,
+                process_group=self._pg,
+            )
+
+        id_key = append_prefix(prefix, f"{table_config.name}.weight_id")
+        id_local_metadata = copy.deepcopy(table_config.local_metadata)
+        id_local_metadata.shard_offsets[1] = 0
+        id_local_metadata.shard_sizes = list(id_tensor.size())  # one column tensor
+        id_local_shards = [Shard(id_tensor, id_local_metadata)]
+        id_global_size = (self._num_embeddings[emb_table_inx], 1)
+        if self._pg is not None:
+            ret[id_key] = ShardedTensor._init_from_local_shards_and_reset_offsets(
+                id_local_shards,
+                id_global_size,
+                process_group=self._pg,
+            )
+
+        bucket_key = append_prefix(prefix, f"{table_config.name}.bucket")
+        bucket_global_metadata = copy.deepcopy(table_config.global_metadata)
+        bucket_global_metadata.tensor_properties.dtype = torch.int64
+        bucket_global_metadata.tensor_properties.requires_grad = False
+        bucket_global_metadata.size = torch.Size((table_config.total_num_buckets, 1))
+        # prototype: assuming even sharding here
+        bucket_length = self._sharded_local_buckets[emb_table_inx][1]
+        for j, shard in enumerate(bucket_global_metadata.shards_metadata):
+            shard.shard_offsets[0] = j * bucket_length
+            shard.shard_offsets[1] = 0
+            shard.shard_sizes = list(bucket_tensor.size())
+        bucket_local_metadata = copy.deepcopy(table_config.local_metadata)
+        bucket_local_metadata.shard_offsets[0] = self._sharded_local_buckets[
+            emb_table_inx
+        ][0]
+        bucket_local_metadata.shard_offsets[1] = 0
+        bucket_local_metadata.shard_sizes[0] = bucket_length
+        bucket_local_metadata.shard_sizes[1] = 1
+        local_shards = [Shard(bucket_tensor, bucket_local_metadata)]
+        if self._pg is not None:
+            ret[bucket_key] = ShardedTensor._init_from_local_shards_and_global_metadata(
+                local_shards=local_shards,
+                sharded_tensor_metadata=bucket_global_metadata,
+                process_group=self._pg,
+            )
+
+        logger.info(
+            f"get_sharded_split_id_bucket_tensors generated two additiona tensors: {ret.keys()}"
+        )
+        return ret
+
     def named_parameters(
         self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True
     ) -> Iterator[Tuple[str, nn.Parameter]]:
@@ -1002,22 +1096,62 @@ def named_split_embedding_weights(
         ), "remove_duplicate=False not supported in BaseBatchedEmbedding.named_split_embedding_weights"
         for config, tensor in zip(
             self._config.embedding_tables,
-            self.split_embedding_weights(),
+            self.split_embedding_weights_with_id_buckets(),
         ):
+            weight_tensor = tensor[0]
             key = append_prefix(prefix, f"{config.name}.weight")
-            yield key, tensor
+            yield key, weight_tensor
 
     def get_named_split_embedding_weights_snapshot(
         self, prefix: str = ""
-    ) -> Iterator[Tuple[str, PartiallyMaterializedTensor]]:
+    ) -> Iterator[Tuple[str, Union[PartiallyMaterializedTensor, ShardedTensor]]]:
         """
         Return an iterator over embedding tables, yielding both the table name as well as the embedding
         table itself. The embedding table is in the form of PartiallyMaterializedTensor with a valid
         RocksDB snapshot to support windowed access.
         """
-        for config, tensor in zip(
-            self._config.embedding_tables,
-            self.split_embedding_weights(no_snapshot=False),
+        if self._enable_zero_collision_tbe:
+            if self._split_weights is not None:
+                split_weights = self._split_weights
+                for splits in split_weights:
+                    for key, tensor in splits.items():
+                        yield key, tensor
+                return
+            else:
+                self._split_weights = []
+                split_weights = []
+
+                for i, (config, tensors) in enumerate(
+                    zip(
+                        self._config.embedding_tables,
+                        self.split_embedding_weights_with_id_buckets(no_snapshot=False),
+                    )
+                ):
+                    weight_tensor = tensors[0]
+                    bucket_tensor = tensors[1]
+                    id_tensor = tensors[2]
+
+                    if not self._enable_zero_collision_tbe:
+                        key = append_prefix(prefix, f"{config.name}")
+                        yield key, weight_tensor
+                    else:
+                        if id_tensor is None:
+                            continue
+
+                        sharded_tensors = self.get_sharded_split_tensors(
+                            prefix, i, weight_tensor, bucket_tensor, id_tensor
+                        )
+                        split_weights.append(sharded_tensors)
+                        for key, tensor in sharded_tensors.items():
+                            yield key, tensor
+                self._split_weights = split_weights
+                return
+
+        for config, tensor in enumerate(
+            zip(
+                self._config.embedding_tables,
+                self.split_embedding_weights(no_snapshot=False),
+            )
         ):
             key = append_prefix(prefix, f"{config.name}")
             yield key, tensor
@@ -1036,12 +1170,168 @@ def purge(self) -> None:
         self.emb_module.lxu_cache_weights.zero_()
         self.emb_module.lxu_cache_state.fill_(-1)
 
-    # pyre-ignore [15]
+    # pyre-ignore[15]
     def split_embedding_weights(
         self, no_snapshot: bool = True
     ) -> List[PartiallyMaterializedTensor]:
         return self.emb_module.split_embedding_weights(no_snapshot)
 
+    # TODO: read result from torchrec sharding plan
+    def get_sharded_local_buckets(self) -> List[Tuple[int, int, int]]:
+        """
+        utils to get bucket offset, bucket length, bucket size based on embedding sharding spec
+        """
+        sharded_local_buckets: List[Tuple[int, int, int]] = []
+        world_size = dist.get_world_size(self._pg)
+        local_rank = dist.get_rank(self._pg)
+
+        for table in self._config.embedding_tables:
+            # temporary before uneven sharding utils is ready
+            assert (
+                table.num_embeddings % world_size == 0
+            ), "total_num_embeddings must be divisible by world_size"
+            total_num_buckets = none_throws(table.total_num_buckets)
+            bucket_offset = total_num_buckets // world_size * local_rank
+            bucket_length = total_num_buckets // world_size
+            bucket_size = table.num_embeddings // total_num_buckets
+            sharded_local_buckets.append((bucket_offset, bucket_length, bucket_size))
+            logger.info(
+                f"bucket_offset: {bucket_offset}, bucket_length: {bucket_length}, bucket_size: {bucket_size} for table {table.name}"
+            )
+        return sharded_local_buckets
+
+    @torch.jit.export
+    def split_embedding_weights_with_id_buckets(
+        self,
+        no_snapshot: bool = True,
+        should_flush: bool = False,
+    ) -> List[Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]]:
+        """
+        copied from SSDTableBatchedEmbeddingBags.split_embedding_weights
+        for debugging purpose, otherwise, need to rebuild light package
+        if change in SSDTableBatchedEmbeddingBags directly
+        return bucket, id, and weight tensors
+        """
+        if not self._enable_zero_collision_tbe:
+            return [
+                (weight, None, None)
+                for weight in self.emb_module.split_embedding_weights(no_snapshot)
+            ]
+
+        # TODO: move the logic to SSD TBE when debugging is done
+        # Force device synchronize for now
+        # torch.cuda.synchronize()
+        # # Create a snapshot
+        # if no_snapshot:
+        #     snapshot_handle = None
+        # else:
+        #     if should_flush:
+        #         # Flush L1 and L2 caches
+        #         self.emb_module.flush()
+        #     snapshot_handle = self.emb_module.ssd_db.create_snapshot()
+        dtype = self.emb_module.weights_precision.as_dtype()
+        splits = []
+        assert (
+            len(self._config.embedding_tables) == 1
+        ), "only support 1 table in prototype"
+        if self._tracked_ids is None:
+            bucket_length = self._sharded_local_buckets[0][1]
+            bucket_tensor = torch.zeros(
+                (bucket_length, 1),
+                dtype=torch.int64,
+                device=torch.device("cpu"),
+            )
+            splits.append(
+                (
+                    torch.empty(
+                        (1, self._local_cols[0]),
+                        device=torch.device("cpu"),
+                        dtype=dtype,
+                    ),
+                    bucket_tensor,
+                    torch.empty((1, 1), device=torch.device("cpu"), dtype=torch.int64),
+                )
+            )
+            return splits
+
+        # TODO: to support multiple tables
+        # 1. split ids per table
+        # 2. unique and sort ids per table
+        # 3. linearize ids per table
+        # 4. query weight with get_cuda
+        # 5. return sorted ids as weight_id tensor, and queried weight as weight ensor
+        # when we get ids from embedding backend directly, need to do:
+        # 1. sort ids
+        # 2. split ids per table based on table fusion
+        # 3. query weight with get_cuda using split ids per table
+        # 4. deduct table offset from split ids
+        id = self._tracked_ids.values().long().cpu()
+        sorted_id = torch.unique(id, sorted=True).view(-1, 1)
+        # test size mismatch: select half of ids
+        sorted_id, _ = torch.chunk(sorted_id, 2, dim=0)
+
+        bucket_offset = self._sharded_local_buckets[0][0]
+        bucket_length = self._sharded_local_buckets[0][1]
+        bucket_size = self._sharded_local_buckets[0][2]
+
+        def get_bucket_tensor(
+            ids, bucket_offset, bucket_length, bucket_size
+        ) -> torch.Tensor:
+            # Step 1: Compute bucket index for each id
+            ids = ids.flatten()
+            bucket_ids = ids // bucket_size
+
+            # Step 2: Verify bucket range
+            min_bucket = bucket_offset
+            max_bucket = bucket_offset + bucket_length
+            assert torch.all(
+                (bucket_ids >= min_bucket) & (bucket_ids < max_bucket)
+            ), f"Some IDs fall outside the expected bucket range [{min_bucket}, {max_bucket})"
+
+            # Step 3: Normalize bucket indices to 0-based range
+            norm_bucket_ids = bucket_ids - bucket_offset
+
+            # Step 4: Count occurrences
+            counts = torch.bincount(norm_bucket_ids, minlength=bucket_length)
+
+            # Step 5: Return as 2D tensor [bucket_length, 1]
+            return counts.view(-1, 1)
+
+        for i, _ in enumerate(self._config.embedding_tables):
+            bucket_tensor = get_bucket_tensor(
+                sorted_id, bucket_offset, bucket_length, bucket_size
+            )
+
+            # get weight tensor from tracked global id
+            weight_tensor = torch.empty(
+                (sorted_id.size(0), self.emb_module.max_D),
+                dtype=dtype,
+            )
+            # this row throws, comments out to unblock downstream works
+            self.emb_module.ssd_db.get_cuda(
+                sorted_id.to(torch.int64),
+                weight_tensor,
+                torch.as_tensor(sorted_id.size(0)),
+            )
+            splits.append((weight_tensor, bucket_tensor, sorted_id))
+        return splits
+
+    def forward(self, features: KeyedJaggedTensor) -> torch.Tensor:
+        if self._enable_zero_collision_tbe:
+            # track the last access ids for testing purpose
+            self._tracked_ids = KeyedJaggedTensor.from_lengths_sync(
+                keys=features.keys().copy(),
+                values=features.values().clone(),
+                lengths=features.lengths().clone(),
+            )
+            # reset split weights during training
+            self._split_weights = None
+
+        return self.emb_module(
+            indices=features.values().long(),
+            offsets=features.offsets().long(),
+        )
+
 
 class BatchedFusedEmbedding(BaseBatchedEmbedding[torch.Tensor], FusedOptimizerModule):
     def __init__(
diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py
index 4f79bd6b9..d989456c9 100644
--- a/torchrec/distributed/embedding.py
+++ b/torchrec/distributed/embedding.py
@@ -259,6 +259,8 @@ def create_sharding_infos_by_sharding(
                         embedding_names=embedding_names,
                         weight_init_max=config.weight_init_max,
                         weight_init_min=config.weight_init_min,
+                        total_num_buckets=config.total_num_buckets,
+                        zero_collision=config.zero_collision,
                     ),
                     param_sharding=parameter_sharding,
                     param=param,
@@ -353,6 +355,8 @@ def create_sharding_infos_by_sharding_device_group(
                         embedding_names=embedding_names,
                         weight_init_max=config.weight_init_max,
                         weight_init_min=config.weight_init_min,
+                        total_num_buckets=config.total_num_buckets,
+                        zero_collision=config.zero_collision,
                     ),
                     param_sharding=parameter_sharding,
                     param=param,
@@ -767,11 +771,13 @@ def _initialize_torch_state(self) -> None:  # noqa
             )
 
         self._name_to_table_size = {}
+        table_zero_collision = {}
         for table in self._embedding_configs:
             self._name_to_table_size[table.name] = (
                 table.num_embeddings,
                 table.embedding_dim,
             )
+            table_zero_collision[table.name] = table.zero_collision
 
         for sharding_type, lookup in zip(
             self._sharding_type_to_sharding.keys(), self._lookups
@@ -871,8 +877,9 @@ def _initialize_torch_state(self) -> None:  # noqa
                 # created ShardedTensors once in init, use in post_state_dict_hook
                 # note: at this point kvstore backed tensors don't own valid snapshots, so no read
                 # access is allowed on them.
+                # for collision free TBE, the shard sizes should be recalculated during ShardedTensor initilization
                 self._model_parallel_name_to_sharded_tensor[table_name] = (
-                    ShardedTensor._init_from_local_shards(
+                    ShardedTensor._init_from_local_shards_and_reset_offsets(
                         local_shards,
                         self._name_to_table_size[table_name],
                         process_group=(
@@ -925,20 +932,29 @@ def post_state_dict_hook(
                 return
 
             sharded_kvtensors_copy = copy.deepcopy(sharded_kvtensors)
+            sharded_id_buckets_state_dict = None
             for lookup, sharding_type in zip(
                 module._lookups, module._sharding_type_to_sharding.keys()
             ):
                 if sharding_type != ShardingType.DATA_PARALLEL.value:
-                    # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
-                    for key, v in lookup.get_named_split_embedding_weights_snapshot():
-                        assert key in sharded_kvtensors_copy
-                        sharded_kvtensors_copy[key].local_shards()[0].tensor = v
+                    for (
+                        key,
+                        v,
+                    ) in lookup.get_named_split_embedding_weights_snapshot():  # pyre-fixme[29]: `Union[Module, Tensor]` is not a function.
+                        if key in sharded_kvtensors_copy:
+                            sharded_kvtensors_copy[key].local_shards()[0].tensor = v
+                        else:
+                            d_k = f"{prefix}embeddings.{key}"
+                            destination[d_k] = v
+                            logger.info(f"add sharded tensor key {d_k} to state dict")
             for (
                 table_name,
                 sharded_kvtensor,
             ) in sharded_kvtensors_copy.items():
                 destination_key = f"{prefix}embeddings.{table_name}.weight"
                 destination[destination_key] = sharded_kvtensor
+            if sharded_id_buckets_state_dict:
+                destination.update(sharded_id_buckets_state_dict)
 
         self.register_state_dict_pre_hook(self._pre_state_dict_hook)
         self._register_state_dict_hook(post_state_dict_hook)
diff --git a/torchrec/distributed/embedding_kernel.py b/torchrec/distributed/embedding_kernel.py
index f3bb60619..7ba953d1e 100644
--- a/torchrec/distributed/embedding_kernel.py
+++ b/torchrec/distributed/embedding_kernel.py
@@ -8,6 +8,7 @@
 # pyre-strict
 
 import abc
+import copy
 import logging
 from collections import defaultdict, OrderedDict
 from typing import Any, Dict, List, Optional, Tuple, Union
@@ -103,9 +104,10 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
             qbias = param[2]
             param = param[0]
 
-        assert embedding_table.local_rows == param.size(  # pyre-ignore[16]
-            0
-        ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}"  # pyre-ignore[16]
+        if not embedding_table.zero_collision:
+            assert embedding_table.local_rows == param.size(  # pyre-ignore[16]
+                0
+            ), f"{embedding_table.local_rows=}, {param.size(0)=}, {param.shape=}"  # pyre-ignore[16]
 
         if qscale is not None:
             assert embedding_table.local_cols == param.size(1)  # pyre-ignore[16]
@@ -128,14 +130,17 @@ def get_key_from_embedding_table(embedding_table: ShardedEmbeddingTable) -> str:
                 param.requires_grad  # pyre-ignore[16]
             )
             key_to_global_metadata[key] = embedding_table.global_metadata
+            local_metadata = copy.deepcopy(embedding_table.local_metadata)
+            local_metadata.shard_sizes = list(param.size())
 
             key_to_local_shards[key].append(
                 # pyre-fixme[6]: For 1st argument expected `Tensor` but got
                 #  `Union[Module, Tensor]`.
                 # pyre-fixme[6]: For 2nd argument expected `ShardMetadata` but got
                 #  `Optional[ShardMetadata]`.
-                Shard(param, embedding_table.local_metadata)
+                Shard(param, local_metadata)
             )
+
         else:
             destination[key] = param
             if qscale is not None:
diff --git a/torchrec/distributed/embedding_sharding.py b/torchrec/distributed/embedding_sharding.py
index 98fa2d15f..59ebf472d 100644
--- a/torchrec/distributed/embedding_sharding.py
+++ b/torchrec/distributed/embedding_sharding.py
@@ -483,6 +483,16 @@ def _prefetch_and_cached(
     )
 
 
+def _is_kv_tbe(
+    table: ShardedEmbeddingTable,
+) -> bool:
+    """
+    Return true if this embedding enabled bucketized sharding for kv style TBE to support ZCH v.Next.
+    https://docs.google.com/document/d/13atWlDEkrkRulgC_gaoLv8ZsogQefdvsTdwlyam7ed0/edit?tab=t.0#heading=h.lxb1lainm4tc
+    """
+    return table.zero_collision
+
+
 def _all_tables_are_quant_kernel(
     tables: List[ShardedEmbeddingTable],
 ) -> bool:
@@ -558,7 +568,9 @@ def _group_tables_per_rank(
                     table.data_type,
                 ),
                 _prefetch_and_cached(table),
+                _is_kv_tbe(table),
             )
+            print(f"line debug: grouping_key: {grouping_key}")
             # micromanage the order of we traverse the groups to ensure backwards compatibility
             if grouping_key not in groups:
                 grouping_keys.append(grouping_key)
@@ -573,6 +585,7 @@ def _group_tables_per_rank(
                 compute_kernel_type,
                 _,
                 _,
+                is_kv_tbe,
             ) = grouping_key
             grouped_tables = groups[grouping_key]
             # remove non-native fused params
@@ -581,6 +594,8 @@ def _group_tables_per_rank(
                 for k, v in fused_params_tuple
                 if k not in ["_batch_key", USE_ONE_TBE_PER_TABLE]
             }
+            if is_kv_tbe:
+                per_tbe_fused_params["enable_zero_collision_tbe"] = True
             cache_load_factor = _get_weighted_avg_cache_load_factor(grouped_tables)
             if cache_load_factor is not None:
                 per_tbe_fused_params[CACHE_LOAD_FACTOR_STR] = cache_load_factor
diff --git a/torchrec/distributed/embedding_types.py b/torchrec/distributed/embedding_types.py
index 1e155b8ad..903615ca9 100644
--- a/torchrec/distributed/embedding_types.py
+++ b/torchrec/distributed/embedding_types.py
@@ -231,6 +231,21 @@ def feature_hash_sizes(self) -> List[int]:
             feature_hash_sizes.extend(table.num_features() * [table.num_embeddings])
         return feature_hash_sizes
 
+    def feature_total_num_buckets(self) -> Optional[List[int]]:
+        feature_total_num_buckets = []
+        for table in self.embedding_tables:
+            if table.total_num_buckets:
+                feature_total_num_buckets.extend(
+                    table.num_features() * [table.total_num_buckets]
+                )
+        return feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None
+
+    def _is_zero_collision(self) -> bool:
+        for table in self.embedding_tables:
+            if table.zero_collision:
+                return True
+        return False
+
     def num_features(self) -> int:
         num_features = 0
         for table in self.embedding_tables:
diff --git a/torchrec/distributed/quant_embedding_kernel.py b/torchrec/distributed/quant_embedding_kernel.py
index cc324d52a..8ff887f98 100644
--- a/torchrec/distributed/quant_embedding_kernel.py
+++ b/torchrec/distributed/quant_embedding_kernel.py
@@ -495,6 +495,9 @@ def __init__(
         )
         if device is not None:
             self._emb_module.initialize_weights()
+        self._enable_zero_collision_tbe: bool = any(
+            table.zero_collision for table in config.embedding_tables
+        )
 
     @property
     def emb_module(
diff --git a/torchrec/distributed/sharding/rw_sequence_sharding.py b/torchrec/distributed/sharding/rw_sequence_sharding.py
index 4029d9aa6..8a550328a 100644
--- a/torchrec/distributed/sharding/rw_sequence_sharding.py
+++ b/torchrec/distributed/sharding/rw_sequence_sharding.py
@@ -126,6 +126,11 @@ def create_input_dist(
     ) -> BaseSparseFeaturesDist[KeyedJaggedTensor]:
         num_features = self._get_num_features()
         feature_hash_sizes = self._get_feature_hash_sizes()
+        is_zero_collision = any(
+            emb_config._is_zero_collision()
+            for emb_config in self._grouped_embedding_configs
+        )
+        feature_total_num_buckets = self._get_feature_total_num_buckets()
         return RwSparseFeaturesDist(
             # pyre-fixme[6]: For 1st param expected `ProcessGroup` but got
             #  `Optional[ProcessGroup]`.
@@ -136,6 +141,8 @@ def create_input_dist(
             is_sequence=True,
             has_feature_processor=self._has_feature_processor,
             need_pos=False,
+            feature_total_num_buckets=feature_total_num_buckets,
+            keep_original_indices=is_zero_collision,
         )
 
     def create_lookup(
@@ -265,6 +272,10 @@ def create_input_dist(
         (emb_sharding, is_even_sharding) = get_embedding_shard_metadata(
             self._grouped_embedding_configs_per_rank
         )
+        is_zero_collision = any(
+            emb_config._is_zero_collision()
+            for emb_config in self._grouped_embedding_configs
+        )
 
         return InferRwSparseFeaturesDist(
             world_size=self._world_size,
@@ -275,6 +286,7 @@ def create_input_dist(
             has_feature_processor=self._has_feature_processor,
             need_pos=False,
             embedding_shard_metadata=emb_sharding if not is_even_sharding else None,
+            keep_original_indices=is_zero_collision,
         )
 
     def create_lookup(
diff --git a/torchrec/distributed/sharding/rw_sharding.py b/torchrec/distributed/sharding/rw_sharding.py
index b62609da1..cc01b9ba1 100644
--- a/torchrec/distributed/sharding/rw_sharding.py
+++ b/torchrec/distributed/sharding/rw_sharding.py
@@ -56,6 +56,7 @@
     ShardingType,
     ShardMetadata,
 )
+from torchrec.distributed.utils import none_throws
 from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
 from torchrec.streamable import Multistreamable
 
@@ -137,6 +138,9 @@ def __init__(
             []
         )
         self._grouped_embedding_configs_per_rank = group_tables(sharded_tables_per_rank)
+        logger.info(
+            f"self._grouped_embedding_configs_per_rank: {self._grouped_embedding_configs_per_rank}"
+        )
         self._grouped_embedding_configs: List[GroupedEmbeddingConfig] = (
             self._grouped_embedding_configs_per_rank[self._rank]
         )
@@ -146,6 +150,9 @@ def __init__(
             if group_config.has_feature_processor:
                 self._has_feature_processor = True
 
+        for group_config in self._grouped_embedding_configs:
+            group_config.feature_names
+
     def _shard(
         self,
         sharding_infos: List[EmbeddingShardingInfo],
@@ -217,6 +224,8 @@ def _shard(
                         weight_init_min=info.embedding_config.weight_init_min,
                         fused_params=info.fused_params,
                         num_embeddings_post_pruning=info.embedding_config.num_embeddings_post_pruning,
+                        total_num_buckets=info.embedding_config.total_num_buckets,
+                        zero_collision=info.embedding_config.zero_collision,
                     )
                 )
         return tables_per_rank
@@ -272,6 +281,23 @@ def _get_feature_hash_sizes(self) -> List[int]:
             feature_hash_sizes.extend(group_config.feature_hash_sizes())
         return feature_hash_sizes
 
+    def _get_feature_total_num_buckets(self) -> Optional[List[int]]:
+        feature_total_num_buckets: List[int] = []
+        for group_config in self._grouped_embedding_configs:
+            if group_config.feature_total_num_buckets() is not None:
+                feature_total_num_buckets.extend(
+                    none_throws(group_config.feature_total_num_buckets())
+                )
+        return (
+            feature_total_num_buckets if len(feature_total_num_buckets) > 0 else None
+        )  # If no feature_total_num_buckets is provided, we return None to keep backward compatibility.
+
+    def _is_zero_collision(self) -> bool:
+        for group_config in self._grouped_embedding_configs:
+            if group_config._is_zero_collision():
+                return True
+        return False
+
 
 class RwSparseFeaturesDist(BaseSparseFeaturesDist[KeyedJaggedTensor]):
     """
diff --git a/torchrec/modules/embedding_configs.py b/torchrec/modules/embedding_configs.py
index b665257a8..82c739afd 100644
--- a/torchrec/modules/embedding_configs.py
+++ b/torchrec/modules/embedding_configs.py
@@ -195,6 +195,8 @@ class BaseEmbeddingConfig:
 
     # handle the special case
     input_dim: Optional[int] = None
+    total_num_buckets: Optional[int] = None
+    zero_collision: bool = False
 
     def get_weight_init_max(self) -> float:
         if self.weight_init_max is None:

From 10b4d75dbf7e69d29ec53a5c39500f9958ec9f3c Mon Sep 17 00:00:00 2001
From: Faran Ahmad <faran95@meta.com>
Date: Tue, 15 Apr 2025 06:17:01 -0700
Subject: [PATCH 2/2] Bucket offsets and sizes in torchrec shard metadata for
 bucket wise sharding (#2885)

Summary:
X-link: https://github.com/pytorch/pytorch/pull/151192

Pull Request resolved: https://github.com/pytorch/torchrec/pull/2885

Pull Request resolved: https://github.com/pytorch/torchrec/pull/2884

Bucket offsets and sizes in torchrec shard metadata for bucket wise sharding for ZCH v.Next

Differential Revision: D72921209
---
 torchrec/distributed/sharding_plan.py         |  35 ++-
 .../distributed/tests/test_sharding_plan.py   | 224 +++++++++++++++++-
 torchrec/distributed/types.py                 |  39 ++-
 3 files changed, 292 insertions(+), 6 deletions(-)

diff --git a/torchrec/distributed/sharding_plan.py b/torchrec/distributed/sharding_plan.py
index 27b011300..e0374b534 100644
--- a/torchrec/distributed/sharding_plan.py
+++ b/torchrec/distributed/sharding_plan.py
@@ -361,6 +361,7 @@ def _get_parameter_sharding(
     sharder: ModuleSharder[nn.Module],
     placements: Optional[List[str]] = None,
     compute_kernel: Optional[str] = None,
+    bucket_offset_sizes: Optional[List[Tuple[int, int]]] = None,
 ) -> ParameterSharding:
     return ParameterSharding(
         sharding_spec=(
@@ -371,6 +372,8 @@ def _get_parameter_sharding(
                     ShardMetadata(
                         shard_sizes=size,
                         shard_offsets=offset,
+                        bucket_id_offset=bucket_id_offset,
+                        num_buckets=num_buckets,
                         placement=(
                             placement(
                                 device_type,
@@ -381,9 +384,17 @@ def _get_parameter_sharding(
                             else device_placement
                         ),
                     )
-                    for (size, offset, rank), device_placement in zip(
+                    for (size, offset, rank), device_placement, (
+                        num_buckets,
+                        bucket_id_offset,
+                    ) in zip(
                         size_offset_ranks,
                         placements if placements else [None] * len(size_offset_ranks),
+                        (
+                            bucket_offset_sizes
+                            if bucket_offset_sizes
+                            else [(None, None)] * len(size_offset_ranks)
+                        ),
                     )
                 ]
             )
@@ -512,7 +523,8 @@ def _parameter_sharding_generator(
 
 
 def row_wise(
-    sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None
+    sizes_placement: Optional[Tuple[List[int], Union[str, List[str]]]] = None,
+    num_buckets_per_rank: Optional[List[int]] = None,  # propagate num buckets per rank
 ) -> ParameterShardingGenerator:
     """
     Returns a generator of ParameterShardingPlan for `ShardingType::ROW_WISE` for construct_module_sharding_plan.
@@ -545,6 +557,7 @@ def _parameter_sharding_generator(
         device_type: str,
         sharder: ModuleSharder[nn.Module],
     ) -> ParameterSharding:
+        bucket_offset_sizes = None
         if sizes_placement is None:
             size_and_offsets = _get_parameter_size_offsets(
                 param,
@@ -558,17 +571,34 @@ def _parameter_sharding_generator(
                 size_offset_ranks.append((size, offset, rank))
         else:
             size_offset_ranks = []
+            bucket_offset_sizes = None if num_buckets_per_rank is None else []
             sizes = sizes_placement[0]
+            if num_buckets_per_rank is not None:
+                assert len(sizes) == len(
+                    num_buckets_per_rank
+                ), f"sizes and num_buckets_per_rank must have the same length during row_wise sharding, got {len(sizes)} and {len(num_buckets_per_rank)} respectively"
             (rows, cols) = param.shape
             cur_offset = 0
             prev_offset = 0
+            prev_bucket_offset = 0
+            cur_bucket_offset = 0
             for rank, size in enumerate(sizes):
                 per_rank_row = size
+                per_rank_bucket_size = None
+                if num_buckets_per_rank is not None:
+                    per_rank_bucket_size = num_buckets_per_rank[rank]
+                    cur_bucket_offset += per_rank_bucket_size
                 cur_offset += per_rank_row
                 cur_offset = min(cur_offset, rows)
                 per_rank_row = cur_offset - prev_offset
                 size_offset_ranks.append(([per_rank_row, cols], [prev_offset, 0], rank))
                 prev_offset = cur_offset
+                if num_buckets_per_rank is not None:
+                    # bucket has only one col for now
+                    none_throws(bucket_offset_sizes).append(
+                        (per_rank_bucket_size, prev_bucket_offset)
+                    )
+                    prev_bucket_offset = cur_bucket_offset
 
             if cur_offset < rows:
                 raise ValueError(
@@ -601,6 +631,7 @@ def _parameter_sharding_generator(
             compute_kernel=(
                 EmbeddingComputeKernel.QUANT.value if sizes_placement else None
             ),
+            bucket_offset_sizes=bucket_offset_sizes,
         )
 
     return _parameter_sharding_generator
diff --git a/torchrec/distributed/tests/test_sharding_plan.py b/torchrec/distributed/tests/test_sharding_plan.py
index 5dc18885a..f6c1c568a 100644
--- a/torchrec/distributed/tests/test_sharding_plan.py
+++ b/torchrec/distributed/tests/test_sharding_plan.py
@@ -816,6 +816,159 @@ def test_row_wise_set_heterogenous_device(self, data_type: DataType) -> None:
                 0,
             )
 
+    # pyre-fixme[56]
+    @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
+    @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
+    def test_row_wise_bucket_level_sharding(self, data_type: DataType) -> None:
+
+        embedding_config = [
+            EmbeddingBagConfig(
+                name=f"table_{idx}",
+                feature_names=[f"feature_{idx}"],
+                embedding_dim=64,
+                num_embeddings=4096,
+                data_type=data_type,
+            )
+            for idx in range(2)
+        ]
+        module_sharding_plan = construct_module_sharding_plan(
+            EmbeddingCollection(tables=embedding_config),
+            per_param_sharding={
+                "table_0": row_wise(
+                    sizes_placement=(
+                        [2048, 1024, 1024],
+                        ["cpu", "cuda", "cuda"],
+                    ),
+                    num_buckets_per_rank=[20, 30, 40],
+                ),
+                "table_1": row_wise(
+                    sizes_placement=([2048, 1024, 1024], ["cpu", "cpu", "cpu"])
+                ),
+            },
+            local_size=1,
+            world_size=2,
+            device_type="cuda",
+        )
+
+        # Make sure per_param_sharding setting override the default device_type
+        device_table_0_shard_0 = (
+            # pyre-ignore[16]
+            module_sharding_plan["table_0"]
+            .sharding_spec.shards[0]
+            .placement
+        )
+        self.assertEqual(
+            device_table_0_shard_0.device().type,
+            "cpu",
+        )
+        # cpu always has rank 0
+        self.assertEqual(
+            device_table_0_shard_0.rank(),
+            0,
+        )
+        for i in range(1, 3):
+            device_table_0_shard_i = (
+                module_sharding_plan["table_0"].sharding_spec.shards[i].placement
+            )
+            self.assertEqual(
+                device_table_0_shard_i.device().type,
+                "cuda",
+            )
+            # first rank is assigned to cpu so index = rank - 1
+            self.assertEqual(
+                device_table_0_shard_i.device().index,
+                i - 1,
+            )
+            self.assertEqual(
+                device_table_0_shard_i.rank(),
+                i,
+            )
+        for i in range(3):
+            device_table_1_shard_i = (
+                module_sharding_plan["table_1"].sharding_spec.shards[i].placement
+            )
+            self.assertEqual(
+                device_table_1_shard_i.device().type,
+                "cpu",
+            )
+            # cpu always has rank 0
+            self.assertEqual(
+                device_table_1_shard_i.rank(),
+                0,
+            )
+
+        expected = {
+            "table_0": ParameterSharding(
+                sharding_type="row_wise",
+                compute_kernel="quant",
+                ranks=[
+                    0,
+                    1,
+                    2,
+                ],
+                sharding_spec=EnumerableShardingSpec(
+                    shards=[
+                        ShardMetadata(
+                            shard_offsets=[0, 0],
+                            shard_sizes=[2048, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=0,
+                            num_buckets=20,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[2048, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:1/cuda:0",
+                            bucket_id_offset=20,
+                            num_buckets=30,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[3072, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:2/cuda:1",
+                            bucket_id_offset=50,
+                            num_buckets=40,
+                        ),
+                    ]
+                ),
+            ),
+            "table_1": ParameterSharding(
+                sharding_type="row_wise",
+                compute_kernel="quant",
+                ranks=[
+                    0,
+                    1,
+                    2,
+                ],
+                sharding_spec=EnumerableShardingSpec(
+                    shards=[
+                        ShardMetadata(
+                            shard_offsets=[0, 0],
+                            shard_sizes=[2048, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[2048, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                        ShardMetadata(
+                            shard_offsets=[3072, 0],
+                            shard_sizes=[1024, 64],
+                            placement="rank:0/cpu",
+                            bucket_id_offset=None,
+                            num_buckets=None,
+                        ),
+                    ]
+                ),
+            ),
+        }
+        self.assertDictEqual(expected, module_sharding_plan)
+
     # pyre-fixme[56]
     @given(data_type=st.sampled_from([DataType.FP32, DataType.FP16]))
     @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None)
@@ -929,18 +1082,85 @@ def test_str(self) -> None:
         )
         expected = """module: ebc
 
- param   | sharding type | compute kernel | ranks
+ param   | sharding type | compute kernel | ranks  
 -------- | ------------- | -------------- | ------
 user_id  | table_wise    | dense          | [0]
 movie_id | row_wise      | dense          | [0, 1]
 
- param   | shard offsets | shard sizes |   placement
+ param   | shard offsets | shard sizes |   placement  
 -------- | ------------- | ----------- | -------------
 user_id  | [0, 0]        | [4096, 32]  | rank:0/cuda:0
 movie_id | [0, 0]        | [2048, 32]  | rank:0/cuda:0
 movie_id | [2048, 0]     | [2048, 32]  | rank:0/cuda:1
+"""
+        for i in range(len(expected.splitlines())):
+            self.assertEqual(
+                expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()
+            )
+
+    def test_str_bucket_wise_sharding(self) -> None:
+        plan = ShardingPlan(
+            {
+                "ebc": EmbeddingModuleShardingPlan(
+                    {
+                        "user_id": ParameterSharding(
+                            sharding_type="table_wise",
+                            compute_kernel="dense",
+                            ranks=[0],
+                            sharding_spec=EnumerableShardingSpec(
+                                [
+                                    ShardMetadata(
+                                        shard_offsets=[0, 0],
+                                        shard_sizes=[4096, 32],
+                                        placement="rank:0/cuda:0",
+                                    ),
+                                ]
+                            ),
+                        ),
+                        "movie_id": ParameterSharding(
+                            sharding_type="row_wise",
+                            compute_kernel="dense",
+                            ranks=[0, 1],
+                            sharding_spec=EnumerableShardingSpec(
+                                [
+                                    ShardMetadata(
+                                        shard_offsets=[0, 0],
+                                        shard_sizes=[2048, 32],
+                                        placement="rank:0/cuda:0",
+                                        bucket_id_offset=0,
+                                        num_buckets=20,
+                                    ),
+                                    ShardMetadata(
+                                        shard_offsets=[2048, 0],
+                                        shard_sizes=[2048, 32],
+                                        placement="rank:0/cuda:1",
+                                        bucket_id_offset=20,
+                                        num_buckets=30,
+                                    ),
+                                ]
+                            ),
+                        ),
+                    }
+                )
+            }
+        )
+        expected = """module: ebc
+        
+ param   | sharding type | compute kernel | ranks  
+-------- | ------------- | -------------- | ------
+user_id  | table_wise    | dense          | [0]
+movie_id | row_wise      | dense          | [0, 1]
+
+ param   | shard offsets | shard sizes |   placement   | bucket id offset | num buckets
+-------- | ------------- | ----------- | ------------- | ---------------- | -----------
+user_id  | [0, 0]        | [4096, 32]  | rank:0/cuda:0 | None             | None       
+movie_id | [0, 0]        | [2048, 32]  | rank:0/cuda:0 | 0                | 20       
+movie_id | [2048, 0]     | [2048, 32]  | rank:0/cuda:1 | 20               | 30       
 """
         self.maxDiff = None
+        print("STR PLAN BUCKET WISE")
+        print(str(plan))
+        print("=======")
         for i in range(len(expected.splitlines())):
             self.assertEqual(
                 expected.splitlines()[i].strip(), str(plan).splitlines()[i].strip()
diff --git a/torchrec/distributed/types.py b/torchrec/distributed/types.py
index 1d45bff69..2be9a9c86 100644
--- a/torchrec/distributed/types.py
+++ b/torchrec/distributed/types.py
@@ -737,6 +737,7 @@ def __str__(self) -> str:
         out = ""
         param_table = []
         shard_table = []
+        contains_bucket_wise_shards = False
         for param_name, param_sharding in self.items():
             param_table.append(
                 [
@@ -749,20 +750,54 @@ def __str__(self) -> str:
             if isinstance(param_sharding.sharding_spec, EnumerableShardingSpec):
                 shards = param_sharding.sharding_spec.shards
                 if shards is not None:
+                    param_sharding_contains_bucket_info = any(
+                        shard.bucket_id_offset is not None for shard in shards
+                    )
+                    if param_sharding_contains_bucket_info:
+                        contains_bucket_wise_shards = True
                     for shard in shards:
-                        shard_table.append(
+                        cols = (
                             [
                                 param_name,
                                 shard.shard_offsets,
                                 shard.shard_sizes,
                                 shard.placement,
                             ]
+                            if param_sharding_contains_bucket_info is None
+                            else [
+                                param_name,
+                                shard.shard_offsets,
+                                shard.shard_sizes,
+                                shard.placement,
+                                shard.bucket_id_offset,
+                                shard.num_buckets,
+                            ]
                         )
+                        shard_table.append(cols)
+        if contains_bucket_wise_shards:
+            for i in range(len(shard_table)):
+                if len(shard_table[i]) == 4:
+                    # add None for the tables that don't have bucket info
+                    shard_table[i].append(None)
+                    shard_table[i].append(None)
         out += "\n\n" + _tabulate(
             param_table, ["param", "sharding type", "compute kernel", "ranks"]
         )
+        column_str = (
+            ["param", "shard offsets", "shard sizes", "placement"]
+            if not contains_bucket_wise_shards
+            else [
+                "param",
+                "shard offsets",
+                "shard sizes",
+                "placement",
+                "bucket id offset",
+                "num buckets",
+            ]
+        )
         out += "\n\n" + _tabulate(
-            shard_table, ["param", "shard offsets", "shard sizes", "placement"]
+            shard_table,
+            column_str,
         )
         return out