-
Notifications
You must be signed in to change notification settings - Fork 4.3k
enable ut test for xpu devices #11712
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Kangyan-Zhou
merged 49 commits into
sgl-project:main
from
DiweiSun:molly/ut_enabling_xpu
Feb 3, 2026
Merged
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit
Hold shift + click to select a range
9f0ca6d
merge latest commit
1pikachu 3899d93
adjust the code
1pikachu aecef41
add Rocm empty cache
1pikachu 791b0cf
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu c1be899
pre-commit reformatted
1pikachu ec82c87
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 40968f4
Made the code simpler and fixed the import issue
1pikachu efde53e
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu 544f29c
Fix the issue where torch.accelerator doesn't support empty_cache whe…
1pikachu 630c70d
fix parameter command error
1pikachu 32c58b5
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 4add2ca
fix parameter command error
1pikachu 8c506d9
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 8540ef3
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 980079f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 27921d5
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu b1cf749
Merge branch 'main' of https://github.com/sgl-project/sglang into mol…
1pikachu c65601f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 4a36917
remove the global device setting to avoid some unkown errors
1pikachu 77ba062
Avoid falling back to CPU when the device is not found
1pikachu 03649c7
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu 8a92401
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu 72f0c74
Merge branch 'main' into molly/ut_enabling_xpu
Kangyan-Zhou e146d0e
add xpu support for ut
gaopengff e66c545
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu febe90c
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 24fe8be
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu c35d196
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 5a0aad6
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu f569fa8
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu ad855f2
fix lint check issue
1pikachu 9963009
fix conflicts
1pikachu 7aeb46f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu bab4582
fix clict
1pikachu 89a71e3
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 6d42b84
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu f62d650
Merge pull request #21 from gaopengff/gaopengf/add_more_xpu_cases
1pikachu e7d6d53
fix conflict
1pikachu 8df8a69
fix conflict
1pikachu a07633e
fix conflict
1pikachu 1c62bba
fix conflict
1pikachu 41dfd05
fix conflict
1pikachu 38916b7
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu f48ee60
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 5da947b
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 00ccb81
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu be731eb
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu 1bd4d2e
fallback automatic device detection
1pikachu 1e216fb
fix conflict
1pikachu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,9 @@ | |
| from sglang.srt.utils import ( | ||
| get_bool_env_var, | ||
| get_device, | ||
| is_cuda, | ||
| is_port_available, | ||
| is_xpu, | ||
| kill_process_tree, | ||
| retry, | ||
| ) | ||
|
|
@@ -1474,11 +1476,6 @@ def run_bench_one_batch(model, other_args): | |
| device: Device type ("auto", "cuda", "rocm" or "cpu"). | ||
| If "auto", will detect available platforms automatically. | ||
| """ | ||
| # Auto-detect device if needed | ||
|
|
||
| device = auto_config_device() | ||
| print(f"Auto-configed device: {device}", flush=True) | ||
| other_args += ["--device", str(device)] | ||
|
|
||
| command = [ | ||
| "python3", | ||
|
|
@@ -2243,6 +2240,44 @@ def wrapper(self): | |
| return decorator | ||
|
|
||
|
|
||
| def get_gpu_count(): | ||
| if get_device() == "cpu": | ||
| gpu_count = 0 | ||
| else: | ||
| gpu_count = torch.accelerator.device_count() | ||
| return gpu_count | ||
|
|
||
|
|
||
| def empty_gpu_cache(): | ||
| """ | ||
| Unified empty_cache for PyTorch 2.8 (no torch.accelerator) | ||
| and PyTorch 2.9+ (where torch.accelerator.empty_cache() exists). | ||
| """ | ||
| if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "empty_cache"): | ||
| return torch.accelerator.empty_cache() | ||
|
|
||
| # CUDA | ||
| if hasattr(torch, "cuda") and torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
|
Comment on lines
2256
to
2266
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we need to have rocm here? and also it needs a final else. |
||
| return | ||
|
|
||
| # XPU (Intel) | ||
| if hasattr(torch, "xpu") and torch.xpu.is_available(): | ||
| torch.xpu.empty_cache() | ||
| return | ||
|
|
||
| return | ||
|
|
||
|
|
||
| def get_gpu_memory_gb(): | ||
| if is_cuda(): | ||
| return torch.cuda.device_memory_used() / 1024**3 | ||
| elif is_xpu(): | ||
| return torch.xpu.memory_allocated() / 1024**3 | ||
| else: | ||
| return 0 | ||
|
|
||
|
|
||
| def run_doctests(obj: Callable[..., Any] | ModuleType): | ||
| mod = inspect.getmodule(obj) | ||
| globals = dict(mod.__dict__) | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,6 +4,7 @@ | |
| import torch | ||
|
|
||
| from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton | ||
| from sglang.srt.utils import get_device | ||
| from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci | ||
| from sglang.test.test_utils import CustomTestCase | ||
|
|
||
|
|
@@ -15,30 +16,28 @@ | |
| class TestCreateKvIndices(CustomTestCase): | ||
| @classmethod | ||
| def setUpClass(cls): | ||
| if not torch.cuda.is_available(): | ||
| raise unittest.SkipTest("CUDA is not available") | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ditto |
||
| torch.set_default_device("cuda") | ||
| torch.set_default_device(get_device()) | ||
|
|
||
| def _run_test(self, batch, max_batch, max_context_len): | ||
| req_to_token = torch.arange( | ||
| max_batch * max_context_len, dtype=torch.int32, device="cuda" | ||
| max_batch * max_context_len, dtype=torch.int32, device=get_device() | ||
| ).reshape((max_batch, max_context_len)) | ||
| req_pool_indices = torch.tensor( | ||
| torch.from_numpy( | ||
| np.random.choice(range(max_batch), size=batch, replace=False) | ||
| ), | ||
| dtype=torch.int32, | ||
| device="cuda", | ||
| device=get_device(), | ||
| ) | ||
| paged_kernel_lens = torch.tensor( | ||
| torch.from_numpy( | ||
| np.random.choice(range(max_context_len), size=batch, replace=False) | ||
| ), | ||
| dtype=torch.int32, | ||
| device="cuda", | ||
| device=get_device(), | ||
| ) | ||
|
|
||
| kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") | ||
| kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=get_device()) | ||
| kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) | ||
|
|
||
| # ref | ||
|
|
@@ -53,7 +52,9 @@ def _run_test(self, batch, max_batch, max_context_len): | |
| ).contiguous() | ||
|
|
||
| # triton | ||
| kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") | ||
| kv_indices_triton = torch.empty( | ||
| kv_indptr[-1], dtype=torch.int32, device=get_device() | ||
| ) | ||
| create_flashinfer_kv_indices_triton[(batch,)]( | ||
| req_to_token, | ||
| req_pool_indices, | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main issue is here:
sglang/python/sglang/test/test_utils.py
Line 418 in 7541da1
This change here may not be ideal, but I think we should not fall back to CPU and should raise an error directly.