Skip to content

Commit

Permalink
Fixes cooperative multiple inheritance __post_init__
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Sep 6, 2024
1 parent 14e2c49 commit dae889d
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 17 deletions.
32 changes: 15 additions & 17 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def __new__(

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
# (Scenario (c) above. Scenarios (a) and (b) are handled later.)
if has_dataclass_init and hasattr(cls, "__post_init__"):
if has_dataclass_init and "__post_init__" in cls.__dict__:
post_init = cls.__post_init__

@ft.wraps(post_init) # pyright: ignore
Expand All @@ -293,29 +293,23 @@ def __post_init__(self, *args, **kwargs):
# We want to only convert once, at the top level.
#
# This check is basically testing whether or not the function we're in
# now (`cls.__post_init__`) is at the top level
# (`self.__class__.__post_init__`). If we are, do conversion. If we're
# not, it's presumably because someone is calling us via `super()` in
# the middle of their own `__post_init__`. No conversion then; their own
# version of this wrapper will do it at the appropriate time instead.
#
# One small foible: we write `cls.__post_init__`, rather than just
# `__post_init__`, to refer to this function. This allows someone else
# to also monkey-patch `cls.__post_init__` if they wish, and this won't
# remove conversion. (Conversion is a at-the-top-level thing, not a
# this-particular-function thing.)
# now (`cls`) is at the top level (`self.__class__`). If we are, do
# conversion. If we're not, it's presumably because someone is calling
# us via `super()` in the middle of their own `__post_init__`. No
# conversion then; their own version of this wrapper will do it at the
# appropriate time instead.
#
# This top-level business means that this is very nearly the same as
# doing conversion in `_ModuleMeta.__call__`. The differences are that
# (a) that wouldn't allow us to convert fields before the user-provided
# `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping)
# to later monkey-patch `__init__`, and we have our converter run before
# their own monkey-patched-in code.
if self.__class__.__post_init__ is cls.__post_init__:
if self.__class__ is _make_initable_wrapper(cls):
# Convert all fields currently available.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs) # pyright: ignore
if self.__class__.__post_init__ is cls.__post_init__:
if self.__class__ is _make_initable_wrapper(cls):
# Convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)

Expand Down Expand Up @@ -377,7 +371,7 @@ def __init__(self, *args, **kwargs):
__tracebackhide__ = True
init(self, *args, **kwargs)
# Same `if` trick as with `__post_init__`.
if self.__class__.__init__ is cls.__init__:
if self.__class__ is _make_initable_wrapper(cls):
_convert_fields(self, init=True)
_convert_fields(self, init=False)

Expand Down Expand Up @@ -566,8 +560,7 @@ def __call__(cls, *args, **kwargs):
# else it's handled in __setattr__, but that isn't called here.
# [Step 1] Modules are immutable -- except during construction. So defreeze
# before init.
post_init = getattr(cls, "__post_init__", None)
initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False)
initable_cls = _make_initable_wrapper(cls)
# [Step 2] Instantiate the class as normal.
self = super(_ActualModuleMeta, initable_cls).__call__(*args, **kwargs)
assert not _is_abstract(cls)
Expand Down Expand Up @@ -792,6 +785,11 @@ def __call__(self, ...):
break


def _make_initable_wrapper(cls: _ActualModuleMeta) -> _ActualModuleMeta:
post_init = getattr(cls, "__post_init__", None)
return _make_initable(cls, cls.__init__, post_init, wraps=False)


@ft.lru_cache(maxsize=128)
def _make_initable(
cls: _ActualModuleMeta, init, post_init, wraps: bool
Expand Down
35 changes: 35 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,38 @@ class InvalidArr(eqx.Module):
match="A JAX array is being set as static!",
):
InvalidArr((), jnp.ones(10))


# https://github.com/patrick-kidger/equinox/issues/832
def test_cooperative_multiple_inheritance():
called_a = False
called_b = False
called_d = False

class A(eqx.Module):
def __post_init__(self) -> None:
nonlocal called_a
called_a = True

class B(A):
def __post_init__(self) -> None:
nonlocal called_b
called_b = True
super().__post_init__()

class C(A):
pass

class D(C, A):
def __post_init__(self) -> None:
nonlocal called_d
called_d = True
super().__post_init__()

class E(D, B):
pass

E()
assert called_a
assert called_b
assert called_d

0 comments on commit dae889d

Please sign in to comment.