-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] FusedLinearCrossEntropy support for Gemma2 #127
Comments
#take @yundai424 I would like to make an attempt to make it available. I'm thinking this approach:
Can you assign it to me if this sounds okay? |
@troy1729 Sounds reasonable to me. Assigned and feel free to kick off the implementation and ping us to discuss or review on any issues. Thank you! |
Hi @qingquansong, I've made the changes but still have to add the tests hence kept the PR in draft stage. |
Hey @troy1729 , thanks for the question (no silly question) and fast kick off! I think
In sum, my suggestion would be: implement the tanh option for now only + follow geglu backward to see how tanh gradient is computed with chain rule to device the equation and implement it here |
I believe I've implemented softcap in cross entropy function correctly and the flce support for gemma2. But since gemma2 currently can't pass the test even without flce, do I need to find a way to pass the relevant convergence test (test_mini_models_no_logits.py)? cc @yundai424 |
## Summary Resolves #127. Fuse softcapping into cross_entropy kernel, so it can be called by fused linear cross entropy function. ## Testing Done Current monkey patch for Gemma2 can't pass covergence test without flce either. The test is commented out for now. - Hardware Type: - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence --------- Co-authored-by: Byron Hsu <[email protected]> Co-authored-by: Shao Tang <[email protected]>
🚀 The feature, motivation and pitch
FLCE needs special handling for the soft capping in gemma2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma2/modeling_gemma2.py#L1054
Alternatives
No response
Additional context
No response
The text was updated successfully, but these errors were encountered: