diff --git a/fast_llm/models/gpt/model.py b/fast_llm/models/gpt/model.py index bb2f4ed5..d29fc28d 100644 --- a/fast_llm/models/gpt/model.py +++ b/fast_llm/models/gpt/model.py @@ -249,21 +249,28 @@ def preprocess( TransformerKwargs.presents: presents, } if phase != PhaseType.inference: + sequence_offset = sequence_k - sequence_q + 1 if sequence_first: - labels = batch.token_ids[sequence_k - sequence_q + 1 : sequence_k + 1] + labels = batch.token_ids[sequence_offset : sequence_k + 1] else: # TODO: Avoid multiple contiguous calls? labels = batch.token_ids[:, sequence_k - sequence_q + 1 : sequence_k + 1].contiguous() # We set label indices to -100 for masked spans, inline with ignore_index in torch.nn.CrossEntropyLoss # TODO: take ignore_index from config - if batch.loss_masking_spans is not None: - for i, spans_i in enumerate(batch.loss_masking_spans): - mask_indices = ( - torch.cat([torch.arange(s - 1, e) for s, e in spans_i]) - if len(spans_i) - else torch.tensor([], dtype=torch.int64) - ) - labels[i, mask_indices] = -100 + if batch.loss_masking_spans is not None: + for i, spans in enumerate(batch.loss_masking_spans): + if not spans.numel(): + continue + valid_spans = spans[(spans[:, 0] <= sequence_k) & (spans[:, 1] >= sequence_offset)] + if valid_spans.numel(): + valid_spans[:, 0].clamp_(min=sequence_offset) + valid_spans[:, 1].clamp_(max=sequence_k) + valid_spans -= sequence_offset + for start, end in valid_spans: + if sequence_first: + labels[start : end + 1, i] = -100 + else: + labels[i, start : end + 1] = -100 kwargs[LanguageModelKwargs.labels] = labels if self._config.use_absolute_position_embeddings: self._position_embedding_preprocessor.preprocess(kwargs)