Skip to content

sashakolpakov/dire-jax

Repository files navigation

DiRe-JAX logo

License Python 3.8+ PyPI DOI badge

Pepy Total Downloads CI Docs Docs Live

A high-performance DImensionality REduction package with JAX

DiRe offers fast dimensionality reduction preserving the global dataset structure, with benchmarks showing competitive performance against UMAP and t-SNE. Built with JAX for efficient computation on CPUs and GPUs.

Quick start Open in Colab

Basic installation (JAX backend only):

pip install dire-jax

With utilities for benchmarking:

pip install dire-jax[utils]

Complete installation with utilities:

pip install dire-jax[all]

Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.

Example usage:

from dire_jax import DiRe
from sklearn.datasets import make_blobs
n_samples  = 100_000
n_features = 1_000
n_centers  = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)

reducer_blobs = DiRe(n_components=2,
                     n_neighbors=16,
                     init='pca',
                     max_iter_layout=32,
                     min_dist=1e-4,
                     spread=1.0,
                     cutoff=4.0,
                     n_sample_dirs=8,
                     sample_size=16,
                     neg_ratio=32,
                     verbose=False,)

_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)

The output should look similar to

12 blobs with 100k points in 1k dimensions embedded in dimension 2

Documentation

Please refer to the DiRe API documentation for more instructions.

Project documentation structure:

  • /docs/ - API documentation and architecture details
  • /benchmarking/ - Performance benchmarks and scaling results
  • /examples/ - Example usage and demos
  • /tests/ - Test suite and benchmarking notebooks

Working paper

Our working paper is available on the arXiv. Paper

Performance Characteristics

DiRe-JAX is optimized for small-medium datasets (<50K points) with excellent CPU performance and GPU acceleration via JAX. Features include:

  • Fully vectorized computation with JIT compilation for optimal performance
  • Memory-efficient chunking to handle large datasets without excessive memory usage
  • Mixed precision arithmetic (MPA) support for improved performance on modern hardware
  • Optimized kernel caching to avoid recompilation and improve runtime efficiency
  • Large dataset mode with automatic memory management for datasets >65K points

Benchmarking and utilities

For benchmarking utilities and quality metrics:

pip install dire-jax[utils]

This provides access to dimensionality reduction quality metrics and benchmarking routines. Some utilities use external packages for persistent homology computations which may increase runtime.

Contributing

Please follow the contibuting guide. Thanks!

Citation

If you use this work, please cite it as:

BibTeX:

@misc{kolpakov-rivin-2025dimensionality,
  title={Dimensionality reduction for homological stability and global structure preservation},
  author={Kolpakov, Alexander and Rivin, Igor},
  year={2025},
  eprint={2503.03156},
  archivePrefix={arXiv},
  primaryClass={cs.LG},
  url={https://arxiv.org/abs/2503.03156}
}

APA Style:

Kolpakov, A., & Rivin, I. (2025). Dimensionality reduction for homological stability and global structure preservation. arXiv preprint arXiv:2503.03156. https://arxiv.org/abs/2503.03156

Acknowledgement

This work is supported by the Google Cloud Research Award number GCP19980904.