Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

better bug hint when writing a simple neural network in equinox #788

Open
zhengqigao opened this issue Jul 22, 2024 · 1 comment
Open

better bug hint when writing a simple neural network in equinox #788

zhengqigao opened this issue Jul 22, 2024 · 1 comment
Labels
question User queries

Comments

@zhengqigao
Copy link

Hi,

Thanks for the nice package. I am new to equinox. I attempted to write a simple MLP but failed with an error. From the returned information, I am a bit confused on how I should revise my code.

import jax
import jax.numpy as jnp
import equinox as eqx

class MLPeqx(eqx.Module):
    def __init__(self, hidden_dims):
        super().__init__()
        tmp_key = jax.random.split(jax.random.PRNGKey(0), len(hidden_dims) - 1)
        self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
        self.activation = jax.nn.relu

    def __call__(self, x):
        for i in range(len(self.layers) - 1):
            x = self.activation(self.layers[i](x))
        x = self.layers[-1](x)
        return x

MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])

The error I got:

Traceback (most recent call last):
  File "xxxxxx/misc/test1.py", line 18, in <module>
    MLP = MLPeqx(hidden_dims=[1,2,4,4,2,1])
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 548, in __call__
    self = super(_ModuleMeta, initable_cls).__call__(*args, **kwargs)
  File "xxxxxx/python3.9/site-packages/equinox/_better_abstract.py", line 226, in __call__
    self = super().__call__(*args, **kwargs)
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 376, in __init__
    init(self, *args, **kwargs)
  File "xxxxxx/misc/test1.py", line 9, in __init__
    self.layers = [eqx.nn.Linear(hidden_dims[i], hidden_dims[i + 1], key=tmp_key[i]) for i in range(len(hidden_dims) - 1)]
  File "xxxxxx/python3.9/site-packages/equinox/_module.py", line 811, in __setattr__
    raise AttributeError(f"Cannot set attribute {name}")
AttributeError: Cannot set attribute layers

What did I miss?

@lockwo
Copy link
Contributor

lockwo commented Jul 22, 2024

Equinox modules are data classes (https://docs.python.org/3/library/dataclasses.html), so you have to specify the attributes in the class header, see https://docs.kidger.site/equinox/ for example.

@patrick-kidger patrick-kidger added the question User queries label Jul 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

3 participants