diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index c1ee9e46178b..255c2e5912cd 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -57,6 +57,7 @@ is_cuda, is_hip, is_npu, + is_xpu, ) from sglang.srt.utils.patch_torch import register_fake_if_exists @@ -69,6 +70,7 @@ _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() +_is_xpu = is_xpu() _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip @@ -85,7 +87,7 @@ except ImportError as e: pass -if _is_cuda or _is_hip: +if _is_cuda or _is_hip or _is_xpu: from sgl_kernel import topk_softmax try: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index ebc9912da41a..95190904d82a 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -32,7 +32,7 @@ from sglang.srt.entrypoints.engine import Engine from sglang.srt.model_loader.ci_weight_validation import ci_validate_and_clean_hf_cache -from sglang.srt.utils import is_npu, load_image +from sglang.srt.utils import get_device, is_npu, load_image from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l @@ -122,7 +122,7 @@ def _get_sentence_transformer_embedding_model( modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim ) - return model.cuda() + return model.to(get_device()) @dataclass @@ -271,7 +271,7 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code, low_cpu_mem_usage=True, - ).cuda() + ).to(get_device()) elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): self.model = AutoModelForVision2Seq.from_pretrained( @@ -279,10 +279,10 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=False, low_cpu_mem_usage=True, - ).cuda() + ).to(get_device()) self.processor = AutoProcessor.from_pretrained(model_path) elif "clip" in model_path.lower(): - self.model = AutoModel.from_pretrained(model_path).cuda() + self.model = AutoModel.from_pretrained(model_path).to(get_device()) self.processor = AutoProcessor.from_pretrained(model_path) else: self.model = _get_sentence_transformer_embedding_model( @@ -295,7 +295,7 @@ def start_model_process( model_path, torch_dtype=torch_dtype, trust_remote_code=self.needs_trust_remote_code(model_path), - ).cuda() + ).to(get_device()) else: raise Exception(f"Unrecognized model type {self.model_type}") self.tokenizer = get_tokenizer( @@ -338,15 +338,19 @@ def start_model_process( images=image[0], return_tensors="pt" ) logits = self.model.get_image_features( - pixel_values=inputs.data["pixel_values"].cuda(), + pixel_values=inputs.data["pixel_values"].to( + get_device() + ), ).tolist() else: inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" ) logits = self.model.get_text_features( - input_ids=inputs.data["input_ids"].cuda(), - attention_mask=inputs.data["attention_mask"].cuda(), + input_ids=inputs.data["input_ids"].to(get_device()), + attention_mask=inputs.data["attention_mask"].to( + get_device() + ), ).tolist() else: logits = self.model.encode(prompts).tolist() @@ -354,7 +358,7 @@ def start_model_process( elif self.model_type == "cross_encoder": inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" - ).to("cuda") + ).to(get_device()) scores = self.model(**inputs).logits scores = scores.squeeze().tolist() if not isinstance(scores, list): @@ -369,7 +373,7 @@ def start_model_process( ) conv_tokenized = self.tokenizer( conv_formatted, return_tensors="pt" - ).to("cuda") + ).to(get_device()) scores.append( float(self.model(**conv_tokenized).logits[0][0].item()) ) @@ -426,9 +430,9 @@ def forward_generation_raw( for i, p in enumerate(prompts): if isinstance(p, str): - input_ids = tokenizer.encode(p, return_tensors="pt").cuda() + input_ids = tokenizer.encode(p, return_tensors="pt").to(get_device()) else: - input_ids = torch.tensor([p], device="cuda") + input_ids = torch.tensor([p], device=get_device()) if lora_paths is not None and lora_paths[i] is not None: from peft import PeftModel diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 06550ff037c2..7be56b3783a8 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -37,7 +37,9 @@ from sglang.srt.utils import ( get_bool_env_var, get_device, + is_cuda, is_port_available, + is_xpu, kill_process_tree, retry, ) @@ -2243,6 +2245,44 @@ def wrapper(self): return decorator +def get_gpu_count(): + if get_device() == "cpu": + gpu_count = 0 + else: + gpu_count = torch.accelerator.device_count() + return gpu_count + + +def empty_gpu_cache(): + """ + Unified empty_cache for PyTorch 2.8 (no torch.accelerator) + and PyTorch 2.9+ (where torch.accelerator.empty_cache() exists). + """ + if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "empty_cache"): + return torch.accelerator.empty_cache() + + # CUDA + if hasattr(torch, "cuda") and torch.cuda.is_available(): + torch.cuda.empty_cache() + return + + # XPU (Intel) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + return + + return + + +def get_gpu_memory_gb(): + if is_cuda(): + return torch.cuda.device_memory_used() / 1024**3 + elif is_xpu(): + return torch.xpu.memory_allocated() / 1024**3 + else: + return 0 + + def run_doctests(obj: Callable[..., Any] | ModuleType): mod = inspect.getmodule(obj) globals = dict(mod.__dict__) diff --git a/test/manual/test_expert_location_updater.py b/test/manual/test_expert_location_updater.py index 094540294dbe..513205e72ff1 100644 --- a/test/manual/test_expert_location_updater.py +++ b/test/manual/test_expert_location_updater.py @@ -10,6 +10,7 @@ from torch.multiprocessing import Process from sglang.srt.eplb import expert_location_updater +from sglang.srt.utils import get_device from sglang.test.test_utils import CustomTestCase, find_available_port from sglang.utils import is_in_ci @@ -61,7 +62,7 @@ def test_cpu_slow(self): def test_gpu(self): if is_in_ci(): return - self._test_common(device="cuda") + self._test_common(device=get_device()) def _test_common(self, device): infos = [] @@ -135,6 +136,8 @@ def _run_subprocess( ) if device == "cuda": torch.cuda.set_device(f"cuda:{rank}") + if device == "xpu": + torch.xpu.set_device(f"xpu:{rank}") for info in infos: _execute_test(info, rank=rank, num_gpus=num_gpus, device=device) diff --git a/test/manual/test_forward_split_prefill.py b/test/manual/test_forward_split_prefill.py index 66e3262badb5..7c23f4f14306 100644 --- a/test/manual/test_forward_split_prefill.py +++ b/test/manual/test_forward_split_prefill.py @@ -20,6 +20,7 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import get_device from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase @@ -32,7 +33,7 @@ def setUpClass(cls): """Set up the test environment once for all tests.""" cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.tp_size = 1 - cls.device = "cuda" + cls.device = get_device() # Initialize server args cls.server_args = ServerArgs( diff --git a/test/manual/test_get_weights_by_name.py b/test/manual/test_get_weights_by_name.py index 3d404df10a72..fa97c7df8070 100644 --- a/test/manual/test_get_weights_by_name.py +++ b/test/manual/test_get_weights_by_name.py @@ -3,16 +3,18 @@ import numpy as np import requests -import torch from transformers import AutoModelForCausalLM import sglang as sgl +from sglang.srt.utils import get_device from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + empty_gpu_cache, + get_gpu_count, is_in_ci, popen_launch_server, ) @@ -32,7 +34,7 @@ class TestGetWeightsByName(CustomTestCase): def init_hf_model(self, model_name, tie_word_embeddings): self.hf_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings - ).to("cuda:0") + ).to(get_device()) def init_backend(self, backend, dp, tp, model_name): self.backend = backend @@ -61,7 +63,7 @@ def init_backend(self, backend, dp, tp, model_name): def clean_up(self): del self.hf_model gc.collect() - torch.cuda.empty_cache() + empty_gpu_cache() if self.backend == "Engine": self.engine.shutdown() else: @@ -132,11 +134,11 @@ def test_get_weights_by_name(self): ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST), ] - if torch.cuda.device_count() >= 2: + if get_gpu_count() >= 2: test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST)) test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST)) - if torch.cuda.device_count() >= 4: + if get_gpu_count() >= 4: test_suits.extend( [ ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), diff --git a/test/manual/test_triton_moe_wna16.py b/test/manual/test_triton_moe_wna16.py index a7e4a3a89382..35983a04c240 100644 --- a/test/manual/test_triton_moe_wna16.py +++ b/test/manual/test_triton_moe_wna16.py @@ -7,6 +7,7 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.utils import get_device NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -159,10 +160,10 @@ def test_fused_moe_wn16( weight_bits: int, ): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = torch.randn((m, k), device=get_device(), dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=get_device(), dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=get_device(), dtype=dtype) / 10 + score = torch.randn((m, e), device=get_device(), dtype=dtype) if weight_bits == 4: pack_factor = 2 @@ -174,16 +175,22 @@ def test_fused_moe_wn16( w1_ref = w1.clone() w2_ref = w2.clone() w1_qweight = torch.empty( - (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + (e, 2 * n, k // pack_factor), device=get_device(), dtype=torch.uint8 ) - w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) - w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) - w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w2_qweight = torch.empty( + (e, k, n // pack_factor), device=get_device(), dtype=torch.uint8 + ) + w1_scales = torch.empty( + (e, 2 * n, k // group_size), device=get_device(), dtype=dtype + ) + w2_scales = torch.empty((e, k, n // group_size), device=get_device(), dtype=dtype) w1_qzeros = torch.empty( - (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + (e, 2 * n // pack_factor, k // group_size), + device=get_device(), + dtype=torch.uint8, ) w2_qzeros = torch.empty( - (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + (e, k // pack_factor, n // group_size), device=get_device(), dtype=torch.uint8 ) for i in range(e * 2): diff --git a/test/registered/attention/test_create_kvindices.py b/test/registered/attention/test_create_kvindices.py index 881e68d6e1f5..3117abe9679b 100644 --- a/test/registered/attention/test_create_kvindices.py +++ b/test/registered/attention/test_create_kvindices.py @@ -4,6 +4,7 @@ import torch from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import CustomTestCase @@ -15,30 +16,28 @@ class TestCreateKvIndices(CustomTestCase): @classmethod def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device("cuda") + torch.set_default_device(get_device()) def _run_test(self, batch, max_batch, max_context_len): req_to_token = torch.arange( - max_batch * max_context_len, dtype=torch.int32, device="cuda" + max_batch * max_context_len, dtype=torch.int32, device=get_device() ).reshape((max_batch, max_context_len)) req_pool_indices = torch.tensor( torch.from_numpy( np.random.choice(range(max_batch), size=batch, replace=False) ), dtype=torch.int32, - device="cuda", + device=get_device(), ) paged_kernel_lens = torch.tensor( torch.from_numpy( np.random.choice(range(max_context_len), size=batch, replace=False) ), dtype=torch.int32, - device="cuda", + device=get_device(), ) - kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=get_device()) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) # ref @@ -53,7 +52,9 @@ def _run_test(self, batch, max_batch, max_context_len): ).contiguous() # triton - kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + kv_indices_triton = torch.empty( + kv_indptr[-1], dtype=torch.int32, device=get_device() + ) create_flashinfer_kv_indices_triton[(batch,)]( req_to_token, req_pool_indices, diff --git a/test/registered/attention/test_wave_attention_kernels.py b/test/registered/attention/test_wave_attention_kernels.py index f7cd5c3b32f9..fbd3470487f3 100644 --- a/test/registered/attention/test_wave_attention_kernels.py +++ b/test/registered/attention/test_wave_attention_kernels.py @@ -21,6 +21,7 @@ from sglang.srt.layers.attention.wave_ops.prefill_attention import ( prefill_attention_wave, ) +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci # Wave attention kernel unit tests (AMD only - requires wave_lang) @@ -47,24 +48,24 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_seq_len = 1024 b_seq_len_prefix = torch.full( - (B,), N_CTX // B, dtype=torch.int32, device="cuda" + (B,), N_CTX // B, dtype=torch.int32, device=get_device() ) b_seq_len_extend = torch.full( - (B,), extend_seq_len, dtype=torch.int32, device="cuda" + (B,), extend_seq_len, dtype=torch.int32, device=get_device() ) b_seq_len = b_seq_len_prefix + b_seq_len_extend max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_req_idx = torch.arange(B, dtype=torch.int32, device=get_device()) + b_start_loc = torch.zeros((B,), dtype=torch.int32, device=get_device()) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=get_device()) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=get_device()) kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) kv_indices = torch.zeros( - (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=get_device() ) for i in range(B): @@ -75,15 +76,21 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() k_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" + (total_token_num, H_KV, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) v_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" + (total_token_num, H_KV, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) - k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + k_extend = torch.empty( + (extend_token_num, H_KV, D), dtype=dtype, device=get_device() + ) + v_extend = torch.empty( + (extend_token_num, H_KV, D), dtype=dtype, device=get_device() + ) + q_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() + ) for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] @@ -96,20 +103,22 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.empty( - (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() + ) o_extend_mask = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) o_redundant = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) b_seq_len_extend = b_seq_len - b_seq_len_prefix max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=get_device()) qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) custom_mask = None @@ -129,7 +138,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): is_causal = True - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() + ) extend_attention_fwd( q_extend, k_extend, @@ -146,7 +157,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): max_len_extend, ) - o_wave = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_wave = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() + ) extend_attention_wave( q_extend, k_extend, @@ -181,33 +194,37 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) max_kv_splits = 8 - num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=get_device()) # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + q = torch.randn(B, H_Q, D, dtype=dtype, device=get_device()) # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") - v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=get_device()) + v_buffer = torch.randn( + total_tokens, H_KV, D_V, dtype=dtype, device=get_device() + ) # o will have the same shape as q - o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) - req_to_token = torch.arange(total_tokens, device="cuda", dtype=torch.int32) - b_req_idx = torch.zeros(B + 1, device="cuda", dtype=torch.int32) - b_seq_len = torch.full((B,), seq_len, device="cuda", dtype=torch.int32) + req_to_token = torch.arange( + total_tokens, device=get_device(), dtype=torch.int32 + ) + b_req_idx = torch.zeros(B + 1, device=get_device(), dtype=torch.int32) + b_seq_len = torch.full((B,), seq_len, device=get_device(), dtype=torch.int32) b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) attn_logits = torch.empty( (B, H_Q, max_kv_splits, D_V + 1), dtype=torch.float32, - device="cuda", + device=get_device(), ) attn_lse = torch.empty( (B, H_Q, max_kv_splits), dtype=torch.float32, - device="cuda", + device=get_device(), ) logit_cap = 0.0 @@ -233,13 +250,13 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): attn_logits = torch.empty( attn_logits_shape, dtype=torch.float32, - device="cuda", + device=get_device(), ) attn_logits_max = torch.empty( attn_logits_max_shape, dtype=torch.float32, - device="cuda", + device=get_device(), ) decode_attention_wave( @@ -288,17 +305,25 @@ def _test_context_attention_once(self, head_dim, is_causal): max_seq_len = max(seq_lens) # Create random input tensors - q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") - k = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") - v = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") + q = torch.randn( + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() + ) + k = torch.randn( + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=get_device() + ) + v = torch.randn( + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=get_device() + ) o_triton = torch.zeros( - sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda" + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() + ) + o = torch.zeros( + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() ) - o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") # Create b_start_loc and b_seq_len tensors - b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") - b_seq_len = torch.tensor(seq_lens, device="cuda") + b_start_loc = torch.tensor([0, seq_lens[0]], device=get_device()) + b_seq_len = torch.tensor(seq_lens, device=get_device()) context_attention_fwd( q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal diff --git a/test/registered/core/test_hidden_states.py b/test/registered/core/test_hidden_states.py index 5ddbf17c814a..4bbdf828aedd 100644 --- a/test/registered/core/test_hidden_states.py +++ b/test/registered/core/test_hidden_states.py @@ -4,7 +4,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import sglang as sgl -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_device, is_hip from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase @@ -57,7 +57,7 @@ def test_return_hidden_states(self): ) model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.bfloat16, device_map="cuda" + model_path, torch_dtype=torch.bfloat16, device_map=get_device() ) for input_id, output in zip(input_ids, outputs): @@ -75,7 +75,7 @@ def test_return_hidden_states(self): i.unsqueeze(0) if len(i.shape) == 1 else i for i in output["meta_info"]["hidden_states"] ] - ).to("cuda") + ).to(get_device()) print("=== SRT Hiddens ===") print(sg_hidden_states) diff --git a/test/registered/layers/mamba/test_causal_conv1d.py b/test/registered/layers/mamba/test_causal_conv1d.py index 953ba4488060..c777226eae00 100644 --- a/test/registered/layers/mamba/test_causal_conv1d.py +++ b/test/registered/layers/mamba/test_causal_conv1d.py @@ -18,6 +18,8 @@ causal_conv1d_fn, causal_conv1d_update, ) +from sglang.srt.utils import get_device +from sglang.test.test_utils import empty_gpu_cache def causal_conv1d_ref( @@ -154,10 +156,8 @@ def causal_conv1d_opcheck_fn( @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" + device = get_device() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -193,10 +193,8 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity def test_causal_conv1d_update_with_batch_gather( batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype ): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" + device = get_device() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -273,11 +271,9 @@ def test_causal_conv1d_update_with_batch_gather( def test_causal_conv1d_varlen( batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype ): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" - torch.cuda.empty_cache() + device = get_device() + empty_gpu_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -336,7 +332,7 @@ def test_causal_conv1d_varlen( weight, bias=bias, conv_states=final_states, - query_start_loc=cumsum.cuda(), + query_start_loc=cumsum.to(get_device()), seq_lens_cpu=torch.tensor(seqlens[0]), cache_indices=padded_state_indices, has_initial_state=has_initial_states, diff --git a/test/registered/model_loading/test_modelopt_loader.py b/test/registered/model_loading/test_modelopt_loader.py index 96aaffa95845..606ddd2f06fc 100644 --- a/test/registered/model_loading/test_modelopt_loader.py +++ b/test/registered/model_loading/test_modelopt_loader.py @@ -15,6 +15,7 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES from sglang.srt.model_loader.loader import ModelOptModelLoader +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import CustomTestCase @@ -62,7 +63,7 @@ def setUp(self): self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.load_config = LoadConfig() - self.device_config = DeviceConfig(device="cuda") + self.device_config = DeviceConfig(device=get_device()) # Create a basic model config with unified quantization flag self.model_config = ModelConfig( diff --git a/test/registered/moe/test_fused_moe.py b/test/registered/moe/test_fused_moe.py index 3921ce1a8b62..a5f6d9234653 100644 --- a/test/registered/moe/test_fused_moe.py +++ b/test/registered/moe/test_fused_moe.py @@ -9,9 +9,9 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler -from sglang.srt.utils import is_hip +from sglang.srt.utils import get_device, get_device_capability, is_hip from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci -from sglang.test.test_utils import CustomTestCase +from sglang.test.test_utils import CustomTestCase, empty_gpu_cache register_cuda_ci(est_time=80, suite="stage-b-test-large-1-gpu") register_amd_ci(est_time=30, suite="stage-b-test-small-1-gpu-amd") @@ -25,8 +25,8 @@ class TestFusedMOE(CustomTestCase): TOP_KS = [2, 6] @staticmethod - def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): - """Create a random CUDA tensor + def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): + """Create a random Torch(device) tensor Args: shape: Tensor shape @@ -35,9 +35,9 @@ def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): std: Standard deviation Returns: - torch.Tensor: Randomly initialized CUDA tensor + torch.Tensor: Randomly initialized Torch(device) tensor """ - return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + return torch.empty(shape, dtype=dtype, device=get_device()).normal_(mean, std) def get_tolerance(self, dtype): """Get tolerance values for different data types @@ -109,20 +109,20 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - capability = torch.cuda.get_device_capability() + capability = get_device_capability() if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)): return - a = self.create_random_cuda_tensor((m, k), dtype) - w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) - w2 = self.create_random_cuda_tensor((e, k, n), dtype) + a = self.create_random_gpu_tensor((m, k), dtype) + w1 = self.create_random_gpu_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_gpu_tensor((e, k, n), dtype) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = self.create_random_cuda_tensor(e, torch.float32) - w2_scale = self.create_random_cuda_tensor(e, torch.float32) - a1_scale = self.create_random_cuda_tensor(1, torch.float32) - a2_scale = self.create_random_cuda_tensor(1, torch.float32) + score = self.create_random_gpu_tensor((m, e), dtype) + w1_scale = self.create_random_gpu_tensor(e, torch.float32) + w2_scale = self.create_random_gpu_tensor(e, torch.float32) + a1_scale = self.create_random_gpu_tensor(1, torch.float32) + a2_scale = self.create_random_gpu_tensor(1, torch.float32) # Handle HIP case: normalize float8 weights so fused kernel doesn't break # on ROCm. @@ -172,10 +172,10 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): sglang_output, torch_output, rtol=rtol, atol=atol ) else: - a = self.create_random_cuda_tensor((m, k), dtype) - w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) - w2 = self.create_random_cuda_tensor((e, k, n), dtype) - score = self.create_random_cuda_tensor((m, e), dtype) + a = self.create_random_gpu_tensor((m, k), dtype) + w1 = self.create_random_gpu_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_gpu_tensor((e, k, n), dtype) + score = self.create_random_gpu_tensor((m, e), dtype) topk_output = select_experts( hidden_states=a, @@ -236,7 +236,7 @@ def test_various_configurations(self): dtype, use_fp8_w8a8=use_fp8_w8a8, ) - torch.cuda.empty_cache() + empty_gpu_cache() pbar.update(1) diff --git a/test/registered/quant/test_awq_dequant.py b/test/registered/quant/test_awq_dequant.py index 18856aaf26d3..8f63a824364d 100644 --- a/test/registered/quant/test_awq_dequant.py +++ b/test/registered/quant/test_awq_dequant.py @@ -16,12 +16,13 @@ awq_dequantize_triton, awq_gemm_triton, ) +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci from sglang.test.test_utils import CustomTestCase register_amd_ci(est_time=2, suite="stage-a-test-1-amd") -device = "cuda" +device = get_device() def reverse_awq_order(t: torch.Tensor) -> torch.Tensor: diff --git a/test/registered/quant/test_gptqmodel_dynamic.py b/test/registered/quant/test_gptqmodel_dynamic.py index 7a52b9028b05..9d08cace6808 100644 --- a/test/registered/quant/test_gptqmodel_dynamic.py +++ b/test/registered/quant/test_gptqmodel_dynamic.py @@ -5,7 +5,7 @@ import torch from sglang.srt.server_args import set_global_server_args_for_scheduler -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import get_device, kill_process_tree from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, @@ -49,7 +49,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig() - device_config = DeviceConfig("cuda") + device_config = DeviceConfig(get_device()) model = get_model( model_config=model_config, load_config=load_config, device_config=device_config ) diff --git a/test/registered/radix_cache/test_mamba_unittest.py b/test/registered/radix_cache/test_mamba_unittest.py index 4cc231095163..5109d23a0ff9 100644 --- a/test/registered/radix_cache/test_mamba_unittest.py +++ b/test/registered/radix_cache/test_mamba_unittest.py @@ -17,6 +17,7 @@ from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci register_cuda_ci(est_time=9, suite="stage-b-test-small-1-gpu") @@ -39,7 +40,7 @@ def test_hybrid_linear_kv_pool(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] @@ -67,7 +68,7 @@ def test_mamba_pool(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = "cuda" + device = get_device() global_interval = 4 num_layers = 48 full_attention_layer_ids = [ @@ -151,7 +152,7 @@ def test_mamba_radix_cache_1(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = "cuda" + device = get_device() full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] diff --git a/test/registered/radix_cache/test_swa_unittest.py b/test/registered/radix_cache/test_swa_unittest.py index 24c5615de80e..626fda769549 100644 --- a/test/registered/radix_cache/test_swa_unittest.py +++ b/test/registered/radix_cache/test_swa_unittest.py @@ -13,6 +13,7 @@ from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_memory_pool import SWAKVPool, SWATokenToKVPoolAllocator from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci register_cuda_ci(est_time=8, suite="stage-b-test-large-1-gpu") @@ -37,7 +38,7 @@ def test_swa_memory_pool(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ @@ -89,7 +90,7 @@ def test_swa_radix_cache_1(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ @@ -243,7 +244,7 @@ def test_swa_radix_cache_eagle(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ diff --git a/test/registered/rl/test_fp32_lm_head.py b/test/registered/rl/test_fp32_lm_head.py index 6a96e8844dde..5f19332e7c06 100644 --- a/test/registered/rl/test_fp32_lm_head.py +++ b/test/registered/rl/test_fp32_lm_head.py @@ -12,6 +12,7 @@ get_global_server_args, set_global_server_args_for_scheduler, ) +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci register_cuda_ci(est_time=9, suite="stage-b-test-small-1-gpu") @@ -19,7 +20,7 @@ class LMHeadStub(nn.Module): - def __init__(self, vocab, hidden, dtype, device="cuda"): + def __init__(self, vocab, hidden, dtype, device=get_device()): super().__init__() self.weight = nn.Parameter( torch.randn(vocab, hidden, dtype=dtype, device=device) @@ -36,8 +37,10 @@ def compute_dp_attention_metadata(self): ... class TestLMHeadFP32(unittest.TestCase): @classmethod def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("needs CUDA GPU") + if not torch.cuda.is_available() and not ( + hasattr(torch, "xpu") and torch.xpu.is_available() + ): + raise unittest.SkipTest("needs CUDA GPU or XPU") def _make_logprocessor(self, vocab_size, enable_fp32): set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) @@ -54,7 +57,7 @@ def _run_case( expected_a_dtype, expected_b_dtype, ): - device = "cuda" + device = get_device() BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128 hidden_state = torch.randn( BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device diff --git a/test/registered/rl/test_release_memory_occupation.py b/test/registered/rl/test_release_memory_occupation.py index 59bc070a6cb5..7a4f8725da1c 100644 --- a/test/registered/rl/test_release_memory_occupation.py +++ b/test/registered/rl/test_release_memory_occupation.py @@ -31,7 +31,6 @@ import time import unittest -import torch from transformers import AutoModelForCausalLM import sglang as sgl @@ -40,6 +39,7 @@ GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, ) +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_cuda_ci from sglang.test.test_utils import ( DEFAULT_HYBRID_MAMBA_MODEL_NAME_FOR_TEST, @@ -48,6 +48,9 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, CustomTestCase, + empty_gpu_cache, + get_gpu_count, + get_gpu_memory_gb, ) register_cuda_ci( @@ -60,10 +63,6 @@ _DEBUG_EXTRA = False -def get_gpu_memory_gb(): - return torch.cuda.device_memory_used() / 1024**3 - - class TestReleaseMemoryOccupation(CustomTestCase): def _setup_engine( self, @@ -120,9 +119,7 @@ def _test_initial_generation( def test_release_and_resume_occupation(self): # Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - assert ( - torch.cuda.device_count() >= 2 - ), "Need at least 2 GPUs for tensor parallel tests" + assert get_gpu_count() >= 2, "Need at least 2 GPUs for tensor parallel tests" for tp_size in [1, 2]: @@ -165,13 +162,13 @@ def test_release_and_resume_occupation(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device_map=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() print("generate (#2)") outputs = engine.generate(params["prompt"], params["sampling_params"])[ @@ -232,7 +229,7 @@ def test_multi_stage_release_and_resume(self): model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST for tp_size in [1, 2]: - if tp_size == 2 and torch.cuda.device_count() < 2: + if tp_size == 2 and get_gpu_count() < 2: continue print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume") @@ -320,14 +317,14 @@ def test_multi_stage_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device_map=get_device(), ) gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb() engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE]) gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb() @@ -399,13 +396,13 @@ def test_moe_model_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device_map=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() print("generate (#2)") outputs = engine.generate(params["prompt_moe"], params["sampling_params_moe"])[ @@ -463,7 +460,7 @@ def test_hybrid_mamba_model_release_and_resume(self): engine.update_weights_from_disk(model_name) # destroy the hf model - torch.cuda.empty_cache() + empty_gpu_cache() print("generate (#2)") outputs = engine.generate( diff --git a/test/registered/spec/utils/test_build_eagle_tree.py b/test/registered/spec/utils/test_build_eagle_tree.py index 72b0b4215cdd..4a60e8475bc9 100644 --- a/test/registered/spec/utils/test_build_eagle_tree.py +++ b/test/registered/spec/utils/test_build_eagle_tree.py @@ -6,6 +6,7 @@ build_tree_kernel_efficient, organize_draft_results, ) +from sglang.srt.utils import get_device from sglang.test.ci.ci_register import register_amd_ci, register_cuda_ci register_cuda_ci(est_time=3, suite="stage-b-test-small-1-gpu") @@ -17,7 +18,7 @@ class TestBuildEagleTree(unittest.TestCase): def test_build_tree_kernel_efficient(self): """Test the build_tree_kernel_efficient function with known inputs and expected outputs.""" - verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) + verified_id = torch.tensor([29974, 13], device=get_device(), dtype=torch.int32) score_list = [ torch.tensor( [ @@ -25,7 +26,7 @@ def test_build_tree_kernel_efficient(self): [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]], ], dtype=torch.float32, - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -43,7 +44,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -61,7 +62,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -79,14 +80,14 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=get_device(), ), ] token_list = [ torch.tensor( [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]], dtype=torch.int64, - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -127,7 +128,7 @@ def test_build_tree_kernel_efficient(self): 259, ], ], - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -168,7 +169,7 @@ def test_build_tree_kernel_efficient(self): 2186, ], ], - device="cuda", + device=get_device(), ), torch.tensor( [ @@ -209,7 +210,7 @@ def test_build_tree_kernel_efficient(self): 13, ], ], - device="cuda", + device=get_device(), ), ] parents_list = [