Skip to content

Commit bacec6c

Browse files
MUCDKselmanozleyen
andauthored
udpate ott-jax (#839)
* udpate ott-jax * udpate ott-jax * fix get_solver * regenerate true data * Fix 'ott' documentation URL in conf.py Updated the URL for the 'ott' documentation. * Add PubMed URL pattern to references * Update nitpick_ignore_regex in conf.py Add typing references to nitpick ignore list in Sphinx config * adapt thresholds --------- Co-authored-by: Selman Özleyen <[email protected]>
1 parent 89a7936 commit bacec6c

File tree

18 files changed

+31
-884
lines changed

18 files changed

+31
-884
lines changed

docs/conf.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
"pandas": ("https://pandas.pydata.org/docs/", None),
5151
"networkx": ("https://networkx.org/documentation/stable/", None),
5252
"jax": ("https://jax.readthedocs.io/en/latest/", None),
53-
"ott": ("https://ott-jax.readthedocs.io/en/latest/", None),
53+
"ott": ("https://ott-jax.readthedocs.io/", None),
5454
"matplotlib": ("https://matplotlib.org/stable/", None),
5555
"anndata": ("https://anndata.readthedocs.io/en/latest/", None),
5656
"scanpy": ("https://scanpy.readthedocs.io/en/latest/", None),
@@ -75,6 +75,13 @@
7575
# ignore these classes until ott-jax adds them to their docs
7676
("py:class", "ott.initializers.quadratic.initializers.BaseQuadraticInitializer"),
7777
("py:class", "ott.initializers.linear.initializers.SinkhornInitializer"),
78+
# https://stackoverflow.com/questions/11417221/sphinx-autodoc-gives-warning-pyclass-reference-target-not-found-type-warning
79+
("py:data", "typing.Union"),
80+
("py:data", "typing.Optional"),
81+
("py:data", "typing.Literal"),
82+
("py:class", "typing.Union"),
83+
("py:class", "typing.Optional"),
84+
("py:class", "typing.Literal"),
7885
]
7986
# TODO(michalk8): remove once typing has been cleaned-up
8087
nitpick_ignore_regex = [
@@ -150,6 +157,7 @@
150157
r"https://doi.org/10.1145/2516971.2516977",
151158
r"https://doi.org/10.3390/a13090212",
152159
r"https://www.mdpi.com/1999-4893/13/9/212",
160+
r"https://pubmed\.ncbi\.nlm\.nih\.gov/.*",
153161
]
154162

155163
exclude_patterns = ["_build", "**.ipynb_checkpoints", "notebooks/README.rst", "notebooks/CONTRIBUTING.rst"]

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ dependencies = [
5555
"wrapt>=1.13.2",
5656
"docrep>=0.3.2",
5757
"jax>=0.6.1",
58-
"ott-jax>=0.5.0",
58+
"ott-jax>=0.6.0",
5959
"cloudpickle>=2.2.0",
6060
"rich>=13.5",
6161
"docstring_inheritance>=2.0.0",

src/moscot/backends/ott/__init__.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
from ott.geometry import costs
22

33
from moscot.backends.ott._utils import sinkhorn_divergence
4-
from moscot.backends.ott.output import GraphOTTOutput, NeuralOutput, OTTOutput
5-
from moscot.backends.ott.solver import GENOTLinSolver, GWSolver, SinkhornSolver
4+
from moscot.backends.ott.output import GraphOTTOutput, OTTOutput
5+
from moscot.backends.ott.solver import GWSolver, SinkhornSolver
66
from moscot.costs import register_cost
77

88
__all__ = [
99
"OTTOutput",
1010
"GWSolver",
1111
"SinkhornSolver",
12-
"NeuralOutput",
1312
"sinkhorn_divergence",
1413
"GENOTLinSolver",
1514
"GraphOTTOutput",

src/moscot/backends/ott/_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from ott.geometry import epsilon_scheduler, geodesic, geometry, pointcloud
1111
from ott.initializers.linear import initializers as init_lib
1212
from ott.initializers.linear import initializers_lr as lr_init_lib
13-
from ott.neural import datasets
1413
from ott.solvers import utils as solver_utils
1514
from ott.tools.sinkhorn_divergence import sinkhorn_divergence as sinkhorn_div
1615

1716
from moscot._logging import logger
1817
from moscot._types import ArrayLike, ScaleCost_t
1918

2019
Scale_t = Union[float, Literal["mean", "median", "max_cost", "max_norm", "max_bound"]]
20+
OTDataset = Any # to be removed when neural part is being removed from moscot
2121

2222

2323
__all__ = ["sinkhorn_divergence"]
@@ -272,7 +272,7 @@ def data_match_fn(
272272

273273
class Loader:
274274

275-
def __init__(self, dataset: datasets.OTDataset, batch_size: int, seed: Optional[int] = None):
275+
def __init__(self, dataset: OTDataset, batch_size: int, seed: Optional[int] = None):
276276
self.dataset = dataset
277277
self.batch_size = batch_size
278278
self._rng = np.random.default_rng(seed)

src/moscot/backends/ott/output.py

Lines changed: 3 additions & 220 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,18 @@
1-
from typing import Any, Callable, List, Optional, Tuple, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import jax
44
import jax.numpy as jnp
55
import numpy as np
6-
import scipy.sparse as sp
7-
from ott.neural.methods.flows.genot import GENOT
86
from ott.solvers.linear import sinkhorn, sinkhorn_lr
97
from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr
108

119
import matplotlib as mpl
1210
import matplotlib.pyplot as plt
1311

1412
from moscot._types import ArrayLike, Device_t
15-
from moscot.backends.ott._utils import get_nearest_neighbors
16-
from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput
13+
from moscot.base.output import BaseDiscreteSolverOutput
1714

18-
__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"]
15+
__all__ = ["OTTOutput", "GraphOTTOutput"]
1916

2017

2118
class OTTOutput(BaseDiscreteSolverOutput):
@@ -242,220 +239,6 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102
242239
return jnp.ones((n,))
243240

244241

245-
class NeuralOutput(BaseNeuralOutput):
246-
"""Output wrapper for GENOT."""
247-
248-
def __init__(self, model: GENOT, logs: dict[str, list[float]]):
249-
"""Initialize `NeuralOutput`.
250-
251-
Parameters
252-
----------
253-
model : GENOT
254-
The OTT-Jax GENOT model
255-
"""
256-
self._logs = logs
257-
self._model = model
258-
259-
@property
260-
def logs(self):
261-
"""Logs of the training. A dictionary containing what the numeric values are i.e., loss.
262-
263-
Returns
264-
-------
265-
dict[str, list[float]]
266-
"""
267-
return self._logs
268-
269-
def _project_transport_matrix(
270-
self,
271-
src_dist: ArrayLike,
272-
tgt_dist: ArrayLike,
273-
func: Callable[[ArrayLike], ArrayLike],
274-
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
275-
batch_size: int = 1024,
276-
k: int = 30,
277-
length_scale: Optional[float] = None,
278-
seed: int = 42,
279-
recall_target: float = 0.95,
280-
aggregate_to_topk: bool = True,
281-
) -> sp.csr_matrix:
282-
row_indices: List[ArrayLike] = []
283-
column_indices: List[ArrayLike] = []
284-
distances_list: List[ArrayLike] = []
285-
if length_scale is None:
286-
key = jax.random.PRNGKey(seed)
287-
src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))]
288-
tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))]
289-
length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch)))
290-
for index in range(0, len(src_dist), batch_size):
291-
distances, indices = get_nearest_neighbors(
292-
func(src_dist[index : index + batch_size, :]),
293-
tgt_dist,
294-
k,
295-
recall_target=recall_target,
296-
aggregate_to_topk=aggregate_to_topk,
297-
)
298-
distances = jnp.exp(-((distances / length_scale) ** 2))
299-
distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1)
300-
distances_list.append(distances.flatten())
301-
column_indices.append(indices.flatten())
302-
row_indices.append(
303-
jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist)))
304-
)
305-
distances = jnp.concatenate(distances_list)
306-
row_indices = jnp.concatenate(row_indices)
307-
column_indices = jnp.concatenate(column_indices)
308-
tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)])
309-
if save_transport_matrix:
310-
self._transport_matrix = tm
311-
return tm
312-
313-
def project_to_transport_matrix( # type:ignore[override]
314-
self,
315-
src_cells: ArrayLike,
316-
tgt_cells: ArrayLike,
317-
condition: ArrayLike = None,
318-
save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments
319-
batch_size: int = 1024,
320-
k: int = 30,
321-
length_scale: Optional[float] = None,
322-
seed: int = 42,
323-
recall_target: float = 0.95,
324-
aggregate_to_topk: bool = True,
325-
) -> sp.csr_matrix:
326-
"""Project conditional neural OT map onto cells.
327-
328-
In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells,
329-
but a cell can also be mapped to a location between two cells. This function computes
330-
a pseudo-transport matrix considering the neighborhood of where a cell is mapped to.
331-
Therefore, a neighborhood graph of `k` target cells is computed around each transported cell
332-
of the source distribution. The assignment likelihood of each mapped cell to the target cells is then
333-
computed with a Gaussian kernel with parameter `length_scale`.
334-
335-
Parameters
336-
----------
337-
condition
338-
Condition `src_cells` correspond to.
339-
src_cells
340-
Cells which are to be mapped.
341-
tgt_cells
342-
Cells from which the neighborhood graph around the mapped `src_cells` are computed.
343-
forward
344-
Whether to map cells based on the forward transport map or backward transport map.
345-
save_transport_matrix
346-
Whether to save the transport matrix.
347-
batch_size
348-
Number of data points in the source distribution the neighborhood graph is computed
349-
for in parallel.
350-
k
351-
Number of neighbors to construct the k-nearest neighbor graph of a mapped cell.
352-
length_scale
353-
Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`,
354-
`length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the
355-
mapped source and target distribution.
356-
seed
357-
Random seed for sampling the pairs of distributions for computing the variance in case `length_scale`
358-
is `None`.
359-
recall_target
360-
Recall target for the approximation.
361-
aggregate_to_topk
362-
When true, the nearest neighbor aggregates approximate results to the top-k in sorted order.
363-
When false, returns the approximate results unsorted.
364-
In this case, the number of the approximate results is implementation defined and is greater or
365-
equal to the specified k.
366-
367-
Returns
368-
-------
369-
The projected transport matrix.
370-
"""
371-
src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells)
372-
conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition)
373-
push = self.push if condition is None else conditioned_fn
374-
func, src_dist, tgt_dist = (push, src_cells, tgt_cells)
375-
return self._project_transport_matrix(
376-
src_dist=src_dist,
377-
tgt_dist=tgt_dist,
378-
func=func,
379-
save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments
380-
batch_size=batch_size,
381-
k=k,
382-
length_scale=length_scale,
383-
seed=seed,
384-
recall_target=recall_target,
385-
aggregate_to_topk=aggregate_to_topk,
386-
)
387-
388-
def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
389-
"""Push distribution `x` conditioned on condition `cond`.
390-
391-
Parameters
392-
----------
393-
x
394-
Distribution to push.
395-
cond
396-
Condition of conditional neural OT.
397-
398-
Returns
399-
-------
400-
Pushed distribution.
401-
"""
402-
if isinstance(x, (bool, int, float, complex)):
403-
raise ValueError("Expected array, found scalar value.")
404-
if x.ndim not in (1, 2):
405-
raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.")
406-
return self._apply_forward(x, cond=cond)
407-
408-
def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike:
409-
return self._model.transport(x, condition=cond)
410-
411-
@property
412-
def is_linear(self) -> bool: # noqa: D102
413-
return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT
414-
415-
@property
416-
def shape(self) -> Tuple[int, int]:
417-
"""%(shape)s."""
418-
raise NotImplementedError()
419-
420-
def to(
421-
self,
422-
device: Optional[Device_t] = None,
423-
) -> "NeuralOutput":
424-
"""Transfer the output to another device or change its data type.
425-
426-
Parameters
427-
----------
428-
device
429-
If not `None`, the output will be transferred to `device`.
430-
431-
Returns
432-
-------
433-
The output on a saved on `device`.
434-
"""
435-
# # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep
436-
# if isinstance(device, str) and ":" in device:
437-
# device, ix = device.split(":")
438-
# idx = int(ix)
439-
# else:
440-
# idx = 0
441-
442-
# if not isinstance(device, jax.Device):
443-
# try:
444-
# device = jax.devices(device)[idx]
445-
# except IndexError as err:
446-
# raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err
447-
448-
# out = jax.device_put(self._model, device)
449-
# return NeuralOutput(out)
450-
return self # TODO(ilan-gold) move model to device
451-
452-
@property
453-
def converged(self) -> bool:
454-
"""%(converged)s."""
455-
# always return True for now
456-
return True
457-
458-
459242
class GraphOTTOutput(OTTOutput):
460243
"""Output of :term:`OT` problems with a graph geometry in the linear term.
461244

0 commit comments

Comments
 (0)