Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"pandas": ("https://pandas.pydata.org/docs/", None),
"networkx": ("https://networkx.org/documentation/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
"ott": ("https://ott-jax.readthedocs.io/en/latest/", None),
"ott": ("https://ott-jax.readthedocs.io/", None),
"matplotlib": ("https://matplotlib.org/stable/", None),
"anndata": ("https://anndata.readthedocs.io/en/latest/", None),
"scanpy": ("https://scanpy.readthedocs.io/en/latest/", None),
Expand All @@ -75,6 +75,13 @@
# ignore these classes until ott-jax adds them to their docs
("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"),
("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"),
# https://stackoverflow.com/questions/11417221/sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning
("py:data", "typing.Union"),
("py:data", "typing.Optional"),
("py:data", "typing.Literal"),
("py:class", "typing.Union"),
("py:class", "typing.Optional"),
("py:class", "typing.Literal"),
]
# TODO(michalk8): remove once typing has been cleaned-up
nitpick_ignore_regex = [
Expand Down Expand Up @@ -150,6 +157,7 @@
r"https://doi.org/10.1145/2516971.2516977",
r"https://doi.org/10.3390/a13090212",
r"https://www.mdpi.com/1999-4893/13/9/212",
r"https://pubmed\.ncbi\.nlm\.nih\.gov/.*",
]

exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"]
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dependencies = [
"wrapt>=1.13.2",
"docrep>=0.3.2",
"jax>=0.6.1",
"ott-jax>=0.5.0",
"ott-jax>=0.6.0",
"cloudpickle>=2.2.0",
"rich>=13.5",
"docstring_inheritance>=2.0.0",
Expand Down
5 changes: 2 additions & 3 deletions src/moscot/backends/ott/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from ott.geometry import costs

from moscot.backends.ott._utils import sinkhorn_divergence
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
from moscot.costs import register_cost

__all__ = [
"OTTOutput",
"GWSolver",
"SinkhornSolver",
"NeuralOutput",
"sinkhorn_divergence",
"GENOTLinSolver",
"GraphOTTOutput",
Expand Down
4 changes: 2 additions & 2 deletions src/moscot/backends/ott/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
from ott.initializers.linear import initializers as init_lib
from ott.initializers.linear import initializers_lr as lr_init_lib
from ott.neural import datasets
from ott.solvers import utils as solver_utils
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div

from moscot._logging import logger
from moscot._types import ArrayLike, ScaleCost_t

Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
OTDataset = Any # to be removed when neural part is being removed from moscot


__all__ = ["sinkhorn_divergence"]
Expand Down Expand Up @@ -272,7 +272,7 @@ def data_match_fn(

class Loader:

def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
def __init__(self, dataset: OTDataset, batch_size: int, seed: Optional[int] = None):
self.dataset = dataset
self.batch_size = batch_size
self._rng = np.random.default_rng(seed)
Expand Down
223 changes: 3 additions & 220 deletions src/moscot/backends/ott/output.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import numpy as np
import scipy.sparse as sp
from ott.neural.methods.flows.genot import GENOT
from ott.solvers.linear import sinkhorn, sinkhorn_lr
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr

import matplotlib as mpl
import matplotlib.pyplot as plt

from moscot._types import ArrayLike, Device_t
from moscot.backends.ott._utils import get_nearest_neighbors
from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
from moscot.base.output import BaseDiscreteSolverOutput

__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]
__all__ = ["OTTOutput", "GraphOTTOutput"]


class OTTOutput(BaseDiscreteSolverOutput):
Expand Down Expand Up @@ -242,220 +239,6 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102
return jnp.ones((n,))


class NeuralOutput(BaseNeuralOutput):
"""Output wrapper for GENOT."""

def __init__(self, model: GENOT, logs: dict[str, list[float]]):
"""Initialize `NeuralOutput`.

Parameters
----------
model : GENOT
The OTT-Jax GENOT model
"""
self._logs = logs
self._model = model

@property
def logs(self):
"""Logs of the training. A dictionary containing what the numeric values are i.e., loss.

Returns
-------
dict[str, list[float]]
"""
return self._logs

def _project_transport_matrix(
self,
src_dist: ArrayLike,
tgt_dist: ArrayLike,
func: Callable[[ArrayLike], ArrayLike],
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
batch_size: int = 1024,
k: int = 30,
length_scale: Optional[float] = None,
seed: int = 42,
recall_target: float = 0.95,
aggregate_to_topk: bool = True,
) -> sp.csr_matrix:
row_indices: List[ArrayLike] = []
column_indices: List[ArrayLike] = []
distances_list: List[ArrayLike] = []
if length_scale is None:
key = jax.random.PRNGKey(seed)
src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))]
length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch)))
for index in range(0, len(src_dist), batch_size):
distances, indices = get_nearest_neighbors(
func(src_dist[index : index + batch_size, :]),
tgt_dist,
k,
recall_target=recall_target,
aggregate_to_topk=aggregate_to_topk,
)
distances = jnp.exp(-((distances / length_scale) ** 2))
distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1)
distances_list.append(distances.flatten())
column_indices.append(indices.flatten())
row_indices.append(
jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist)))
)
distances = jnp.concatenate(distances_list)
row_indices = jnp.concatenate(row_indices)
column_indices = jnp.concatenate(column_indices)
tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
if save_transport_matrix:
self._transport_matrix = tm
return tm

def project_to_transport_matrix( # type:ignore[override]
self,
src_cells: ArrayLike,
tgt_cells: ArrayLike,
condition: ArrayLike = None,
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
batch_size: int = 1024,
k: int = 30,
length_scale: Optional[float] = None,
seed: int = 42,
recall_target: float = 0.95,
aggregate_to_topk: bool = True,
) -> sp.csr_matrix:
"""Project conditional neural OT map onto cells.

In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells,
but a cell can also be mapped to a location between two cells. This function computes
a pseudo-transport matrix considering the neighborhood of where a cell is mapped to.
Therefore, a neighborhood graph of `k` target cells is computed around each transported cell
of the source distribution. The assignment likelihood of each mapped cell to the target cells is then
computed with a Gaussian kernel with parameter `length_scale`.

Parameters
----------
condition
Condition `src_cells` correspond to.
src_cells
Cells which are to be mapped.
tgt_cells
Cells from which the neighborhood graph around the mapped `src_cells` are computed.
forward
Whether to map cells based on the forward transport map or backward transport map.
save_transport_matrix
Whether to save the transport matrix.
batch_size
Number of data points in the source distribution the neighborhood graph is computed
for in parallel.
k
Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
length_scale
Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`,
`length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the
mapped source and target distribution.
seed
Random seed for sampling the pairs of distributions for computing the variance in case `length_scale`
is `None`.
recall_target
Recall target for the approximation.
aggregate_to_topk
When true, the nearest neighbor aggregates approximate results to the top-k in sorted order.
When false, returns the approximate results unsorted.
In this case, the number of the approximate results is implementation defined and is greater or
equal to the specified k.

Returns
-------
The projected transport matrix.
"""
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
push = self.push if condition is None else conditioned_fn
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
return self._project_transport_matrix(
src_dist=src_dist,
tgt_dist=tgt_dist,
func=func,
save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
batch_size=batch_size,
k=k,
length_scale=length_scale,
seed=seed,
recall_target=recall_target,
aggregate_to_topk=aggregate_to_topk,
)

def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
"""Push distribution `x` conditioned on condition `cond`.

Parameters
----------
x
Distribution to push.
cond
Condition of conditional neural OT.

Returns
-------
Pushed distribution.
"""
if isinstance(x, (bool, int, float, complex)):
raise ValueError("Expected array, found scalar value.")
if x.ndim not in (1, 2):
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
return self._apply_forward(x, cond=cond)

def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
return self._model.transport(x, condition=cond)

@property
def is_linear(self) -> bool: # noqa: D102
return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT

@property
def shape(self) -> Tuple[int, int]:
"""%(shape)s."""
raise NotImplementedError()

def to(
self,
device: Optional[Device_t] = None,
) -> "NeuralOutput":
"""Transfer the output to another device or change its data type.

Parameters
----------
device
If not `None`, the output will be transferred to `device`.

Returns
-------
The output on a saved on `device`.
"""
# # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep
# if isinstance(device, str) and ":" in device:
# device, ix = device.split(":")
# idx = int(ix)
# else:
# idx = 0

# if not isinstance(device, jax.Device):
# try:
# device = jax.devices(device)[idx]
# except IndexError as err:
# raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err

# out = jax.device_put(self._model, device)
# return NeuralOutput(out)
return self # TODO(ilan-gold) move model to device

@property
def converged(self) -> bool:
"""%(converged)s."""
# always return True for now
return True


class GraphOTTOutput(OTTOutput):
"""Output of :term:`OT` problems with a graph geometry in the linear term.

Expand Down
Loading