Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ModernBERT] Add CausalLM functionality to ModernBERT #35946

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

orionw
Copy link

@orionw orionw commented Jan 28, 2025

This PR adds CausalLM functionality to ModernBERT. This means:

  • Tokenizer can work in either setting
  • CausalLM functions as either an MLM-based CausalLM or a normal CausalLM decoder-only.
  • The attention mask logic gets more complicated since there are more combinations (causal/bidirectional x local/global)
  • Given decoders use cache=True by default, I changed the default attention away from FA2 for is_causal since otherwise we'd need to pad/repad too frequently. We lose a bit of speed, but the code is much cleaner.

This is for an upcoming release that trained both options for a comparison of encoders and decoders.

Anticipated FAQs:

Q: Why not just make a new model class?
A: Models are trained in both settings and need to be able to operate back and forth with the same weights. This also reduces redundant code since all the base functionality is the same, other than kv caching and attention masking.

Q: Is there a test model?
A: Currently I have a test model at https://huggingface.co/blab-jhu/test-32m-dec that has been trained with CLM and works with this PR.

Q: Why are there two methods of CausalLM?
A: One works for the encoder-only models, which perform better than I anticipated on decoder-only tasks. The other is standard decoder-only modeling, which is the default.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Tagging those who have reviewed ModernBERT PRs @Rocketknight1 @ArthurZucker

Sorry, something went wrong.

oweller2 and others added 11 commits January 27, 2025 19:36
@orionw orionw changed the title Add CausalLM functionality to ModernBERT [ModernBERT] Add CausalLM functionality to ModernBERT Jan 31, 2025
@gante gante self-assigned this Feb 10, 2025
@orionw
Copy link
Author

orionw commented Feb 12, 2025

Thanks for assigning @gante! Also cc'ing the others who may be interested again @Rocketknight1 @ArthurZucker.

Let me know if I need to do anything else, I see I need to update the branch with main since it's been a while.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey! Thanks for the PR 🤗
As you already anticipated, my main concern is that this is a different "architecture" per say! With modular writing ModernBertDecoder + ModernBertForCausalLM should be easier, but we usually never change the codebase of existing models like this!

Other than this, the tokenizer is not needed either! The settings you added can be controlled directly via the template processor that needs to be different for each model ! 🤗

@orionw
Copy link
Author

orionw commented Feb 13, 2025

Thanks @ArthurZucker! I am not 100% sure if you are saying having them combined is a no-go or potentially still on the table, so I will give you my last attempt to describe why I think it's a good idea to have them combined.

  1. Say a user wants to mix use the functionality e.g. mix causal and non-causal training or inference. In this case they would have to initialize two versions of the model (as far as I understand) but somehow tie the weights together. Even simply loading the model seems like it would be unintuitive as the AutoModel would load it as one class and they would have to override that to load as the other version.

  2. Keeping them together means that bug fixes only go to one model instead of two, especially since they keep almost all of the same functionality (other than attn_masks).

  3. I was told that the SmolLM team at HF is working on a similar project (combining encoder and decoder training in one model). If they will be adding this capability to transformers soon (which I believe is the case), I would love to get that exemption also :)

What are your thoughts? You're the expert here, so if you still think it's best to separate them I will move the implementation to a ModernBERTDecoder class.

@gante
Copy link
Member

gante commented Feb 14, 2025

@orionw a question: this would be the equivalent to what we have in the original bert code, which has a masked lm head and a causal lm head, correct?

If so, I'd be inclined to include your PR, given that it is a pattern present in such a staple model. One modeling file would be a near copy of the other. But @ArthurZucker has the final word here :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants