Skip to content

Commit 16eb33f

Browse files
authored
Update vocab embedding deps and add TP switch (#1856)
1 parent 61cf00e commit 16eb33f

31 files changed

+602
-101
lines changed

python/sglang/srt/layers/quantization/base_config.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
22

3+
import inspect
34
from abc import ABC, abstractmethod
4-
from typing import Any, Dict, List, Optional
5+
from typing import Any, Dict, List, Optional, Type
56

67
import torch
78
from torch import nn
@@ -120,3 +121,17 @@ def get_scaled_act_names(self) -> List[str]:
120121
For now, this is only used by AWQ.
121122
"""
122123
raise NotImplementedError
124+
125+
def method_has_implemented_embedding(
126+
method_class: Type[QuantizeMethodBase]) -> bool:
127+
"""
128+
Not all quant methods have embedding implemented, so we need to check that
129+
it exists for our given method. We check this by making sure the function
130+
has been changed from the base implementation.
131+
"""
132+
base_embedding = inspect.getattr_static(QuantizeMethodBase, "embedding",
133+
None)
134+
class_embedding = inspect.getattr_static(method_class, "embedding", None)
135+
136+
return (class_embedding is not None
137+
and class_embedding is not base_embedding)

python/sglang/srt/layers/vocab_parallel_embedding.py

Lines changed: 486 additions & 0 deletions
Large diffs are not rendered by default.

python/sglang/srt/models/baichuan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@
3434
RowParallelLinear,
3535
)
3636
from vllm.model_executor.layers.rotary_embedding import get_rope
37-
from vllm.model_executor.layers.vocab_parallel_embedding import (
38-
ParallelLMHead,
39-
VocabParallelEmbedding,
40-
)
4137
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
4238

4339
from sglang.srt.layers.activation import SiluAndMul
4440
from sglang.srt.layers.layernorm import RMSNorm
4541
from sglang.srt.layers.logits_processor import LogitsProcessor
4642
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4743
from sglang.srt.layers.radix_attention import RadixAttention
44+
from sglang.srt.layers.vocab_parallel_embedding import (
45+
ParallelLMHead,
46+
VocabParallelEmbedding,
47+
)
4848
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4949

5050

python/sglang/srt/models/chatglm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,6 @@
2424
from torch.nn import LayerNorm
2525
from vllm.distributed import get_tensor_model_parallel_world_size
2626
from vllm.model_executor.layers.rotary_embedding import get_rope
27-
from vllm.model_executor.layers.vocab_parallel_embedding import (
28-
ParallelLMHead,
29-
VocabParallelEmbedding,
30-
)
3127
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3228
from vllm.transformers_utils.configs import ChatGLMConfig
3329

@@ -41,6 +37,10 @@
4137
from sglang.srt.layers.logits_processor import LogitsProcessor
4238
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4339
from sglang.srt.layers.radix_attention import RadixAttention
40+
from sglang.srt.layers.vocab_parallel_embedding import (
41+
ParallelLMHead,
42+
VocabParallelEmbedding,
43+
)
4444
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4545

4646
LoraConfig = None

python/sglang/srt/models/commandr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
get_tensor_model_parallel_world_size,
5151
)
5252
from vllm.model_executor.layers.rotary_embedding import get_rope
53-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
5453
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
5554

5655
from sglang.srt.layers.activation import SiluAndMul
@@ -62,6 +61,7 @@
6261
from sglang.srt.layers.logits_processor import LogitsProcessor
6362
from sglang.srt.layers.quantization.base_config import QuantizationConfig
6463
from sglang.srt.layers.radix_attention import RadixAttention
64+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
6565
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
6666
from sglang.srt.utils import set_weight_attrs
6767

python/sglang/srt/models/dbrx.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,6 @@
2727
)
2828
from vllm.model_executor.layers.fused_moe import fused_moe
2929
from vllm.model_executor.layers.rotary_embedding import get_rope
30-
from vllm.model_executor.layers.vocab_parallel_embedding import (
31-
DEFAULT_VOCAB_PADDING_SIZE,
32-
ParallelLMHead,
33-
VocabParallelEmbedding,
34-
)
3530
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3631
from vllm.transformers_utils.configs.dbrx import DbrxConfig
3732

@@ -43,6 +38,11 @@
4338
from sglang.srt.layers.logits_processor import LogitsProcessor
4439
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4540
from sglang.srt.layers.radix_attention import RadixAttention
41+
from sglang.srt.layers.vocab_parallel_embedding import (
42+
DEFAULT_VOCAB_PADDING_SIZE,
43+
ParallelLMHead,
44+
VocabParallelEmbedding,
45+
)
4646
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4747
from sglang.srt.utils import set_weight_attrs
4848

python/sglang/srt/models/deepseek.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@
2828
)
2929
from vllm.model_executor.layers.fused_moe import fused_moe
3030
from vllm.model_executor.layers.rotary_embedding import get_rope
31-
from vllm.model_executor.layers.vocab_parallel_embedding import (
32-
ParallelLMHead,
33-
VocabParallelEmbedding,
34-
)
3531
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3632

3733
from sglang.srt.layers.activation import SiluAndMul
@@ -45,6 +41,10 @@
4541
from sglang.srt.layers.logits_processor import LogitsProcessor
4642
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4743
from sglang.srt.layers.radix_attention import RadixAttention
44+
from sglang.srt.layers.vocab_parallel_embedding import (
45+
ParallelLMHead,
46+
VocabParallelEmbedding,
47+
)
4848
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4949

5050

python/sglang/srt/models/deepseek_v2.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,6 @@
2727
)
2828
from vllm.model_executor.layers.fused_moe import FusedMoE
2929
from vllm.model_executor.layers.rotary_embedding import get_rope
30-
from vllm.model_executor.layers.vocab_parallel_embedding import (
31-
ParallelLMHead,
32-
VocabParallelEmbedding,
33-
)
3430
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3531

3632
from sglang.srt.layers.activation import SiluAndMul
@@ -44,6 +40,10 @@
4440
from sglang.srt.layers.logits_processor import LogitsProcessor
4541
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4642
from sglang.srt.layers.radix_attention import RadixAttention
43+
from sglang.srt.layers.vocab_parallel_embedding import (
44+
ParallelLMHead,
45+
VocabParallelEmbedding,
46+
)
4747
from sglang.srt.managers.schedule_batch import global_server_args_dict
4848
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4949
from sglang.srt.utils import is_flashinfer_available

python/sglang/srt/models/exaone.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,6 @@
2323
from torch import nn
2424
from vllm.distributed import get_tensor_model_parallel_world_size
2525
from vllm.model_executor.layers.rotary_embedding import get_rope
26-
from vllm.model_executor.layers.vocab_parallel_embedding import (
27-
ParallelLMHead,
28-
VocabParallelEmbedding,
29-
)
3026
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
3127

3228
from sglang.srt.layers.activation import SiluAndMul
@@ -39,6 +35,10 @@
3935
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
4036
from sglang.srt.layers.quantization.base_config import QuantizationConfig
4137
from sglang.srt.layers.radix_attention import RadixAttention
38+
from sglang.srt.layers.vocab_parallel_embedding import (
39+
ParallelLMHead,
40+
VocabParallelEmbedding,
41+
)
4242
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4343

4444

python/sglang/srt/models/gemma.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from vllm.config import LoRAConfig
2525
from vllm.distributed import get_tensor_model_parallel_world_size
2626
from vllm.model_executor.layers.rotary_embedding import get_rope
27-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
2827
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
2928

3029
from sglang.srt.layers.activation import GeluAndMul
@@ -37,6 +36,7 @@
3736
from sglang.srt.layers.logits_processor import LogitsProcessor
3837
from sglang.srt.layers.quantization.base_config import QuantizationConfig
3938
from sglang.srt.layers.radix_attention import RadixAttention
39+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
4040
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4141

4242

0 commit comments

Comments
 (0)