Skip to content
Merged
Show file tree
Hide file tree
Changes from 47 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
9f0ca6d
merge latest commit
1pikachu Nov 21, 2025
3899d93
adjust the code
1pikachu Nov 21, 2025
aecef41
add Rocm empty cache
1pikachu Nov 21, 2025
791b0cf
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 24, 2025
c1be899
pre-commit reformatted
1pikachu Nov 24, 2025
ec82c87
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 24, 2025
40968f4
Made the code simpler and fixed the import issue
1pikachu Nov 24, 2025
efde53e
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu Nov 24, 2025
544f29c
Fix the issue where torch.accelerator doesn't support empty_cache whe…
1pikachu Nov 25, 2025
630c70d
fix parameter command error
1pikachu Nov 25, 2025
32c58b5
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 25, 2025
4add2ca
fix parameter command error
1pikachu Nov 25, 2025
8c506d9
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 25, 2025
8540ef3
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 26, 2025
980079f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 26, 2025
27921d5
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Nov 27, 2025
b1cf749
Merge branch 'main' of https://github.com/sgl-project/sglang into mol…
1pikachu Dec 4, 2025
c65601f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 5, 2025
4a36917
remove the global device setting to avoid some unkown errors
1pikachu Dec 8, 2025
77ba062
Avoid falling back to CPU when the device is not found
1pikachu Dec 8, 2025
03649c7
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu Dec 8, 2025
8a92401
Merge branch 'molly/ut_enabling_xpu' of https://github.com/DiweiSun/s…
1pikachu Dec 10, 2025
72f0c74
Merge branch 'main' into molly/ut_enabling_xpu
Kangyan-Zhou Dec 10, 2025
e146d0e
add xpu support for ut
gaopengff Dec 11, 2025
e66c545
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 12, 2025
febe90c
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 15, 2025
24fe8be
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 16, 2025
c35d196
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 18, 2025
5a0aad6
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 19, 2025
f569fa8
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 22, 2025
ad855f2
fix lint check issue
1pikachu Dec 22, 2025
9963009
fix conflicts
1pikachu Dec 24, 2025
7aeb46f
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 26, 2025
bab4582
fix clict
1pikachu Dec 29, 2025
89a71e3
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Dec 30, 2025
6d42b84
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 3, 2026
f62d650
Merge pull request #21 from gaopengff/gaopengf/add_more_xpu_cases
1pikachu Jan 6, 2026
e7d6d53
fix conflict
1pikachu Jan 6, 2026
8df8a69
fix conflict
1pikachu Jan 9, 2026
a07633e
fix conflict
1pikachu Jan 13, 2026
1c62bba
fix conflict
1pikachu Jan 16, 2026
41dfd05
fix conflict
1pikachu Jan 20, 2026
38916b7
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 21, 2026
f48ee60
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 22, 2026
5da947b
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 23, 2026
00ccb81
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 28, 2026
be731eb
Merge branch 'main' into molly/ut_enabling_xpu
1pikachu Jan 30, 2026
1bd4d2e
fallback automatic device detection
1pikachu Jan 30, 2026
1e216fb
fix conflict
1pikachu Feb 2, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/sglang/srt/layers/moe/topk.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
is_cuda,
is_hip,
is_npu,
is_xpu,
)
from sglang.srt.utils.patch_torch import register_fake_if_exists

Expand All @@ -69,6 +70,7 @@
_is_hip = is_hip()
_is_cpu = is_cpu()
_is_cpu_amx_available = cpu_has_amx_support()
_is_xpu = is_xpu()
_is_npu = is_npu()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip

Expand All @@ -85,7 +87,7 @@
except ImportError as e:
pass

if _is_cuda or _is_hip:
if _is_cuda or _is_hip or _is_xpu:
from sgl_kernel import topk_softmax

try:
Expand Down
30 changes: 17 additions & 13 deletions python/sglang/test/runners.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

from sglang.srt.entrypoints.engine import Engine
from sglang.srt.model_loader.ci_weight_validation import ci_validate_and_clean_hf_cache
from sglang.srt.utils import is_npu, load_image
from sglang.srt.utils import get_device, is_npu, load_image
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l

Expand Down Expand Up @@ -122,7 +122,7 @@ def _get_sentence_transformer_embedding_model(
modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim
)

return model.cuda()
return model.to(get_device())


@dataclass
Expand Down Expand Up @@ -271,18 +271,18 @@ def start_model_process(
torch_dtype=torch_dtype,
trust_remote_code=self.trust_remote_code,
low_cpu_mem_usage=True,
).cuda()
).to(get_device())
elif self.model_type == "embedding":
if "gme-qwen2-vl" in model_path.lower():
self.model = AutoModelForVision2Seq.from_pretrained(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=False,
low_cpu_mem_usage=True,
).cuda()
).to(get_device())
self.processor = AutoProcessor.from_pretrained(model_path)
elif "clip" in model_path.lower():
self.model = AutoModel.from_pretrained(model_path).cuda()
self.model = AutoModel.from_pretrained(model_path).to(get_device())
self.processor = AutoProcessor.from_pretrained(model_path)
else:
self.model = _get_sentence_transformer_embedding_model(
Expand All @@ -295,7 +295,7 @@ def start_model_process(
model_path,
torch_dtype=torch_dtype,
trust_remote_code=self.needs_trust_remote_code(model_path),
).cuda()
).to(get_device())
else:
raise Exception(f"Unrecognized model type {self.model_type}")
self.tokenizer = get_tokenizer(
Expand Down Expand Up @@ -338,23 +338,27 @@ def start_model_process(
images=image[0], return_tensors="pt"
)
logits = self.model.get_image_features(
pixel_values=inputs.data["pixel_values"].cuda(),
pixel_values=inputs.data["pixel_values"].to(
get_device()
),
).tolist()
else:
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
)
logits = self.model.get_text_features(
input_ids=inputs.data["input_ids"].cuda(),
attention_mask=inputs.data["attention_mask"].cuda(),
input_ids=inputs.data["input_ids"].to(get_device()),
attention_mask=inputs.data["attention_mask"].to(
get_device()
),
).tolist()
else:
logits = self.model.encode(prompts).tolist()
out_queue.put(ModelOutput(embed_logits=logits))
elif self.model_type == "cross_encoder":
inputs = self.tokenizer(
prompts, padding=True, return_tensors="pt"
).to("cuda")
).to(get_device())
scores = self.model(**inputs).logits
scores = scores.squeeze().tolist()
if not isinstance(scores, list):
Expand All @@ -369,7 +373,7 @@ def start_model_process(
)
conv_tokenized = self.tokenizer(
conv_formatted, return_tensors="pt"
).to("cuda")
).to(get_device())
scores.append(
float(self.model(**conv_tokenized).logits[0][0].item())
)
Expand Down Expand Up @@ -426,9 +430,9 @@ def forward_generation_raw(

for i, p in enumerate(prompts):
if isinstance(p, str):
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
input_ids = tokenizer.encode(p, return_tensors="pt").to(get_device())
else:
input_ids = torch.tensor([p], device="cuda")
input_ids = torch.tensor([p], device=get_device())

if lora_paths is not None and lora_paths[i] is not None:
from peft import PeftModel
Expand Down
45 changes: 40 additions & 5 deletions python/sglang/test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@
from sglang.srt.utils import (
get_bool_env_var,
get_device,
is_cuda,
is_port_available,
is_xpu,
kill_process_tree,
retry,
)
Expand Down Expand Up @@ -1474,11 +1476,6 @@ def run_bench_one_batch(model, other_args):
device: Device type ("auto", "cuda", "rocm" or "cpu").
If "auto", will detect available platforms automatically.
"""
# Auto-detect device if needed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main issue is here:

except (RuntimeError, ImportError) as e:

This change here may not be ideal, but I think we should not fall back to CPU and should raise an error directly.


device = auto_config_device()
print(f"Auto-configed device: {device}", flush=True)
other_args += ["--device", str(device)]

command = [
"python3",
Expand Down Expand Up @@ -2243,6 +2240,44 @@ def wrapper(self):
return decorator


def get_gpu_count():
if get_device() == "cpu":
gpu_count = 0
else:
gpu_count = torch.accelerator.device_count()
return gpu_count


def empty_gpu_cache():
"""
Unified empty_cache for PyTorch 2.8 (no torch.accelerator)
and PyTorch 2.9+ (where torch.accelerator.empty_cache() exists).
"""
if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "empty_cache"):
return torch.accelerator.empty_cache()

# CUDA
if hasattr(torch, "cuda") and torch.cuda.is_available():
torch.cuda.empty_cache()
Comment on lines 2256 to 2266
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to have rocm here?

and also it needs a final else.

return

# XPU (Intel)
if hasattr(torch, "xpu") and torch.xpu.is_available():
torch.xpu.empty_cache()
return

return


def get_gpu_memory_gb():
if is_cuda():
return torch.cuda.device_memory_used() / 1024**3
elif is_xpu():
return torch.xpu.memory_allocated() / 1024**3
else:
return 0


def run_doctests(obj: Callable[..., Any] | ModuleType):
mod = inspect.getmodule(obj)
globals = dict(mod.__dict__)
Expand Down
5 changes: 4 additions & 1 deletion test/manual/test_expert_location_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from torch.multiprocessing import Process

from sglang.srt.eplb import expert_location_updater
from sglang.srt.utils import get_device
from sglang.test.test_utils import CustomTestCase, find_available_port
from sglang.utils import is_in_ci

Expand Down Expand Up @@ -61,7 +62,7 @@ def test_cpu_slow(self):
def test_gpu(self):
if is_in_ci():
return
self._test_common(device="cuda")
self._test_common(device=get_device())

def _test_common(self, device):
infos = []
Expand Down Expand Up @@ -135,6 +136,8 @@ def _run_subprocess(
)
if device == "cuda":
torch.cuda.set_device(f"cuda:{rank}")
if device == "xpu":
torch.xpu.set_device(f"xpu:{rank}")

for info in infos:
_execute_test(info, rank=rank, num_gpus=num_gpus, device=device)
Expand Down
3 changes: 2 additions & 1 deletion test/manual/test_forward_split_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from sglang.srt.sampling.sampling_params import SamplingParams
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import get_device
from sglang.srt.utils.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase

Expand All @@ -32,7 +33,7 @@ def setUpClass(cls):
"""Set up the test environment once for all tests."""
cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.tp_size = 1
cls.device = "cuda"
cls.device = get_device()

# Initialize server args
cls.server_args = ServerArgs(
Expand Down
12 changes: 7 additions & 5 deletions test/manual/test_get_weights_by_name.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,18 @@

import numpy as np
import requests
import torch
from transformers import AutoModelForCausalLM

import sglang as sgl
from sglang.srt.utils import get_device
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
empty_gpu_cache,
get_gpu_count,
is_in_ci,
popen_launch_server,
)
Expand All @@ -32,7 +34,7 @@ class TestGetWeightsByName(CustomTestCase):
def init_hf_model(self, model_name, tie_word_embeddings):
self.hf_model = AutoModelForCausalLM.from_pretrained(
model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings
).to("cuda:0")
).to(get_device())

def init_backend(self, backend, dp, tp, model_name):
self.backend = backend
Expand Down Expand Up @@ -61,7 +63,7 @@ def init_backend(self, backend, dp, tp, model_name):
def clean_up(self):
del self.hf_model
gc.collect()
torch.cuda.empty_cache()
empty_gpu_cache()
if self.backend == "Engine":
self.engine.shutdown()
else:
Expand Down Expand Up @@ -132,11 +134,11 @@ def test_get_weights_by_name(self):
("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST),
]
if torch.cuda.device_count() >= 2:
if get_gpu_count() >= 2:
test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST))
test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST))

if torch.cuda.device_count() >= 4:
if get_gpu_count() >= 4:
test_suits.extend(
[
("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST),
Expand Down
27 changes: 17 additions & 10 deletions test/manual/test_triton_moe_wna16.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
from sglang.srt.layers.moe.topk import TopKConfig, select_experts
from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler
from sglang.srt.utils import get_device

NUM_EXPERTS = [8, 64]
TOP_KS = [2, 6]
Expand Down Expand Up @@ -159,10 +160,10 @@ def test_fused_moe_wn16(
weight_bits: int,
):
print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits)
a = torch.randn((m, k), device="cuda", dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10
w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10
score = torch.randn((m, e), device="cuda", dtype=dtype)
a = torch.randn((m, k), device=get_device(), dtype=dtype) / 10
w1 = torch.randn((e, 2 * n, k), device=get_device(), dtype=dtype) / 10
w2 = torch.randn((e, k, n), device=get_device(), dtype=dtype) / 10
score = torch.randn((m, e), device=get_device(), dtype=dtype)

if weight_bits == 4:
pack_factor = 2
Expand All @@ -174,16 +175,22 @@ def test_fused_moe_wn16(
w1_ref = w1.clone()
w2_ref = w2.clone()
w1_qweight = torch.empty(
(e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8
(e, 2 * n, k // pack_factor), device=get_device(), dtype=torch.uint8
)
w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8)
w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype)
w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype)
w2_qweight = torch.empty(
(e, k, n // pack_factor), device=get_device(), dtype=torch.uint8
)
w1_scales = torch.empty(
(e, 2 * n, k // group_size), device=get_device(), dtype=dtype
)
w2_scales = torch.empty((e, k, n // group_size), device=get_device(), dtype=dtype)
w1_qzeros = torch.empty(
(e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8
(e, 2 * n // pack_factor, k // group_size),
device=get_device(),
dtype=torch.uint8,
)
w2_qzeros = torch.empty(
(e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8
(e, k // pack_factor, n // group_size), device=get_device(), dtype=torch.uint8
)

for i in range(e * 2):
Expand Down
17 changes: 9 additions & 8 deletions test/registered/attention/test_create_kvindices.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.utils import get_device
from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci
from sglang.test.test_utils import CustomTestCase

Expand All @@ -15,30 +16,28 @@
class TestCreateKvIndices(CustomTestCase):
@classmethod
def setUpClass(cls):
if not torch.cuda.is_available():
raise unittest.SkipTest("CUDA is not available")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

torch.set_default_device("cuda")
torch.set_default_device(get_device())

def _run_test(self, batch, max_batch, max_context_len):
req_to_token = torch.arange(
max_batch * max_context_len, dtype=torch.int32, device="cuda"
max_batch * max_context_len, dtype=torch.int32, device=get_device()
).reshape((max_batch, max_context_len))
req_pool_indices = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_batch), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
device=get_device(),
)
paged_kernel_lens = torch.tensor(
torch.from_numpy(
np.random.choice(range(max_context_len), size=batch, replace=False)
),
dtype=torch.int32,
device="cuda",
device=get_device(),
)

kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda")
kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=get_device())
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)

# ref
Expand All @@ -53,7 +52,9 @@ def _run_test(self, batch, max_batch, max_context_len):
).contiguous()

# triton
kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda")
kv_indices_triton = torch.empty(
kv_indptr[-1], dtype=torch.int32, device=get_device()
)
create_flashinfer_kv_indices_triton[(batch,)](
req_to_token,
req_pool_indices,
Expand Down
Loading
Loading