Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fp8 #266

Open
wants to merge 46 commits into
base: main
Choose a base branch
from
Open

fp8 #266

wants to merge 46 commits into from

Conversation

xrsrke
Copy link
Member

@xrsrke xrsrke commented Dec 18, 2024

This PR contains the implementation of the 2nd fp8 pretraining recipe except an FP8 optimizer stater and update clipping. For the experimental implementation of recipe 1, please check out [this pull request].

Convergence

Found two stable FP8 pretraining recipes that pretrained a LLaMA 2 architecture in FP8 for both forward and backward passes, as well as both momentums (50% memory reduction), while matching the standard BF16 mixed-precision baseline after 100B tokens [link] [pull request]

  • Recipe 1 with architectural and optimizer changes [link]

image

  • Ablated recipe 2 without architectural changes (better recipe) [[link]]: Remove all the architectures changes in the recipe 1, and add gradient clipping (silly mistake)
image

Trained a 1b llama2's loss curve for 100B tokens

image

Trained 7b llama2's loss curve for 24k steps with 300k batch size (except 2nd momentum in bfloat16)

image

Speed

  • Got 1033 FLOPs for a fp8 tensor parallel linear [link]

image

Failed Experiments

  • Smooth Quantization (in the paper, they show it works well for inference): without gradient clipping, it solves some divergence issues, but with gradient clipping, it hurts model performance
  • AdamW_atan2: it even blew up before AdamW
  • Sync the amax of the same fp8 tensor across tp ranks: weights, input gradients, weight gradients, output gradients → doesn't result in much difference in performance
  • Weight decay without learning rate decay
  • QKV clipping
  • CohereLayerNorm
  • Tune Triton RMS norm
  • Tune gradient clipping's epsilon factor (expected at some scale, this factor does influence training stability)
  • Delayed quantization: not much difference
  • Tuned the quantization interval: 1, 2, 16, 32 → not much of a difference
  • Warm up quantization: not much difference (warmup the amax, before calculating amax in interval)
  • Try truncated normal distribution (timm's trunc_normal_) in initialization of fp8 weights → it didn't fix the divergence in recipe 1's setup
    • trunc_normal_(weight, std=0.02)
    • trunc_normal_(weight, std=math.sqrt(1 / 64))
    • trunc_normal_(weight, std=math.sqrt(1 / 64 * 4))
    • trunc_normal_(weight, std=1)
  • In recipe 1's setup
    • The model loss stuck at 8 when using PyTorch's default random weight initialization

xrsrke added 30 commits October 9, 2024 10:15
… fp8 after mark sharded, tied, and parametrization
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant