diff --git a/equinox/_module.py b/equinox/_module.py index 3c726b80..839dafcd 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -284,7 +284,7 @@ def __new__( # Add support for `eqx.field(converter=...)` when using `__post_init__`. # (Scenario (c) above. Scenarios (a) and (b) are handled later.) - if has_dataclass_init and hasattr(cls, "__post_init__"): + if has_dataclass_init and "__post_init__" in cls.__dict__: post_init = cls.__post_init__ @ft.wraps(post_init) # pyright: ignore @@ -293,17 +293,11 @@ def __post_init__(self, *args, **kwargs): # We want to only convert once, at the top level. # # This check is basically testing whether or not the function we're in - # now (`cls.__post_init__`) is at the top level - # (`self.__class__.__post_init__`). If we are, do conversion. If we're - # not, it's presumably because someone is calling us via `super()` in - # the middle of their own `__post_init__`. No conversion then; their own - # version of this wrapper will do it at the appropriate time instead. - # - # One small foible: we write `cls.__post_init__`, rather than just - # `__post_init__`, to refer to this function. This allows someone else - # to also monkey-patch `cls.__post_init__` if they wish, and this won't - # remove conversion. (Conversion is a at-the-top-level thing, not a - # this-particular-function thing.) + # now (`cls`) is at the top level (`self.__class__`). If we are, do + # conversion. If we're not, it's presumably because someone is calling + # us via `super()` in the middle of their own `__post_init__`. No + # conversion then; their own version of this wrapper will do it at the + # appropriate time instead. # # This top-level business means that this is very nearly the same as # doing conversion in `_ModuleMeta.__call__`. The differences are that @@ -311,11 +305,11 @@ def __post_init__(self, *args, **kwargs): # `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping) # to later monkey-patch `__init__`, and we have our converter run before # their own monkey-patched-in code. - if self.__class__.__post_init__ is cls.__post_init__: + if self.__class__ is _make_initable_wrapper(cls): # Convert all fields currently available. _convert_fields(self, init=True) post_init(self, *args, **kwargs) # pyright: ignore - if self.__class__.__post_init__ is cls.__post_init__: + if self.__class__ is _make_initable_wrapper(cls): # Convert all the fields filled in by `__post_init__` as well. _convert_fields(self, init=False) @@ -377,7 +371,7 @@ def __init__(self, *args, **kwargs): __tracebackhide__ = True init(self, *args, **kwargs) # Same `if` trick as with `__post_init__`. - if self.__class__.__init__ is cls.__init__: + if self.__class__ is _make_initable_wrapper(cls): _convert_fields(self, init=True) _convert_fields(self, init=False) @@ -566,8 +560,7 @@ def __call__(cls, *args, **kwargs): # else it's handled in __setattr__, but that isn't called here. # [Step 1] Modules are immutable -- except during construction. So defreeze # before init. - post_init = getattr(cls, "__post_init__", None) - initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False) + initable_cls = _make_initable_wrapper(cls) # [Step 2] Instantiate the class as normal. self = super(_ActualModuleMeta, initable_cls).__call__(*args, **kwargs) assert not _is_abstract(cls) @@ -792,6 +785,11 @@ def __call__(self, ...): break +def _make_initable_wrapper(cls: _ActualModuleMeta) -> _ActualModuleMeta: + post_init = getattr(cls, "__post_init__", None) + return _make_initable(cls, cls.__init__, post_init, wraps=False) + + @ft.lru_cache(maxsize=128) def _make_initable( cls: _ActualModuleMeta, init, post_init, wraps: bool diff --git a/tests/test_module.py b/tests/test_module.py index 76996e36..46bb4d46 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1169,3 +1169,38 @@ class InvalidArr(eqx.Module): match="A JAX array is being set as static!", ): InvalidArr((), jnp.ones(10)) + + +# https://github.com/patrick-kidger/equinox/issues/832 +def test_cooperative_multiple_inheritance(): + called_a = False + called_b = False + called_d = False + + class A(eqx.Module): + def __post_init__(self) -> None: + nonlocal called_a + called_a = True + + class B(A): + def __post_init__(self) -> None: + nonlocal called_b + called_b = True + super().__post_init__() + + class C(A): + pass + + class D(C, A): + def __post_init__(self) -> None: + nonlocal called_d + called_d = True + super().__post_init__() + + class E(D, B): + pass + + E() + assert called_a + assert called_b + assert called_d