Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Scan over hidden layers instead of python for-loop #294

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions equinox/nn/composed.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,10 @@

import jax.nn as jnn
import jax.random as jrandom
import jax.numpy as jnp
from jaxtyping import Array
from jax import lax
from jax.tree_util import tree_flatten

from ..custom_types import PRNGKey
from ..module import Module, static_field
Expand Down Expand Up @@ -101,6 +104,35 @@ def __init__(
self.activation = activation
self.final_activation = final_activation

def _scan_hidden_layers(self, x):
def step(inp, layer_weights):
weight = layer_weights[:, :-1]
bias = layer_weights[:, -1].T
layer = self.layers[1]._tree_unflatten(
[
["weight", "bias"],
["in_features", "out_features", "use_bias"],
[
self.layers[1].in_features,
self.layers[1].out_features,
self.layers[1].use_bias,
],
],
[weight, bias],
)
inp = self.activation(layer(inp))
return inp, None

flattened_layers, _ = tree_flatten(self.layers[1:-1])
concatenated_weight_bias = [
jnp.concatenate([weight, bias.reshape(-1, 1)], axis=1)
for weight, bias in zip(flattened_layers[::2], flattened_layers[1::2])
]
stacked_weights = jnp.stack(concatenated_weight_bias)

x, _ = lax.scan(step, x, stacked_weights)
return x

def __call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array:
"""**Arguments:**

Expand All @@ -113,9 +145,14 @@ def __call__(self, x: Array, *, key: Optional[PRNGKey] = None) -> Array:

A JAX array with shape `(out_size,)`. (Or shape `()` if `out_size="scalar"`.)
"""
for layer in self.layers[:-1]:
x = layer(x)

if len(self.layers) > 1:
x = self.layers[0](x)
x = self.activation(x)

if len(self.layers) > 2:
x = self._scan_hidden_layers(x)

x = self.layers[-1](x)
x = self.final_activation(x)
return x
Expand Down