Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Z Loss in CE #239

Merged
merged 37 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
0454a12
Implement z loss in LigerCrossEntropyFunction
Tcc0403 Sep 9, 2024
9349e89
Merge branch 'main' into z-loss
lancerts Sep 9, 2024
27783be
Merge branch 'main' into z-loss
lancerts Sep 9, 2024
02e90db
Rename z_loss_scale to lse_square_scale
Tcc0403 Sep 9, 2024
aa43dca
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 Sep 10, 2024
aa4a4b2
Fix a mistake of the gradient calculation and update comments
Tcc0403 Sep 10, 2024
f53f61c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 Sep 10, 2024
b43c457
Implement z loss in LigerCrossEntropyFunction
Tcc0403 Sep 9, 2024
59bc0a3
Rename z_loss_scale to lse_square_scale
Tcc0403 Sep 9, 2024
0921c81
Fix a mistake of the gradient calculation and update comments
Tcc0403 Sep 10, 2024
c19f69c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 Sep 10, 2024
83c99ad
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 Sep 10, 2024
1ee07de
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 10, 2024
83f23d0
Support z loss in flce
Tcc0403 Sep 11, 2024
fcd5ff4
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 11, 2024
295aab7
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 13, 2024
f72e9bb
Fix parameter orders of ce and flce
Tcc0403 Sep 13, 2024
10fa578
Fix functional tests
Tcc0403 Sep 14, 2024
03beb05
Fix bfloat16 precision issue on custom model
Tcc0403 Sep 14, 2024
3a6cad4
Add missing arguments in test and cleanup stdout
Tcc0403 Sep 14, 2024
7e4cc4b
Merge branch 'main' into ce-z-loss
lancerts Sep 19, 2024
c0f2581
Merge branch 'main' into ce-z-loss
lancerts Sep 21, 2024
9abd163
Merge branch 'main' into ce-z-loss
Tcc0403 Sep 28, 2024
5c24241
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 1, 2024
97db6b4
Merge branch 'main' into ce-z-loss
lancerts Oct 1, 2024
cf632d8
Merge branch 'main' into ce-z-loss
lancerts Oct 3, 2024
91b62fd
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 12, 2024
d2d6e44
Fix merge conflicts
Tcc0403 Oct 12, 2024
f7083f2
Merge branch 'ce-z-loss' of github.com:Tcc0403/Liger-Kernel into ce-z…
Tcc0403 Oct 12, 2024
b89f335
Merge branch 'main' into ce-z-loss
Tcc0403 Oct 27, 2024
9a6079a
Merge branch 'main' into ce-z-loss
Tcc0403 Nov 2, 2024
c8d0fac
Merge branch 'main' into ce-z-loss
Tcc0403 Nov 5, 2024
c957357
chekcstyle
Tcc0403 Nov 5, 2024
4e34bf2
Merge branch 'main' into ce-z-loss
ByronHsu Nov 6, 2024
c304cc3
Merge branch 'main' into ce-z-loss
ByronHsu Nov 6, 2024
fb7aff7
Merge branch 'main' into ce-z-loss
ByronHsu Nov 7, 2024
d2ab058
Update src/liger_kernel/ops/cross_entropy.py
ByronHsu Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 105 additions & 14 deletions src/liger_kernel/ops/cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import triton
import triton.language as tl

_TRUE = tl.constexpr(1)
_FALSE = tl.constexpr(0)


@triton.jit
def liger_cross_entropy_kernel(
Expand All @@ -10,11 +13,14 @@ def liger_cross_entropy_kernel(
Y_ptr,
Y_stride,
loss_ptr,
z_loss_ptr,
loss_stride,
n_cols,
n_non_ignore,
ignore_index,
label_smoothing: tl.constexpr,
lse_square_scale: tl.constexpr,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if making label_smoothing and lse_square_scale tl.constexpr is a correct move.
Not familiar with model training. Are these two parameters often changed in practice? I'm worried that it might cause the same issue as #146.

Flash-attention's implementation creates a new constexpr for it in triton.heuristics to solve branching issues.
I wonder what the difference is between

  1. declarelabel_smoothing as a constexpr, and
  2. do calculations in triton.heuristics then assign a value to the constexpr HAS_SMOOTHING

My assumption is that:
in case 1, JIT every time label_smoothing changes
in case 2, JIT only when HAS_SMOOTHING changes because of calculations on label_smoothing.

If so, I will go with flash-attn's approach.

RETURN_Z_LOSS: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
"""
Expand All @@ -27,11 +33,14 @@ def liger_cross_entropy_kernel(
Y_ptr: Pointer to target tensor.
Y_stride (int): The stride of the target tensor.
loss_ptr: Pointer to tensor to store the loss.
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
loss_stride (int): The stride of the loss tensor.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
RETURN_Z_LOSS (int): The boolean value to decide whether storing z loss to z_loss_ptr or not. It must be 0 or 1.
BLOCK_SIZE (int): The block size for Triton operations.
"""

Expand All @@ -54,6 +63,7 @@ def liger_cross_entropy_kernel(
return

loss_ptr += program_id * loss_stride
z_loss_ptr += program_id * loss_stride

# Online softmax: 2 loads + 1 store (compared with 3 loads + 1 store for the safe softmax)
# Refer to Algorithm 3 in the paper: https://arxiv.org/pdf/1805.02867
Expand Down Expand Up @@ -83,20 +93,35 @@ def liger_cross_entropy_kernel(
d = d * tl.exp(m - m_new) + tl.sum(tl.exp(X_block - m_new))
m = m_new

# log (sum(e^(X_i))) = log (sum(e ^ (max(X) * e ^ (X_i - max(X)))))
# = log (e^(max(X)) * sum(e ^ (X_i - max(X))))
# = max(X) + log (sum(e ^ (X_i - max(X)))) = m + log d
lse = m + tl.log(d)

# 4. [Online softmax] second pass: calculate the gradients
# dx_y = (softmax(x_y) - 1) / N
# dx_i = softmax(x_i) / N, i != y
# N is the number of non ignored elements in the batch
# For label smoothing:
# dx_i = (softmax(x_y) - label_smoothing / V) / N, V = n_cols, i != y
# dx_i = (softmax(x_i) - label_smoothing / V) / N, V = n_cols, i != y
# dx_y = (softmax(x_y) - label_smoothing / V - (1 - label_smoothing)) / N
# = dx_i - (1 - label_smoothing) / N
# With Z loss:
# dx_i = ((1 + 2 * lse_square_scale * lse) * softmax(x_i) - label_smoothing / V) / N, i != y
# dx_y = dx_i - (1 - label_smoothing) / N
for i in range(0, n_cols, BLOCK_SIZE):
X_offsets = i + tl.arange(0, BLOCK_SIZE)
X_block = tl.load(
X_ptr + X_offsets, mask=X_offsets < n_cols, other=float("-inf")
)
X_block = (tl.exp(X_block - m) / d - eps) / (n_non_ignore)
# softmax(x_i)
X_block = tl.exp(X_block - m) / d
# derivative of z-loss: 2 * lse_square_scale * lse * softmax(x_i)
X_block += 2 * lse_square_scale * lse * X_block
# smoothing term
X_block += -eps
# reduction scale
X_block = X_block / (n_non_ignore)
tl.store(X_ptr + X_offsets, X_block, mask=X_offsets < n_cols)

# We need tl.debug_barrier() to ensure the new result of X_ptr is written as mentioned in
Expand All @@ -105,11 +130,12 @@ def liger_cross_entropy_kernel(

# 5. Calculate the loss

# loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
# -loss = log (softmax(X_y)) = log ((e ^ (X_y - max(X)) / sum(e ^ (X - max(X))))
ByronHsu marked this conversation as resolved.
Show resolved Hide resolved
# = (X_y - max(X)) - log(sum(e ^ (X - max(X))))
# = X_y - m - log d = X_y - lse
# sum(e ^ (X - max(X))) must >= 1 because the max term is e ^ 0 = 1
# So we can safely calculate log (softmax(X_y)) without overflow
loss = -(ori_X_y - m - tl.log(d))
loss = lse - ori_X_y

# Orginal loss = H(q, p), with label smoothing regularization = H(q', p) and (label_smoothing / V) = eps
# H(q', p) = (1 - label_smoothing) * H(q, p) + label_smoothing * H(u, p)
Expand All @@ -120,14 +146,21 @@ def liger_cross_entropy_kernel(
# pytorch: https://github.com/pytorch/pytorch/blob/2981534f54d49fa3a9755c9b0855e7929c2527f0/aten/src/ATen/native/LossNLL.cpp#L516
# See full derivation at https://github.com/linkedin/Liger-Kernel/pull/198#issuecomment-2333753087
if label_smoothing > 0:
smooth_loss = scaled_x_sum + label_smoothing * (m + tl.log(d))
smooth_loss = scaled_x_sum + label_smoothing * lse
loss = loss * (1 - label_smoothing) + smooth_loss

# An auxiliary loss, z_loss
# Refer to Page14 Loss function section in the paper PaLM: https://www.jmlr.org/papers/v24/22-1144.html
z_loss = lse_square_scale * lse * lse
loss += z_loss

# 6. Specially handle the i==y case where `dx_y = (softmax(x_y) - (1 - label_smoothing) / N`
X_y = tl.load(X_ptr + y)
X_y += -(1 - label_smoothing) / (n_non_ignore)

tl.store(loss_ptr, loss)
if RETURN_Z_LOSS == _TRUE:
tl.store(z_loss_ptr, z_loss)
tl.store(X_ptr + y, X_y)


Expand Down Expand Up @@ -173,14 +206,41 @@ def element_mul_kernel(
tl.store(X_ptr + X_offsets, X_block * grad_output, mask=X_offsets < n_cols)


def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
_bool_to_return_z_loss = {
True: _TRUE.value,
False: _FALSE.value,
}


def cross_entropy_forward(
_input,
target,
ignore_index,
label_smoothing,
lse_square_scale,
return_z_loss,
):
if not isinstance(return_z_loss, int):
assert (
return_z_loss in _bool_to_return_z_loss
), f"return_z_loss must be True or False. Got: {return_z_loss}"
return_z_loss = _bool_to_return_z_loss[return_z_loss]
else:
assert (
return_z_loss in _bool_to_return_z_loss
), f"return_z_loss must be True or False. Got: {return_z_loss}"

BT, V = _input.shape
n_rows = BT

BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(V))

# unreduced loss
loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
if return_z_loss == _TRUE.value:
z_loss_1d = torch.zeros(n_rows, dtype=_input.dtype, device=_input.device)
else:
z_loss_1d = loss_1d # dummy ptr when return_z_loss == False

n_non_ignore = (target != ignore_index).sum().item()

Expand All @@ -197,19 +257,27 @@ def cross_entropy_forward(_input, target, ignore_index, label_smoothing):
Y_ptr=target,
Y_stride=target.stride(-1), # always 1
loss_ptr=loss_1d,
z_loss_ptr=z_loss_1d,
loss_stride=loss_1d.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
lse_square_scale=lse_square_scale,
BLOCK_SIZE=BLOCK_SIZE,
RETURN_Z_LOSS=return_z_loss,
# TODO: 32 seems to give the best performance
# Performance is quite sensitive to num_warps
num_warps=32,
)

loss = torch.sum(loss_1d) / n_non_ignore
return loss, _input
if return_z_loss == _TRUE.value:
z_loss = torch.sum(z_loss_1d) / n_non_ignore
else:
z_loss = None

return loss, z_loss, _input


def cross_entropy_backward(_input, grad_output):
Expand Down Expand Up @@ -243,7 +311,15 @@ class LigerCrossEntropyFunction(torch.autograd.Function):
"""

@staticmethod
def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
def forward(
ctx,
_input,
target,
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale=0.0,
return_z_loss=False,
):
"""
The forward pass of the Liger Cross Entropy loss.

Expand All @@ -253,36 +329,51 @@ def forward(ctx, _input, target, ignore_index=-100, label_smoothing=0.0):
target (tensor): The target tensor of shape (BT) where each value is in [0, V-1].
ignore_index (int): The index to ignore in the target.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
lse_square_scale (float): The scaler of (logsumexp(_input)) ^ 2 adding to the loss for the stability of training.
return_z_loss (bool): When `return_z_loss` is `True`, returns (loss, z_loss) instead of (loss, None). Default: `False`


Returns:
tensor: The computed loss.
tuple: A tuple with the compouted losses with respect to loss and z loss. The elements are tensors or None.
"""
loss, _input = cross_entropy_forward(
_input, target, ignore_index, label_smoothing
loss, z_loss, _input = cross_entropy_forward(
_input,
target,
ignore_index,
label_smoothing,
lse_square_scale,
return_z_loss,
)
# TODO: investigation
# If we don't detach the _input tensor, the memory will double
# Not sure why but seems that there will be a time both grad and value exist but in different location
ctx.save_for_backward(_input.detach())
return loss
ctx.return_z_loss = return_z_loss

return loss, z_loss

@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output, grad_ouput2):
"""
The backward pass of the Liger Cross Entropy loss.

Parameters:
ctx : The context object with saved tensors.
grad_output (tensor): The tensor containing the gradient of the loss with respect to the output.

grad_output2 (tenosr): No use.
Returns:
tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None.
"""
if ctx.return_z_loss:
del grad_ouput2 # z_loss is only for logging

(_input,) = ctx.saved_tensors
_input = cross_entropy_backward(_input, grad_output)
return (
_input,
None,
None,
None,
None,
None,
)
30 changes: 26 additions & 4 deletions src/liger_kernel/ops/fused_linear_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,13 @@


def fused_linear_cross_entropy_forward(
_input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
_input,
weight,
target,
bias=None,
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale=0.0,
):
dtype = (
torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else _input.dtype
Expand Down Expand Up @@ -79,12 +85,15 @@ def fused_linear_cross_entropy_forward(
Y_ptr=target_chunk,
Y_stride=target_chunk.stride(-1), # always 1
loss_ptr=loss_1d_slice,
z_loss_ptr=loss_1d_slice, # dummy ptr, not used
loss_stride=loss_1d_slice.stride(-1), # always 1
n_cols=V,
n_non_ignore=n_non_ignore,
ignore_index=ignore_index,
label_smoothing=label_smoothing,
lse_square_scale=lse_square_scale,
BLOCK_SIZE=BLOCK_SIZE,
RETURN_Z_LOSS=0, # False
num_warps=32,
)

Expand Down Expand Up @@ -179,7 +188,14 @@ def fused_linear_cross_entropy_backward(
class LigerFusedLinearCrossEntropyFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx, _input, weight, target, bias=None, ignore_index=-100, label_smoothing=0.0
ctx,
_input,
weight,
target,
bias=None,
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale=0.0,
):
"""
Fusing the last linear layer with cross-entropy loss
Expand All @@ -198,7 +214,13 @@ def forward(
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
"""
loss, grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_forward(
_input, weight, target, bias, ignore_index, label_smoothing
_input,
weight,
target,
bias,
ignore_index,
label_smoothing,
lse_square_scale,
)
# downcast to dtype and store for backward
ctx.save_for_backward(
Expand All @@ -214,4 +236,4 @@ def backward(ctx, grad_output):
grad_input, grad_weight, grad_bias = fused_linear_cross_entropy_backward(
grad_output, grad_input, grad_weight, grad_bias
)
return (grad_input, grad_weight, None, grad_bias, None, None)
return (grad_input, grad_weight, None, grad_bias, None, None, None)
30 changes: 24 additions & 6 deletions src/liger_kernel/transformers/cross_entropy.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,34 @@
from torch.nn import CrossEntropyLoss
import torch.nn as nn

from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction


class LigerCrossEntropyLoss(CrossEntropyLoss):
def __init__(self, *args, **kwargs):
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs)
class LigerCrossEntropyLoss(nn.Module):
def __init__(
self,
ignore_index=-100,
label_smoothing=0.0,
lse_square_scale=0.0,
return_z_loss=False,
):
super().__init__()
self.ignore_index = ignore_index
self.label_smoothing = label_smoothing
self.lse_square_scale = lse_square_scale
self.return_z_loss = return_z_loss
assert (self.label_smoothing >= 0) and (
self.label_smoothing <= 1
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}"

def forward(self, _input, target):
return LigerCrossEntropyFunction.apply(
_input, target, self.ignore_index, self.label_smoothing
loss, z_loss = LigerCrossEntropyFunction.apply(
_input,
target,
self.ignore_index,
self.label_smoothing,
self.lse_square_scale,
self.return_z_loss,
)
if not self.return_z_loss:
return loss
return loss, z_loss
Loading
Loading