Skip to content

Commit 7cb8a5f

Browse files
committed
unify method & spec_method to method to avoid bug
1 parent 2c1a59b commit 7cb8a5f

20 files changed

Lines changed: 40 additions & 45 deletions

fastdeploy/config.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -788,7 +788,6 @@ def _init_from_defaults(self):
788788
"""Initialize all config options from class defaults."""
789789
for key, value in self._DEFAULTS.items():
790790
setattr(self, key, value)
791-
self.spec_method = None # Will be set during validation
792791

793792
def _apply_user_args(self, args: Dict[str, Any]):
794793
"""Apply user-provided arguments."""
@@ -822,14 +821,14 @@ def _convert_and_validate(self):
822821
"""
823822
Convert string configs to enums and validate all parameters.
824823
"""
825-
# Parse spec_method from string to enum using the new from_string method
824+
# Convert method from string to SpecMethod enum
826825
if self.method is not None:
827826
from fastdeploy.spec_decode import SpecMethod
828827

829-
self.spec_method = SpecMethod.from_string(self.method)
828+
self.method = SpecMethod.from_string(self.method)
830829

831830
# Set method-specific computed values
832-
if self.spec_method == SpecMethod.MTP:
831+
if self.method == SpecMethod.MTP:
833832
self.num_extra_cache_layer = 1
834833

835834
# Run validation (includes dependency validation)
@@ -886,15 +885,15 @@ def check_legality_parameters(
886885
m.value for m in SpecMethod
887886
], f"speculative method only support {[m.value for m in SpecMethod]} now, but get {self.method}."
888887

889-
if self.spec_method != SpecMethod.NAIVE:
888+
if self.method != SpecMethod.NAIVE:
890889
assert (
891890
self.num_speculative_tokens >= 1 and self.num_speculative_tokens <= 5
892891
), f"num_speculative_tokens only support in range[1, 5], but get {self.num_speculative_tokens}."
893892
assert (
894893
self.num_model_steps >= 1 and self.num_model_steps <= 5
895894
), f"num_model_steps only support in range[1, 5], but get {self.num_model_steps}."
896895

897-
if self.spec_method == SpecMethod.MTP:
896+
if self.method == SpecMethod.MTP:
898897
if self.num_speculative_tokens < self.num_model_steps:
899898
logger.warning(
900899
f"Get num_model_steps > num_speculative_tokens. Reset num_speculative_tokens to {self.num_model_steps}"
@@ -968,8 +967,8 @@ def _validate_dependencies(self) -> None:
968967
],
969968
}
970969

971-
if self.spec_method in constraints:
972-
method_constraints = constraints[self.spec_method]
970+
if self.method in constraints:
971+
method_constraints = constraints[self.method]
973972
for constraint in method_constraints:
974973
if not constraint["check"]():
975974
if constraint["auto_fix"] is not None:
@@ -1820,7 +1819,7 @@ def __init__(
18201819

18211820
# Initialize cuda graph capture list
18221821
max_capture_shape = self.scheduler_config.max_num_seqs
1823-
if self.speculative_config is not None and self.speculative_config.spec_method in [
1822+
if self.speculative_config is not None and self.speculative_config.method in [
18241823
SpecMethod.MTP,
18251824
SpecMethod.SUFFIX,
18261825
]:
@@ -2054,7 +2053,7 @@ def postprocess(self):
20542053
)
20552054

20562055
# adjust speculative config
2057-
if self.speculative_config is not None and self.speculative_config.spec_method == SpecMethod.MTP:
2056+
if self.speculative_config is not None and self.speculative_config.method == SpecMethod.MTP:
20582057
if self.scheduler_config.splitwise_role == "prefill":
20592058
self.speculative_config.num_speculative_tokens = 1
20602059
self.speculative_config.num_model_steps = 1

fastdeploy/engine/common_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -564,7 +564,7 @@ def _insert_prefilled_requests(self, request_outputs: List[RequestOutput]):
564564
cur_req.metrics = req_out.metrics
565565
cur_req.metrics.decode_inference_start_time = time.time()
566566
if (
567-
self.cfg.speculative_config.spec_method == SpecMethod.MTP
567+
self.cfg.speculative_config.method == SpecMethod.MTP
568568
and self.cfg.scheduler_config.splitwise_role == "decode"
569569
):
570570
cur_req.draft_token_ids = copy.deepcopy(req_out.outputs.draft_token_ids)

fastdeploy/engine/sched/resource_manager_v1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1361,7 +1361,7 @@ def add_prefilled_request(self, request_output: RequestOutput):
13611361
request.output_token_ids.append(request_output.outputs.token_ids[0])
13621362
request.num_cached_tokens = request_output.num_cached_tokens
13631363
if (
1364-
self.config.speculative_config.spec_method == SpecMethod.MTP
1364+
self.config.speculative_config.method == SpecMethod.MTP
13651365
and self.config.scheduler_config.splitwise_role == "decode"
13661366
):
13671367
request.draft_token_ids = copy.deepcopy(request_output.outputs.draft_token_ids)

fastdeploy/model_executor/layers/attention/append_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,10 +144,10 @@ def __init__(
144144
if fd_config.speculative_config.model_type != "main":
145145
self.rope_3d = False
146146
self.causal: bool = getattr(fd_config.model_config, "causal", True)
147-
self.speculative_method: str = fd_config.speculative_config.method
147+
self.speculative_method = fd_config.speculative_config.method
148148
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
149149
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
150-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
150+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
151151

152152
self.kv_num_heads: int = kv_num_heads
153153
self.num_heads: int = num_heads

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,11 +245,11 @@ def __init__(
245245
self.encoder_block_shape_q: int = encoder_block_shape_q
246246
self.decoder_block_shape_q: int = decoder_block_shape_q
247247

248-
self.speculative_method = fd_config.speculative_config.spec_method
248+
self.speculative_method = fd_config.speculative_config.method
249249
self.use_speculate = self.speculative_method is not None
250250
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
251251
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
252-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
252+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
253253

254254
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
255255

fastdeploy/model_executor/layers/attention/flash_mask_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def __init__(
103103
self.encoder_block_shape_q: int = encoder_block_shape_q
104104
self.decoder_block_shape_q: int = decoder_block_shape_q
105105

106-
self.speculative_method = fd_config.speculative_config.spec_method
106+
self.speculative_method = fd_config.speculative_config.method
107107
self.use_speculate = self.speculative_method is not None
108108
self.speculate_max_draft_token_num = fd_config.speculative_config.num_speculative_tokens
109109
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
110-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
110+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
111111

112112
self.pd_disaggregation_mode: str = fd_config.parallel_config.pd_disaggregation_mode
113113

fastdeploy/model_executor/layers/attention/mla_attention_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,11 +258,11 @@ def __init__(
258258
)
259259
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
260260
self.causal: bool = getattr(fd_config.model_config, "causal", True)
261-
self.speculative_method: str = fd_config.speculative_config.method
261+
self.speculative_method = fd_config.speculative_config.method
262262
self.use_speculate: bool = self.speculative_method is not None
263263
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
264264
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
265-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
265+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
266266

267267
self.num_heads: int = num_heads
268268
self.head_dim: int = fd_config.model_config.head_dim

fastdeploy/model_executor/layers/backends/intel_hpu/attention/hpu_attn_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def __init__(
221221
self.rope_theta = 10000.0 if llm_config.model_config.rope_theta is None else llm_config.model_config.rope_theta
222222
self.rope_3d = getattr(llm_config.model_config, "rope_3d", False)
223223
self.causal = getattr(llm_config.model_config, "causal", True)
224-
self.speculative_method: str = llm_config.speculative_config.method
224+
self.speculative_method = llm_config.speculative_config.method
225225
self.use_speculate: bool = self.speculative_method is not None
226226
self.speculate_max_draft_token_num: int = llm_config.speculative_config.num_speculative_tokens
227227
self.keep_pd_step_flag: bool = llm_config.speculative_config.model_type == "mtp"

fastdeploy/model_executor/layers/backends/metax/attention/flash_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,11 +103,11 @@ def __init__(
103103
)
104104
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
105105
self.causal: bool = getattr(fd_config.model_config, "causal", True)
106-
self.speculative_method: str = fd_config.speculative_config.method
106+
self.speculative_method = fd_config.speculative_config.method
107107
self.use_speculate: bool = self.speculative_method is not None
108108
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
109109
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
110-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
110+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
111111
self.encoder_block_shape_q: int = encoder_block_shape_q
112112
self.decoder_block_shape_q: int = decoder_block_shape_q
113113

fastdeploy/model_executor/layers/backends/metax/attention/mla_attn_metax_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,11 @@ def __init__(
107107
)
108108
self.rope_3d: bool = getattr(fd_config.model_config, "rope_3d", False)
109109
self.causal: bool = getattr(fd_config.model_config, "causal", True)
110-
self.speculative_method: str = fd_config.speculative_config.method
110+
self.speculative_method = fd_config.speculative_config.method
111111
self.use_speculate: bool = self.speculative_method is not None
112112
self.speculate_max_draft_token_num: int = fd_config.speculative_config.num_speculative_tokens
113113
self.keep_pd_step_flag: bool = fd_config.speculative_config.model_type == "mtp"
114-
self.num_layers_draft_model: int = int(fd_config.speculative_config.spec_method == SpecMethod.MTP)
114+
self.num_layers_draft_model: int = int(fd_config.speculative_config.method == SpecMethod.MTP)
115115

116116
self.kv_num_heads: int = kv_num_heads
117117
self.num_heads: int = num_heads

0 commit comments

Comments
 (0)