Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question] Modifying the Static Variable in a Model #806

Open
ahahn2813 opened this issue Aug 20, 2024 · 5 comments
Open

[Question] Modifying the Static Variable in a Model #806

ahahn2813 opened this issue Aug 20, 2024 · 5 comments

Comments

@ahahn2813
Copy link

Hello,

My advisor and I are attempting to do something non-typical with Equinox where we are trying to figure out how to change the static part of the neural network architecture on the fly.

Suppose we seek to learn the best activation function for a single layer of our neural network (all other architecture features are pre-chosen). The line of code below:

params,static = equinox.partition(model,equinox.is_array)

allows one to separate the model into a “static” variable and a “params” variable. The params variable contains all the weights, but the static variable contains the information regarding the activation function for the layer. As we will need to change the activation function when we update the architecture, we would like to know if it is possible to make modifications within the static variable? In other words, is there an easy way to convert the static variable to a params variable and then back to a static variable?

One way we can think of it is to convert part of the static variable into an array so we can modify it, but we do not know how to convert back to the static variable once it has been changed to an array.
Thank you!

@lockwo
Copy link
Contributor

lockwo commented Aug 20, 2024

If you just want to update one member variable of the module, you can just use a tree at:

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable

  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)

net = NN(jnp.ones((10, 10)), jnp.ones(10), jax.nn.relu)
print(net(jnp.ones(10)))
print(net)
net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
print(net(jnp.ones(10)))
print(net)

where the act_fn would be partitioned into static in your above code

@krm9c
Copy link

krm9c commented Aug 20, 2024

Thank you for your response, just so I understand properly. I can define variables inside the class corresponding to each quantity that I want dynamically change. For instance, number of layers, activation function and run a eqx.tree_at loop to identify what these values should be replaced as. For pseudocode

import jax
from jax import numpy as jnp
import equinox as eqx
from typing import Callable

class NN(eqx.Module):
  w: jax.Array
  b: jax.Array
  act_fn: Callable
  width: float

  def __init__(width, act):
         self.act_fn = act
         self.width = width
         
  def reintialize():
        self.w = ...
        self.b = ....
  def __call__(self, x):
    return self.act_fn(self.w @ x + self.b)


net = NN(width = 10, act =  jax.nn.relu)
print(net(jnp.ones(10)))
print(net)

for training network architecture loop:
         
         net = eqx.tree_at(lambda x: x.width, net, 5)
         net = eqx.tree_at(lambda x: x.act_fn, net, jax.nn.sigmoid)
         
         for training network weights loop
                 .....

As long as I have the right variable names within my net class i would be able to assign them on the fly with .tree_at(). The pseudo code might be crude but, have I understood the way you meant it?

@lockwo
Copy link
Contributor

lockwo commented Aug 20, 2024

Sure, that would work. Although it seems like width impacts/determines other variables (such as w), but the code you run would work.

@krm9c
Copy link

krm9c commented Aug 20, 2024

This is exactly what we want to do, on the fly determine the architecture/hyperparameter.

We are trying to build some sort of neural architecture/hyperparameter search setup with equinox. This would be helpful in this regard.

Equinox is a wonderful library. Thank you for maintaining it and working on this. Thank you very much.

@lockwo
Copy link
Contributor

lockwo commented Aug 21, 2024

I just answer a few issues, all credit goes to Patrick

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants