Skip to content

Commit a76cc4b

Browse files
authored
Simplify log_softmax_and_gather operation (#1124)
1 parent 6b18946 commit a76cc4b

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

open_instruct/model_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -507,18 +507,19 @@ def save_with_accelerate(
507507
# customize model card (TODO (Costa): this can be prettier)
508508

509509

510-
@torch.compile(dynamic=True)
511510
def log_softmax_and_gather(logits: torch.Tensor, index: torch.Tensor) -> torch.Tensor:
512511
"""
513-
torch compiled version of the common `log_softmax -> gather` operation.
512+
A more memory efficient version of the common `log_softmax -> gather` operation used in
513+
post-training algorithms.
514514
515-
The compiled version of this opration avoids the (significant) memory overhead of
516-
allocating a new (batch_size, seq_len, vocab_size) tensor to store the logprobs.
515+
Using the negative cross entropy loss is equivalent to the log_softmax -> gather operation,
516+
but is more memory efficient since it doesn't require allocating a new
517+
(batch_size, seq_len, vocab_size) tensor to store the logprobs.
517518
518519
See https://github.com/allenai/open-instruct/pull/584
519520
"""
520-
logprobs = logits.log_softmax(dim=-1)
521-
return torch.gather(logprobs, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
521+
B, T, V = logits.shape
522+
return -torch.nn.functional.cross_entropy(logits.view(-1, V), index.view(-1), reduction="none").view(B, T)
522523

523524

524525
@retry_on_exception()

0 commit comments

Comments
 (0)