From 9f0ca6d4d1c5dbeff8ba5d90428d3c1f889c61d8 Mon Sep 17 00:00:00 2001 From: jundu Date: Fri, 21 Nov 2025 03:43:10 +0000 Subject: [PATCH 01/13] merge latest commit --- python/sglang/srt/layers/moe/topk.py | 4 +- python/sglang/test/runners.py | 31 +++--- python/sglang/test/test_utils.py | 34 ++++++ test/manual/test_expert_location_updater.py | 7 +- test/manual/test_forward_split_prefill.py | 5 +- test/manual/test_get_weights_by_name.py | 22 +++- test/manual/test_triton_moe_wna16.py | 29 +++-- .../attention/mamba/test_causal_conv1d.py | 20 ++-- test/srt/test_build_eagle_tree.py | 20 ++-- test/srt/test_create_kvindices.py | 15 ++- test/srt/test_fused_moe.py | 41 +++---- test/srt/test_gptqmodel_dynamic.py | 5 +- test/srt/test_hidden_states.py | 6 +- test/srt/test_mamba_unittest.py | 8 +- test/srt/test_modelopt_loader.py | 4 +- test/srt/test_release_memory_occupation.py | 25 +++-- test/srt/test_wave_attention_kernels.py | 101 +++++++++++------- 17 files changed, 242 insertions(+), 135 deletions(-) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 30b7cc5da496..819cda854ac3 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -52,6 +52,7 @@ is_cuda, is_hip, is_npu, + is_xpu, ) if TYPE_CHECKING: @@ -68,13 +69,14 @@ _is_hip = is_hip() _is_cpu = is_cpu() _is_cpu_amx_available = cpu_has_amx_support() +_is_xpu = is_xpu() _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip if _is_cuda: from sgl_kernel import kimi_k2_moe_fused_gate, moe_fused_gate -if _is_cuda or _is_hip: +if _is_cuda or _is_hip or _is_xpu: from sgl_kernel import topk_softmax if _use_aiter: try: diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e174eb0c2f39..9bdabde21954 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -35,6 +35,9 @@ from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", "The capital of the United Kingdom is", @@ -114,7 +117,7 @@ def _get_sentence_transformer_embedding_model( modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim ) - return model.cuda() + return model.to(device_type) @dataclass @@ -259,7 +262,7 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code, low_cpu_mem_usage=True, - ).cuda() + ).to(device_type) elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): self.model = AutoModelForVision2Seq.from_pretrained( @@ -267,10 +270,10 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=False, low_cpu_mem_usage=True, - ).cuda() + ).to(device_type) self.processor = AutoProcessor.from_pretrained(model_path) elif "clip" in model_path.lower(): - self.model = AutoModel.from_pretrained(model_path).cuda() + self.model = AutoModel.from_pretrained(model_path).to(device_type) self.processor = AutoProcessor.from_pretrained(model_path) else: self.model = _get_sentence_transformer_embedding_model( @@ -283,7 +286,7 @@ def start_model_process( model_path, torch_dtype=torch_dtype, trust_remote_code=self.needs_trust_remote_code(model_path), - ).cuda() + ).to(device_type) else: raise Exception(f"Unrecognized model type {self.model_type}") self.tokenizer = get_tokenizer( @@ -326,15 +329,19 @@ def start_model_process( images=image[0], return_tensors="pt" ) logits = self.model.get_image_features( - pixel_values=inputs.data["pixel_values"].cuda(), + pixel_values=inputs.data["pixel_values"].to( + device_type + ), ).tolist() else: inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" ) logits = self.model.get_text_features( - input_ids=inputs.data["input_ids"].cuda(), - attention_mask=inputs.data["attention_mask"].cuda(), + input_ids=inputs.data["input_ids"].to(device_type), + attention_mask=inputs.data["attention_mask"].to( + device_type + ), ).tolist() else: logits = self.model.encode(prompts).tolist() @@ -342,7 +349,7 @@ def start_model_process( elif self.model_type == "cross_encoder": inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" - ).to("cuda") + ).to(device_type) scores = self.model(**inputs).logits scores = scores.squeeze().tolist() if not isinstance(scores, list): @@ -357,7 +364,7 @@ def start_model_process( ) conv_tokenized = self.tokenizer( conv_formatted, return_tensors="pt" - ).to("cuda") + ).to(device_type) scores.append( float(self.model(**conv_tokenized).logits[0][0].item()) ) @@ -414,9 +421,9 @@ def forward_generation_raw( for i, p in enumerate(prompts): if isinstance(p, str): - input_ids = tokenizer.encode(p, return_tensors="pt").cuda() + input_ids = tokenizer.encode(p, return_tensors="pt").to(device_type) else: - input_ids = torch.tensor([p], device="cuda") + input_ids = torch.tensor([p], device=device_type) if lora_paths is not None and lora_paths[i] is not None: from peft import PeftModel diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 12e6c0fad5e5..419c45b00b89 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -33,7 +33,10 @@ from sglang.srt.utils import ( get_bool_env_var, get_device, + is_cuda, is_port_available, + is_rocm, + is_xpu, kill_process_tree, retry, ) @@ -1917,3 +1920,34 @@ def wrapper(self): return wrapper return decorator + + +def get_gpu_rank(): + if is_xpu(): + gpu_rank = torch.xpu.device_count() + elif is_cuda(): + gpu_rank = torch.cuda.device_count() + elif is_rocm(): + gpu_rank = torch.rocm.device_count() + return gpu_rank + + +def empty_gpu_cache(): + if is_xpu(): + torch.xpu.empty_cache() + elif is_cuda(): + torch.cuda.empty_cache() + + +def get_gpu_memory_gb(): + if is_cuda(): + return torch.cuda.device_memory_used() / 1024**3 + elif is_xpu(): + return torch.xpu.device_memory_used() / 1024**3 + + +def get_gpu_capability(): + if is_cuda(): + return torch.cuda.get_device_capability() + elif is_xpu(): + return torch.xpu.get_device_capability() diff --git a/test/manual/test_expert_location_updater.py b/test/manual/test_expert_location_updater.py index 094540294dbe..e069335d9787 100644 --- a/test/manual/test_expert_location_updater.py +++ b/test/manual/test_expert_location_updater.py @@ -13,6 +13,9 @@ from sglang.test.test_utils import CustomTestCase, find_available_port from sglang.utils import is_in_ci +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + @dataclass class _TestInfo: @@ -61,7 +64,7 @@ def test_cpu_slow(self): def test_gpu(self): if is_in_ci(): return - self._test_common(device="cuda") + self._test_common(device=device_type) def _test_common(self, device): infos = [] @@ -135,6 +138,8 @@ def _run_subprocess( ) if device == "cuda": torch.cuda.set_device(f"cuda:{rank}") + if device == "xpu": + torch.xpu.set_device(f"xpu:{rank}") for info in infos: _execute_test(info, rank=rank, num_gpus=num_gpus, device=device) diff --git a/test/manual/test_forward_split_prefill.py b/test/manual/test_forward_split_prefill.py index 4ca3c12fe0d8..f59fd0354565 100644 --- a/test/manual/test_forward_split_prefill.py +++ b/test/manual/test_forward_split_prefill.py @@ -23,6 +23,9 @@ from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + class TestForwardSplitPrefill(CustomTestCase): """Test cases for forward_split_prefill functionality.""" @@ -32,7 +35,7 @@ def setUpClass(cls): """Set up the test environment once for all tests.""" cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.tp_size = 1 - cls.device = "cuda" + cls.device = device_type # Initialize server args cls.server_args = ServerArgs( diff --git a/test/manual/test_get_weights_by_name.py b/test/manual/test_get_weights_by_name.py index 3d404df10a72..c7a4232a9c06 100644 --- a/test/manual/test_get_weights_by_name.py +++ b/test/manual/test_get_weights_by_name.py @@ -7,6 +7,7 @@ from transformers import AutoModelForCausalLM import sglang as sgl +from sglang.srt.utils import is_cuda, is_xpu from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -18,6 +19,16 @@ ) from sglang.utils import terminate_process +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + + +def get_gpu_rank(): + if is_xpu(): + gpu_rank = torch.xpu.device_count() + elif is_cuda(): + gpu_rank = torch.cuda.device_count() + return gpu_rank + def _process_return(ret): if isinstance(ret, list) and len(ret) == 2: @@ -32,7 +43,7 @@ class TestGetWeightsByName(CustomTestCase): def init_hf_model(self, model_name, tie_word_embeddings): self.hf_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings - ).to("cuda:0") + ).to(device_type) def init_backend(self, backend, dp, tp, model_name): self.backend = backend @@ -61,7 +72,10 @@ def init_backend(self, backend, dp, tp, model_name): def clean_up(self): del self.hf_model gc.collect() - torch.cuda.empty_cache() + if is_cuda(): + torch.cuda.empty_cache() + elif is_xpu(): + torch.xpu.empty_cache() if self.backend == "Engine": self.engine.shutdown() else: @@ -132,11 +146,11 @@ def test_get_weights_by_name(self): ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST), ] - if torch.cuda.device_count() >= 2: + if get_gpu_rank() >= 2: test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST)) test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST)) - if torch.cuda.device_count() >= 4: + if get_gpu_rank() >= 4: test_suits.extend( [ ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), diff --git a/test/manual/test_triton_moe_wna16.py b/test/manual/test_triton_moe_wna16.py index a7e4a3a89382..25e96ae56a2d 100644 --- a/test/manual/test_triton_moe_wna16.py +++ b/test/manual/test_triton_moe_wna16.py @@ -8,6 +8,9 @@ from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -159,10 +162,10 @@ def test_fused_moe_wn16( weight_bits: int, ): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) - a = torch.randn((m, k), device="cuda", dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device="cuda", dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device="cuda", dtype=dtype) / 10 - score = torch.randn((m, e), device="cuda", dtype=dtype) + a = torch.randn((m, k), device=device_type, dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=device_type, dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=device_type, dtype=dtype) / 10 + score = torch.randn((m, e), device=device_type, dtype=dtype) if weight_bits == 4: pack_factor = 2 @@ -174,16 +177,22 @@ def test_fused_moe_wn16( w1_ref = w1.clone() w2_ref = w2.clone() w1_qweight = torch.empty( - (e, 2 * n, k // pack_factor), device="cuda", dtype=torch.uint8 + (e, 2 * n, k // pack_factor), device=device_type, dtype=torch.uint8 + ) + w2_qweight = torch.empty( + (e, k, n // pack_factor), device=device_type, dtype=torch.uint8 + ) + w1_scales = torch.empty( + (e, 2 * n, k // group_size), device=device_type, dtype=dtype ) - w2_qweight = torch.empty((e, k, n // pack_factor), device="cuda", dtype=torch.uint8) - w1_scales = torch.empty((e, 2 * n, k // group_size), device="cuda", dtype=dtype) - w2_scales = torch.empty((e, k, n // group_size), device="cuda", dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device=device_type, dtype=dtype) w1_qzeros = torch.empty( - (e, 2 * n // pack_factor, k // group_size), device="cuda", dtype=torch.uint8 + (e, 2 * n // pack_factor, k // group_size), + device=device_type, + dtype=torch.uint8, ) w2_qzeros = torch.empty( - (e, k // pack_factor, n // group_size), device="cuda", dtype=torch.uint8 + (e, k // pack_factor, n // group_size), device=device_type, dtype=torch.uint8 ) for i in range(e * 2): diff --git a/test/srt/layers/attention/mamba/test_causal_conv1d.py b/test/srt/layers/attention/mamba/test_causal_conv1d.py index dd1a9a25fab6..0300f6c9d686 100644 --- a/test/srt/layers/attention/mamba/test_causal_conv1d.py +++ b/test/srt/layers/attention/mamba/test_causal_conv1d.py @@ -13,6 +13,10 @@ causal_conv1d_fn, causal_conv1d_update, ) +from sglang.test.test_utils import empty_gpu_cache + +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) def causal_conv1d_ref( @@ -149,10 +153,8 @@ def causal_conv1d_opcheck_fn( @pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" + device = device_type rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -188,10 +190,8 @@ def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, ity def test_causal_conv1d_update_with_batch_gather( batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype ): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" + device = device_type rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -268,11 +268,9 @@ def test_causal_conv1d_update_with_batch_gather( def test_causal_conv1d_varlen( batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype ): - if not torch.cuda.is_available(): - pytest.skip("CUDA device not available") - device = "cuda" - torch.cuda.empty_cache() + device = device_type + empty_gpu_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -331,7 +329,7 @@ def test_causal_conv1d_varlen( weight, bias=bias, conv_states=final_states, - query_start_loc=cumsum.cuda(), + query_start_loc=cumsum.to(device_type), seq_lens_cpu=torch.tensor(seqlens[0]), cache_indices=padded_state_indices, has_initial_state=has_initial_states, diff --git a/test/srt/test_build_eagle_tree.py b/test/srt/test_build_eagle_tree.py index 5372393da6db..50128c639d65 100644 --- a/test/srt/test_build_eagle_tree.py +++ b/test/srt/test_build_eagle_tree.py @@ -7,13 +7,15 @@ organize_draft_results, ) +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + class TestBuildEagleTree(unittest.TestCase): """Unit tests for build_eagle_tree functionality.""" def test_build_tree_kernel_efficient(self): """Test the build_tree_kernel_efficient function with known inputs and expected outputs.""" - verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32) + verified_id = torch.tensor([29974, 13], device=device_type, dtype=torch.int32) score_list = [ torch.tensor( [ @@ -21,7 +23,7 @@ def test_build_tree_kernel_efficient(self): [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]], ], dtype=torch.float32, - device="cuda", + device=device_type, ), torch.tensor( [ @@ -39,7 +41,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=device_type, ), torch.tensor( [ @@ -57,7 +59,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=device_type, ), torch.tensor( [ @@ -75,14 +77,14 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device="cuda", + device=device_type, ), ] token_list = [ torch.tensor( [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]], dtype=torch.int64, - device="cuda", + device=device_type, ), torch.tensor( [ @@ -123,7 +125,7 @@ def test_build_tree_kernel_efficient(self): 259, ], ], - device="cuda", + device=device_type, ), torch.tensor( [ @@ -164,7 +166,7 @@ def test_build_tree_kernel_efficient(self): 2186, ], ], - device="cuda", + device=device_type, ), torch.tensor( [ @@ -205,7 +207,7 @@ def test_build_tree_kernel_efficient(self): 13, ], ], - device="cuda", + device=device_type, ), ] parents_list = [ diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 7e63fd823f37..38ce247caf7d 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -6,34 +6,33 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.test.test_utils import CustomTestCase +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") class TestCreateKvIndices(CustomTestCase): @classmethod def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("CUDA is not available") - torch.set_default_device("cuda") + torch.set_default_device(device_type) def _run_test(self, batch, max_batch, max_context_len): req_to_token = torch.arange( - max_batch * max_context_len, dtype=torch.int32, device="cuda" + max_batch * max_context_len, dtype=torch.int32, device=device_type ).reshape((max_batch, max_context_len)) req_pool_indices = torch.tensor( torch.from_numpy( np.random.choice(range(max_batch), size=batch, replace=False) ), dtype=torch.int32, - device="cuda", + device=device_type, ) paged_kernel_lens = torch.tensor( torch.from_numpy( np.random.choice(range(max_context_len), size=batch, replace=False) ), dtype=torch.int32, - device="cuda", + device=device_type, ) - kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device="cuda") + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=device_type) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) # ref @@ -48,7 +47,7 @@ def _run_test(self, batch, max_batch, max_context_len): ).contiguous() # triton - kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device="cuda") + kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device=device_type) create_flashinfer_kv_indices_triton[(batch,)]( req_to_token, req_pool_indices, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 65d35c59aad3..d03db0d63e5e 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -10,19 +10,22 @@ from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler from sglang.srt.utils import is_hip -from sglang.test.test_utils import CustomTestCase +from sglang.test.test_utils import CustomTestCase, empty_gpu_cache, get_gpu_capability _is_hip = is_hip() _is_fp8_fnuz = is_fp8_fnuz() +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + class TestFusedMOE(CustomTestCase): NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @staticmethod - def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): - """Create a random CUDA tensor + def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): + """Create a random Torch(device) tensor Args: shape: Tensor shape @@ -31,9 +34,9 @@ def create_random_cuda_tensor(shape, dtype, mean=0, std=0.01): std: Standard deviation Returns: - torch.Tensor: Randomly initialized CUDA tensor + torch.Tensor: Randomly initialized Torch(device) tensor """ - return torch.empty(shape, dtype=dtype, device="cuda").normal_(mean, std) + return torch.empty(shape, dtype=dtype, device=device_type).normal_(mean, std) def get_tolerance(self, dtype): """Get tolerance values for different data types @@ -105,20 +108,20 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - capability = torch.cuda.get_device_capability() + capability = get_gpu_capability() if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)): return - a = self.create_random_cuda_tensor((m, k), dtype) - w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) - w2 = self.create_random_cuda_tensor((e, k, n), dtype) + a = self.create_random_gpu_tensor((m, k), dtype) + w1 = self.create_random_gpu_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_gpu_tensor((e, k, n), dtype) w1 = w1.to(torch.float8_e4m3fn) w2 = w2.to(torch.float8_e4m3fn) - score = self.create_random_cuda_tensor((m, e), dtype) - w1_scale = self.create_random_cuda_tensor(e, torch.float32) - w2_scale = self.create_random_cuda_tensor(e, torch.float32) - a1_scale = self.create_random_cuda_tensor(1, torch.float32) - a2_scale = self.create_random_cuda_tensor(1, torch.float32) + score = self.create_random_gpu_tensor((m, e), dtype) + w1_scale = self.create_random_gpu_tensor(e, torch.float32) + w2_scale = self.create_random_gpu_tensor(e, torch.float32) + a1_scale = self.create_random_gpu_tensor(1, torch.float32) + a2_scale = self.create_random_gpu_tensor(1, torch.float32) # Handle HIP case: normalize float8 weights so fused kernel doesn't break # on ROCm. @@ -168,10 +171,10 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): sglang_output, torch_output, rtol=rtol, atol=atol ) else: - a = self.create_random_cuda_tensor((m, k), dtype) - w1 = self.create_random_cuda_tensor((e, 2 * n, k), dtype) - w2 = self.create_random_cuda_tensor((e, k, n), dtype) - score = self.create_random_cuda_tensor((m, e), dtype) + a = self.create_random_gpu_tensor((m, k), dtype) + w1 = self.create_random_gpu_tensor((e, 2 * n, k), dtype) + w2 = self.create_random_gpu_tensor((e, k, n), dtype) + score = self.create_random_gpu_tensor((m, e), dtype) topk_output = select_experts( hidden_states=a, @@ -232,7 +235,7 @@ def test_various_configurations(self): dtype, use_fp8_w8a8=use_fp8_w8a8, ) - torch.cuda.empty_cache() + empty_gpu_cache() pbar.update(1) diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index ea141df3e377..eb939eaa12e3 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -13,6 +13,9 @@ popen_launch_server, ) +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + def check_quant_method(model_path: str, use_marlin_kernel: bool): from sglang.srt.configs.device_config import DeviceConfig @@ -46,7 +49,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig() - device_config = DeviceConfig("cuda") + device_config = DeviceConfig(device_type) model = get_model( model_config=model_config, load_config=load_config, device_config=device_config ) diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 2046ce5297a3..66d832e5c246 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -6,6 +6,8 @@ import sglang as sgl from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + class TestHiddenState(CustomTestCase): def test_return_hidden_states(self): @@ -46,7 +48,7 @@ def test_return_hidden_states(self): ) model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.bfloat16, device_map="cuda" + model_path, torch_dtype=torch.bfloat16, device_map=device_type ) for input_id, output in zip(input_ids, outputs): @@ -64,7 +66,7 @@ def test_return_hidden_states(self): i.unsqueeze(0) if len(i.shape) == 1 else i for i in output["meta_info"]["hidden_states"] ] - ).to("cuda") + ).to(device_type) print("=== SRT Hiddens ===") print(sg_hidden_states) diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index ae93c415121a..72f46a560a9e 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -11,6 +11,8 @@ from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.sampling.sampling_params import SamplingParams +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + class TestMamba(unittest.TestCase): @classmethod @@ -28,7 +30,7 @@ def test_hybrid_linear_kv_pool(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = device_type full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] @@ -56,7 +58,7 @@ def test_mamba_pool(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = "cuda" + device = device_type global_interval = 4 num_layers = 48 full_attention_layer_ids = [ @@ -134,7 +136,7 @@ def test_mamba_radix_cache_1(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = "cuda" + device = device_type full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] diff --git a/test/srt/test_modelopt_loader.py b/test/srt/test_modelopt_loader.py index a2bad70b5f0c..f0171205e4e2 100644 --- a/test/srt/test_modelopt_loader.py +++ b/test/srt/test_modelopt_loader.py @@ -29,6 +29,8 @@ from sglang.srt.model_loader.loader import ModelOptModelLoader from sglang.test.test_utils import CustomTestCase +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") + class TestModelOptModelLoader(CustomTestCase): """Test cases for ModelOptModelLoader functionality.""" @@ -64,7 +66,7 @@ def setUp(self): self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.load_config = LoadConfig() - self.device_config = DeviceConfig(device="cuda") + self.device_config = DeviceConfig(device=device_type) # Create a basic model config with unified quantization flag self.model_config = ModelConfig( diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 12bc30933356..7e51dd7034f6 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -44,14 +44,15 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, CustomTestCase, + empty_gpu_cache, + get_gpu_memory_gb, + get_gpu_rank, ) # (temporarily) set to true to observe memory usage in nvidia-smi more clearly _DEBUG_EXTRA = False - -def get_gpu_memory_gb(): - return torch.cuda.device_memory_used() / 1024**3 +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") class TestReleaseMemoryOccupation(CustomTestCase): @@ -106,9 +107,7 @@ def _test_initial_generation( def test_release_and_resume_occupation(self): # Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - assert ( - torch.cuda.device_count() >= 2 - ), "Need at least 2 GPUs for tensor parallel tests" + assert get_gpu_rank() >= 2, "Need at least 2 GPUs for tensor parallel tests" for tp_size in [1, 2]: @@ -151,13 +150,13 @@ def test_release_and_resume_occupation(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device=device_type, ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() print("generate (#2)") outputs = engine.generate(params["prompt"], params["sampling_params"])[ @@ -218,7 +217,7 @@ def test_multi_stage_release_and_resume(self): model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST for tp_size in [1, 2]: - if tp_size == 2 and torch.cuda.device_count() < 2: + if tp_size == 2 and get_gpu_rank() < 2: continue print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume") @@ -306,14 +305,14 @@ def test_multi_stage_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device=device_type, ) gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb() engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() engine.resume_memory_occupation(tags=[GPU_MEMORY_TYPE_KV_CACHE]) gpu_memory_usage_after_resume_kv_cache = get_gpu_memory_gb() @@ -385,13 +384,13 @@ def test_moe_model_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device_map="cuda", + device=device_type, ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) # destroy the hf model del hf_model_new - torch.cuda.empty_cache() + empty_gpu_cache() print("generate (#2)") outputs = engine.generate(params["prompt_moe"], params["sampling_params_moe"])[ diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index d4c2ff8e5a55..818f8f940e3e 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -22,6 +22,9 @@ prefill_attention_wave, ) +device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +torch.set_default_device(device_type) + class TestWaveAttention(unittest.TestCase): @@ -43,24 +46,24 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_seq_len = 1024 b_seq_len_prefix = torch.full( - (B,), N_CTX // B, dtype=torch.int32, device="cuda" + (B,), N_CTX // B, dtype=torch.int32, device=device_type ) b_seq_len_extend = torch.full( - (B,), extend_seq_len, dtype=torch.int32, device="cuda" + (B,), extend_seq_len, dtype=torch.int32, device=device_type ) b_seq_len = b_seq_len_prefix + b_seq_len_extend max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_req_idx = torch.arange(B, dtype=torch.int32, device=device_type) + b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device_type) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device_type) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device_type) kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) kv_indices = torch.zeros( - (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device_type ) for i in range(B): @@ -71,15 +74,21 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() k_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" + (total_token_num, H_KV, D), dtype=dtype, device=device_type ).normal_(mean=0.1, std=0.2) v_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" + (total_token_num, H_KV, D), dtype=dtype, device=device_type ).normal_(mean=0.1, std=0.2) - k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + k_extend = torch.empty( + (extend_token_num, H_KV, D), dtype=dtype, device=device_type + ) + v_extend = torch.empty( + (extend_token_num, H_KV, D), dtype=dtype, device=device_type + ) + q_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=device_type + ) for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] @@ -92,20 +101,22 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.empty( - (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device_type ).normal_(mean=0.1, std=0.2) - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=device_type + ) o_extend_mask = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + (extend_token_num, H_Q, D), dtype=dtype, device=device_type ) o_redundant = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + (extend_token_num, H_Q, D), dtype=dtype, device=device_type ) b_seq_len_extend = b_seq_len - b_seq_len_prefix max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device_type) qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) custom_mask = None @@ -125,7 +136,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): is_causal = True - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_extend = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=device_type + ) extend_attention_fwd( q_extend, k_extend, @@ -142,7 +155,9 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): max_len_extend, ) - o_wave = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_wave = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device=device_type + ) extend_attention_wave( q_extend, k_extend, @@ -177,33 +192,33 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) max_kv_splits = 8 - num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device="cuda") + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device_type) # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + q = torch.randn(B, H_Q, D, dtype=dtype, device=device_type) # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") - v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device_type) + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device_type) # o will have the same shape as q - o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device_type) + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device_type) - req_to_token = torch.arange(total_tokens, device="cuda", dtype=torch.int32) - b_req_idx = torch.zeros(B + 1, device="cuda", dtype=torch.int32) - b_seq_len = torch.full((B,), seq_len, device="cuda", dtype=torch.int32) + req_to_token = torch.arange(total_tokens, device=device_type, dtype=torch.int32) + b_req_idx = torch.zeros(B + 1, device=device_type, dtype=torch.int32) + b_seq_len = torch.full((B,), seq_len, device=device_type, dtype=torch.int32) b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) attn_logits = torch.empty( (B, H_Q, max_kv_splits, D_V + 1), dtype=torch.float32, - device="cuda", + device=device_type, ) attn_lse = torch.empty( (B, H_Q, max_kv_splits), dtype=torch.float32, - device="cuda", + device=device_type, ) logit_cap = 0.0 @@ -229,13 +244,13 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): attn_logits = torch.empty( attn_logits_shape, dtype=torch.float32, - device="cuda", + device=device_type, ) attn_logits_max = torch.empty( attn_logits_max_shape, dtype=torch.float32, - device="cuda", + device=device_type, ) decode_attention_wave( @@ -284,17 +299,25 @@ def _test_context_attention_once(self, head_dim, is_causal): max_seq_len = max(seq_lens) # Create random input tensors - q = torch.randn(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") - k = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") - v = torch.randn(sum(seq_lens), kv_heads, head_dim, dtype=dtype, device="cuda") + q = torch.randn( + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type + ) + k = torch.randn( + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=device_type + ) + v = torch.randn( + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=device_type + ) o_triton = torch.zeros( - sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda" + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type + ) + o = torch.zeros( + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type ) - o = torch.zeros(sum(seq_lens), num_heads, head_dim, dtype=dtype, device="cuda") # Create b_start_loc and b_seq_len tensors - b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") - b_seq_len = torch.tensor(seq_lens, device="cuda") + b_start_loc = torch.tensor([0, seq_lens[0]], device=device_type) + b_seq_len = torch.tensor(seq_lens, device=device_type) context_attention_fwd( q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal From 3899d9372733522a01daed61b61c3fa2a2fcfc0e Mon Sep 17 00:00:00 2001 From: jundu Date: Fri, 21 Nov 2025 06:29:51 +0000 Subject: [PATCH 02/13] adjust the code --- conftest.py | 9 ++ python/sglang/test/runners.py | 28 +++---- python/sglang/test/test_utils.py | 22 +++-- test/manual/test_expert_location_updater.py | 6 +- test/manual/test_forward_split_prefill.py | 6 +- test/manual/test_get_weights_by_name.py | 19 ++--- test/manual/test_triton_moe_wna16.py | 23 +++--- test/nightly/test_batch_invariant_ops.py | 3 +- .../attention/mamba/test_causal_conv1d.py | 12 ++- test/srt/test_build_eagle_tree.py | 21 +++-- test/srt/test_create_kvindices.py | 14 ++-- test/srt/test_fused_moe.py | 11 +-- test/srt/test_gptqmodel_dynamic.py | 7 +- test/srt/test_hidden_states.py | 7 +- test/srt/test_mamba_unittest.py | 9 +- test/srt/test_modelopt_loader.py | 5 +- test/srt/test_release_memory_occupation.py | 9 +- test/srt/test_wave_attention_kernels.py | 82 +++++++++---------- 18 files changed, 134 insertions(+), 159 deletions(-) create mode 100644 conftest.py diff --git a/conftest.py b/conftest.py new file mode 100644 index 000000000000..ab86886c1666 --- /dev/null +++ b/conftest.py @@ -0,0 +1,9 @@ +import pytest +import torch +from sglang.srt.utils import get_device + + +@pytest.fixture(scope="session", autouse=True) +def setup_session(): + torch.set_default_device(get_device()) + yield diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 9bdabde21954..7ee8ef640e48 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -31,12 +31,10 @@ ) from sglang.srt.entrypoints.engine import Engine -from sglang.srt.utils import load_image +from sglang.srt.utils import load_image, get_device from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", @@ -117,7 +115,7 @@ def _get_sentence_transformer_embedding_model( modules=[word_embedding_model, pooling_model], truncate_dim=matryoshka_dim ) - return model.to(device_type) + return model.to(get_device()) @dataclass @@ -262,7 +260,7 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=self.trust_remote_code, low_cpu_mem_usage=True, - ).to(device_type) + ).to(get_device()) elif self.model_type == "embedding": if "gme-qwen2-vl" in model_path.lower(): self.model = AutoModelForVision2Seq.from_pretrained( @@ -270,10 +268,10 @@ def start_model_process( torch_dtype=torch_dtype, trust_remote_code=False, low_cpu_mem_usage=True, - ).to(device_type) + ).to(get_device()) self.processor = AutoProcessor.from_pretrained(model_path) elif "clip" in model_path.lower(): - self.model = AutoModel.from_pretrained(model_path).to(device_type) + self.model = AutoModel.from_pretrained(model_path).to(get_device()) self.processor = AutoProcessor.from_pretrained(model_path) else: self.model = _get_sentence_transformer_embedding_model( @@ -286,7 +284,7 @@ def start_model_process( model_path, torch_dtype=torch_dtype, trust_remote_code=self.needs_trust_remote_code(model_path), - ).to(device_type) + ).to(get_device()) else: raise Exception(f"Unrecognized model type {self.model_type}") self.tokenizer = get_tokenizer( @@ -330,7 +328,7 @@ def start_model_process( ) logits = self.model.get_image_features( pixel_values=inputs.data["pixel_values"].to( - device_type + get_device() ), ).tolist() else: @@ -338,9 +336,9 @@ def start_model_process( prompts, padding=True, return_tensors="pt" ) logits = self.model.get_text_features( - input_ids=inputs.data["input_ids"].to(device_type), + input_ids=inputs.data["input_ids"].to(get_device()), attention_mask=inputs.data["attention_mask"].to( - device_type + get_device() ), ).tolist() else: @@ -349,7 +347,7 @@ def start_model_process( elif self.model_type == "cross_encoder": inputs = self.tokenizer( prompts, padding=True, return_tensors="pt" - ).to(device_type) + ).to(get_device()) scores = self.model(**inputs).logits scores = scores.squeeze().tolist() if not isinstance(scores, list): @@ -364,7 +362,7 @@ def start_model_process( ) conv_tokenized = self.tokenizer( conv_formatted, return_tensors="pt" - ).to(device_type) + ).to(get_device()) scores.append( float(self.model(**conv_tokenized).logits[0][0].item()) ) @@ -421,9 +419,9 @@ def forward_generation_raw( for i, p in enumerate(prompts): if isinstance(p, str): - input_ids = tokenizer.encode(p, return_tensors="pt").to(device_type) + input_ids = tokenizer.encode(p, return_tensors="pt").to(get_device()) else: - input_ids = torch.tensor([p], device=device_type) + input_ids = torch.tensor([p], device=get_device()) if lora_paths is not None and lora_paths[i] is not None: from peft import PeftModel diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 419c45b00b89..39fb7c6fc40e 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1922,14 +1922,16 @@ def wrapper(self): return decorator -def get_gpu_rank(): +def get_gpu_count(): if is_xpu(): - gpu_rank = torch.xpu.device_count() + gpu_count = torch.xpu.device_count() elif is_cuda(): - gpu_rank = torch.cuda.device_count() + gpu_count = torch.cuda.device_count() elif is_rocm(): - gpu_rank = torch.rocm.device_count() - return gpu_rank + gpu_count = torch.rocm.device_count() + else: + gpu_count = 0 + return gpu_count def empty_gpu_cache(): @@ -1943,11 +1945,7 @@ def get_gpu_memory_gb(): if is_cuda(): return torch.cuda.device_memory_used() / 1024**3 elif is_xpu(): - return torch.xpu.device_memory_used() / 1024**3 - + return torch.xpu.memory_allocated() / 1024**3 + else: + return 0 -def get_gpu_capability(): - if is_cuda(): - return torch.cuda.get_device_capability() - elif is_xpu(): - return torch.xpu.get_device_capability() diff --git a/test/manual/test_expert_location_updater.py b/test/manual/test_expert_location_updater.py index e069335d9787..3f3be519e5ca 100644 --- a/test/manual/test_expert_location_updater.py +++ b/test/manual/test_expert_location_updater.py @@ -12,9 +12,7 @@ from sglang.srt.eplb import expert_location_updater from sglang.test.test_utils import CustomTestCase, find_available_port from sglang.utils import is_in_ci - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) +from sglang.srt.utils import get_device @dataclass @@ -64,7 +62,7 @@ def test_cpu_slow(self): def test_gpu(self): if is_in_ci(): return - self._test_common(device=device_type) + self._test_common(device=get_device()) def _test_common(self, device): infos = [] diff --git a/test/manual/test_forward_split_prefill.py b/test/manual/test_forward_split_prefill.py index f59fd0354565..7e8161d0476b 100644 --- a/test/manual/test_forward_split_prefill.py +++ b/test/manual/test_forward_split_prefill.py @@ -21,11 +21,9 @@ from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils.hf_transformers_utils import get_tokenizer +from sglang.srt.utils import get_device from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) - class TestForwardSplitPrefill(CustomTestCase): """Test cases for forward_split_prefill functionality.""" @@ -35,7 +33,7 @@ def setUpClass(cls): """Set up the test environment once for all tests.""" cls.model_path = DEFAULT_SMALL_MODEL_NAME_FOR_TEST cls.tp_size = 1 - cls.device = device_type + cls.device = get_device() # Initialize server args cls.server_args = ServerArgs( diff --git a/test/manual/test_get_weights_by_name.py b/test/manual/test_get_weights_by_name.py index c7a4232a9c06..509b00872c92 100644 --- a/test/manual/test_get_weights_by_name.py +++ b/test/manual/test_get_weights_by_name.py @@ -7,7 +7,7 @@ from transformers import AutoModelForCausalLM import sglang as sgl -from sglang.srt.utils import is_cuda, is_xpu +from sglang.srt.utils import is_cuda, is_xpu, get_device from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, @@ -16,19 +16,10 @@ CustomTestCase, is_in_ci, popen_launch_server, + get_gpu_count ) from sglang.utils import terminate_process -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") - - -def get_gpu_rank(): - if is_xpu(): - gpu_rank = torch.xpu.device_count() - elif is_cuda(): - gpu_rank = torch.cuda.device_count() - return gpu_rank - def _process_return(ret): if isinstance(ret, list) and len(ret) == 2: @@ -43,7 +34,7 @@ class TestGetWeightsByName(CustomTestCase): def init_hf_model(self, model_name, tie_word_embeddings): self.hf_model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype="bfloat16", tie_word_embeddings=tie_word_embeddings - ).to(device_type) + ).to(get_device()) def init_backend(self, backend, dp, tp, model_name): self.backend = backend @@ -146,11 +137,11 @@ def test_get_weights_by_name(self): ("Runtime", 1, 1, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), ("Engine", 1, 1, DEFAULT_MODEL_NAME_FOR_TEST), ] - if get_gpu_rank() >= 2: + if get_gpu_count() >= 2: test_suits.append(("Engine", 1, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST)) test_suits.append(("Runtime", 2, 1, DEFAULT_MODEL_NAME_FOR_TEST)) - if get_gpu_rank() >= 4: + if get_gpu_count() >= 4: test_suits.extend( [ ("Engine", 2, 2, DEFAULT_SMALL_MODEL_NAME_FOR_TEST), diff --git a/test/manual/test_triton_moe_wna16.py b/test/manual/test_triton_moe_wna16.py index 25e96ae56a2d..2a079b30a308 100644 --- a/test/manual/test_triton_moe_wna16.py +++ b/test/manual/test_triton_moe_wna16.py @@ -7,9 +7,8 @@ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe from sglang.srt.layers.moe.topk import TopKConfig, select_experts from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler +from sglang.srt.utils import get_device -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] @@ -162,10 +161,10 @@ def test_fused_moe_wn16( weight_bits: int, ): print(m, n, k, e, topk, dtype, group_size, has_zp, weight_bits) - a = torch.randn((m, k), device=device_type, dtype=dtype) / 10 - w1 = torch.randn((e, 2 * n, k), device=device_type, dtype=dtype) / 10 - w2 = torch.randn((e, k, n), device=device_type, dtype=dtype) / 10 - score = torch.randn((m, e), device=device_type, dtype=dtype) + a = torch.randn((m, k), device=get_device(), dtype=dtype) / 10 + w1 = torch.randn((e, 2 * n, k), device=get_device(), dtype=dtype) / 10 + w2 = torch.randn((e, k, n), device=get_device(), dtype=dtype) / 10 + score = torch.randn((m, e), device=get_device(), dtype=dtype) if weight_bits == 4: pack_factor = 2 @@ -177,22 +176,22 @@ def test_fused_moe_wn16( w1_ref = w1.clone() w2_ref = w2.clone() w1_qweight = torch.empty( - (e, 2 * n, k // pack_factor), device=device_type, dtype=torch.uint8 + (e, 2 * n, k // pack_factor), device=get_device(), dtype=torch.uint8 ) w2_qweight = torch.empty( - (e, k, n // pack_factor), device=device_type, dtype=torch.uint8 + (e, k, n // pack_factor), device=get_device(), dtype=torch.uint8 ) w1_scales = torch.empty( - (e, 2 * n, k // group_size), device=device_type, dtype=dtype + (e, 2 * n, k // group_size), device=get_device(), dtype=dtype ) - w2_scales = torch.empty((e, k, n // group_size), device=device_type, dtype=dtype) + w2_scales = torch.empty((e, k, n // group_size), device=get_device(), dtype=dtype) w1_qzeros = torch.empty( (e, 2 * n // pack_factor, k // group_size), - device=device_type, + device=get_device(), dtype=torch.uint8, ) w2_qzeros = torch.empty( - (e, k // pack_factor, n // group_size), device=device_type, dtype=torch.uint8 + (e, k // pack_factor, n // group_size), device=get_device(), dtype=torch.uint8 ) for i in range(e * 2): diff --git a/test/nightly/test_batch_invariant_ops.py b/test/nightly/test_batch_invariant_ops.py index 115e7f0fa5d9..a7ac80ae3afc 100644 --- a/test/nightly/test_batch_invariant_ops.py +++ b/test/nightly/test_batch_invariant_ops.py @@ -7,9 +7,8 @@ from sglang.srt.batch_invariant_ops import batch_invariant_ops from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import get_device -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) # Just to get the logging out of the way with set_batch_invariant_mode(True): diff --git a/test/srt/layers/attention/mamba/test_causal_conv1d.py b/test/srt/layers/attention/mamba/test_causal_conv1d.py index 0300f6c9d686..8c645590442d 100644 --- a/test/srt/layers/attention/mamba/test_causal_conv1d.py +++ b/test/srt/layers/attention/mamba/test_causal_conv1d.py @@ -14,9 +14,7 @@ causal_conv1d_update, ) from sglang.test.test_utils import empty_gpu_cache - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) +from sglang.srt.utils import get_device def causal_conv1d_ref( @@ -154,7 +152,7 @@ def causal_conv1d_opcheck_fn( @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): - device = device_type + device = get_device() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -191,7 +189,7 @@ def test_causal_conv1d_update_with_batch_gather( batch_size, with_padding, dim, width, seqlen, has_bias, silu_activation, itype ): - device = device_type + device = get_device() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 @@ -269,7 +267,7 @@ def test_causal_conv1d_varlen( batch, with_padding, dim, seqlen, width, has_bias, silu_activation, itype ): - device = device_type + device = get_device() empty_gpu_cache() rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: @@ -329,7 +327,7 @@ def test_causal_conv1d_varlen( weight, bias=bias, conv_states=final_states, - query_start_loc=cumsum.to(device_type), + query_start_loc=cumsum.to(get_device()), seq_lens_cpu=torch.tensor(seqlens[0]), cache_indices=padded_state_indices, has_initial_state=has_initial_states, diff --git a/test/srt/test_build_eagle_tree.py b/test/srt/test_build_eagle_tree.py index 50128c639d65..cae8838b0fff 100644 --- a/test/srt/test_build_eagle_tree.py +++ b/test/srt/test_build_eagle_tree.py @@ -6,8 +6,7 @@ build_tree_kernel_efficient, organize_draft_results, ) - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +from sglang.srt.utils import get_device class TestBuildEagleTree(unittest.TestCase): @@ -15,7 +14,7 @@ class TestBuildEagleTree(unittest.TestCase): def test_build_tree_kernel_efficient(self): """Test the build_tree_kernel_efficient function with known inputs and expected outputs.""" - verified_id = torch.tensor([29974, 13], device=device_type, dtype=torch.int32) + verified_id = torch.tensor([29974, 13], device=get_device(), dtype=torch.int32) score_list = [ torch.tensor( [ @@ -23,7 +22,7 @@ def test_build_tree_kernel_efficient(self): [[9.7476e-01, 2.2219e-02, 6.5031e-04, 1.3212e-04]], ], dtype=torch.float32, - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -41,7 +40,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -59,7 +58,7 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -77,14 +76,14 @@ def test_build_tree_kernel_efficient(self): ], ], dtype=torch.float32, - device=device_type, + device=get_device(), ), ] token_list = [ torch.tensor( [[29896, 29906, 29900, 29945], [13, 2, 29871, 28956]], dtype=torch.int64, - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -125,7 +124,7 @@ def test_build_tree_kernel_efficient(self): 259, ], ], - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -166,7 +165,7 @@ def test_build_tree_kernel_efficient(self): 2186, ], ], - device=device_type, + device=get_device(), ), torch.tensor( [ @@ -207,7 +206,7 @@ def test_build_tree_kernel_efficient(self): 13, ], ], - device=device_type, + device=get_device(), ), ] parents_list = [ diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 38ce247caf7d..0370c88343d6 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -5,34 +5,34 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import get_device -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") class TestCreateKvIndices(CustomTestCase): @classmethod def setUpClass(cls): - torch.set_default_device(device_type) + torch.set_default_device(get_device()) def _run_test(self, batch, max_batch, max_context_len): req_to_token = torch.arange( - max_batch * max_context_len, dtype=torch.int32, device=device_type + max_batch * max_context_len, dtype=torch.int32, device=get_device() ).reshape((max_batch, max_context_len)) req_pool_indices = torch.tensor( torch.from_numpy( np.random.choice(range(max_batch), size=batch, replace=False) ), dtype=torch.int32, - device=device_type, + device=get_device(), ) paged_kernel_lens = torch.tensor( torch.from_numpy( np.random.choice(range(max_context_len), size=batch, replace=False) ), dtype=torch.int32, - device=device_type, + device=get_device(), ) - kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=device_type) + kv_indptr = torch.zeros((batch + 1,), dtype=torch.int32, device=get_device()) kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0) # ref @@ -47,7 +47,7 @@ def _run_test(self, batch, max_batch, max_context_len): ).contiguous() # triton - kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device=device_type) + kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device=get_device()) create_flashinfer_kv_indices_triton[(batch,)]( req_to_token, req_pool_indices, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index d03db0d63e5e..0ef32dcf3807 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -9,15 +9,12 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler -from sglang.srt.utils import is_hip -from sglang.test.test_utils import CustomTestCase, empty_gpu_cache, get_gpu_capability +from sglang.srt.utils import is_hip, get_device, get_device_capability +from sglang.test.test_utils import CustomTestCase, empty_gpu_cache _is_hip = is_hip() _is_fp8_fnuz = is_fp8_fnuz() -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) - class TestFusedMOE(CustomTestCase): NUM_EXPERTS = [8, 64] @@ -36,7 +33,7 @@ def create_random_gpu_tensor(shape, dtype, mean=0, std=0.01): Returns: torch.Tensor: Randomly initialized Torch(device) tensor """ - return torch.empty(shape, dtype=dtype, device=device_type).normal_(mean, std) + return torch.empty(shape, dtype=dtype, device=get_device()).normal_(mean, std) def get_tolerance(self, dtype): """Get tolerance values for different data types @@ -108,7 +105,7 @@ def _test_case(self, m, n, k, e, topk, dtype, use_fp8_w8a8=False): if use_fp8_w8a8: # AssertionError: fp8e4nv data type is not supported on CUDA arch < 89 - capability = get_gpu_capability() + capability = get_device_capability() if not _is_hip and not (capability[0] >= 9 or capability == (8, 9)): return diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index eb939eaa12e3..7bd0d79c07ab 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -5,7 +5,7 @@ import torch from sglang.srt.server_args import set_global_server_args_for_scheduler -from sglang.srt.utils import kill_process_tree +from sglang.srt.utils import kill_process_tree, get_device from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -13,9 +13,6 @@ popen_launch_server, ) -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) - def check_quant_method(model_path: str, use_marlin_kernel: bool): from sglang.srt.configs.device_config import DeviceConfig @@ -49,7 +46,7 @@ def check_quant_method(model_path: str, use_marlin_kernel: bool): model_config = ModelConfig.from_server_args(server_args) load_config = LoadConfig() - device_config = DeviceConfig(device_type) + device_config = DeviceConfig(get_device()) model = get_model( model_config=model_config, load_config=load_config, device_config=device_config ) diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 66d832e5c246..621645008070 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -5,8 +5,7 @@ import sglang as sgl from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +from sglang.srt.utils import get_device class TestHiddenState(CustomTestCase): @@ -48,7 +47,7 @@ def test_return_hidden_states(self): ) model = AutoModelForCausalLM.from_pretrained( - model_path, torch_dtype=torch.bfloat16, device_map=device_type + model_path, torch_dtype=torch.bfloat16, device_map=get_device() ) for input_id, output in zip(input_ids, outputs): @@ -66,7 +65,7 @@ def test_return_hidden_states(self): i.unsqueeze(0) if len(i.shape) == 1 else i for i in output["meta_info"]["hidden_states"] ] - ).to(device_type) + ).to(get_device()) print("=== SRT Hiddens ===") print(sg_hidden_states) diff --git a/test/srt/test_mamba_unittest.py b/test/srt/test_mamba_unittest.py index 72f46a560a9e..6bee08df2019 100644 --- a/test/srt/test_mamba_unittest.py +++ b/test/srt/test_mamba_unittest.py @@ -10,8 +10,7 @@ from sglang.srt.mem_cache.memory_pool import HybridLinearKVPool, HybridReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.sampling.sampling_params import SamplingParams - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +from sglang.srt.utils import get_device class TestMamba(unittest.TestCase): @@ -30,7 +29,7 @@ def test_hybrid_linear_kv_pool(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = device_type + device = get_device() full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] @@ -58,7 +57,7 @@ def test_mamba_pool(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = device_type + device = get_device() global_interval = 4 num_layers = 48 full_attention_layer_ids = [ @@ -136,7 +135,7 @@ def test_mamba_radix_cache_1(self): max_num_reqs = 10 mamba_cache_size = 20 max_context_len = 128 - device = device_type + device = get_device() full_attention_layer_ids = [ i for i in range(global_interval - 1, num_layers, global_interval) ] diff --git a/test/srt/test_modelopt_loader.py b/test/srt/test_modelopt_loader.py index f0171205e4e2..d133e2fc63e9 100644 --- a/test/srt/test_modelopt_loader.py +++ b/test/srt/test_modelopt_loader.py @@ -28,8 +28,7 @@ from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES from sglang.srt.model_loader.loader import ModelOptModelLoader from sglang.test.test_utils import CustomTestCase - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") +from sglang.srt.utils import get_device class TestModelOptModelLoader(CustomTestCase): @@ -66,7 +65,7 @@ def setUp(self): self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" self.load_config = LoadConfig() - self.device_config = DeviceConfig(device=device_type) + self.device_config = DeviceConfig(device=get_device()) # Create a basic model config with unified quantization flag self.model_config = ModelConfig( diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 7e51dd7034f6..5045167e46b7 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -48,12 +48,11 @@ get_gpu_memory_gb, get_gpu_rank, ) +from sglang.srt.utils import get_device # (temporarily) set to true to observe memory usage in nvidia-smi more clearly _DEBUG_EXTRA = False -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") - class TestReleaseMemoryOccupation(CustomTestCase): def _setup_engine( @@ -150,7 +149,7 @@ def test_release_and_resume_occupation(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=device_type, + device=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) @@ -305,7 +304,7 @@ def test_multi_stage_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=device_type, + device=get_device(), ) gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb() engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) @@ -384,7 +383,7 @@ def test_moe_model_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=device_type, + device=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index 818f8f940e3e..37cacfc524df 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -21,9 +21,7 @@ from sglang.srt.layers.attention.wave_ops.prefill_attention import ( prefill_attention_wave, ) - -device_type = getattr(torch.accelerator.current_accelerator(), "type", "cpu") -torch.set_default_device(device_type) +from sglang.srt.utils import get_device class TestWaveAttention(unittest.TestCase): @@ -46,24 +44,24 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_seq_len = 1024 b_seq_len_prefix = torch.full( - (B,), N_CTX // B, dtype=torch.int32, device=device_type + (B,), N_CTX // B, dtype=torch.int32, device=get_device() ) b_seq_len_extend = torch.full( - (B,), extend_seq_len, dtype=torch.int32, device=device_type + (B,), extend_seq_len, dtype=torch.int32, device=get_device() ) b_seq_len = b_seq_len_prefix + b_seq_len_extend max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - b_req_idx = torch.arange(B, dtype=torch.int32, device=device_type) - b_start_loc = torch.zeros((B,), dtype=torch.int32, device=device_type) + b_req_idx = torch.arange(B, dtype=torch.int32, device=get_device()) + b_start_loc = torch.zeros((B,), dtype=torch.int32, device=get_device()) b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=device_type) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device=get_device()) b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device_type) + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=get_device()) kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) kv_indices = torch.zeros( - (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=device_type + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device=get_device() ) for i in range(B): @@ -74,20 +72,20 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): total_token_num = torch.sum(b_seq_len).item() extend_token_num = torch.sum(b_seq_len_extend).item() k_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device=device_type + (total_token_num, H_KV, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) v_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device=device_type + (total_token_num, H_KV, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) k_extend = torch.empty( - (extend_token_num, H_KV, D), dtype=dtype, device=device_type + (extend_token_num, H_KV, D), dtype=dtype, device=get_device() ) v_extend = torch.empty( - (extend_token_num, H_KV, D), dtype=dtype, device=device_type + (extend_token_num, H_KV, D), dtype=dtype, device=get_device() ) q_extend = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) for i in range(B): extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] @@ -101,22 +99,22 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): extend_start_in_buffer:extend_end_in_buffer ] q_extend[extend_start:extend_end] = torch.empty( - (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=device_type + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device=get_device() ).normal_(mean=0.1, std=0.2) o_extend = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) o_extend_mask = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) o_redundant = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) b_seq_len_extend = b_seq_len - b_seq_len_prefix max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=device_type) + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device=get_device()) qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) custom_mask = None @@ -137,7 +135,7 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): is_causal = True o_extend = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) extend_attention_fwd( q_extend, @@ -156,7 +154,7 @@ def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): ) o_wave = torch.empty( - (extend_token_num, H_Q, D), dtype=dtype, device=device_type + (extend_token_num, H_Q, D), dtype=dtype, device=get_device() ) extend_attention_wave( q_extend, @@ -192,33 +190,33 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) max_kv_splits = 8 - num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=device_type) + num_kv_splits = torch.full((B,), 4, dtype=torch.int32, device=get_device()) # q represents the new token being generated, one per batch - q = torch.randn(B, H_Q, D, dtype=dtype, device=device_type) + q = torch.randn(B, H_Q, D, dtype=dtype, device=get_device()) # k_buffer and v_buffer represent all previous tokens - k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device_type) - v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device_type) + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=get_device()) + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=get_device()) # o will have the same shape as q - o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device_type) - o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device_type) + o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) - req_to_token = torch.arange(total_tokens, device=device_type, dtype=torch.int32) - b_req_idx = torch.zeros(B + 1, device=device_type, dtype=torch.int32) - b_seq_len = torch.full((B,), seq_len, device=device_type, dtype=torch.int32) + req_to_token = torch.arange(total_tokens, device=get_device(), dtype=torch.int32) + b_req_idx = torch.zeros(B + 1, device=get_device(), dtype=torch.int32) + b_seq_len = torch.full((B,), seq_len, device=get_device(), dtype=torch.int32) b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) attn_logits = torch.empty( (B, H_Q, max_kv_splits, D_V + 1), dtype=torch.float32, - device=device_type, + device=get_device(), ) attn_lse = torch.empty( (B, H_Q, max_kv_splits), dtype=torch.float32, - device=device_type, + device=get_device(), ) logit_cap = 0.0 @@ -244,13 +242,13 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): attn_logits = torch.empty( attn_logits_shape, dtype=torch.float32, - device=device_type, + device=get_device(), ) attn_logits_max = torch.empty( attn_logits_max_shape, dtype=torch.float32, - device=device_type, + device=get_device(), ) decode_attention_wave( @@ -300,24 +298,24 @@ def _test_context_attention_once(self, head_dim, is_causal): # Create random input tensors q = torch.randn( - sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() ) k = torch.randn( - sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=device_type + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=get_device() ) v = torch.randn( - sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=device_type + sum(seq_lens), kv_heads, head_dim, dtype=dtype, device=get_device() ) o_triton = torch.zeros( - sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() ) o = torch.zeros( - sum(seq_lens), num_heads, head_dim, dtype=dtype, device=device_type + sum(seq_lens), num_heads, head_dim, dtype=dtype, device=get_device() ) # Create b_start_loc and b_seq_len tensors - b_start_loc = torch.tensor([0, seq_lens[0]], device=device_type) - b_seq_len = torch.tensor(seq_lens, device=device_type) + b_start_loc = torch.tensor([0, seq_lens[0]], device=get_device()) + b_seq_len = torch.tensor(seq_lens, device=get_device()) context_attention_fwd( q, k, v, o_triton, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal From aecef41285d13f78657e3c6b94de6f98f9f910f6 Mon Sep 17 00:00:00 2001 From: jundu Date: Fri, 21 Nov 2025 07:47:46 +0000 Subject: [PATCH 03/13] add Rocm empty cache --- python/sglang/test/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 39fb7c6fc40e..34a9f02b7e69 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1937,8 +1937,10 @@ def get_gpu_count(): def empty_gpu_cache(): if is_xpu(): torch.xpu.empty_cache() - elif is_cuda(): + elif is_cuda() or is_rocm(): torch.cuda.empty_cache() + else: + print("There is no suitable GPU") def get_gpu_memory_gb(): From c1be89913e2fda6f4e8da85a24badc4ccd63e56d Mon Sep 17 00:00:00 2001 From: jundu Date: Mon, 24 Nov 2025 02:06:06 +0000 Subject: [PATCH 04/13] pre-commit reformatted --- conftest.py | 1 + python/sglang/test/runners.py | 3 +-- python/sglang/test/test_utils.py | 1 - test/manual/test_expert_location_updater.py | 2 +- test/manual/test_forward_split_prefill.py | 2 +- test/manual/test_get_weights_by_name.py | 4 ++-- test/manual/test_triton_moe_wna16.py | 1 - test/nightly/test_batch_invariant_ops.py | 2 -- test/srt/layers/attention/mamba/test_causal_conv1d.py | 2 +- test/srt/test_create_kvindices.py | 6 ++++-- test/srt/test_fused_moe.py | 2 +- test/srt/test_gptqmodel_dynamic.py | 2 +- test/srt/test_hidden_states.py | 2 +- test/srt/test_modelopt_loader.py | 2 +- test/srt/test_release_memory_occupation.py | 3 +-- test/srt/test_wave_attention_kernels.py | 8 ++++++-- 16 files changed, 22 insertions(+), 21 deletions(-) diff --git a/conftest.py b/conftest.py index ab86886c1666..a4a5121b192a 100644 --- a/conftest.py +++ b/conftest.py @@ -1,5 +1,6 @@ import pytest import torch + from sglang.srt.utils import get_device diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 7ee8ef640e48..3dff79487c45 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -31,11 +31,10 @@ ) from sglang.srt.entrypoints.engine import Engine -from sglang.srt.utils import load_image, get_device +from sglang.srt.utils import get_device, load_image from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l - DEFAULT_PROMPTS = [ "Apple is red. Banana is Yellow. " * 800 + "Apple is", "The capital of the United Kingdom is", diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a9fd0a12fbe2..43fe74707c28 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1950,4 +1950,3 @@ def get_gpu_memory_gb(): return torch.xpu.memory_allocated() / 1024**3 else: return 0 - diff --git a/test/manual/test_expert_location_updater.py b/test/manual/test_expert_location_updater.py index 3f3be519e5ca..513205e72ff1 100644 --- a/test/manual/test_expert_location_updater.py +++ b/test/manual/test_expert_location_updater.py @@ -10,9 +10,9 @@ from torch.multiprocessing import Process from sglang.srt.eplb import expert_location_updater +from sglang.srt.utils import get_device from sglang.test.test_utils import CustomTestCase, find_available_port from sglang.utils import is_in_ci -from sglang.srt.utils import get_device @dataclass diff --git a/test/manual/test_forward_split_prefill.py b/test/manual/test_forward_split_prefill.py index 7e8161d0476b..1dc58900b4e7 100644 --- a/test/manual/test_forward_split_prefill.py +++ b/test/manual/test_forward_split_prefill.py @@ -20,8 +20,8 @@ from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.srt.utils import get_device +from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase diff --git a/test/manual/test_get_weights_by_name.py b/test/manual/test_get_weights_by_name.py index 509b00872c92..fdfc073c7e11 100644 --- a/test/manual/test_get_weights_by_name.py +++ b/test/manual/test_get_weights_by_name.py @@ -7,16 +7,16 @@ from transformers import AutoModelForCausalLM import sglang as sgl -from sglang.srt.utils import is_cuda, is_xpu, get_device +from sglang.srt.utils import get_device, is_cuda, is_xpu from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + get_gpu_count, is_in_ci, popen_launch_server, - get_gpu_count ) from sglang.utils import terminate_process diff --git a/test/manual/test_triton_moe_wna16.py b/test/manual/test_triton_moe_wna16.py index 2a079b30a308..35983a04c240 100644 --- a/test/manual/test_triton_moe_wna16.py +++ b/test/manual/test_triton_moe_wna16.py @@ -9,7 +9,6 @@ from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler from sglang.srt.utils import get_device - NUM_EXPERTS = [8, 64] TOP_KS = [2, 6] diff --git a/test/nightly/test_batch_invariant_ops.py b/test/nightly/test_batch_invariant_ops.py index a7ac80ae3afc..80e373bfbc5f 100644 --- a/test/nightly/test_batch_invariant_ops.py +++ b/test/nightly/test_batch_invariant_ops.py @@ -7,8 +7,6 @@ from sglang.srt.batch_invariant_ops import batch_invariant_ops from sglang.srt.batch_invariant_ops.batch_invariant_ops import set_batch_invariant_mode from sglang.test.test_utils import CustomTestCase -from sglang.srt.utils import get_device - # Just to get the logging out of the way with set_batch_invariant_mode(True): diff --git a/test/srt/layers/attention/mamba/test_causal_conv1d.py b/test/srt/layers/attention/mamba/test_causal_conv1d.py index 8c645590442d..f59f7307b18d 100644 --- a/test/srt/layers/attention/mamba/test_causal_conv1d.py +++ b/test/srt/layers/attention/mamba/test_causal_conv1d.py @@ -13,8 +13,8 @@ causal_conv1d_fn, causal_conv1d_update, ) -from sglang.test.test_utils import empty_gpu_cache from sglang.srt.utils import get_device +from sglang.test.test_utils import empty_gpu_cache def causal_conv1d_ref( diff --git a/test/srt/test_create_kvindices.py b/test/srt/test_create_kvindices.py index 0370c88343d6..b2da977378b8 100644 --- a/test/srt/test_create_kvindices.py +++ b/test/srt/test_create_kvindices.py @@ -4,8 +4,8 @@ import torch from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton -from sglang.test.test_utils import CustomTestCase from sglang.srt.utils import get_device +from sglang.test.test_utils import CustomTestCase class TestCreateKvIndices(CustomTestCase): @@ -47,7 +47,9 @@ def _run_test(self, batch, max_batch, max_context_len): ).contiguous() # triton - kv_indices_triton = torch.empty(kv_indptr[-1], dtype=torch.int32, device=get_device()) + kv_indices_triton = torch.empty( + kv_indptr[-1], dtype=torch.int32, device=get_device() + ) create_flashinfer_kv_indices_triton[(batch,)]( req_to_token, req_pool_indices, diff --git a/test/srt/test_fused_moe.py b/test/srt/test_fused_moe.py index 0ef32dcf3807..8e59a2d6315b 100644 --- a/test/srt/test_fused_moe.py +++ b/test/srt/test_fused_moe.py @@ -9,7 +9,7 @@ from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz from sglang.srt.server_args import ServerArgs, set_global_server_args_for_scheduler -from sglang.srt.utils import is_hip, get_device, get_device_capability +from sglang.srt.utils import get_device, get_device_capability, is_hip from sglang.test.test_utils import CustomTestCase, empty_gpu_cache _is_hip = is_hip() diff --git a/test/srt/test_gptqmodel_dynamic.py b/test/srt/test_gptqmodel_dynamic.py index 7bd0d79c07ab..a95c4b9e4b24 100644 --- a/test/srt/test_gptqmodel_dynamic.py +++ b/test/srt/test_gptqmodel_dynamic.py @@ -5,7 +5,7 @@ import torch from sglang.srt.server_args import set_global_server_args_for_scheduler -from sglang.srt.utils import kill_process_tree, get_device +from sglang.srt.utils import get_device, kill_process_tree from sglang.test.test_utils import ( DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, diff --git a/test/srt/test_hidden_states.py b/test/srt/test_hidden_states.py index 621645008070..81c72c955cac 100644 --- a/test/srt/test_hidden_states.py +++ b/test/srt/test_hidden_states.py @@ -4,8 +4,8 @@ from transformers import AutoModelForCausalLM, AutoTokenizer import sglang as sgl -from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase from sglang.srt.utils import get_device +from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST, CustomTestCase class TestHiddenState(CustomTestCase): diff --git a/test/srt/test_modelopt_loader.py b/test/srt/test_modelopt_loader.py index d133e2fc63e9..56b6cd452885 100644 --- a/test/srt/test_modelopt_loader.py +++ b/test/srt/test_modelopt_loader.py @@ -27,8 +27,8 @@ from sglang.srt.configs.model_config import ModelConfig from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES from sglang.srt.model_loader.loader import ModelOptModelLoader -from sglang.test.test_utils import CustomTestCase from sglang.srt.utils import get_device +from sglang.test.test_utils import CustomTestCase class TestModelOptModelLoader(CustomTestCase): diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 5045167e46b7..c0d05be4aa1f 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -29,7 +29,6 @@ import time import unittest -import torch from transformers import AutoModelForCausalLM import sglang as sgl @@ -38,6 +37,7 @@ GPU_MEMORY_TYPE_KV_CACHE, GPU_MEMORY_TYPE_WEIGHTS, ) +from sglang.srt.utils import get_device from sglang.test.test_utils import ( DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, @@ -48,7 +48,6 @@ get_gpu_memory_gb, get_gpu_rank, ) -from sglang.srt.utils import get_device # (temporarily) set to true to observe memory usage in nvidia-smi more clearly _DEBUG_EXTRA = False diff --git a/test/srt/test_wave_attention_kernels.py b/test/srt/test_wave_attention_kernels.py index 37cacfc524df..420ab63d2df6 100644 --- a/test/srt/test_wave_attention_kernels.py +++ b/test/srt/test_wave_attention_kernels.py @@ -197,13 +197,17 @@ def _test_grouped_decode_attention_once(self, B, S, H_Q, H_KV, D, D_V): # k_buffer and v_buffer represent all previous tokens k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=get_device()) - v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=get_device()) + v_buffer = torch.randn( + total_tokens, H_KV, D_V, dtype=dtype, device=get_device() + ) # o will have the same shape as q o_triton = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=get_device()) - req_to_token = torch.arange(total_tokens, device=get_device(), dtype=torch.int32) + req_to_token = torch.arange( + total_tokens, device=get_device(), dtype=torch.int32 + ) b_req_idx = torch.zeros(B + 1, device=get_device(), dtype=torch.int32) b_seq_len = torch.full((B,), seq_len, device=get_device(), dtype=torch.int32) b_req_idx[1 : B + 1] = torch.cumsum(b_seq_len, dim=0) From 40968f4e1e7354f4b795929d0782d9e267136125 Mon Sep 17 00:00:00 2001 From: jundu Date: Mon, 24 Nov 2025 08:58:32 +0000 Subject: [PATCH 05/13] Made the code simpler and fixed the import issue --- python/sglang/srt/layers/moe/topk.py | 2 +- python/sglang/test/test_utils.py | 19 ++++++------------- conftest.py => test/srt/conftest.py | 0 3 files changed, 7 insertions(+), 14 deletions(-) rename conftest.py => test/srt/conftest.py (100%) diff --git a/python/sglang/srt/layers/moe/topk.py b/python/sglang/srt/layers/moe/topk.py index 421faa1b8f6c..e09327cd8313 100644 --- a/python/sglang/srt/layers/moe/topk.py +++ b/python/sglang/srt/layers/moe/topk.py @@ -73,7 +73,7 @@ _is_npu = is_npu() _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip -if _is_cuda or _is_hip or _is_xpu: +if _is_cuda: from sgl_kernel import kimi_k2_moe_fused_gate, moe_fused_gate @torch.library.register_fake("sgl_kernel::kimi_k2_moe_fused_gate") diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 43fe74707c28..4467f10a2975 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -35,7 +35,6 @@ get_device, is_cuda, is_port_available, - is_rocm, is_xpu, kill_process_tree, retry, @@ -1923,24 +1922,18 @@ def wrapper(self): def get_gpu_count(): - if is_xpu(): - gpu_count = torch.xpu.device_count() - elif is_cuda(): - gpu_count = torch.cuda.device_count() - elif is_rocm(): - gpu_count = torch.rocm.device_count() - else: + if get_device() == "cpu": gpu_count = 0 + else: + gpu_count = torch.accelerator.device_count() return gpu_count def empty_gpu_cache(): - if is_xpu(): - torch.xpu.empty_cache() - elif is_cuda() or is_rocm(): - torch.cuda.empty_cache() - else: + if get_device() == "cpu": print("There is no suitable GPU") + else: + torch.accelerator.empty_cache() def get_gpu_memory_gb(): diff --git a/conftest.py b/test/srt/conftest.py similarity index 100% rename from conftest.py rename to test/srt/conftest.py From 544f29c6202117ee0771fbf2000a4a14bac8d6f4 Mon Sep 17 00:00:00 2001 From: jundu Date: Tue, 25 Nov 2025 01:07:11 +0000 Subject: [PATCH 06/13] Fix the issue where torch.accelerator doesn't support empty_cache when the version < 2.9 --- python/sglang/test/test_utils.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 4467f10a2975..91dc7b4463a9 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1930,10 +1930,24 @@ def get_gpu_count(): def empty_gpu_cache(): - if get_device() == "cpu": - print("There is no suitable GPU") - else: - torch.accelerator.empty_cache() + """ + Unified empty_cache for PyTorch 2.8 (no torch.accelerator) + and PyTorch 2.9+ (where torch.accelerator.empty_cache() exists). + """ + if hasattr(torch, "accelerator") and hasattr(torch.accelerator, "empty_cache"): + return torch.accelerator.empty_cache() + + # CUDA + if hasattr(torch, "cuda") and torch.cuda.is_available(): + torch.cuda.empty_cache() + return + + # XPU (Intel) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + torch.xpu.empty_cache() + return + + return def get_gpu_memory_gb(): From 630c70dd0599cc3947b8b965fe2ac0bc3443a6ff Mon Sep 17 00:00:00 2001 From: jundu Date: Tue, 25 Nov 2025 02:39:15 +0000 Subject: [PATCH 07/13] fix parameter command error --- test/manual/test_get_weights_by_name.py | 9 +++------ test/srt/test_release_memory_occupation.py | 6 +++--- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/test/manual/test_get_weights_by_name.py b/test/manual/test_get_weights_by_name.py index fdfc073c7e11..fa97c7df8070 100644 --- a/test/manual/test_get_weights_by_name.py +++ b/test/manual/test_get_weights_by_name.py @@ -3,17 +3,17 @@ import numpy as np import requests -import torch from transformers import AutoModelForCausalLM import sglang as sgl -from sglang.srt.utils import get_device, is_cuda, is_xpu +from sglang.srt.utils import get_device from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, DEFAULT_SMALL_MODEL_NAME_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, + empty_gpu_cache, get_gpu_count, is_in_ci, popen_launch_server, @@ -63,10 +63,7 @@ def init_backend(self, backend, dp, tp, model_name): def clean_up(self): del self.hf_model gc.collect() - if is_cuda(): - torch.cuda.empty_cache() - elif is_xpu(): - torch.xpu.empty_cache() + empty_gpu_cache() if self.backend == "Engine": self.engine.shutdown() else: diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index c0d05be4aa1f..37d00314152a 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -148,7 +148,7 @@ def test_release_and_resume_occupation(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=get_device(), + device_map=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) @@ -303,7 +303,7 @@ def test_multi_stage_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=get_device(), + device_map=get_device(), ) gpu_memory_usage_after_loaded_hf_model = get_gpu_memory_gb() engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) @@ -382,7 +382,7 @@ def test_moe_model_release_and_resume(self): hf_model_new = AutoModelForCausalLM.from_pretrained( DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_BASE, torch_dtype="bfloat16", - device=get_device(), + device_map=get_device(), ) engine.update_weights_from_tensor(list(hf_model_new.named_parameters())) From 4add2ca609d389c06be1731378f96ef97c65a3bf Mon Sep 17 00:00:00 2001 From: jundu Date: Tue, 25 Nov 2025 09:02:36 +0000 Subject: [PATCH 08/13] fix parameter command error --- test/srt/test_release_memory_occupation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/srt/test_release_memory_occupation.py b/test/srt/test_release_memory_occupation.py index 37d00314152a..1a27e0dba8c3 100644 --- a/test/srt/test_release_memory_occupation.py +++ b/test/srt/test_release_memory_occupation.py @@ -45,8 +45,8 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST_CHAT, CustomTestCase, empty_gpu_cache, + get_gpu_count, get_gpu_memory_gb, - get_gpu_rank, ) # (temporarily) set to true to observe memory usage in nvidia-smi more clearly @@ -105,7 +105,7 @@ def _test_initial_generation( def test_release_and_resume_occupation(self): # Without multi-stage release and resume, we need to carefully control the memory fraction to avoid OOM model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST - assert get_gpu_rank() >= 2, "Need at least 2 GPUs for tensor parallel tests" + assert get_gpu_count() >= 2, "Need at least 2 GPUs for tensor parallel tests" for tp_size in [1, 2]: @@ -215,7 +215,7 @@ def test_multi_stage_release_and_resume(self): model_name = DEFAULT_SMALL_MODEL_NAME_FOR_TEST for tp_size in [1, 2]: - if tp_size == 2 and get_gpu_rank() < 2: + if tp_size == 2 and get_gpu_count() < 2: continue print(f"Testing tp_size={tp_size} for test_multi_stage_release_and_resume") From 4a369177aa6e06255bb38066da369c9ca410cda7 Mon Sep 17 00:00:00 2001 From: dujun Date: Mon, 8 Dec 2025 02:03:29 +0000 Subject: [PATCH 09/13] remove the global device setting to avoid some unkown errors --- test/srt/conftest.py | 10 ---------- 1 file changed, 10 deletions(-) delete mode 100644 test/srt/conftest.py diff --git a/test/srt/conftest.py b/test/srt/conftest.py deleted file mode 100644 index a4a5121b192a..000000000000 --- a/test/srt/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest -import torch - -from sglang.srt.utils import get_device - - -@pytest.fixture(scope="session", autouse=True) -def setup_session(): - torch.set_default_device(get_device()) - yield From 77ba06292bfc0b18df870ae5bd5656680acff24a Mon Sep 17 00:00:00 2001 From: dujun Date: Mon, 8 Dec 2025 02:12:48 +0000 Subject: [PATCH 10/13] Avoid falling back to CPU when the device is not found --- python/sglang/test/test_utils.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 91dc7b4463a9..0f2e5f261728 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1176,11 +1176,6 @@ def run_bench_one_batch(model, other_args): device: Device type ("auto", "cuda", "rocm" or "cpu"). If "auto", will detect available platforms automatically. """ - # Auto-detect device if needed - - device = auto_config_device() - print(f"Auto-configed device: {device}", flush=True) - other_args += ["--device", str(device)] command = [ "python3", From e146d0e8ec87da50b6b79be5a6c2adef0246e46c Mon Sep 17 00:00:00 2001 From: "Gao, Pengfei" Date: Wed, 10 Dec 2025 22:30:57 -0800 Subject: [PATCH 11/13] add xpu support for ut --- test/srt/quant/test_awq_dequant.py | 3 ++- test/srt/rl/test_fp32_lm_head.py | 9 +++++---- test/srt/test_swa_unittest.py | 7 ++++--- 3 files changed, 11 insertions(+), 8 deletions(-) diff --git a/test/srt/quant/test_awq_dequant.py b/test/srt/quant/test_awq_dequant.py index ec1f2b16a3d2..6cd1c95d41a0 100644 --- a/test/srt/quant/test_awq_dequant.py +++ b/test/srt/quant/test_awq_dequant.py @@ -17,8 +17,9 @@ awq_gemm_triton, ) from sglang.test.test_utils import CustomTestCase +from sglang.srt.utils import get_device -device = "cuda" +device = get_device() def reverse_awq_order(t: torch.Tensor) -> torch.Tensor: diff --git a/test/srt/rl/test_fp32_lm_head.py b/test/srt/rl/test_fp32_lm_head.py index cf6dd28398f1..fec15c789331 100644 --- a/test/srt/rl/test_fp32_lm_head.py +++ b/test/srt/rl/test_fp32_lm_head.py @@ -12,10 +12,11 @@ get_global_server_args, set_global_server_args_for_scheduler, ) +from sglang.srt.utils import get_device class LMHeadStub(nn.Module): - def __init__(self, vocab, hidden, dtype, device="cuda"): + def __init__(self, vocab, hidden, dtype, device=get_device()): super().__init__() self.weight = nn.Parameter( torch.randn(vocab, hidden, dtype=dtype, device=device) @@ -32,8 +33,8 @@ def compute_dp_attention_metadata(self): ... class TestLMHeadFP32(unittest.TestCase): @classmethod def setUpClass(cls): - if not torch.cuda.is_available(): - raise unittest.SkipTest("needs CUDA GPU") + if not torch.cuda.is_available() and not(hasattr(torch, "xpu") and torch.xpu.is_available()): + raise unittest.SkipTest("needs CUDA GPU or XPU") def _make_logprocessor(self, vocab_size, enable_fp32): set_global_server_args_for_scheduler(ServerArgs(model_path="dummy")) @@ -50,7 +51,7 @@ def _run_case( expected_a_dtype, expected_b_dtype, ): - device = "cuda" + device = get_device() BATCH_SIZE, HIDDEN_SIZE, VOCAB_SIZE = 2, 64, 128 hidden_state = torch.randn( BATCH_SIZE, HIDDEN_SIZE, dtype=hidden_state_dtype, device=device diff --git a/test/srt/test_swa_unittest.py b/test/srt/test_swa_unittest.py index 2d01f90bd05d..ccebe0625c18 100644 --- a/test/srt/test_swa_unittest.py +++ b/test/srt/test_swa_unittest.py @@ -7,6 +7,7 @@ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool from sglang.srt.mem_cache.radix_cache import RadixKey from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache +from sglang.srt.utils import get_device class TestSWA(unittest.TestCase): @@ -26,7 +27,7 @@ def test_swa_memory_pool(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ @@ -75,7 +76,7 @@ def test_swa_radix_cache_1(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ @@ -208,7 +209,7 @@ def test_swa_radix_cache_eagle(self): num_layers = 48 global_interval = 4 dtype = torch.bfloat16 - device = "cuda" + device = get_device() full_attention_layer_ids = [i for i in range(0, num_layers, global_interval)] full_attention_layer_ids_set = set(full_attention_layer_ids) swa_attention_layer_ids = [ From ad855f28970489a7b705f18772cfab0d6707a530 Mon Sep 17 00:00:00 2001 From: "Du, Jun" Date: Mon, 22 Dec 2025 03:09:00 +0000 Subject: [PATCH 12/13] fix lint check issue --- python/sglang/test/runners.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index 839715a5ff43..79dce8a4207d 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -32,7 +32,6 @@ from sglang.srt.entrypoints.engine import Engine from sglang.srt.utils import get_device, is_npu, load_image - from sglang.srt.utils.hf_transformers_utils import get_tokenizer from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l From 1bd4d2e45fc9f9b888f8c9ab9cd1f6d2c530471f Mon Sep 17 00:00:00 2001 From: dujun Date: Fri, 30 Jan 2026 07:47:03 +0000 Subject: [PATCH 13/13] fallback automatic device detection --- python/sglang/test/test_utils.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index f592742749a7..7be56b3783a8 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -1476,6 +1476,11 @@ def run_bench_one_batch(model, other_args): device: Device type ("auto", "cuda", "rocm" or "cpu"). If "auto", will detect available platforms automatically. """ + # Auto-detect device if needed + + device = auto_config_device() + print(f"Auto-configed device: {device}", flush=True) + other_args += ["--device", str(device)] command = [ "python3",