|
1 | | -from typing import Any, Callable, List, Optional, Tuple, Union |
| 1 | +from typing import Any, Optional, Tuple, Union |
2 | 2 |
|
3 | 3 | import jax |
4 | 4 | import jax.numpy as jnp |
5 | 5 | import numpy as np |
6 | | -import scipy.sparse as sp |
7 | | -from ott.neural.methods.flows.genot import GENOT |
8 | 6 | from ott.solvers.linear import sinkhorn, sinkhorn_lr |
9 | 7 | from ott.solvers.quadratic import gromov_wasserstein, gromov_wasserstein_lr |
10 | 8 |
|
11 | 9 | import matplotlib as mpl |
12 | 10 | import matplotlib.pyplot as plt |
13 | 11 |
|
14 | 12 | from moscot._types import ArrayLike, Device_t |
15 | | -from moscot.backends.ott._utils import get_nearest_neighbors |
16 | | -from moscot.base.output import BaseDiscreteSolverOutput, BaseNeuralOutput |
| 13 | +from moscot.base.output import BaseDiscreteSolverOutput |
17 | 14 |
|
18 | | -__all__ = ["OTTOutput", "GraphOTTOutput", "NeuralOutput"] |
| 15 | +__all__ = ["OTTOutput", "GraphOTTOutput"] |
19 | 16 |
|
20 | 17 |
|
21 | 18 | class OTTOutput(BaseDiscreteSolverOutput): |
@@ -242,220 +239,6 @@ def _ones(self, n: int) -> ArrayLike: # noqa: D102 |
242 | 239 | return jnp.ones((n,)) |
243 | 240 |
|
244 | 241 |
|
245 | | -class NeuralOutput(BaseNeuralOutput): |
246 | | - """Output wrapper for GENOT.""" |
247 | | - |
248 | | - def __init__(self, model: GENOT, logs: dict[str, list[float]]): |
249 | | - """Initialize `NeuralOutput`. |
250 | | -
|
251 | | - Parameters |
252 | | - ---------- |
253 | | - model : GENOT |
254 | | - The OTT-Jax GENOT model |
255 | | - """ |
256 | | - self._logs = logs |
257 | | - self._model = model |
258 | | - |
259 | | - @property |
260 | | - def logs(self): |
261 | | - """Logs of the training. A dictionary containing what the numeric values are i.e., loss. |
262 | | -
|
263 | | - Returns |
264 | | - ------- |
265 | | - dict[str, list[float]] |
266 | | - """ |
267 | | - return self._logs |
268 | | - |
269 | | - def _project_transport_matrix( |
270 | | - self, |
271 | | - src_dist: ArrayLike, |
272 | | - tgt_dist: ArrayLike, |
273 | | - func: Callable[[ArrayLike], ArrayLike], |
274 | | - save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments |
275 | | - batch_size: int = 1024, |
276 | | - k: int = 30, |
277 | | - length_scale: Optional[float] = None, |
278 | | - seed: int = 42, |
279 | | - recall_target: float = 0.95, |
280 | | - aggregate_to_topk: bool = True, |
281 | | - ) -> sp.csr_matrix: |
282 | | - row_indices: List[ArrayLike] = [] |
283 | | - column_indices: List[ArrayLike] = [] |
284 | | - distances_list: List[ArrayLike] = [] |
285 | | - if length_scale is None: |
286 | | - key = jax.random.PRNGKey(seed) |
287 | | - src_batch = src_dist[jax.random.choice(key, src_dist.shape[0], shape=((batch_size,)))] |
288 | | - tgt_batch = tgt_dist[jax.random.choice(key, tgt_dist.shape[0], shape=((batch_size,)))] |
289 | | - length_scale = jnp.std(jnp.concatenate((func(src_batch), tgt_batch))) |
290 | | - for index in range(0, len(src_dist), batch_size): |
291 | | - distances, indices = get_nearest_neighbors( |
292 | | - func(src_dist[index : index + batch_size, :]), |
293 | | - tgt_dist, |
294 | | - k, |
295 | | - recall_target=recall_target, |
296 | | - aggregate_to_topk=aggregate_to_topk, |
297 | | - ) |
298 | | - distances = jnp.exp(-((distances / length_scale) ** 2)) |
299 | | - distances /= jnp.expand_dims(jnp.sum(distances, axis=1), axis=1) |
300 | | - distances_list.append(distances.flatten()) |
301 | | - column_indices.append(indices.flatten()) |
302 | | - row_indices.append( |
303 | | - jnp.repeat(jnp.arange(index, index + min(batch_size, len(src_dist) - index)), min(k, len(tgt_dist))) |
304 | | - ) |
305 | | - distances = jnp.concatenate(distances_list) |
306 | | - row_indices = jnp.concatenate(row_indices) |
307 | | - column_indices = jnp.concatenate(column_indices) |
308 | | - tm = sp.csr_matrix((distances, (row_indices, column_indices)), shape=[len(src_dist), len(tgt_dist)]) |
309 | | - if save_transport_matrix: |
310 | | - self._transport_matrix = tm |
311 | | - return tm |
312 | | - |
313 | | - def project_to_transport_matrix( # type:ignore[override] |
314 | | - self, |
315 | | - src_cells: ArrayLike, |
316 | | - tgt_cells: ArrayLike, |
317 | | - condition: ArrayLike = None, |
318 | | - save_transport_matrix: bool = False, # TODO(@MUCDK) adapt order of arguments |
319 | | - batch_size: int = 1024, |
320 | | - k: int = 30, |
321 | | - length_scale: Optional[float] = None, |
322 | | - seed: int = 42, |
323 | | - recall_target: float = 0.95, |
324 | | - aggregate_to_topk: bool = True, |
325 | | - ) -> sp.csr_matrix: |
326 | | - """Project conditional neural OT map onto cells. |
327 | | -
|
328 | | - In constrast to discrete OT, (conditional) neural OT does not necessarily map cells onto cells, |
329 | | - but a cell can also be mapped to a location between two cells. This function computes |
330 | | - a pseudo-transport matrix considering the neighborhood of where a cell is mapped to. |
331 | | - Therefore, a neighborhood graph of `k` target cells is computed around each transported cell |
332 | | - of the source distribution. The assignment likelihood of each mapped cell to the target cells is then |
333 | | - computed with a Gaussian kernel with parameter `length_scale`. |
334 | | -
|
335 | | - Parameters |
336 | | - ---------- |
337 | | - condition |
338 | | - Condition `src_cells` correspond to. |
339 | | - src_cells |
340 | | - Cells which are to be mapped. |
341 | | - tgt_cells |
342 | | - Cells from which the neighborhood graph around the mapped `src_cells` are computed. |
343 | | - forward |
344 | | - Whether to map cells based on the forward transport map or backward transport map. |
345 | | - save_transport_matrix |
346 | | - Whether to save the transport matrix. |
347 | | - batch_size |
348 | | - Number of data points in the source distribution the neighborhood graph is computed |
349 | | - for in parallel. |
350 | | - k |
351 | | - Number of neighbors to construct the k-nearest neighbor graph of a mapped cell. |
352 | | - length_scale |
353 | | - Length scale of the Gaussian kernel used to compute the assignment likelihood. If `None`, |
354 | | - `length_scale` is set to the empirical standard deviation of `batch_size` pairs of data points of the |
355 | | - mapped source and target distribution. |
356 | | - seed |
357 | | - Random seed for sampling the pairs of distributions for computing the variance in case `length_scale` |
358 | | - is `None`. |
359 | | - recall_target |
360 | | - Recall target for the approximation. |
361 | | - aggregate_to_topk |
362 | | - When true, the nearest neighbor aggregates approximate results to the top-k in sorted order. |
363 | | - When false, returns the approximate results unsorted. |
364 | | - In this case, the number of the approximate results is implementation defined and is greater or |
365 | | - equal to the specified k. |
366 | | -
|
367 | | - Returns |
368 | | - ------- |
369 | | - The projected transport matrix. |
370 | | - """ |
371 | | - src_cells, tgt_cells = jnp.asarray(src_cells), jnp.asarray(tgt_cells) |
372 | | - conditioned_fn: Callable[[ArrayLike], ArrayLike] = lambda x: self.push(x, condition) |
373 | | - push = self.push if condition is None else conditioned_fn |
374 | | - func, src_dist, tgt_dist = (push, src_cells, tgt_cells) |
375 | | - return self._project_transport_matrix( |
376 | | - src_dist=src_dist, |
377 | | - tgt_dist=tgt_dist, |
378 | | - func=func, |
379 | | - save_transport_matrix=save_transport_matrix, # TODO(@MUCDK) adapt order of arguments |
380 | | - batch_size=batch_size, |
381 | | - k=k, |
382 | | - length_scale=length_scale, |
383 | | - seed=seed, |
384 | | - recall_target=recall_target, |
385 | | - aggregate_to_topk=aggregate_to_topk, |
386 | | - ) |
387 | | - |
388 | | - def push(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: |
389 | | - """Push distribution `x` conditioned on condition `cond`. |
390 | | -
|
391 | | - Parameters |
392 | | - ---------- |
393 | | - x |
394 | | - Distribution to push. |
395 | | - cond |
396 | | - Condition of conditional neural OT. |
397 | | -
|
398 | | - Returns |
399 | | - ------- |
400 | | - Pushed distribution. |
401 | | - """ |
402 | | - if isinstance(x, (bool, int, float, complex)): |
403 | | - raise ValueError("Expected array, found scalar value.") |
404 | | - if x.ndim not in (1, 2): |
405 | | - raise ValueError(f"Expected 1D or 2D array, found `{x.ndim}`.") |
406 | | - return self._apply_forward(x, cond=cond) |
407 | | - |
408 | | - def _apply_forward(self, x: ArrayLike, cond: Optional[ArrayLike] = None) -> ArrayLike: |
409 | | - return self._model.transport(x, condition=cond) |
410 | | - |
411 | | - @property |
412 | | - def is_linear(self) -> bool: # noqa: D102 |
413 | | - return True # TODO(ilan-gold): need to contribute something to ott-jax so this is resolvable from GENOT |
414 | | - |
415 | | - @property |
416 | | - def shape(self) -> Tuple[int, int]: |
417 | | - """%(shape)s.""" |
418 | | - raise NotImplementedError() |
419 | | - |
420 | | - def to( |
421 | | - self, |
422 | | - device: Optional[Device_t] = None, |
423 | | - ) -> "NeuralOutput": |
424 | | - """Transfer the output to another device or change its data type. |
425 | | -
|
426 | | - Parameters |
427 | | - ---------- |
428 | | - device |
429 | | - If not `None`, the output will be transferred to `device`. |
430 | | -
|
431 | | - Returns |
432 | | - ------- |
433 | | - The output on a saved on `device`. |
434 | | - """ |
435 | | - # # TODO(michalk8): when polishing docs, move the definition to the base class + use docrep |
436 | | - # if isinstance(device, str) and ":" in device: |
437 | | - # device, ix = device.split(":") |
438 | | - # idx = int(ix) |
439 | | - # else: |
440 | | - # idx = 0 |
441 | | - |
442 | | - # if not isinstance(device, jax.Device): |
443 | | - # try: |
444 | | - # device = jax.devices(device)[idx] |
445 | | - # except IndexError as err: |
446 | | - # raise IndexError(f"Unable to fetch the device with `id={idx}`.") from err |
447 | | - |
448 | | - # out = jax.device_put(self._model, device) |
449 | | - # return NeuralOutput(out) |
450 | | - return self # TODO(ilan-gold) move model to device |
451 | | - |
452 | | - @property |
453 | | - def converged(self) -> bool: |
454 | | - """%(converged)s.""" |
455 | | - # always return True for now |
456 | | - return True |
457 | | - |
458 | | - |
459 | 242 | class GraphOTTOutput(OTTOutput): |
460 | 243 | """Output of :term:`OT` problems with a graph geometry in the linear term. |
461 | 244 |
|
|
0 commit comments