Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reasons for upcasting the logits dtype outside the kernel #241

Open
yzhangcs opened this issue Sep 10, 2024 · 7 comments
Open

Reasons for upcasting the logits dtype outside the kernel #241

yzhangcs opened this issue Sep 10, 2024 · 7 comments

Comments

@yzhangcs
Copy link

Hello, thank you for this great work.

logits_chunk = logits_chunk.float()

# gradient of logits_chunk is computed in-place by the above triton kernel.
# Following HuggingFace model source code, we do the forward and backward
# w.r.t. logits in fp32 for numerical stability especially as the num classes (vocab size) os huge.
# (reference: https://github.com/huggingface/transformers/blob/v4.42.4/src/transformers/models/llama/modeling_llama.py#L1194)
# Propagating to lm_head's backward, we'll switch back to the original dtype.
logits_chunk = logits_chunk.to(dtype)

I'm wondering if there are any reasons for upcasting/downcasting the logits dtype outside the kernel?
If I understand correctly, we already do fp32 upcast inside, so this op is redundant?
I just compare the outputs of the two versions, i.e., w/ and w/o the upcast, and found there's no precision loss if the above code r removed.

@ByronHsu
Copy link
Collaborator

ByronHsu commented Sep 10, 2024

Thanks! We did the casting to stay consistent with huggingface behavior. But yes, i think we can do it inside. There is a PR doing this: #238.

found there's no precision loss if the above code r removed.

Curious how did you measure the loss?

Also, very impressive background. Welcome contribution :-)

@yzhangcs
Copy link
Author

yzhangcs commented Sep 10, 2024

@ByronHsu

Curious how did you measure the loss?

torch.manual_seed(42)
batch_size, seq_len, hidde_size, vocab_size = 8, 4096, 2048, 128000
x = torch.randn(batch_size * seq_len, hidde_size).cuda().bfloat16().requires_grad_()
target = torch.randint(0, vocab_size, (batch_size * seq_len,)).cuda()
weight = torch.randn(vocab_size, hidde_size).cuda().bfloat16().requires_grad_()
bias = torch.randn(vocab_size).cuda().bfloat16().requires_grad_()

logits = F.linear(x, weight, bias).float()

output1 = nn.CrossEntropyLoss()(logits, target)
do = torch.randn_like(output1).cuda().bfloat16()

output1.backward(do)
ref_dx, x.grad = x.grad.clone(), None
ref_dw, weight.grad = weight.grad.clone(), None
ref_db, bias.grad = bias.grad.clone(), None

output2 = FusedLinearCrossEntropyLoss()(x, target, weight, bias)
output2.backward(do)
tri_dx, x.grad = x.grad.clone(), None
tri_dw, weight.grad = weight.grad.clone(), None
tri_db, bias.grad = bias.grad.clone(), None
# print('\n\n', output1, )
# print(output2, '\n\n',)
print(" o", torch.abs(output1 - output2).max())
print("dx", torch.abs(ref_dx - tri_dx).max())
print("dw", torch.abs(ref_dw - tri_dw).max())
print("db", torch.abs(ref_db - tri_db).max())

very simple testing code.
They give exactly the same abs errors.

 o tensor(0.0005, device='cuda:0', grad_fn=<MaxBackward1>)
dx tensor(2.3842e-07, device='cuda:0', dtype=torch.bfloat16)
dw tensor(4.7684e-07, device='cuda:0', dtype=torch.bfloat16)
db tensor(1.1921e-07, device='cuda:0', dtype=torch.bfloat16)

@ByronHsu
Copy link
Collaborator

 o tensor(0.0005, device='cuda:0', grad_fn=<MaxBackward1>)

o difference is a bit high?

@yzhangcs
Copy link
Author

yzhangcs commented Sep 10, 2024

@ByronHsu That's the diffs of current impls. I think this loss makes sense given that the vocab is large and the first output is computed under fp32 while the second is bf16.

@yzhangcs
Copy link
Author

@ByronHsu Hi, have you compared the final loss of FLCE with the naive counterpart?
I think chunking the input into several pieces might be problematic for very large V and relatively small H.
For example, if H=1024 and V=128*1024, the inv_factor would be 128. Since the accumulated dw are maintained in bf16/fp16, I’m uncertain if this could cause any issues.

It might be better to limit the maximum number of chunks. For instance, we could set inv_factor to be min(8, triton.cdiv(V/H)).

@yzhangcs
Copy link
Author

yzhangcs commented Sep 14, 2024

I've also found that for 128K vocab, 8 chunks can be faster, with the cost of nearly <1G additional mem.
This is my benchmarks on H100 using the scripts of https://github.com/sustcsonglin/flash-linear-attention/blob/main/benchmarks/benchmark_training_throughput.py:

V Configuration K Tokens / sec GiB
32K w/ chunk (V/H) 46.18 35.09
w/ chunk (4) 46.55 35.09
w/ chunk (8) 46.52 35.10
w/o chunk 46.13 39.21
128K w/ chunk (V/H) 40.80 37.66
w/ chunk (4) 41.82 39.03
w/ chunk (8) 41.80 38.05
w/o chunk 41.12 56.05

@yzhangcs yzhangcs reopened this Sep 14, 2024
@yzhangcs
Copy link
Author

yzhangcs commented Sep 14, 2024

UPDATE:

Just trained 3 370M models on 10B tokens of Fineweb-edu with 8K ctx length and 32K vocab
Below are the results. Superisingly, 8 chunks exhibits the best ppl. V/H=32K/1K=32 chunks performs worse than the others.

PPL $_\downarrow$ Throughputs
w/ chunk (V/H) 12.73 109.22
w/ chunk (8) 12.70 109.54
w/o chunk 12.71 109.27

Regarding throughputs, 8 is faster.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants