You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-Input tokens of `shape=(batch, seq), dtype=torch.long`
134
-
- Outputs embeddings of `shape=(batch, seq, hidden_dim), dtype=torch.bfloat16`
133
+
-Takes input tokens of `shape=(batch, seq)`
134
+
- Outputs embeddings of `shape=(batch, seq, hidden_dim)`
135
135
136
-
Now that we've sharded it it will actually output:
136
+
Now that we've sharded the embedding weight tensor, the layer will actually output:
137
137
- Sharded output embeddings of `shape=(batch, seq, hidden_dim / mesh["tp"].size())`.
138
138
139
-
We have a problem though: Our `self_attn` module will first receive the output of this embedding, and we used ColwiseParallel on that. ColwiseParallel actually expects input to be **replicated** not shared.
139
+
We have a problem though: Our *colwise* pieces of the `self_attn` module will receive the output of this module. ColwiseParallel actually expects input to be **replicated** not sharded.
140
140
141
-
So we need to transform the tensor to `shape=(batch, seq, hidden_dim)`. Luckily we can just specify this additional transformation with the `output_layouts` argument:
141
+
So we need to do an allgather on the tensor to replicate it across the group (i.e. it will be back to `shape=(batch, seq, hidden_dim)`). Luckily we can just specify this additional transformation with the `output_layouts` argument:
142
142
143
143
```python
144
144
tp.parallelize_module(
@@ -162,7 +162,7 @@ tp.parallelize_module(
162
162
)
163
163
```
164
164
165
-
We have to include `Replicate()` here because by default colwise shards on the last dimension, but we need the output of the network to be replicated across our TP dimension.
165
+
We have to include `Replicate()` here because our loss expects replicated tensors, but colwise by default shards on the last dimension.
166
166
167
167
## Parallelizing Norm Layers with SequenceParallel
0 commit comments