Skip to content

v0.6.0

Compare
Choose a tag to compare
@henrylhtsang henrylhtsang released this 30 Jan 23:40
7c266e6

VBE

TorchRec now natively supports VBE (variable batched embeddings) within the EmbeddingBagCollection module. This allows variable batch size per feature, unlocking sparse input data deduplication, which can greatly speed up embedding lookup and all-to-all time. To enable, simply initialize KeyedJaggedTensor with stride_per_key_per_rank and inverse_indices fields, which specify batch size per feature and inverse indices to reindex the embedding output respectively.

Embedding offloading

Embedding offloading is UVM caching (i.e. storing embedding tables on host memory with cache on HBM memory) plus prefetching and optimal sizing of cache. Embedding offloading would allow running a larger model with fewer GPUs, while maintaining competitive performance. To use, one needs to use the prefetching pipeline (PrefetchTrainPipelineSparseDist) and pass in per table cache load factor and the prefetch_pipeline flag through constraints in the planner.

Trec.shard/shard_modules

These APIs replace embedding submodules with its sharded variant. The shard API applies to an individual embedding module while the shard_modules API replaces all embedding modules and won’t touch other non-embedding submodules.
Embedding sharding follows similar behavior to the prior TorchRec DistributedModuleParallel behavior, except the ShardedModules have been made composable, meaning the modules are backed by TableBatchedEmbeddingSlices which are views into the underlying TBE (including .grad). This means that fused parameters are now returned with named_parameters(), including in DistributedModuleParallel.