Skip to content

Commit

Permalink
Fix __post_init__ detection in Module
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 6, 2024
1 parent 14e2c49 commit dfa4acc
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 2 deletions.
5 changes: 3 additions & 2 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,11 +250,12 @@ def __new__(
else:
assert name == "Module"
has_dataclass_init = True # eqx.Module itself
has_dataclass_post_init = "__post_init__" in cls.__dict__

# Check for a common error. (Check for `_Initable` to avoid duplicate warnings.)
if (
not has_dataclass_init
and hasattr(cls, "__post_init__")
and has_dataclass_post_init
and not issubclass(cls, _Initable)
):
warnings.warn(
Expand Down Expand Up @@ -284,7 +285,7 @@ def __new__(

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
# (Scenario (c) above. Scenarios (a) and (b) are handled later.)
if has_dataclass_init and hasattr(cls, "__post_init__"):
if has_dataclass_init and has_dataclass_post_init:
post_init = cls.__post_init__

@ft.wraps(post_init) # pyright: ignore
Expand Down
36 changes: 36 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,39 @@ class InvalidArr(eqx.Module):
match="A JAX array is being set as static!",
):
InvalidArr((), jnp.ones(10))


def test_multiple_inheritance():
class A(eqx.Module):
def __post_init__(self) -> None:
if hasattr(super(), "__post_init__"):
super().__post_init__() # pyright: ignore

class B(A):
x: jax.Array = eqx.field(init=False)

def __post_init__(self) -> None:
super().__post_init__()
self.x = jnp.zeros(())

class C(A):
pass

class D(C, A):
def __post_init__(self) -> None:
super().__post_init__()

class E(D, B):
pass

E()


def test_init_despite_post_init_in_super():
class A(eqx.Module):
def __post_init__(self) -> None:
pass

class B(A):
def __init__(self) -> None:
pass

0 comments on commit dfa4acc

Please sign in to comment.