Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan over hidden layers instead of python for-loop #294

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

boris-kuz
Copy link
Contributor

@boris-kuz boris-kuz commented Mar 19, 2023

This is my attempt at #293, not exactly sure if this is the best way of doing things.
Initially, I thought I could simply pass layers[1:-1] to the xs argument of lax.scan, but that would lead scan to pass a list of nn.Linear of reduced dimensionality to f.
I'm particularly unsure if the unflattening is really the best way to do it, but I couldn't find a better way to do it.
Open to suggestions!

I didn't bother adding new tests, as I consider this an implementation detail that should be covered by the existing tests.

@patrick-kidger
Copy link
Owner

Whoops sorry, I missed this one! Thanks for the pull request.

So it should be possible to make this happen more efficiently by batching the linear constructor, e.g. eqx.filter_vmap(Linear)(...); see here.

The only question is ensuring backward compatibility. For example some users have be used to doing something like mlp.layers[-1] to get access to the final layer (indeed this is an example I use frequently; see here) and changing the internal representation of the MLP would break this. I'm not sure how best to tackle this.

@boris-kuz
Copy link
Contributor Author

No worries, thanks for having a look!

Wouldn't the approach I've taken ensure backwards compatibility, even if it's not the most efficient implementation?
Maybe it would be worth to add something like a FusedMLP class that does the batching of the Linear constructor also? Or have it be opt-in by providing a defaulted constructor argument to the MLP?

@patrick-kidger
Copy link
Owner

Yup, you're right!

Honestly, I'm thinking the best approach may just be to do the backwards-incompatible thing, and bump the version number. I suspect there are few-to-no cases of people actually modifying the hidden layers; probably most such cases are only the final layer. That means we could make the change by setting:

# Inside MLP.__init__
self.layers = [input_layer, _HiddenLayers(...), output_layer]

where _HiddenLayers is another module, which calls eqx.filter_vmap(Linear) at initialisation, and a lax.scan at call time.

WDYT?

@boris-kuz
Copy link
Contributor Author

I think that's a good idea! Plus if you really depend on being able to change the hidden layers, you could just implement your own MLP using Sequential.

Would the __call__ part of this PR have to change?
Since we're only batching the constructor of Linear, it feels like it shouldn't, but I thought I'd check.

@patrick-kidger
Copy link
Owner

I think leaving __call__ as

for layer in self.layers[:-1]:
    x = layer(x)

should probably work.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants