Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initialization of large models on multi-hosts environment #778

Open
kazewong opened this issue Jul 7, 2024 · 9 comments
Open

Initialization of large models on multi-hosts environment #778

kazewong opened this issue Jul 7, 2024 · 9 comments

Comments

@kazewong
Copy link

kazewong commented Jul 7, 2024

Hi all,

I am wondering what is the preferred way to create a model that is too large to fit in a single device

As a reference starting point, if I use data parallelism, I will first create per-device data arrays, and use make_array from single_device_arrays to put them on the global mesh (This is basically following https://jax.readthedocs.io/en/latest/_autosummary/jax.make_array_from_single_device_arrays.html#jax.make_array_from_single_device_arrays)

Since by default the pre-defined modules in equinox will initialize the full set of parameters on every device, I cannot just follow the guide in https://docs.kidger.site/equinox/tricks/#custom-parameter-initialisation and update the parameters to the shared version after I create the model.

My current way of bypassing this is to create a wrapper class of the nn.Modules I want to use, so I can create the sharded version of the parameters on each device, and then combine them as I would for the data parallelism case.

Here is a minimal example for wrapping the eqx.nn.Linear class
https://gist.github.com/kazewong/c976b48c5870d866496740341382acb5

Since the multi-host interface in Jax is still experimental, we probably don't want to put too much of this into the equinox core code. To make creating large models easier now, I think a wrapper class or a decorator is probably the easiest way, but I want to see what people think about this before submitting a PR. @patrick-kidger

@patrick-kidger
Copy link
Owner

So one way to do this with the current API is to create the model 'skeleton' via

eqx.filter_eval_shape(SomeModule, ...)

and then fill in the parameters with eqx.tree_at afterwards.

I do this when translating model weights from PyTorch, for example.

I'm not sure what the cleanest way of doing this in general is. Ideally it wouldn't require wrapping the constructors. These can be thought of as functions that return pytrees, so something respecting that in some way would be ideal.

@kazewong
Copy link
Author

kazewong commented Jul 7, 2024

I see! I tried using filter_eval_shape and it does work for me. Do you think it is worth making a short tutorial notebook for this? I have the script so converting it into a notebook shouldn't take too long. Let me know if this is useful then I can work on a PR.

@patrick-kidger
Copy link
Owner

If this ends up being the best way to do it, then yes! I think having an example would be really useful.

Right now I'm not completely convinced it is the best way, though. It's certainly a lot of work. I'm wondering if there's some way we can directly create the arrays in the right way by wrapping the constructor in a filtered/tree-map'd version of the usual JAX operations here. (make_array from single_device_arrays etc.) Right now I've simply not thought about it yet!

@kazewong
Copy link
Author

I agree if there is an ergonomic way to initialize a shared model, that would be great. I am trying to figure out how to best do this, and here are some thoughts:

The ideal scenario is we have a function that does the following

def init_shard_model(ModelClass, mesh, sharding, ...) -> model:
...

I have so far run into two complications:

  1. Sharding internals of a model differently. Take a linear layer with 512 inputs and 256 outputs as an example, the weight matrix will have shape (256, 512) and the bias vector will have shape(256,). Say somehow I want to shard the weight matrix across its second axis, and the bias vector across its first axis, where should I give that information? Right now I am defining a sharding function for each kind of layer, so it is set there
  2. Sharding structured arrays. Say my model has an upper triangular matrix as its weight at initialization, and I want to shard it along one axis. This means when I construct the local arrays, I need to pass the information of the process IDs or some bookkeeping parameters to make sure the correct part is initialized on each host before I combine them.

I think this may be why the devs of Levanter rolled their own nn library haliax on top of equinox so they can handle these issues, which is essentially smart wrapping around the equinox layers? @dlwh

I am gonna try a number of things in the coming weeks, any input will be awesome!

@dlwh
Copy link
Sponsor Contributor

dlwh commented Jul 16, 2024

This is indeed one of the main reasons I did names. (There are a few more but I won't belabor them here.) It basically defines this problem away, though it does mean you have to do weird gymnastics when you have square arrays that have semantically identical axes. But it's a price I'm happy to pay. (Obviously happy to have you over in Haliax land, but only if it appeals.)

You could instead follow something more like what flax does, if you wanted. See for example, t5x. It's basically the same as Haliax except the names of the axes aren't carried around with the array and so you occasionally have to sprinkle more wsc's. Obviously I'm partial to Haliax but tastes vary.

Regardless, either way the basic idea is to do initialization inside jit and use with_sharding_constraint and/or out_shardings to ensure things are correct. IMHO re (2) the right way to do any of this is to always always do as much as you can in a "global" way using jit, and only fall back on make_array_from_callback when you absolutely have to (or for data loading).

@patrick-kidger
Copy link
Owner

Thank you both!

To riff on @dlwh's last paragraph -- doing as much as possible inside JIT -- would something like

@jax.jit
def f():
    model = SomeModel(**hyperparams)
    return eqx.filter_shard(model, shard_tree)

work? By describing an appropriate shard_tree then this would also address @kazewong's first point, about sharding each layer differently.

(I can see that this would be a t5x way of doing things rather than a Haliax way of doing things.)

@dlwh
Copy link
Sponsor Contributor

dlwh commented Jul 20, 2024

I'd actually say Haliax and flax/t5x are closer to each other than what you're proposing. (Not saying that's necessarily a bad thing, just trying to clarify.)

Haliax looks like this:

class SomeModel(Module):
    param: hax.NamedArray
    # i use static methods but whatever
    def __init__(...):
          self.param = hax.random.normal( (Embed, Mlp), ...)  # [Embed, Mlp]

and hax.random.normal is basically:

   def normal(key, shape: AxisSpec, dtype=float):
    shape = ensure_tuple(shape)
    jax_shape = _to_jax_shape(shape)
    jax_array = jrandom.normal(key=key, shape=jax_shape, dtype=dtype)

    # this snippet is in a wrapper. simplifying some
    physical_axes = infer_physical_axes_from_logical(shape)  # uses a global context manager
    jax_array = with_sharding_constraint(jax_array, physical_axes)
    
    return NamedArray(jax_array, shape)

flax/t5x is more like:

class SomeModel(Module):

    @nn.compact
    def __call__(...):
          param = param_with_axes("param", init, (512, 2048), jnp.float32, axes=("embed", "mlp"))

and param_with_axes would have something like this in it (simplifying a bit)

    physical_axes = map_logical_to_physical(global_axis_mapping, axes)  # uses a global context manager
    param = with_sharding_constraint(param, PartitionSpec(*physical_axes)

so, like Haliax, it injects wscs at parameter declaration and also uses an implicit logical-to-physical axis map.

By comparison, you're proposing keeping sharding an explicit outer step that operates on entire modules trees. Haliax also supports doing it that way, but it's not really the default anymore. (In Haliax, when you do it that way, we can still use the map_logical_to_physical function because the module tree has the names in it, so it's relatively painless from a user's perspective. I think this would be harder in flax, but I could be wrong.)

My suspicion is that what you're proposing will prove quite cumbersome compared to keeping sharding at the per-parameter declaration site (as in Flax/t5x or Haliax) for more complex models, but I'm happy to be proven wrong.

@patrick-kidger
Copy link
Owner

Right! I'm just getting at the Flax-vs-Haliax distinction of whether there's a named array object available at call time.

The thing is that I don't see another way to support this kind of per-parameter behaviour in an off-the-shelf library like Equinox. I think if per-parameter behaviour is required then the user usually needs to control the definition of every layer? (Or else use some eval_shape + tree_at hackery to add such support after-the-fact, as above.)

To the turn the above into something actionable. I think I'd like to include an example on this topic as per @kazewong's suggestion, and I'm trying to figure out what approach to recommend. I'd like to include both (a) a pure-Equinox solution (whether that's my latest thing, or an eval_shape + tree_at combination, or something else) and (b) a reference towards Haliax.

(To wax a little philosophical, by the way -- this kind of thing comes up in other scenarios too, such as which parameters to apply a transform to: differentiate/vmap/etc. Now JAX's model here is either to specify things at the call site (e.g. jax.jit(..., static_argnums=...)) or in the call itself (e.g. lax.stop_gradient, wsc), as opposed to putting them in the model definition (e.g. torch.tensor(..., requires_grad=True)). The problem with the latter is that it only works for a fixed set of transformations known ahead-of-time, but isn't extensible to new ones that may become available after the model definition has been written. This extensibility, and JAX consistency, are the reasons Equinox tends to go for top-level pytree-manipultation tools like eqx.partition, rather than in-the-pytree specifiers like (the nonexistent) eqx.field(requires_grad=True).)

@dlwh
Copy link
Sponsor Contributor

dlwh commented Jul 22, 2024

Yeah, makes sense as a philosophy. In my experience, just as with stop_gradient in e.g. softmax, you will need to insert with_sharding_constraints deep inside the "__call__" and not just at the top level, particularly when dealing with Linear in FSDP-ish code. maybe it's fine to expect anyone who cares to thread through enough context to call wsc themselves. That means you maybe won't be able to use the built-in MHA implementation but probably fine.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants