Skip to content

Commit

Permalink
Add overloads to eqx.combine
Browse files Browse the repository at this point in the history
In the common case wherein all trees have the same structure, the return
should have the same structure.  Helps type checking a lot.
  • Loading branch information
NeilGirdhar authored and patrick-kidger committed Sep 15, 2024
1 parent 05b01c1 commit fab726b
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion equinox/_filters.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any, Optional, Union
from typing import Any, Optional, overload, TypeVar, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -163,6 +163,15 @@ def _is_none(x):
return x is None


_T = TypeVar("_T", bound=PyTree)


@overload
def combine(*pytrees: _T, is_leaf: Optional[Callable[[Any], bool]] = None) -> _T: ...
@overload
def combine(
*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None
) -> PyTree: ...
def combine(
*pytrees: PyTree, is_leaf: Optional[Callable[[Any], bool]] = None
) -> PyTree:
Expand Down

0 comments on commit fab726b

Please sign in to comment.