-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support Z Loss in CE #239
Merged
Merged
Support Z Loss in CE #239
Changes from 15 commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
0454a12
Implement z loss in LigerCrossEntropyFunction
Tcc0403 9349e89
Merge branch 'main' into z-loss
lancerts 27783be
Merge branch 'main' into z-loss
lancerts 02e90db
Rename z_loss_scale to lse_square_scale
Tcc0403 aa43dca
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 aa4a4b2
Fix a mistake of the gradient calculation and update comments
Tcc0403 f53f61c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 b43c457
Implement z loss in LigerCrossEntropyFunction
Tcc0403 59bc0a3
Rename z_loss_scale to lse_square_scale
Tcc0403 0921c81
Fix a mistake of the gradient calculation and update comments
Tcc0403 c19f69c
Remove the parameter `lse_square_scale` in FusedLinearCrossEntropyLos…
Tcc0403 83c99ad
Merge branch 'z-loss' of github.com:Tcc0403/Liger-Kernel into z-loss
Tcc0403 1ee07de
Merge branch 'main' into ce-z-loss
Tcc0403 83f23d0
Support z loss in flce
Tcc0403 fcd5ff4
Merge branch 'main' into ce-z-loss
Tcc0403 295aab7
Merge branch 'main' into ce-z-loss
Tcc0403 f72e9bb
Fix parameter orders of ce and flce
Tcc0403 10fa578
Fix functional tests
Tcc0403 03beb05
Fix bfloat16 precision issue on custom model
Tcc0403 3a6cad4
Add missing arguments in test and cleanup stdout
Tcc0403 7e4cc4b
Merge branch 'main' into ce-z-loss
lancerts c0f2581
Merge branch 'main' into ce-z-loss
lancerts 9abd163
Merge branch 'main' into ce-z-loss
Tcc0403 5c24241
Merge branch 'main' into ce-z-loss
Tcc0403 97db6b4
Merge branch 'main' into ce-z-loss
lancerts cf632d8
Merge branch 'main' into ce-z-loss
lancerts 91b62fd
Merge branch 'main' into ce-z-loss
Tcc0403 d2d6e44
Fix merge conflicts
Tcc0403 f7083f2
Merge branch 'ce-z-loss' of github.com:Tcc0403/Liger-Kernel into ce-z…
Tcc0403 b89f335
Merge branch 'main' into ce-z-loss
Tcc0403 9a6079a
Merge branch 'main' into ce-z-loss
Tcc0403 c8d0fac
Merge branch 'main' into ce-z-loss
Tcc0403 c957357
chekcstyle
Tcc0403 4e34bf2
Merge branch 'main' into ce-z-loss
ByronHsu c304cc3
Merge branch 'main' into ce-z-loss
ByronHsu fb7aff7
Merge branch 'main' into ce-z-loss
ByronHsu d2ab058
Update src/liger_kernel/ops/cross_entropy.py
ByronHsu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,34 @@ | ||
from torch.nn import CrossEntropyLoss | ||
import torch.nn as nn | ||
|
||
from liger_kernel.ops.cross_entropy import LigerCrossEntropyFunction | ||
|
||
|
||
class LigerCrossEntropyLoss(CrossEntropyLoss): | ||
def __init__(self, *args, **kwargs): | ||
super(LigerCrossEntropyLoss, self).__init__(*args, **kwargs) | ||
class LigerCrossEntropyLoss(nn.Module): | ||
def __init__( | ||
self, | ||
ignore_index=-100, | ||
label_smoothing=0.0, | ||
lse_square_scale=0.0, | ||
return_z_loss=False, | ||
): | ||
super().__init__() | ||
self.ignore_index = ignore_index | ||
self.label_smoothing = label_smoothing | ||
self.lse_square_scale = lse_square_scale | ||
self.return_z_loss = return_z_loss | ||
assert (self.label_smoothing >= 0) and ( | ||
self.label_smoothing <= 1 | ||
), f"label_smoothing must be between 0.0 and 1.0. Got: {self.label_smoothing}" | ||
|
||
def forward(self, _input, target): | ||
return LigerCrossEntropyFunction.apply( | ||
_input, target, self.ignore_index, self.label_smoothing | ||
loss, z_loss = LigerCrossEntropyFunction.apply( | ||
_input, | ||
target, | ||
self.ignore_index, | ||
self.label_smoothing, | ||
self.lse_square_scale, | ||
self.return_z_loss, | ||
) | ||
if not self.return_z_loss: | ||
return loss | ||
return loss, z_loss |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if making label_smoothing and lse_square_scale
tl.constexpr
is a correct move.Not familiar with model training. Are these two parameters often changed in practice? I'm worried that it might cause the same issue as #146.
Flash-attention's implementation creates a new constexpr for it in
triton.heuristics
to solve branching issues.I wonder what the difference is between
label_smoothing
as a constexpr, andtriton.heuristics
then assign a value to the constexprHAS_SMOOTHING
My assumption is that:
in case 1, JIT every time
label_smoothing
changesin case 2, JIT only when
HAS_SMOOTHING
changes because of calculations onlabel_smoothing
.If so, I will go with flash-attn's approach.