Skip to content

Conversation

@wirthual
Copy link
Collaborator

@wirthual wirthual commented Aug 6, 2025

Add support for TPU.

Tested on a google colab with Cloud TPU v6e (Trillium)

@wirthual wirthual requested review from Copilot and michaelfeil August 6, 2025 11:37
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR adds support for TPU (Tensor Processing Unit) devices by introducing XLA backend support to the infinity embedding library. The changes enable automatic detection and configuration of TPU devices through Google's XLA framework.

  • Add "xla" as a new device type option
  • Implement TPU device detection and configuration logic
  • Add optional import handling for torch_xla dependency

Reviewed Changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.

File Description
primitives.py Adds "xla" as a new device enum value
loading_strategy.py Implements TPU detection logic and device configuration for XLA backend
_optional_imports.py Adds torch_xla as an optional dependency with proper import checking

from transformers import is_torch_npu_available # type: ignore
from transformers.utils.import_utils import is_torch_xla_available

if CHECK_XLA.is_available:
Copy link

Copilot AI Aug 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The torch_xla import is only conditionally executed when CHECK_XLA.is_available is true, but torch_xla is used unconditionally on line 69 in the device count check. This will cause a NameError when XLA is available but torch_xla fails to import for other reasons.

Copilot uses AI. Check for mistakes.
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Summary

This PR adds TPU (Tensor Processing Unit) support to the infinity_emb library, enabling users to run embedding models on Google's specialized ML hardware accelerators. The changes introduce XLA (Accelerated Linear Algebra) device support through three key modifications:

  1. Device Type Addition: Adds xla = "xla" to the Device enum in primitives.py, following the established pattern for device types like CUDA and MPS

  2. Optional Import Management: Introduces CHECK_XLA = OptionalImports("torch_xla", "torch_xla") in _optional_imports.py to handle the torch_xla dependency gracefully when not available

  3. Loading Strategy Integration: Updates loading_strategy.py with XLA device auto-detection logic using is_torch_xla_available() from transformers, device counting via torch_xla.device_count(), and proper device validation

The implementation integrates with the existing device auto-detection system, allowing TPUs to be automatically selected when available or explicitly specified by users. This extends the library's deployment capabilities beyond traditional CPU/GPU setups to include Google Cloud TPU instances and Colab TPU environments, potentially offering significant performance improvements for large-scale embedding computations.

Confidence score: 2/5

  • This PR has significant dependency and integration issues that could prevent TPU functionality from working properly
  • Score reflects missing torch_xla dependency in pyproject.toml, potential import failures, and lack of comprehensive error handling
  • Pay close attention to loading_strategy.py and the missing dependency configuration in pyproject.toml

3 files reviewed, 2 comments

Edit Code Review Bot Settings | Greptile

Comment on lines 12 to 13
if CHECK_XLA.is_available:
import torch_xla
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: torch_xla import is only guarded by CHECK_XLA.is_available but torch_xla.device_count() is called unconditionally on line 69

Suggested change
if CHECK_XLA.is_available:
import torch_xla
if CHECK_XLA.is_available:
import torch_xla
else:
torch_xla = None

@wirthual
Copy link
Collaborator Author

wirthual commented Aug 6, 2025

Liniting problem related to this PR: pytorch/xla#9515

Copy link
Owner

@michaelfeil michaelfeil left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have not reviewed!

@wirthual
Copy link
Collaborator Author

wirthual commented Aug 7, 2025

Waiting for 3.12 wheels to add torch_xla. See pytorch/xla#9500

@codecov-commenter
Copy link

codecov-commenter commented Aug 23, 2025

⚠️ Please install the 'codecov app svg image' to ensure uploads and comments are reliably processed by Codecov.

Codecov Report

❌ Patch coverage is 60.00000% with 4 lines in your changes missing coverage. Please review.
✅ Project coverage is 79.48%. Comparing base (ff80951) to head (965ff06).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...ity_emb/infinity_emb/inference/loading_strategy.py 50.00% 4 Missing ⚠️
❗ Your organization needs to install the Codecov GitHub app to enable full functionality.
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #629      +/-   ##
==========================================
- Coverage   79.60%   79.48%   -0.12%     
==========================================
  Files          43       43              
  Lines        3486     3495       +9     
==========================================
+ Hits         2775     2778       +3     
- Misses        711      717       +6     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@michaelfeil
Copy link
Owner

looks good to me @wirthual

@michaelfeil
Copy link
Owner

@wirthual we can add it without changes to the lock file.

@wirthual
Copy link
Collaborator Author

@michaelfeil Using the lock file from main would result in a poetry error that the lock file is outdated. Should we consider not commiting the lock file at all?

@michaelfeil
Copy link
Owner

NO lock file, no change to pyproject for now.

@michaelfeil
Copy link
Owner

lets merge it like this.

@michaelfeil michaelfeil merged commit f98ccf4 into main Aug 29, 2025
24 checks passed
@michaelfeil michaelfeil deleted the tpu-support branch August 29, 2025 17:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants