DiRe offers fast dimensionality reduction preserving the global dataset structure, with benchmarks showing competitive performance against UMAP and t-SNE. Built with JAX for efficient computation on CPUs and GPUs.
Basic installation (JAX backend only):
pip install dire-jaxWith utilities for benchmarking:
pip install dire-jax[utils]Complete installation with utilities:
pip install dire-jax[all]Note: For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. See the JAX documentation for more details on enabling GPU/TPU support.
Example usage:
from dire_jax import DiRe
from sklearn.datasets import make_blobsn_samples = 100_000
n_features = 1_000
n_centers = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)
reducer_blobs = DiRe(n_components=2,
n_neighbors=16,
init='pca',
max_iter_layout=32,
min_dist=1e-4,
spread=1.0,
cutoff=4.0,
n_sample_dirs=8,
sample_size=16,
neg_ratio=32,
verbose=False,)
_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)The output should look similar to
Please refer to the DiRe API documentation for more instructions.
Project documentation structure:
/docs/- API documentation and architecture details/benchmarking/- Performance benchmarks and scaling results/examples/- Example usage and demos/tests/- Test suite and benchmarking notebooks
Our working paper is available on the arXiv.
DiRe-JAX is optimized for small-medium datasets (<50K points) with excellent CPU performance and GPU acceleration via JAX. Features include:
- Fully vectorized computation with JIT compilation for optimal performance
- Memory-efficient chunking to handle large datasets without excessive memory usage
- Mixed precision arithmetic (MPA) support for improved performance on modern hardware
- Optimized kernel caching to avoid recompilation and improve runtime efficiency
- Large dataset mode with automatic memory management for datasets >65K points
For benchmarking utilities and quality metrics:
pip install dire-jax[utils]This provides access to dimensionality reduction quality metrics and benchmarking routines. Some utilities use external packages for persistent homology computations which may increase runtime.
Please follow the contibuting guide. Thanks!
If you use this work, please cite it as:
BibTeX:
@misc{kolpakov-rivin-2025dimensionality,
title={Dimensionality reduction for homological stability and global structure preservation},
author={Kolpakov, Alexander and Rivin, Igor},
year={2025},
eprint={2503.03156},
archivePrefix={arXiv},
primaryClass={cs.LG},
url={https://arxiv.org/abs/2503.03156}
}APA Style:
Kolpakov, A., & Rivin, I. (2025). Dimensionality reduction for homological stability and global structure preservation. arXiv preprint arXiv:2503.03156. https://arxiv.org/abs/2503.03156
This work is supported by the Google Cloud Research Award number GCP19980904.

