@@ -50,7 +50,9 @@ def __init__(
5050 self ._group_size = tensor_space .distributed_config .tensor_parallel
5151 self ._sequence_parallel = tensor_space .distributed_config .sequence_tensor_parallel
5252 self ._parallel_embeddings = tensor_space .distributed_config .tensor_parallel > 1 and config .parallel_embeddings
53- self ._sequence_parallel_logits = self ._sequence_parallel and not self ._parallel_embeddings
53+ self ._sequence_parallel_logits = (
54+ tensor_space .distributed_config .sequence_tensor_parallel and not config .parallel_embeddings
55+ )
5456 self ._cross_entropy_splits = config .cross_entropy_splits
5557 if self ._cross_entropy_splits is not None and self ._sequence_parallel :
5658 assert not self ._parallel_embeddings
@@ -67,7 +69,7 @@ def __init__(
6769 # >0: multi-token prediction (MTP)
6870 Assert .geq (prediction_distance , 0 )
6971 self ._prediction_distance = prediction_distance
70- self .is_last_head = self ._prediction_distance == config .prediction_heads - 1
72+ self ._is_last_head = self ._prediction_distance == config .prediction_heads - 1
7173
7274 self ._init_output_weights (hidden_dim , config )
7375
@@ -114,7 +116,7 @@ def forward(
114116 tensor_name = "Loss" ,
115117 reductions = ((DistributedDimNames .data , ReduceOp .AVG ),), # noqa
116118 )
117- if not self .is_last_head :
119+ if not self ._is_last_head :
118120 # MTP: split the stacked input
119121 shared_hidden , input_ = torch .unbind (input_ , dim = 0 )
120122 # TODO: Pytorch copies the grads in backward for no reason (not sure if still the case)
@@ -123,10 +125,10 @@ def forward(
123125 # TODO: Drop autograd entirely.
124126 # TODO: Skip cross-entropy backward if not needed.
125127 language_model_loss = self ._forward (input_ , kwargs , losses )
126- if language_model_loss is not None :
128+ if losses is not None and language_model_loss is not None :
127129 losses [self ._loss_name ].append (language_model_loss )
128130 # TODO: Return the model output when needed.
129- if self .is_last_head :
131+ if self ._is_last_head :
130132 # Last head should return the loss for backward.
131133 return language_model_loss
132134 else :
@@ -147,14 +149,13 @@ def _forward_backward(
147149 if target is not None :
148150 if self ._config .distillation_model is None :
149151 # MTP: Shift the labels
150- target = (
151- target [self ._prediction_distance : self ._prediction_distance + input_ .size (0 ),]
152- if kwargs [TransformerKwargs .sequence_first ]
153- else target [
154- :,
155- self ._prediction_distance : self ._prediction_distance + input_ .size (1 ),
156- ]
152+ target_sequence_length = (
153+ target .size (1 - kwargs [TransformerKwargs .sequence_first ]) + 1 - self ._config .prediction_heads
157154 )
155+ if TransformerKwargs .sequence_q_dim in kwargs :
156+ Assert .eq (target_sequence_length , kwargs [TransformerKwargs .sequence_q_dim ].size )
157+ target_slice = slice (self ._prediction_distance , self ._prediction_distance + target_sequence_length )
158+ target = target [target_slice ] if kwargs [TransformerKwargs .sequence_first ] else target [:, target_slice ]
158159 target = target .flatten ()
159160 else :
160161 # Target is reference model logits.
0 commit comments