Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Z Loss in CE #239

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open

Support Z Loss in CE #239

wants to merge 21 commits into from

Conversation

Tcc0403
Copy link
Contributor

@Tcc0403 Tcc0403 commented Sep 10, 2024

Summary

This PR aims to resolve #197

Implemented z loss in LigerCrossEntropy.

note: lse_square_scale not exposed at flce yet, having issues passing the tests.

Details

For loss:

$$\begin{align} L_{total} &= L_{ce} + z\_loss\\ z\_loss &= lse\_square\_scale \cdot lse^2\\ lse &= log \sum e^{X_i} \end{align}$$

We can use $m = max(X_i)$ and $d = \sum e^{X_i - m}$, obtained from online softmax algorithm, to calculate $lse$ directly.

$$\begin{align} lse &= log \sum e^{X_i}\\ &= log \sum e^{X_i - m + m} = log \sum e^{X_i -m} \cdot e^m\\ &= log\ e^m\sum e^{X_i - m} = m + d \end{align}$$

For gradients:

First, we calculate the derivative of lse

$$\begin{align} \frac{\partial}{\partial x_i}(lse) &= \frac{\partial}{\partial x_i}(log \sum e^{x_i}) \\ &= \frac{1}{\sum e^{x_i}} \cdot \frac{\partial}{\partial x_i} \sum e^{x_i}\\ &= \frac{e^{x_i}}{\sum e^{x_i}} = softmax(x_i). \end{align}$$

Then we can obtain the derivative of z_loss by chain rule.

$$\frac{\partial z\_loss}{\partial x_i} = \frac{\partial}{\partial x_i}\left( lse\_square\_scale \cdot lse^2\right) = 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i),$$

and we have the derivative of cross entropy loss with label smoothing

$$\frac{\partial L_{ce}}{\partial x_i} = softmax(x_i) - (1 - \epsilon)\delta_{k,y} + \frac{\epsilon}{K}= \begin{cases} softmax(x_i) - \frac{\epsilon}{K}, & i \neq y \\\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon) & i = y \end{cases}$$

where $\epsilon$ is label_smoothing and $K$ is the number of total classes.
Thus, the derivative of total loss is

$$\begin{align} \frac{\partial}{\partial x_i}L_{total} &= \frac{\partial}{\partial x_i}L_{ce} + \frac{\partial}{\partial x_i}z\_loss\\ &= softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon)\delta_{k,y} + 2\cdot lse\_square\_scale \cdot lse \cdot softmax(x_i)\\ &=\begin{cases} (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K}, & i \neq y\\\ (1 + 2\cdot lse\_square\_scale \cdot lse)\ softmax(x_i) - \frac{\epsilon}{K} - (1 - \epsilon), & i = y \end{cases} \end{align}$$

Reference

PaLM: Scaling Language Modeling with Pathways
Chameleon: Mixed-Modal Early-Fusion Foundation Models

Testing Done

benchmark gist
neglectable error in speed benchmark.

This benchmark was done on my machine, which is probably not accurate.

liger ce: 66.123ms
Peak mem:  8.66200832

liger ce with zloss: 65.991ms
Peak mem:  8.66200832

liger ce with zloss with return zloss: 65.951ms
Peak mem:  8.662073856
  • Hardware Type:
  • run make test to ensure correctness
  • run make checkstyle to ensure code style
  • run make test-convergence to ensure convergence

@Tcc0403 Tcc0403 changed the title Ce z loss Support Z Loss in CE Sep 10, 2024
@Tcc0403
Copy link
Contributor Author

Tcc0403 commented Sep 10, 2024

Passed all tests. Ready for review!

loss_stride,
n_cols,
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
lse_square_scale: tl.constexpr,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if making label_smoothing and lse_square_scale tl.constexpr is a correct move.
Not familiar with model training. Are these two parameters often changed in practice? I'm worried that it might cause the same issue as #146.

Flash-attention's implementation creates a new constexpr for it in triton.heuristics to solve branching issues.
I wonder what the difference is between

  1. declarelabel_smoothing as a constexpr, and
  2. do calculations in triton.heuristics then assign a value to the constexpr HAS_SMOOTHING

My assumption is that:
in case 1, JIT every time label_smoothing changes
in case 2, JIT only when HAS_SMOOTHING changes because of calculations on label_smoothing.

If so, I will go with flash-attn's approach.

@Tcc0403
Copy link
Contributor Author

Tcc0403 commented Sep 14, 2024

Ignore OOM errors, the current custom CrossEntropyWithZLoss (torch.nn.module), as a ground truth implementation, has precision issue on gradients calculations with bfloat16 and reduction="sum".

LigerCrossEntropyLoss in this PR has no issue passing tests if comparing to flash-attn's CrossEntropyLoss.
(gist)

Current goal is to make the custom torch implementation on par with flash-attn's.

Update: problems solved

@Tcc0403
Copy link
Contributor Author

Tcc0403 commented Sep 14, 2024

All passed

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Z Loss in CE
3 participants