Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide PyTorch implementations by wrapping JAX functions #277

Draft
wants to merge 21 commits into
base: main
Choose a base branch
from

Conversation

matt-graham
Copy link
Collaborator

This PR removes the current manual reimplementation of the precompute transform and some of the utility functions provided by s2fft to allow use with PyTorch in favour of wrapping the JAX implementations using JAX and PyTorch's mutual support for the DLPack standard as outlined by Matt Johnson in this Gist.

Some local benchmarking suggests there is no performance degradation with this wrapping approach compared to the 'native' implementations beyond the very smallest bandlimits L and a potential a small constant factor speedup for larger L - see benchmarks results in files below

precompute-spherical-torch-benchmarks.json
precompute-spherical-torch-wrapper-benchmarks.json

As all imports from torch are after changes in this PR confined to the s2fft.utils.torch_wrapper module and the import there is guarded in an try: ... except ImportError block this PR also removes torch from the required dependencies for the project, with an informative error message being raised when the user tries to use the wrapper functionality without torch being installed.

Todo

  • Add tests for functions introduced in torch_wrapper module
  • Decide whether to keep wrapped utility modules s2fft.utils.quadrature_torch and s2fft.utils.resampling_torch
  • Add wrappers for non-precompute transforms
  • Update documentation to reflect new wider support for PyTorch

@matt-graham matt-graham added the enhancement New feature or request label Mar 11, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant