diff --git a/equinox/_module.py b/equinox/_module.py index 3c726b80..75559c5d 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -250,11 +250,12 @@ def __new__( else: assert name == "Module" has_dataclass_init = True # eqx.Module itself + has_dataclass_post_init = "__post_init__" in cls.__dict__ # Check for a common error. (Check for `_Initable` to avoid duplicate warnings.) if ( not has_dataclass_init - and hasattr(cls, "__post_init__") + and has_dataclass_post_init and not issubclass(cls, _Initable) ): warnings.warn( @@ -284,7 +285,7 @@ def __new__( # Add support for `eqx.field(converter=...)` when using `__post_init__`. # (Scenario (c) above. Scenarios (a) and (b) are handled later.) - if has_dataclass_init and hasattr(cls, "__post_init__"): + if has_dataclass_init and has_dataclass_post_init: post_init = cls.__post_init__ @ft.wraps(post_init) # pyright: ignore diff --git a/tests/test_module.py b/tests/test_module.py index 76996e36..33b0af7c 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1169,3 +1169,39 @@ class InvalidArr(eqx.Module): match="A JAX array is being set as static!", ): InvalidArr((), jnp.ones(10)) + + +def test_multiple_inheritance(): + class A(eqx.Module): + def __post_init__(self) -> None: + if hasattr(super(), "__post_init__"): + super().__post_init__() # pyright: ignore + + class B(A): + x: jax.Array = eqx.field(init=False) + + def __post_init__(self) -> None: + super().__post_init__() + self.x = jnp.zeros(()) + + class C(A): + pass + + class D(C, A): + def __post_init__(self) -> None: + super().__post_init__() + + class E(D, B): + pass + + E() + + +def test_init_despite_post_init_in_super(): + class A(eqx.Module): + def __post_init__(self) -> None: + pass + + class B(A): + def __init__(self) -> None: + pass