diff --git a/docs/conf.py b/docs/conf.py index 99a775fa..12d79ab6 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -50,7 +50,7 @@ "pandas": ("https://pandas.pydata.org/docs/", None), "networkx": ("https://networkx.org/documentation/stable/", None), "jax": ("https://jax.readthedocs.io/en/latest/", None), - "ott": ("https://ott-jax.readthedocs.io/en/latest/", None), + "ott": ("https://ott-jax.readthedocs.io/", None), "matplotlib": ("https://matplotlib.org/stable/", None), "anndata": ("https://anndata.readthedocs.io/en/latest/", None), "scanpy": ("https://scanpy.readthedocs.io/en/latest/", None), @@ -75,6 +75,13 @@ # ignore these classes until ott-jax adds them to their docs ("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"), ("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"), + # https://stackoverflow.com/questions/11417221/sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning + ("py:data", "typing.Union"), + ("py:data", "typing.Optional"), + ("py:data", "typing.Literal"), + ("py:class", "typing.Union"), + ("py:class", "typing.Optional"), + ("py:class", "typing.Literal"), ] # TODO(michalk8): remove once typing has been cleaned-up nitpick_ignore_regex = [ @@ -150,6 +157,7 @@ r"https://doi.org/10.1145/2516971.2516977", r"https://doi.org/10.3390/a13090212", r"https://www.mdpi.com/1999-4893/13/9/212", + r"https://pubmed\.ncbi\.nlm\.nih\.gov/.*", ] exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"] diff --git a/pyproject.toml b/pyproject.toml index efbdf699..afd57d6c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ dependencies = [ "wrapt>=1.13.2", "docrep>=0.3.2", "jax>=0.6.1", - "ott-jax>=0.5.0", + "ott-jax>=0.6.0", "cloudpickle>=2.2.0", "rich>=13.5", "docstring_inheritance>=2.0.0", diff --git a/src/moscot/backends/ott/__init__.py b/src/moscot/backends/ott/__init__.py index 7fdae526..eba507a6 100644 --- a/src/moscot/backends/ott/__init__.py +++ b/src/moscot/backends/ott/__init__.py @@ -1,15 +1,14 @@ from ott.geometry import costs from moscot.backends.ott._utils import sinkhorn_divergence -from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput -from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver +from moscot.backends.ott.output import GraphOTTOutput, OTTOutput +from moscot.backends.ott.solver import GWSolver, SinkhornSolver from moscot.costs import register_cost __all__ = [ "OTTOutput", "GWSolver", "SinkhornSolver", - "NeuralOutput", "sinkhorn_divergence", "GENOTLinSolver", "GraphOTTOutput", diff --git a/src/moscot/backends/ott/_utils.py b/src/moscot/backends/ott/_utils.py index 58bc5b82..c2fc4f55 100644 --- a/src/moscot/backends/ott/_utils.py +++ b/src/moscot/backends/ott/_utils.py @@ -10,7 +10,6 @@ from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud from ott.initializers.linear import initializers as init_lib from ott.initializers.linear import initializers_lr as lr_init_lib -from ott.neural import datasets from ott.solvers import utils as solver_utils from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div @@ -18,6 +17,7 @@ from moscot._types import ArrayLike, ScaleCost_t Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]] +OTDataset = Any # to be removed when neural part is being removed from moscot __all__ = ["sinkhorn_divergence"] @@ -272,7 +272,7 @@ def data_match_fn( class Loader: - def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None): + def __init__(self, dataset: OTDataset, batch_size: int, seed: Optional[int] = None): self.dataset = dataset self.batch_size = batch_size self._rng = np.random.default_rng(seed) diff --git a/src/moscot/backends/ott/output.py b/src/moscot/backends/ott/output.py index ecbd767f..0b7dd8bc 100644 --- a/src/moscot/backends/ott/output.py +++ b/src/moscot/backends/ott/output.py @@ -1,10 +1,8 @@ -from typing import Any, Callable, List, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union import jax import jax.numpy as jnp import numpy as np -import scipy.sparse as sp -from ott.neural.methods.flows.genot import GENOT from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr @@ -12,10 +10,9 @@ import matplotlib.pyplot as plt from moscot._types import ArrayLike, Device_t -from moscot.backends.ott._utils import get_nearest_neighbors -from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput +from moscot.base.output import BaseDiscreteSolverOutput -__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"] +__all__ = ["OTTOutput", "GraphOTTOutput"] class OTTOutput(BaseDiscreteSolverOutput): @@ -242,220 +239,6 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102 return jnp.ones((n,)) -class NeuralOutput(BaseNeuralOutput): - """Output wrapper for GENOT.""" - - def __init__(self, model: GENOT, logs: dict[str, list[float]]): - """Initialize `NeuralOutput`. - - Parameters - ---------- - model : GENOT - The OTT-Jax GENOT model - """ - self._logs = logs - self._model = model - - @property - def logs(self): - """Logs of the training. A dictionary containing what the numeric values are i.e., loss. - - Returns - ------- - dict[str, list[float]] - """ - return self._logs - - def _project_transport_matrix( - self, - src_dist: ArrayLike, - tgt_dist: ArrayLike, - func: Callable[[ArrayLike], ArrayLike], - save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments - batch_size: int = 1024, - k: int = 30, - length_scale: Optional[float] = None, - seed: int = 42, - recall_target: float = 0.95, - aggregate_to_topk: bool = True, - ) -> sp.csr_matrix: - row_indices: List[ArrayLike] = [] - column_indices: List[ArrayLike] = [] - distances_list: List[ArrayLike] = [] - if length_scale is None: - key = jax.random.PRNGKey(seed) - src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))] - tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))] - length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch))) - for index in range(0, len(src_dist), batch_size): - distances, indices = get_nearest_neighbors( - func(src_dist[index : index + batch_size, :]), - tgt_dist, - k, - recall_target=recall_target, - aggregate_to_topk=aggregate_to_topk, - ) - distances = jnp.exp(-((distances / length_scale) ** 2)) - distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1) - distances_list.append(distances.flatten()) - column_indices.append(indices.flatten()) - row_indices.append( - jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist))) - ) - distances = jnp.concatenate(distances_list) - row_indices = jnp.concatenate(row_indices) - column_indices = jnp.concatenate(column_indices) - tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)]) - if save_transport_matrix: - self._transport_matrix = tm - return tm - - def project_to_transport_matrix( # type:ignore[override] - self, - src_cells: ArrayLike, - tgt_cells: ArrayLike, - condition: ArrayLike = None, - save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments - batch_size: int = 1024, - k: int = 30, - length_scale: Optional[float] = None, - seed: int = 42, - recall_target: float = 0.95, - aggregate_to_topk: bool = True, - ) -> sp.csr_matrix: - """Project conditional neural OT map onto cells. - - In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells, - but a cell can also be mapped to a location between two cells. This function computes - a pseudo-transport matrix considering the neighborhood of where a cell is mapped to. - Therefore, a neighborhood graph of `k` target cells is computed around each transported cell - of the source distribution. The assignment likelihood of each mapped cell to the target cells is then - computed with a Gaussian kernel with parameter `length_scale`. - - Parameters - ---------- - condition - Condition `src_cells` correspond to. - src_cells - Cells which are to be mapped. - tgt_cells - Cells from which the neighborhood graph around the mapped `src_cells` are computed. - forward - Whether to map cells based on the forward transport map or backward transport map. - save_transport_matrix - Whether to save the transport matrix. - batch_size - Number of data points in the source distribution the neighborhood graph is computed - for in parallel. - k - Number of neighbors to construct the k-nearest neighbor graph of a mapped cell. - length_scale - Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`, - `length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the - mapped source and target distribution. - seed - Random seed for sampling the pairs of distributions for computing the variance in case `length_scale` - is `None`. - recall_target - Recall target for the approximation. - aggregate_to_topk - When true, the nearest neighbor aggregates approximate results to the top-k in sorted order. - When false, returns the approximate results unsorted. - In this case, the number of the approximate results is implementation defined and is greater or - equal to the specified k. - - Returns - ------- - The projected transport matrix. - """ - src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) - conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition) - push = self.push if condition is None else conditioned_fn - func, src_dist, tgt_dist = (push, src_cells, tgt_cells) - return self._project_transport_matrix( - src_dist=src_dist, - tgt_dist=tgt_dist, - func=func, - save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments - batch_size=batch_size, - k=k, - length_scale=length_scale, - seed=seed, - recall_target=recall_target, - aggregate_to_topk=aggregate_to_topk, - ) - - def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: - """Push distribution `x` conditioned on condition `cond`. - - Parameters - ---------- - x - Distribution to push. - cond - Condition of conditional neural OT. - - Returns - ------- - Pushed distribution. - """ - if isinstance(x, (bool, int, float, complex)): - raise ValueError("Expected array, found scalar value.") - if x.ndim not in (1, 2): - raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.") - return self._apply_forward(x, cond=cond) - - def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: - return self._model.transport(x, condition=cond) - - @property - def is_linear(self) -> bool: # noqa: D102 - return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT - - @property - def shape(self) -> Tuple[int, int]: - """%(shape)s.""" - raise NotImplementedError() - - def to( - self, - device: Optional[Device_t] = None, - ) -> "NeuralOutput": - """Transfer the output to another device or change its data type. - - Parameters - ---------- - device - If not `None`, the output will be transferred to `device`. - - Returns - ------- - The output on a saved on `device`. - """ - # # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep - # if isinstance(device, str) and ":" in device: - # device, ix = device.split(":") - # idx = int(ix) - # else: - # idx = 0 - - # if not isinstance(device, jax.Device): - # try: - # device = jax.devices(device)[idx] - # except IndexError as err: - # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err - - # out = jax.device_put(self._model, device) - # return NeuralOutput(out) - return self # TODO(ilan-gold) move model to device - - @property - def converged(self) -> bool: - """%(converged)s.""" - # always return True for now - return True - - class GraphOTTOutput(OTTOutput): """Output of :term:`OT` problems with a graph geometry in the linear term. diff --git a/src/moscot/backends/ott/solver.py b/src/moscot/backends/ott/solver.py index 50ca82fa..4f50ee87 100644 --- a/src/moscot/backends/ott/solver.py +++ b/src/moscot/backends/ott/solver.py @@ -1,12 +1,9 @@ import abc -import functools import inspect -import math import types from typing import ( Any, Hashable, - List, Literal, Mapping, NamedTuple, @@ -17,21 +14,13 @@ Union, ) -import optax - import jax import jax.numpy as jnp -import numpy as np from ott.geometry import costs, epsilon_scheduler, geodesic, geometry, pointcloud -from ott.neural.datasets import OTData, OTDataset -from ott.neural.methods.flows import dynamics, genot -from ott.neural.networks.layers import time_encoder -from ott.neural.networks.velocity_field import VelocityField from ott.problems.linear import linear_problem from ott.problems.quadratic import quadratic_problem from ott.solvers.linear import sinkhorn, sinkhorn_lr from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr -from ott.solvers.utils import uniform_sampler from moscot._logging import logger from moscot._types import ( @@ -43,23 +32,20 @@ ) from moscot.backends.ott._utils import ( InitializerResolver, - Loader, - MultiLoader, _instantiate_geodesic_cost, alpha_to_fused_penalty, check_shapes, convert_scipy_sparse, - data_match_fn, densify, ensure_2d, ) -from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput +from moscot.backends.ott.output import GraphOTTOutput, OTTOutput from moscot.base.problems._utils import TimeScalesHeatKernel from moscot.base.solver import OTSolver from moscot.costs import get_cost -from moscot.utils.tagged_array import DistributionCollection, TaggedArray +from moscot.utils.tagged_array import TaggedArray -__all__ = ["SinkhornSolver", "GWSolver", "GENOTLinSolver"] +__all__ = ["SinkhornSolver", "GWSolver"] OTTSolver_t = Union[ sinkhorn.Sinkhorn, @@ -516,193 +502,3 @@ def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: problem_kwargs -= {"geom_xx", "geom_yy", "geom_xy", "fused_penalty"} problem_kwargs |= {"alpha"} return geom_kwargs | problem_kwargs, {"epsilon"} - - -class GENOTLinSolver(OTSolver[OTTOutput]): - """Solver class for genot.GENOT linear :cite:`klein2023generative`.""" - - def __init__(self, **kwargs: Any) -> None: - """Initiate the class with any kwargs passed to the ott-jax class.""" - super().__init__() - self._train_sampler: Optional[MultiLoader] = None - self._valid_sampler: Optional[MultiLoader] = None - self._neural_kwargs = kwargs - - @property - def problem_kind(self) -> ProblemKind_t: # noqa: D102 - return "linear" - - def _prepare( # type: ignore[override] - self, - distributions: DistributionCollection[K], - sample_pairs: List[Tuple[Any, Any]], - train_size: float = 0.9, - batch_size: int = 1024, - is_conditional: bool = True, - **kwargs: Any, - ) -> Tuple[MultiLoader, MultiLoader]: - train_loaders = [] - validate_loaders = [] - seed = kwargs.get("seed") - is_aligned = kwargs.get("is_aligned", False) - if train_size == 1.0: - for sample_pair in sample_pairs: - source_key = sample_pair[0] - target_key = sample_pair[1] - src_data = OTData( - lin=distributions[source_key].xy, - condition=distributions[source_key].conditions if is_conditional else None, - ) - tgt_data = OTData( - lin=distributions[target_key].xy, - condition=distributions[target_key].conditions if is_conditional else None, - ) - dataset = OTDataset(src_data=src_data, tgt_data=tgt_data, seed=seed, is_aligned=is_aligned) - loader = Loader(dataset, batch_size=batch_size, seed=seed) - train_loaders.append(loader) - validate_loaders.append(loader) - else: - if train_size > 1.0 or train_size <= 0.0: - raise ValueError("Invalid train_size. Must be: 0 < train_size <= 1") - - seed = kwargs.get("seed", 0) - for sample_pair in sample_pairs: - source_key = sample_pair[0] - target_key = sample_pair[1] - source_data: ArrayLike = distributions[source_key].xy - target_data: ArrayLike = distributions[target_key].xy - source_split_data = self._split_data( - source_data, - conditions=distributions[source_key].conditions, - train_size=train_size, - seed=seed, - a=distributions[source_key].a, - b=distributions[source_key].b, - ) - target_split_data = self._split_data( - target_data, - conditions=distributions[target_key].conditions, - train_size=train_size, - seed=seed, - a=distributions[target_key].a, - b=distributions[target_key].b, - ) - src_data_train = OTData( - lin=source_split_data.data_train, - condition=source_split_data.conditions_train if is_conditional else None, - ) - tgt_data_train = OTData( - lin=target_split_data.data_train, - condition=target_split_data.conditions_train if is_conditional else None, - ) - train_dataset = OTDataset( - src_data=src_data_train, tgt_data=tgt_data_train, seed=seed, is_aligned=is_aligned - ) - train_loader = Loader(train_dataset, batch_size=batch_size, seed=seed) - src_data_validate = OTData( - lin=source_split_data.data_valid, - condition=source_split_data.conditions_valid if is_conditional else None, - ) - tgt_data_validate = OTData( - lin=target_split_data.data_valid, - condition=target_split_data.conditions_valid if is_conditional else None, - ) - validate_dataset = OTDataset( - src_data=src_data_validate, tgt_data=tgt_data_validate, seed=seed, is_aligned=is_aligned - ) - validate_loader = Loader(validate_dataset, batch_size=batch_size, seed=seed) - train_loaders.append(train_loader) - validate_loaders.append(validate_loader) - source_dim = self._neural_kwargs.get("input_dim", 0) - target_dim = source_dim - condition_dim = self._neural_kwargs.get("cond_dim", 0) - # TODO(ilan-gold): What are reasonable defaults here? - neural_vf = VelocityField( - output_dims=[*self._neural_kwargs.get("velocity_field_output_dims", []), target_dim], - condition_dims=( - self._neural_kwargs.get("velocity_field_condition_dims", [source_dim + condition_dim]) - if is_conditional - else None - ), - hidden_dims=self._neural_kwargs.get("velocity_field_hidden_dims", [1024, 1024, 1024]), - time_dims=self._neural_kwargs.get("velocity_field_time_dims", None), - time_encoder=self._neural_kwargs.get( - "velocity_field_time_encoder", functools.partial(time_encoder.cyclical_time_encoder, n_freqs=1024) - ), - ) - seed = self._neural_kwargs.get("seed", 0) - rng = jax.random.PRNGKey(seed) - data_match_fn_kwargs = self._neural_kwargs.get( - "data_match_fn_kwargs", - {} if "data_match_fn" in self._neural_kwargs else {"epsilon": 1e-1, "tau_a": 1.0, "tau_b": 1.0}, - ) - time_sampler = self._neural_kwargs.get("time_sampler", uniform_sampler) - optimizer = self._neural_kwargs.get("optimizer", optax.adam(learning_rate=1e-4)) - self._solver = genot.GENOT( - vf=neural_vf, - flow=self._neural_kwargs.get( - "flow", - dynamics.ConstantNoiseFlow(0.1), - ), - data_match_fn=functools.partial( - self._neural_kwargs.get("data_match_fn", data_match_fn), typ="lin", **data_match_fn_kwargs - ), - source_dim=source_dim, - target_dim=target_dim, - condition_dim=condition_dim if is_conditional else None, - optimizer=optimizer, - time_sampler=time_sampler, - rng=rng, - latent_noise_fn=self._neural_kwargs.get("latent_noise_fn", None), - **self._neural_kwargs.get("velocity_field_train_state_kwargs", {}), - ) - return ( - MultiLoader(datasets=train_loaders, seed=seed), - MultiLoader(datasets=validate_loaders, seed=seed), - ) - - def _split_data( # TODO: adapt for Gromov terms - self, - x: ArrayLike, - conditions: Optional[ArrayLike], - train_size: float, - seed: int, - a: Optional[ArrayLike] = None, - b: Optional[ArrayLike] = None, - ) -> SingleDistributionData: - n_samples_x = x.shape[0] - n_train_x = math.ceil(train_size * n_samples_x) - rng = np.random.default_rng(seed) - x = rng.permutation(x) - if a is not None: - a = rng.permutation(a) - if b is not None: - b = rng.permutation(b) - - return SingleDistributionData( - data_train=x[:n_train_x], - data_valid=x[n_train_x:], - conditions_train=conditions[:n_train_x] if conditions is not None else None, - conditions_valid=conditions[n_train_x:] if conditions is not None else None, - a_train=a[:n_train_x] if a is not None else None, - a_valid=a[n_train_x:] if a is not None else None, - b_train=b[:n_train_x] if b is not None else None, - b_valid=b[n_train_x:] if b is not None else None, - ) - - @property - def solver(self) -> genot.GENOT: - """Underlying optimal transport solver.""" - return self._solver - - @classmethod - def _call_kwargs(cls) -> Tuple[Set[str], Set[str]]: - return {"batch_size", "train_size", "trainloader", "validloader", "seed"}, {} # type: ignore[return-value] - - def _solve(self, data_samplers: Tuple[MultiLoader, MultiLoader]) -> NeuralOutput: # type: ignore[override] - seed = self._neural_kwargs.get("seed", 0) # TODO(ilan-gold): unify rng hadnling like OTT tests - rng = jax.random.PRNGKey(seed) - logs = self.solver( - data_samplers[0], n_iters=self._neural_kwargs.get("n_iters", 100), rng=rng - ) # TODO(ilan-gold): validation and figure out defualts - return NeuralOutput(self.solver, logs) diff --git a/src/moscot/backends/utils.py b/src/moscot/backends/utils.py index 988e0541..12ffff23 100644 --- a/src/moscot/backends/utils.py +++ b/src/moscot/backends/utils.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Callable, Literal, Tuple, Union from moscot import _registry from moscot._types import ProblemKind_t @@ -9,8 +9,8 @@ __all__ = ["get_solver", "register_solver", "get_available_backends"] register_solver_t = Callable[ - [Literal["linear", "quadratic"], Optional[Literal["GENOTLinSolver"]]], - Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"], + [Literal["linear", "quadratic"]], + Union["ott.SinkhornSolver", "ott.GWSolver"], ] @@ -21,13 +21,13 @@ def get_solver(problem_kind: ProblemKind_t, *, backend: str = "ott", return_clas """TODO.""" if backend not in _REGISTRY: raise ValueError(f"Backend `{backend!r}` is not available.") - solver_class = _REGISTRY[backend](problem_kind, solver_name=kwargs.pop("solver_name", None)) + solver_class = _REGISTRY[backend](problem_kind) return solver_class if return_class else solver_class(**kwargs) def register_solver( backend: str, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: +) -> Union["ott.SinkhornSolver", "ott.GWSolver"]: """Register a solver for a specific backend. Parameters @@ -45,18 +45,14 @@ def register_solver( @register_solver("ott") def _( problem_kind: Literal["linear", "quadratic"], - solver_name: Optional[Literal["GENOTLinSolver"]] = None, -) -> Union["ott.SinkhornSolver", "ott.GWSolver", "ott.GENOTLinSolver"]: +) -> Union["ott.SinkhornSolver", "ott.GWSolver"]: from moscot.backends import ott if problem_kind == "linear": - if solver_name == "GENOTLinSolver": - return ott.GENOTLinSolver # type: ignore[return-value] - if solver_name is None: - return ott.SinkhornSolver # type: ignore[return-value] + return ott.SinkhornSolver # type: ignore[return-value] if problem_kind == "quadratic": return ott.GWSolver # type: ignore[return-value] - raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`, {solver_name} problem.") + raise NotImplementedError(f"Unable to create solver for `{problem_kind!r}`.") def get_available_backends() -> Tuple[str, ...]: diff --git a/src/moscot/base/output.py b/src/moscot/base/output.py index 4e10203f..00d8b36b 100644 --- a/src/moscot/base/output.py +++ b/src/moscot/base/output.py @@ -3,7 +3,6 @@ import abc import copy import functools -from abc import abstractmethod from typing import Any, Callable, Iterable, Literal, Optional, Union import numpy as np @@ -13,7 +12,7 @@ from moscot._logging import logger from moscot._types import ArrayLike, Device_t, DTypeLike -__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput", "BaseNeuralOutput"] +__all__ = ["BaseDiscreteSolverOutput", "MatrixSolverOutput"] class BaseSolverOutput(abc.ABC): @@ -394,21 +393,3 @@ def _ones(self, n: int) -> ArrayLike: import jax.numpy as jnp return jnp.ones((n,), dtype=self.transport_matrix.dtype) - - -class BaseNeuralOutput(BaseSolverOutput, abc.ABC): - """Base class for output of.""" - - @abstractmethod - def project_to_transport_matrix( - self, - source: Optional[ArrayLike] = None, - target: Optional[ArrayLike] = None, - condition: Optional[ArrayLike] = None, - save_transport_matrix: bool = False, - batch_size: int = 1024, - k: int = 30, - length_scale: Optional[float] = None, - seed: int = 42, - ) -> sp.csr_matrix: - """Project transport matrix.""" diff --git a/src/moscot/neural/base/__init__.py b/src/moscot/neural/base/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/src/moscot/neural/base/problems/__init__.py b/src/moscot/neural/base/problems/__init__.py deleted file mode 100644 index ec15beb2..00000000 --- a/src/moscot/neural/base/problems/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from moscot.neural.base.problems.problem import NeuralOTProblem - -__all__ = ["NeuralOTProblem"] diff --git a/src/moscot/neural/base/problems/problem.py b/src/moscot/neural/base/problems/problem.py deleted file mode 100644 index cc142f98..00000000 --- a/src/moscot/neural/base/problems/problem.py +++ /dev/null @@ -1,243 +0,0 @@ -from typing import ( - Any, - Hashable, - Iterable, - List, - Literal, - Mapping, - Optional, - Sequence, - Tuple, - TypeVar, - Union, -) - -import numpy as np -import pandas as pd - -from anndata import AnnData - -from moscot import backends -from moscot._types import ArrayLike, Device_t -from moscot.base.output import BaseNeuralOutput -from moscot.base.problems._utils import wrap_prepare, wrap_solve -from moscot.base.problems.problem import BaseProblem -from moscot.base.solver import OTSolver -from moscot.utils.subset_policy import ( # type:ignore[attr-defined] - ExplicitPolicy, - Policy_t, - StarPolicy, - SubsetPolicy, - create_policy, -) -from moscot.utils.tagged_array import DistributionCollection, DistributionContainer - -K = TypeVar("K", bound=Hashable) - -__all__ = ["NeuralOTProblem"] - - -class NeuralOTProblem(BaseProblem): # TODO(@MUCDK) check generic types, save and load - """ - Base class for all conditional (nerual) optimal transport problems. - - Parameters - ---------- - adata - Source annotated data object. - kwargs - Keyword arguments for :class:`moscot.base.problems.problem.BaseProblem` - """ - - def __init__( - self, - adata: AnnData, - **kwargs: Any, - ): - super().__init__(**kwargs) - self._adata = adata - - self._distributions: Optional[DistributionCollection[K]] = None # type: ignore[valid-type] - self._policy: Optional[SubsetPolicy[Any]] = None - self._sample_pairs: Optional[List[Tuple[Any, Any]]] = None - - self._solver: Optional[OTSolver[BaseNeuralOutput]] = None - self._solution: Optional[BaseNeuralOutput] = None - - self._a: Optional[str] = None - self._b: Optional[str] = None - - @wrap_prepare - def prepare( - self, - policy_key: str, - policy: Policy_t, - xy: Mapping[str, Any], - xx: Mapping[str, Any], - conditions: Mapping[str, Any], - a: Optional[str] = None, - b: Optional[str] = None, - subset: Optional[Sequence[Tuple[K, K]]] = None, - reference: K = None, - **kwargs: Any, - ) -> "NeuralOTProblem": - """Prepare conditional optimal transport problem. - - Parameters - ---------- - xy - Geometry defining the linear term. If passed as a :class:`dict`, - :meth:`~moscot.utils.tagged_array.TaggedArray.from_adata` will be called. - policy - Policy defining which pairs of distributions to sample from during training. - policy_key - %(key)s - a - Source marginals. - b - Target marginals. - kwargs - Keyword arguments when creating the source/target marginals. - - - Returns - ------- - Self and modifies the following attributes: - TODO. - """ - self._problem_kind = "linear" - self._distributions = DistributionCollection() - self._solution = None - self._policy_key = policy_key - try: - self._distribution_id = pd.Series(self.adata.obs[policy_key]) - except KeyError: - raise KeyError(f"Unable to find data in `adata.obs[{policy_key!r}]`.") from None - - self._policy = create_policy(policy, adata=self.adata, key=policy_key) - if isinstance(self._policy, ExplicitPolicy): - self._policy = self._policy.create_graph(subset=subset) - elif isinstance(self._policy, StarPolicy): - self._policy = self._policy.create_graph(reference=reference) - else: - _ = self.policy.create_graph() # type: ignore[union-attr] - self._sample_pairs = list(self.policy._graph) # type: ignore[union-attr] - - for el in self.policy.categories: # type: ignore[union-attr] - adata_masked = self.adata[self._create_mask(el)] - a_created = self._create_marginals(adata_masked, data=a, source=True, **kwargs) - b_created = self._create_marginals(adata_masked, data=b, source=False, **kwargs) - self.distributions[el] = DistributionContainer.from_adata( # type: ignore[index] - adata_masked, a=a_created, b=b_created, **xy, **xx, **conditions - ) - return self - - @wrap_solve - def solve( - self, - backend: Literal["ott"] = "ott", - solver_name: Literal["GENOTLinSolver"] = "GENOTLinSolver", - device: Optional[Device_t] = None, - **kwargs: Any, - ) -> "NeuralOTProblem": - """Solve optimal transport problem. - - Parameters - ---------- - backend - Which backend to use, see :func:`moscot.backends.utils.get_available_backends`. - device - Device where to transfer the solution, see :meth:`moscot.base.output.BaseNeuralOutput.to`. - kwargs - Keyword arguments for :meth:`moscot.base.solver.BaseSolver.__call__`. - - - Returns - ------- - Self and modifies the following attributes: - - :attr:`solver`: optimal transport solver. - - :attr:`solution`: optimal transport solution. - """ - tmp = next(iter(self.distributions)) # type: ignore[arg-type] - input_dim = self.distributions[tmp].xy.shape[1] # type: ignore[union-attr, index] - cond_dim = self.distributions[tmp].conditions.shape[1] # type: ignore[union-attr, index] - - solver_class = backends.get_solver( - self.problem_kind, solver_name=solver_name, backend=backend, return_class=True - ) - init_kwargs, call_kwargs = solver_class._partition_kwargs(**kwargs) - self._solver = solver_class(input_dim=input_dim, cond_dim=cond_dim, **init_kwargs) - # note that the solver call consists of solver._prepare and solver._solve - sample_pairs = self._sample_pairs if self._sample_pairs is not None else [] - self._solution = self._solver( # type: ignore[misc] - device=device, - distributions=self.distributions, - sample_pairs=self._sample_pairs, - is_conditional=len(sample_pairs) > 1, - **call_kwargs, - ) - - return self - - def _create_marginals( - self, adata: AnnData, *, source: bool, data: Optional[str] = None, **kwargs: Any - ) -> ArrayLike: - if data is True: - marginals = self.estimate_marginals(adata, source=source, **kwargs) - elif data in (False, None): - marginals = np.ones((adata.n_obs,), dtype=float) / adata.n_obs - elif isinstance(data, str): - try: - marginals = np.asarray(adata.obs[data], dtype=float) - except KeyError: - raise KeyError(f"Unable to find data in `adata.obs[{data!r}]`.") from None - return marginals - - def _create_mask(self, value: Union[K, Sequence[K]], *, allow_empty: bool = False) -> ArrayLike: - """Create a mask used to subset the data. - - TODO(@MUCDK): this is copied from SubsetPolicy, consider making this a function. - - Parameters - ---------- - value - Values in the data which determine the mask. - allow_empty - Whether to allow empty mask. - - Returns - ------- - Boolean mask of the same shape as the data. - """ - if isinstance(value, str) or not isinstance(value, Iterable): - mask = self._distribution_id == value - else: - mask = self._distribution_id.isin(value) - if not allow_empty and not np.sum(mask): - raise ValueError("Unable to construct an empty mask, use `allow_empty=True` to override.") - return np.asarray(mask) - - @property - def distributions(self) -> Optional[DistributionCollection[K]]: - """Collection of distributions.""" - return self._distributions - - @property - def adata(self) -> AnnData: - """Source annotated data object.""" - return self._adata - - @property - def solution(self) -> Optional[BaseNeuralOutput]: - """Solution of the optimal transport problem.""" - return self._solution - - @property - def solver(self) -> Optional[OTSolver[BaseNeuralOutput]]: - """Solver of the optimal transport problem.""" - return self._solver - - @property - def policy(self) -> Optional[SubsetPolicy[Any]]: - """Policy used to subset the data.""" - return self._policy diff --git a/src/moscot/neural/problems/__init__.py b/src/moscot/neural/problems/__init__.py deleted file mode 100644 index cd884d36..00000000 --- a/src/moscot/neural/problems/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from moscot.neural.problems.generic import GENOTLinProblem - -__all__ = ["GENOTLinProblem"] diff --git a/src/moscot/neural/problems/generic/__init__.py b/src/moscot/neural/problems/generic/__init__.py deleted file mode 100644 index 657b4ea6..00000000 --- a/src/moscot/neural/problems/generic/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from moscot.neural.problems.generic._generic import GENOTLinProblem - -__all__ = ["GENOTLinProblem"] diff --git a/src/moscot/neural/problems/generic/_generic.py b/src/moscot/neural/problems/generic/_generic.py deleted file mode 100644 index c37a2f75..00000000 --- a/src/moscot/neural/problems/generic/_generic.py +++ /dev/null @@ -1,78 +0,0 @@ -import types -from types import MappingProxyType -from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Type, Union - -from moscot import _constants -from moscot._types import CostKwargs_t, OttCostFn_t, Policy_t -from moscot.neural.base.problems.problem import NeuralOTProblem -from moscot.problems._utils import ( - handle_conditional_attr, - handle_cost_tmp, - handle_joint_attr_tmp, -) - -__all__ = ["GENOTLinProblem"] - - -class GENOTLinProblem(NeuralOTProblem): - """Class for solving Conditional Parameterized Monge Map problems / Conditional Neural OT problems.""" - - def prepare( - self, - key: str, - joint_attr: Union[str, Mapping[str, Any]], - conditional_attr: Union[str, Mapping[str, Any]], - policy: Literal["sequential", "star", "explicit"] = "sequential", - a: Optional[str] = None, - b: Optional[str] = None, - cost: OttCostFn_t = "sq_euclidean", - cost_kwargs: CostKwargs_t = types.MappingProxyType({}), - **kwargs: Any, - ) -> "GENOTLinProblem": - """Prepare the :class:`moscot.problems.generic.GENOTLinProblem`.""" - self.batch_key = key - xy, kwargs = handle_joint_attr_tmp(joint_attr, kwargs) - conditions = handle_conditional_attr(conditional_attr) - xy, xx = handle_cost_tmp(xy=xy, x={}, y={}, cost=cost, cost_kwargs=cost_kwargs) - return super().prepare( - policy_key=key, - policy=policy, - xy=xy, - xx=xx, - conditions=conditions, - a=a, - b=b, - **kwargs, - ) - - def solve( - self, - batch_size: int = 1024, - seed: int = 0, - iterations: int = 25000, # TODO(@MUCDK): rename to max_iterations - valid_freq: int = 50, - valid_sinkhorn_kwargs: Dict[str, Any] = MappingProxyType({}), - train_size: float = 1.0, - **kwargs: Any, - ) -> "GENOTLinProblem": - """Solve.""" - return super().solve( - batch_size=batch_size, - # tau_a=tau_a, # TODO: unbalancedness handler - # tau_b=tau_b, - seed=seed, - n_iters=iterations, - valid_freq=valid_freq, - valid_sinkhorn_kwargs=valid_sinkhorn_kwargs, - train_size=train_size, - solver_name="GENOTLinSolver", - **kwargs, - ) - - @property - def _base_problem_type(self) -> Type[NeuralOTProblem]: - return NeuralOTProblem - - @property - def _valid_policies(self) -> Tuple[Policy_t, ...]: - return _constants.SEQUENTIAL, _constants.EXPLICIT # type: ignore[return-value] diff --git a/tests/data/alignment_solutions.pkl b/tests/data/alignment_solutions.pkl index 145468ed..cb9eda2e 100644 Binary files a/tests/data/alignment_solutions.pkl and b/tests/data/alignment_solutions.pkl differ diff --git a/tests/data/mapping_solutions.pkl b/tests/data/mapping_solutions.pkl index 63a4d0f2..871930fc 100644 Binary files a/tests/data/mapping_solutions.pkl and b/tests/data/mapping_solutions.pkl differ diff --git a/tests/neural/problems/generic/test_conditional_neural_problem.py b/tests/neural/problems/generic/test_conditional_neural_problem.py deleted file mode 100644 index e4cd9b83..00000000 --- a/tests/neural/problems/generic/test_conditional_neural_problem.py +++ /dev/null @@ -1,86 +0,0 @@ -import optax -import pytest - -import numpy as np -from ott.geometry import costs - -import anndata as ad - -from moscot.base.output import BaseSolverOutput -from moscot.neural.base.problems import NeuralOTProblem -from moscot.neural.problems.generic import GENOTLinProblem # type: ignore[attr-defined] -from moscot.utils.tagged_array import DistributionCollection, DistributionContainer -from tests._utils import ATOL, RTOL -from tests.problems.conftest import neurallin_cond_args_1 - - -class TestGENOTLinProblem: - @pytest.mark.fast - def test_prepare(self, adata_time: ad.AnnData): - problem = GENOTLinProblem(adata=adata_time) - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) - assert isinstance(problem, NeuralOTProblem) - assert isinstance(problem.distributions, DistributionCollection) - assert list(problem.distributions.keys()) == [0, 1, 2] - - container = problem.distributions[0] - n_obs_0 = adata_time[adata_time.obs["time"] == 0].n_obs - assert isinstance(container, DistributionContainer) - assert isinstance(container.xy, np.ndarray) - assert container.xy.shape == (n_obs_0, 50) - assert container.xx is None - assert isinstance(container.conditions, np.ndarray) - assert container.conditions.shape == (n_obs_0, 1) - assert isinstance(container.a, np.ndarray) - assert container.a.shape == (n_obs_0,) - assert isinstance(container.b, np.ndarray) - assert container.b.shape == (n_obs_0,) - assert isinstance(container.cost_xy, costs.SqEuclidean) - assert container.cost_xx is None - - @pytest.mark.parametrize("train_size", [0.9, 1.0]) - def test_solve_balanced_no_baseline(self, adata_time: ad.AnnData, train_size: float): # type: ignore[no-untyped-def] # noqa: E501 - problem = GENOTLinProblem(adata=adata_time) - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) - problem = problem.solve(train_size=train_size, **neurallin_cond_args_1) - assert isinstance(problem.solution, BaseSolverOutput) - - def test_reproducibility(self, adata_time: ad.AnnData): - cond_zero_mask = np.array(adata_time.obs["time"] == 0) - pc_tzero = adata_time[cond_zero_mask].obsm["X_pca"] - problem_one = GENOTLinProblem(adata=adata_time) - problem_one = problem_one.prepare( - key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 - ) - problem_one = problem_one.solve(**neurallin_cond_args_1) - problem_two = GENOTLinProblem(adata=adata_time) - problem_two = problem_two.prepare( - key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}, seed=0 - ) - problem_two = problem_two.solve(**neurallin_cond_args_1) - assert np.allclose( - problem_one.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))), - problem_two.solution.push(pc_tzero, cond=np.zeros((cond_zero_mask.sum(), 1))), - rtol=RTOL, - atol=ATOL, - ) - - # def test_pass_arguments(self, adata_time: ad.AnnData): # TODO(ilan-gold) implement this once the OTT PR is settled - # problem = GENOTLinProblem(adata=adata_time) - # adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] - # problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) - # problem = problem.solve(**neurallin_cond_args_1) - - # solver = problem.solver._solver - # for arg, val in neurallin_cond_args_1.items(): - # assert hasattr(solver, val) - # el = getattr(solver, val)[0] if isinstance(getattr(solver, val), tuple) else getattr(solver, val) - # assert el == neurallin_cond_args_1[arg] - - def test_pass_custom_optimizers(self, adata_time: ad.AnnData): - problem = GENOTLinProblem(adata=adata_time) - adata_time = adata_time[adata_time.obs["time"].isin((0, 1))] - problem = problem.prepare(key="time", joint_attr="X_pca", conditional_attr={"attr": "obs", "key": "time"}) - custom_opt = optax.adagrad(1e-4) - - problem = problem.solve(iterations=2, optimizer=custom_opt) diff --git a/tests/problems/time/test_mixins.py b/tests/problems/time/test_mixins.py index f062adef..4ea69cdb 100644 --- a/tests/problems/time/test_mixins.py +++ b/tests/problems/time/test_mixins.py @@ -259,7 +259,7 @@ def test_compute_interpolated_distance_regression(self, gt_temporal_adata: AnnDa assert isinstance(interpolation_result, float) assert interpolation_result > 0 np.testing.assert_allclose( - interpolation_result, gt_temporal_adata.uns["interpolated_distance_10_105_11"], rtol=1e-6, atol=1e-6 + interpolation_result, gt_temporal_adata.uns["interpolated_distance_10_105_11"], rtol=1e-5, atol=1e-4 ) def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnData): @@ -288,7 +288,7 @@ def test_compute_time_point_distances_regression(self, gt_temporal_adata: AnnDat result[0], gt_temporal_adata.uns["time_point_distances_10_105_11"][0], rtol=1e-6, atol=1e-6 ) np.testing.assert_allclose( - result[1], gt_temporal_adata.uns["time_point_distances_10_105_11"][1], rtol=1e-6, atol=1e-6 + result[1], gt_temporal_adata.uns["time_point_distances_10_105_11"][1], rtol=1e-5, atol=5e-3 ) def test_compute_batch_distances_regression(self, gt_temporal_adata: AnnData):