No “emergent behavior / aha moment” when retraining GPT-2 on FineWeb; warmup / “warm training” guidance requested #889
Replies: 7 comments 4 replies
-
|
Tensorboard metrics showing the training loss, validation loss, learning rate and token/seen |
Beta Was this translation helpful? Give feedback.
-
|
Sample response for model after certain steps, for finishing up the sentence with start context of |
Beta Was this translation helpful? Give feedback.
-
|
Thanks for sharing this very interesting discussion! Regarding your points: A. This looks like a reasonable mod. I have a suggestion below regarding QK-Norm that might additionally help. B. I would stick with it for now, but maybe in the next run you could print the large gradient/high loss samples to further investigate C. I am not sure if you'd see it with this small model, but it should match the published GPT-2 I'd say. D. Usually this is done with re-warming. I briefly wrote about it here [https://magazine.sebastianraschka.com/p/tips-for-llm-pretraining-and-evaluating-rms] based on the Simple and Scalable Strategies to Continually Pre-train Large Language Models paper. E. Yes, it would be reasonable to expect the same loss. What I would do is to take a sample from a news article that wasn't in the training data and then calculate the loss or perplexity for the base GPT-2 127M and then do this periodically for your trained model. Something that we know could not have been in the training data. For example, from a WSJ article today:
This would maybe help to more fairly compare the two models to each other. E.g., via PS: I am getting GPT-2 127M
gpt2-medium (355M)
Qwen 0.6B Base
I have a few questions too if you don't mind:
Regarding tips: I agree that these spikes could come from the data. But that being said there are maybe some improvements I would tr.. The important thing, if you have the budget and time, is to try one thing at a time so you can see where the differences come from. Some suggestions are: F. I would probably remove dropout (or set it to 0.); in my experience it doesn't help and may make things even worse G. Another one would be to add QK-Norm like in Qwen3 here: https://github.com/rasbt/LLMs-from-scratch/blob/main/ch05/11_qwen3/standalone-qwen3.ipynb class MultiHeadAttention(nn.Module):
def __init__(...):
if qk_norm:
self.q_norm = LayerNorm(head_dim)
self.k_norm = LayerNorm(head_dim)
def forward(...):
#...
queries = self.q_norm(queries)
keys = self.k_norm(keys)
attn_scores = queries @ keys.transpose(2, 3)
#...Usually, QK-Norm is nowadays implemented with RMSNorm, but for consistency, I would try LayerNorm first. Maybe it gets rid of the spikes. H. I would also be curious how Qwen3 0.6B performs in terms of smoothness if it is not too large to run. You could technically shrink it by reducing the number of layers. The code there in the notebook should work as a drop-in replacement for the GPT model. If all that doesn't work, it might be dataset or optimizer and learning rate schedule related. But out of curiosity I would try these things above first. I'd be curious what the results are. |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.
-
|
As to training speed of your code on H100, I used your code from Chapter 4, it has the speed of nearly 400,000tokens/s. And could fully utilized the H100 power and memory. It took less than 0.5s for a step. And training 3.2B tokens took: 3.2B/400,000 = 8000s = 2.2hours. There are definitely room for improvement however currently I am working on ensuring the correctness of the model.
Here is the full code I used: https://github.com/talentJay-ux/LLMs-from-scratch/blob/b66c1c9c74a2f06bc612054d030bff0093b693d8/ch05/10_llm-training-speed/03_train_from_scratch.py. Once I was able to fixed the model collapsed issue, I would definitely try your other recommendations, such as using different norms and position embeddings! |
Beta Was this translation helpful? Give feedback.
-
|
If I may add my 2 cents, besides QK norm that Sebastian mentioned, you can maybe find some ideas to stabilize your training from this paper: https://arxiv.org/abs/2410.16682. I'm also interested to know what Sebastian think about doing L2 on Layernorm, I had in mind it wasn't a good idea. Qwen, for their Qwen3-Next, used RMSNorm with "zero centered weights" to better adapt it for L2. Btw since Sebastian mentioned optimizers, a nice variant of Muon that Moonshot used to train their awesome Kimi K2, that could help, is MuonClip (Muon+QK weights rescaling based on max attention logits seen. Can be implemented separately, even with Adam). In any case, good luck with your training, it's a good hands-on project 👍 |
Beta Was this translation helpful? Give feedback.
-
Beta Was this translation helpful? Give feedback.







Uh oh!
There was an error while loading. Please reload this page.
-
Summary
I re-trained a GPT-2-style model from random parameters using HuggingFaceFW/FineWeb dataset. Training/validation loss plateau around ~4 and don’t exhibit a sudden drop; evaluation generations repeat heavily. I’m looking for guidance on “warm training” parameters and optimization suggestions for training the model from scratch
Environment & Model
<|endoftext|>)Data
HuggingFaceFW/fineweb(streaming)language_score ≥ 0.9<|endoftext|>between documentsDataloader (key settings)
num_workers=4,pin_memory=Trueval_mod=100), fixed eval loaders for stable lossRun stats (this run)
What I observe
Reproduction (minimal)
(Happy to post a full runnable script if helpful.)
Questions / requests for guidance
A. Gradient normalization
Still couldn't not avoid the sudden spikes of the loss spikes, do you recommend other methods? For example, I could try to drop the update entirely, if the training loss is too big.
B. Data
C. “Emergence” expectations
D. Warm training suggestions
I would pike up some the model and train on top of it, do you have learning rate suggestions when doing warm training?
E. cross-entropy loss
What would be a good expectation for the loss? Random baseline: ≈ ln(50304) ≈ 10.83 nats. GPT-2 ~124–128M (well-trained on WebText-like): ≈ 3.4–3.8 (PPL ≈ 30–45). Given the above, is ~3.5–3.8 a reasonable validation loss target for GPT-2-124M on FineWeb?
Thank you!
Beta Was this translation helpful? Give feedback.
All reactions