Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] Add support for encoder-only transformers (e.g. BERT) #131

Open
OxxoCodes opened this issue Aug 27, 2024 · 0 comments
Open

[feat] Add support for encoder-only transformers (e.g. BERT) #131

OxxoCodes opened this issue Aug 27, 2024 · 0 comments
Labels

Comments

@OxxoCodes
Copy link

🚀 The feature, motivation and pitch

Liger Kernel is currently incompatible with encoder-only transformer architectures such as BERT, DistilBERT, RoBERTa, XLM-R, and DeBERTa.

Given the importance these models still have in research and industry use-cases, it would be great to see support added to further decrease memory requirements and increase training throughput.

Alternatives

No response

Additional context

No response

@OxxoCodes OxxoCodes changed the title Add support for encoder-only transformers (e.g. BERT) [feat] Add support for encoder-only transformers (e.g. BERT) Aug 27, 2024
lancerts added a commit that referenced this issue Aug 31, 2024
## Summary
- Added Embedding forward/backwards kernels + LigerEmbedding class which
maps to nn.Embedding
- nn.Embedding is useful for encoder-only models such as BERT
- ref: #131

<!---
## Details
This is an optional section; is there anything specific that reviewers
should be aware of?
--->

## Testing Done
<!--- This is a required section; please describe how this change was
tested. --->
- tested against nn.Embedding for correctness on various inputs
- tested with and without padding_idx

<!-- 
Replace BLANK with your device type. For example, A100-80G-PCIe

Complete the following tasks before sending your PR, and replace `[ ]`
with
`[x]` to indicate you have done them. 
-->

- Hardware Type: RTX 3090 + RTX 4090
- [x] run `make test` to ensure correctness
- [x] run `make checkstyle` to ensure code style
- [x] run `make test-convergence` to ensure convergence

---------

Co-authored-by: Shao Tang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants