Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Native Model Parallellism in equinox #825

Open
neel04 opened this issue Aug 30, 2024 · 9 comments
Open

Native Model Parallellism in equinox #825

neel04 opened this issue Aug 30, 2024 · 9 comments

Comments

@neel04
Copy link

neel04 commented Aug 30, 2024

Oftentimes, one wants to do a more general n-way data parallelism, m-way model parallelism as helpfully explained in the official JAX docs.

Here, the common convention is to alternatively shard the layers, as laid out in the Megatron paper. The linked JAX example also uses this:

W2 = jax.device_put(W2, NamedSharding(mesh, P(None, 'model')))
...
W3 = jax.device_put(W3, NamedSharding(mesh, P('model', None)))

Which reduces the communication required.

We don't really have an API for that in equinox. While an external library might be a better fit, ultimately I this is such a common usecase that is should be a core feature IMO.

Scalax has a rule based system where it tries to guess the "correct" axis, like the FSDPShardingRule and for a while I was thinking of simply porting it to work with eqx, but it was proving rather hairy - its doing everything explicitly, because the primary usecase is for operating with flax, and thus requires quite a bit of fiddling, plumbing and replacing things with filtered versions that I ultimately gave up.

I think eqx should adopt a simpler, flexible API system wherein one can configure sharding through a configurable system and be able to apply it to arbitary PyTrees, without needing to explicitly provide sharding for every leaf.

What do you think?

@patrick-kidger
Copy link
Owner

So this is something which Haliax, a downstream library of Equinox, explicitly adds.

I'd be open to open suggestions on how to accomplish something similar within Equinox, although I'm also concious of not wanting to step on Haliax's toes.

@neel04
Copy link
Author

neel04 commented Aug 30, 2024

The problem with depending on downstream libs is that often the maintainers move on with life and are unable to contribute as much for reasons - so having crucial features mostly tied to the core framework would make equinox better equipped for researchers, who can then extend its capabilities through wrappers without having to implement everything from scratch.

(Not to mention that porting an equinox codebase to haliax requires a non-trivial amount of effort)

I'm not sure what a good API here is, tbh. Personally, I'm strongly considering spinning off my own lib for parallelization w/ equinox, but I'd rather prefer to see some sort of interface for autoparallelization within equinox itself.

I wonder - for a quick and dirty hack that is compatible with current codebases, could we just somehow integrate a jmp-like utility API, so atleast its to port computations to being sharded. and allow us to do alternating-sharding-between-weights by maintaining some state and explicitly specifying a "sharding policy"?

@dlwh
Copy link
Sponsor Contributor

dlwh commented Aug 30, 2024

(Disclaimer: Haliax is mine.)

I really think you have to do something like what Haliax (or really even Flax) does in the general case, which is associate semantic names with at least a subset of axes and then map those to the physical mesh axes. In Haliax I decided to go "all in" on names (which is not proving to be super popular) but there is a middle ground like flax where you could use a jmp-style object (or global state like in flax) to hold the semantic-to-physical mapping and then (like flax) make a shard_with_names(tensor, ("fsdp", "...")) fn.

(Totally get where you're coming from w.r.t. people moving on. I will point out that I still merge PRs into my last "big" library Breeze and it's been going for almost 15 years, so I have a decent track record of not totally abandoning things!)

@dlwh
Copy link
Sponsor Contributor

dlwh commented Aug 30, 2024

Thinking more about it, I had been in the process of refactoring how Haliax worked to be a bit more jmp like anyway (making a "mesh env" that had a mesh and an axis mapping.) With a bit more work I could break it out into a library, if you wanted to collaborate on that.

@patrick-kidger
Copy link
Owner

If the question is directed at me -- happy to help out with anything needed from the Equinox side / if you think it's best to upstream anything into Equinox itself. If you're planning a totally separate library then I think I have enough commitments right now, and will politely decline 😅

@dlwh
Copy link
Sponsor Contributor

dlwh commented Aug 30, 2024

No, I meant @neel04 but would be happy to discuss upstreaming into Equinox. 😄 IMHO the right thing to do is to make it as a separate lib, and iterate until we get it right. Then, if you decide you like it, you can bring it into equinox.

@patrick-kidger
Copy link
Owner

Haha!

In that case, SGTM!

@neel04
Copy link
Author

neel04 commented Aug 30, 2024

(Disclaimer: Haliax is mine.)

I really think you have to do something like what Haliax (or really even Flax) does in the general case, which is associate semantic names with at least a subset of axes and then map those to the physical mesh axes. In Haliax I decided to go "all in" on names (which is not proving to be super popular) but there is a middle ground like flax where you could use a jmp-style object (or global state like in flax) to hold the semantic-to-physical mapping and then (like flax) make a shard_with_names(tensor, ("fsdp", "...")) fn.

(Totally get where you're coming from w.r.t. people moving on. I will point out that I still merge PRs into my last "big" library Breeze and it's been going for almost 15 years, so I have a decent track record of not totally abandoning things!)

Haha I know - this wasn't in reference to you 🙂 Rather, I just feel that approach of leveraging 3rd party support almost always leads to fragmentation.

I have strong opinions about torch's ecosystem precisely because of this - 3rd party libs often don't operate on shared abstractions which makes things tricky to then operate with other libs, leading to unforeseen edge cases and an awful dev experience - where it feels like you're perpetually in 'integration hell'.

I feel like equinox might be slipping into a similar mistake here. Parallelization is the lifeblood of JAX, so it should have first-class support in any JAX based framework. Ideally, as strong as flax's.

Looking at flax's approach, I feel like it's overcomplicating it and moving away from the spirit of equinox.

hmmm.. I wonder - is there any iterative way we can inject sharding_constraint and filter_shard at each step of computation in a Module? Even if it requires some hacks to analyze the __call__, the reward would be a drop-in replacement in any equinox codebase and operate with arbitrary PyTrees, potentially.

I'll have to think more about, but I'm imagining it like a wrapper hook on the forward effectively, wherein we insert sharding calls per some preset policy, but without requiring explicit changes, like casting in jmp...

@neel04
Copy link
Author

neel04 commented Aug 31, 2024

I've been playing around with a couple approaches in this Colab. I've done a minimal implementation of approach #1 there.


For consecutive layers, we need to alternate the sharding to minimize communication, as taken from the Megatron paper.

So, we need a way to maintain state and track the index of the leaf which we are at. A couple naive solutions come up:

  1. Using a tree_map to inject an index each "shardable" leaf. i.e, tree_map with f: Array -> (Array, int) and then during our sharding stage, we can extract that index and based on that information, adjust our sharding.

  2. Or somehow use a jax.lax.scan. The problem here is that the leading axes aren't guranteed to be same which causes problems.

Maybe one can develop some custom map solution that incorporates scan or some other utility somehow and embed state enough to hold an index + other information.

  1. Specify an explicit spec beforehand, wherein we provide our sharding utility a PyTree of (key, sharding) which it can directly address. This is similar to how scalax but is a very brittle and time-consuming approach.

Here's how a sample spec would look:

model_sharding_rule = TreePathShardingRule(
    ('embedding', P('fsdp', 'tp')),
    ('lm_head/kernel', P('tp', 'fsdp')),
    # Megatron style feedfoward sharding
    ('mlp/(up|gate)_proj/kernel', P('fsdp', 'tp')),
    ('mlp/down_proj/kernel', P('tp', 'fsdp')),
    # Attention should be sharded by heads
    ('self_attn/(k|q|v)_proj/kernel', P('fsdp', 'tp')),
    ('self_attn/o_proj/kernel', P('tp', 'fsdp')),
    ('norm', P()),
)

None of these approaches feel right to me. Especially for the approach #1 which is intuitive, but might make a copy of the PyTree.

I can use some combination for my personal codebase, but I think we're just hitting the limitations of the toolset equinox has to offer in terms of manipulating PyTrees.

What do you guys think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants