Skip to content

[feat] Hybrid Mamba model with Mamba and discrete Mamba 2 layers #194

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 53 commits into from
Apr 24, 2025
Merged

Conversation

oleksost
Copy link
Contributor

@oleksost oleksost commented Mar 20, 2025

✨ Description

This PR integrates Mamba1 and discrete Mamba2 blocks into fast-llm training pypeline, this is the initial step to address #68 .
It introduces a basic hybrid architecture that can interleave transformer and mamba-1 blocks.

Next steps:

The training with a simple hybrid model can be tested:

  1. Install mamba_ssm and 'causal-conv1d' dependency, pip install mamba_ssm[causal-conv1d]==2.2.4
    1. launch training by passing
        "args": [
                    "train",
                    "hybrid_ssm",
                    "--config",
                    "path/to/hybrid_config.yaml"
                ],
    

and the following simple config to build a hybrid model:

model:
  base_model:
    transformer:
      num_layers: 6
      use_flash_attention: no  
    ssm:
      dt_rank: auto
      state_size: 16
      expansion_factor: 2
      debug_ssm: false
    block_pattern: ["m", "t", "m", "m2", "m", "m"] # mixing transformer, mamba 1 and descrete mamba layers
  
  distributed:
    training_dtype: bf16
    tensor_parallel: 1 
    pipeline_parallel: 1
    world_size: 1 

training:
  train_iters: 1000  
  logs:
    interval: 10
  validation:
    iterations: 25
    interval: 1000
  wandb:  
    project_name: fast-llm-ssm-test
    group_name: ssm
    entity_name: null

data:
  datasets:
    Training:
      type: memmap
      path: /home/toolkit/dev/fast-llm-tutorial/dataset/shard_0_0
    Validation:
      type: memmap
      path: /home/toolkit/dev/fast-llm-tutorial/dataset/shard_0_0

To load Llamba1B model, add the following to the config:

pretrained:
  format: llamba
  path: /mnt/checkpoints/pretrained_models/Llamba-1B

🔍 Type of change

Select all that apply:

  • 🐛 Bug fix (non-breaking change that addresses a specific issue)
  • 🚀 New feature (non-breaking change that adds functionality)
  • ⚠️ Breaking change (a change that could affect existing functionality)
  • 📈 Performance improvement/optimization (improves speed, memory usage, or efficiency)
  • 🛠️ Code refactor (non-functional changes that improve code readability, structure, etc.)
  • 📦 Dependency bump (updates dependencies, including Dockerfile or package changes)
  • 📝 Documentation change (updates documentation, including new content or typo fixes)
  • 🔧 Infrastructure/Build change (affects build process, CI/CD, or dependencies)

📝 Changes

  • added minimal mamba-1 layer and block
  • added hybrid model and corresponding configs
  • the implementation follows the one from https://github.com/Zyphra/Zamba2 and https://github.com/state-spaces/mamba

✅ Checklist

Make sure the following tasks are completed before submitting the PR:

General

  • 📜 I have read and followed the contributing guidelines.
  • 🏷️ I am using a clear and descriptive PR title that summarizes the key change or feature introduced.
  • 🎉 The functionality is complete, and I have tested the changes.
  • 📝 I have updated the documentation if needed.
  • ⚠️ The change does not introduce any new issues (e.g., runtime warnings, type checker errors, linting problems, unhandled edge cases).
  • 🧩 I have commented my code, especially in hard-to-understand areas.

Dependencies and Configuration

  • 🐋 I have updated the Docker configuration or dependencies, if applicable.
  • 🔄 I have ensured compatibility with the existing setup after dependency changes.

Testing

  • 🧪 I have added or updated tests to cover my changes.
  • ✔️ New and existing tests pass locally with my changes.
  • 🚦 I have tested these changes on GPUs and verified training stability.
  • 🏋️ I have tested the changes on realistic training workloads, if applicable.

Performance Impact

  • 📊 I have run benchmarks where applicable to evaluate the performance impact.
  • ✅ The benchmarks show no performance regression.
  • 🚀 The benchmarks indicate a potential performance improvement.
  • ⚠️ The benchmarks indicate a potential performance degradation.
  • 📈 I have provided benchmark results and detailed any performance impact below, if applicable.

🗒️ Additional Notes

  • currently some parameters that are used for defining hybrid model's architecture are in the transformer config, e.g. num_layers, but they should probably moved to higher level configs at some point

@oleksost oleksost marked this pull request as draft March 20, 2025 02:47
@oleksost oleksost changed the title Mamba 1 blocks [feat] Mamba 1 blocks Mar 20, 2025
@tscholak tscholak added the enhancement New feature or request label Mar 23, 2025
@oleksost oleksost changed the title [feat] Mamba 1 blocks [feat] Hybrid Mamba-1 model Mar 31, 2025
@oleksost oleksost requested a review from tscholak March 31, 2025 13:06
@oleksost oleksost requested a review from jlamypoirier March 31, 2025 17:22
@tscholak tscholak marked this pull request as ready for review March 31, 2025 20:01
@oleksost oleksost changed the title [feat] Hybrid Mamba-1 model [feat] Hybrid Mamba model with mamba and discrete mamba 2 layers Mar 31, 2025
@oleksost oleksost requested a review from jlamypoirier April 11, 2025 15:00


@config_class()
class SSMArchitectureConfig(BaseModelArchitectureConfig):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please adjust field names for our naming conventions.

hint=FieldHint.core,
)

dt_rank: str | int = Field(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use None for derived defaults. dt_rank: int = Field(default=None, ...

return init_


class MambaLayer(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work with TP, need to explicitly prevent. (Not sure about PP).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we're gonna have TP eventually, not in scope for this PR though

@tscholak
Copy link
Collaborator

@jlamypoirier, naming-wise, "layer" implies a single functional unit, like attention or an MLP. But here, the repeated unit is a composite: It includes a mixer (attention, SSM, etc.), normalization, residuals, and post-mixer processing (MLPs, MoEs, etc.). This structure isn't atomic. It's multiple layers stitched into a reusable computation unit, which is typically referred to as a "block" in other model families (e.g., ResNet, Swin, and even “Transformer blocks” in many papers and blogs).
Now that we're supporting architectures beyond transformers (Mamba, etc.), the term “block” avoids misleading assumptions:

  1. "Layer" strongly implies a fixed layout rooted in the transformer design.
  2. "Block" reflects the actual structure: a swappable, composite computation unit.

Keeping the name BaseBlock lets us consistently subclass it for Mamba, attention, etc., without implying that all of them are "layers" in the same architectural tradition. In this sense, "TransformerLayer" can be a Block instance, but not every Block is a TransformerLayer.

Also, and I find this the most important argument: internally and casually, we already call them blocks. Making the code match mental models reduces friction.

@jlamypoirier
Copy link
Collaborator

@tscholak Sounds reasonable, problem s these things are called "layers" everywhere else in Fast-LLM. Should we think about renaming these too?

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks already very good to me, thanks @oleksost
This should go in asap!
I had a few comments and suggestions. Mostly, I think we don't want to be too picky with this PR at this point because we're going to be actively working on improving many parts of this anyway in the next weeks.

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we can ignore some issues with the current SSM implementation, if we handle it properly. This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers. We still need to fix code changes outside the model though (transformers.py and set_nested_dict_value)

Also keep in mind that future modifications may break experiment configs, pretrained models and checkpoints, hence the importance of getting good config and parameter structures as soon as possible.

@@ -1,5 +1,6 @@
import logging
import typing
from abc import ABC, abstractmethod
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Import

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good moment to reflect on the pattern.

Right now, Fast-LLM expects contributors to internalize a set of nuanced import rules (see https://servicenow.github.io/Fast-LLM/contributing/style-guide/#imports) that go beyond what most Python projects require. That may have worked at some point, but it doesn't scale. New contributors can't memorize this, and even returning ones keep tripping over it.

If this style is important, it needs to be enforced and automatically fixable through linting or a pre-commit hook. If it can't be, we should let go of it. Patterns that can't be learned quickly or applied automatically create friction and slow down the team.

@jlamypoirier, could you file a ticket outlining what it would take to automate this rule? That's the only sustainable way forward.

@tscholak
Copy link
Collaborator

This means clear warnings that the model is experimental and may break at any point, ex. at model runtime and/or in file headers.

@oleksost, can you emit a warning in the logger when someone tries to instantiate the config for this model class? that should be enough for now.

Copy link
Collaborator

@jlamypoirier jlamypoirier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some minor suggestions, otherwise looks ok to merge. I'll leave final approval to @tscholak

@@ -43,6 +43,7 @@ class ActivationType(str, enum.Enum):
silu = "silu"
relu = "relu"
squared_relu = "squared_relu"
identity = "identity"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add support for it in the MLP? (I know it's triton but this one is trivial.)
https://github.com/ServiceNow/Fast-LLM/blob/main/fast_llm/functional/triton/mlp.py Or otherwise prevent it in the config?

run_test = MambaLayer is not None and torch.cuda.is_available()


def materialize_meta_tensors(model, tensor_space):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

@tscholak tscholak left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

excellent work @oleksost, very proud of what you achieved here!

@oleksost oleksost merged commit df30991 into main Apr 24, 2025
4 checks passed
@oleksost oleksost deleted the ssm branch April 24, 2025 20:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Llamba support [feat] Support Mamba 2 blocks
3 participants