-
Notifications
You must be signed in to change notification settings - Fork 176
Add TPU support #629
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TPU support #629
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds support for TPU (Tensor Processing Unit) devices by introducing XLA backend support to the infinity embedding library. The changes enable automatic detection and configuration of TPU devices through Google's XLA framework.
- Add "xla" as a new device type option
- Implement TPU device detection and configuration logic
- Add optional import handling for torch_xla dependency
Reviewed Changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| primitives.py | Adds "xla" as a new device enum value |
| loading_strategy.py | Implements TPU detection logic and device configuration for XLA backend |
| _optional_imports.py | Adds torch_xla as an optional dependency with proper import checking |
| from transformers import is_torch_npu_available # type: ignore | ||
| from transformers.utils.import_utils import is_torch_xla_available | ||
|
|
||
| if CHECK_XLA.is_available: |
Copilot
AI
Aug 6, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The torch_xla import is only conditionally executed when CHECK_XLA.is_available is true, but torch_xla is used unconditionally on line 69 in the device count check. This will cause a NameError when XLA is available but torch_xla fails to import for other reasons.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Summary
This PR adds TPU (Tensor Processing Unit) support to the infinity_emb library, enabling users to run embedding models on Google's specialized ML hardware accelerators. The changes introduce XLA (Accelerated Linear Algebra) device support through three key modifications:
-
Device Type Addition: Adds
xla = "xla"to theDeviceenum inprimitives.py, following the established pattern for device types like CUDA and MPS -
Optional Import Management: Introduces
CHECK_XLA = OptionalImports("torch_xla", "torch_xla")in_optional_imports.pyto handle the torch_xla dependency gracefully when not available -
Loading Strategy Integration: Updates
loading_strategy.pywith XLA device auto-detection logic usingis_torch_xla_available()from transformers, device counting viatorch_xla.device_count(), and proper device validation
The implementation integrates with the existing device auto-detection system, allowing TPUs to be automatically selected when available or explicitly specified by users. This extends the library's deployment capabilities beyond traditional CPU/GPU setups to include Google Cloud TPU instances and Colab TPU environments, potentially offering significant performance improvements for large-scale embedding computations.
Confidence score: 2/5
- This PR has significant dependency and integration issues that could prevent TPU functionality from working properly
- Score reflects missing torch_xla dependency in pyproject.toml, potential import failures, and lack of comprehensive error handling
- Pay close attention to loading_strategy.py and the missing dependency configuration in pyproject.toml
3 files reviewed, 2 comments
| if CHECK_XLA.is_available: | ||
| import torch_xla |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: torch_xla import is only guarded by CHECK_XLA.is_available but torch_xla.device_count() is called unconditionally on line 69
| if CHECK_XLA.is_available: | |
| import torch_xla | |
| if CHECK_XLA.is_available: | |
| import torch_xla | |
| else: | |
| torch_xla = None |
|
Liniting problem related to this PR: pytorch/xla#9515 |
michaelfeil
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have not reviewed!
|
Waiting for 3.12 wheels to add |
|
Codecov Report❌ Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #629 +/- ##
==========================================
- Coverage 79.60% 79.48% -0.12%
==========================================
Files 43 43
Lines 3486 3495 +9
==========================================
+ Hits 2775 2778 +3
- Misses 711 717 +6 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
|
looks good to me @wirthual |
|
@wirthual we can add it without changes to the lock file. |
|
@michaelfeil Using the lock file from main would result in a poetry error that the lock file is outdated. Should we consider not commiting the lock file at all? |
|
NO lock file, no change to pyproject for now. |
|
lets merge it like this. |
Add support for TPU.
Tested on a google colab with Cloud TPU v6e (Trillium)