-
Notifications
You must be signed in to change notification settings - Fork 28
[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
||
|
||
@config_class() | ||
class SSMArchitectureConfig(BaseModelArchitectureConfig): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please adjust field names for our naming conventions.
fast_llm/layers/ssm/config.py
Outdated
hint=FieldHint.core, | ||
) | ||
|
||
dt_rank: str | int = Field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use None
for derived defaults. dt_rank: int = Field(default=None, ...
return init_ | ||
|
||
|
||
class MambaLayer(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't work with TP, need to explicitly prevent. (Not sure about PP).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we're gonna have TP eventually, not in scope for this PR though
@jlamypoirier, naming-wise, "layer" implies a single functional unit, like attention or an MLP. But here, the repeated unit is a composite: It includes a mixer (attention, SSM, etc.), normalization, residuals, and post-mixer processing (MLPs, MoEs, etc.). This structure isn't atomic. It's multiple layers stitched into a reusable computation unit, which is typically referred to as a "block" in other model families (e.g., ResNet, Swin, and even “Transformer blocks” in many papers and blogs).
Keeping the name Also, and I find this the most important argument: internally and casually, we already call them blocks. Making the code match mental models reduces friction. |
@tscholak Sounds reasonable, problem s these things are called "layers" everywhere else in Fast-LLM. Should we think about renaming these too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks already very good to me, thanks @oleksost
This should go in asap!
I had a few comments and suggestions. Mostly, I think we don't want to be too picky with this PR at this point because we're going to be actively working on improving many parts of this anyway in the next weeks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess we can ignore some issues with the current SSM implementation, if we handle it properly. This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers. We still need to fix code changes outside the model though (transformers.py
and set_nested_dict_value
)
Also keep in mind that future modifications may break experiment configs, pretrained models and checkpoints, hence the importance of getting good config and parameter structures as soon as possible.
@@ -1,5 +1,6 @@ | |||
import logging | |||
import typing | |||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Import
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a good moment to reflect on the pattern.
Right now, Fast-LLM expects contributors to internalize a set of nuanced import rules (see https://servicenow.github.io/Fast-LLM/contributing/style-guide/#imports) that go beyond what most Python projects require. That may have worked at some point, but it doesn't scale. New contributors can't memorize this, and even returning ones keep tripping over it.
If this style is important, it needs to be enforced and automatically fixable through linting or a pre-commit hook. If it can't be, we should let go of it. Patterns that can't be learned quickly or applied automatically create friction and slow down the team.
@jlamypoirier, could you file a ticket outlining what it would take to automate this rule? That's the only sustainable way forward.
@oleksost, can you emit a warning in the logger when someone tries to instantiate the config for this model class? that should be enough for now. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some minor suggestions, otherwise looks ok to merge. I'll leave final approval to @tscholak
@@ -43,6 +43,7 @@ class ActivationType(str, enum.Enum): | |||
silu = "silu" | |||
relu = "relu" | |||
squared_relu = "squared_relu" | |||
identity = "identity" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add support for it in the MLP? (I know it's triton but this one is trivial.)
https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/triton/mlp.py Or otherwise prevent it in the config?
run_test = MambaLayer is not None and torch.cuda.is_available() | ||
|
||
|
||
def materialize_meta_tensors(model, tensor_space): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See also https://github.com/ServiceNow/Fast-LLM/pull/240/files#diff-97c4e262876c1009ed7547ee8518393e89b6622400d92cc88b0dfebe9e2f3de7R73. Using a mock stage handles all of this automatically.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
excellent work @oleksost, very proud of what you achieved here!
✨ Description
This PR integrates Mamba1 and discrete Mamba2 blocks into fast-llm training pypeline, this is the initial step to address #68 .
It introduces a basic hybrid architecture that can interleave transformer and mamba-1 blocks.
Next steps:
The training with a simple hybrid model can be tested:
mamba_ssm
and 'causal-conv1d' dependency,pip install mamba_ssm[causal-conv1d]==2.2.4
and the following simple config to build a hybrid model:
To load Llamba1B model, add the following to the config:
🔍 Type of change
Select all that apply:
📝 Changes
https://github.com/Zyphra/Zamba2
andhttps://github.com/state-spaces/mamba
✅ Checklist
Make sure the following tasks are completed before submitting the PR:
General
Dependencies and Configuration
Testing
Performance Impact
🗒️ Additional Notes
num_layers
, but they should probably moved to higher level configs at some point