Skip to content

Commit 53697ea

Browse files
authored
Update ch 6
1 parent a6abf9a commit 53697ea

File tree

1 file changed

+7
-7
lines changed

1 file changed

+7
-7
lines changed

06-tensor-parallel/README.md

+7-7
Original file line numberDiff line numberDiff line change
@@ -127,18 +127,18 @@ The embeddings weight get's sharded along dimension 1. Meaning each GPU holds a
127127

128128
| Embedding Weight Shape | Sharded Shape |
129129
| --- | --- |
130-
| (vocab_size, hidden_dim) | (vocab_size, hidden_dim / mesh["tp"].size()) |
130+
| `(vocab_size, hidden_dim)` | `(vocab_size, hidden_dim / mesh["tp"].size())` |
131131

132132
In a normal embedding layer it:
133-
- Input tokens of `shape=(batch, seq), dtype=torch.long`
134-
- Outputs embeddings of `shape=(batch, seq, hidden_dim), dtype=torch.bfloat16`
133+
- Takes input tokens of `shape=(batch, seq)`
134+
- Outputs embeddings of `shape=(batch, seq, hidden_dim)`
135135

136-
Now that we've sharded it it will actually output:
136+
Now that we've sharded the embedding weight tensor, the layer will actually output:
137137
- Sharded output embeddings of `shape=(batch, seq, hidden_dim / mesh["tp"].size())`.
138138

139-
We have a problem though: Our `self_attn` module will first receive the output of this embedding, and we used ColwiseParallel on that. ColwiseParallel actually expects input to be **replicated** not shared.
139+
We have a problem though: Our *colwise* pieces of the `self_attn` module will receive the output of this module. ColwiseParallel actually expects input to be **replicated** not sharded.
140140

141-
So we need to transform the tensor to `shape=(batch, seq, hidden_dim)`. Luckily we can just specify this additional transformation with the `output_layouts` argument:
141+
So we need to do an allgather on the tensor to replicate it across the group (i.e. it will be back to `shape=(batch, seq, hidden_dim)`). Luckily we can just specify this additional transformation with the `output_layouts` argument:
142142

143143
```python
144144
tp.parallelize_module(
@@ -162,7 +162,7 @@ tp.parallelize_module(
162162
)
163163
```
164164

165-
We have to include `Replicate()` here because by default colwise shards on the last dimension, but we need the output of the network to be replicated across our TP dimension.
165+
We have to include `Replicate()` here because our loss expects replicated tensors, but colwise by default shards on the last dimension.
166166

167167
## Parallelizing Norm Layers with SequenceParallel
168168

0 commit comments

Comments
 (0)