@@ -101,7 +101,7 @@ def clean(self):
101101
102102 self .prev_batch : Optional [ScheduledBatch ] = None
103103
104- self .pre_num_decode_token_pre_seq = 1
104+ self .pre_num_decode_token_per_seq = 1
105105 self .draft_token_ids : Optional [torch .Tensor ] = None
106106
107107 def prepare_sampled_ids (
@@ -152,7 +152,7 @@ def get_prev_alive_locations(self, batch: ScheduledBatch) -> tuple[list[int], in
152152 i for i , seq_id in enumerate (self .prev_batch .req_ids )
153153 if seq_id in batch .req_ids
154154 ]
155- num_deferred_tokens = len (alive_seq_indices ) * self .pre_num_decode_token_pre_seq
155+ num_deferred_tokens = len (alive_seq_indices ) * self .pre_num_decode_token_per_seq
156156 is_all_alive = len (alive_seq_indices ) == len (self .prev_batch .req_ids )
157157 return alive_seq_indices , num_deferred_tokens , is_all_alive
158158
@@ -212,7 +212,7 @@ def prepare_input_ids(
212212 self .input_ids .np [:num_norm_tokens ] = token_ids
213213 self .input_ids .copy_to_gpu (num_norm_tokens )
214214 # no new requests added and old requests finished
215- if self .draft_token_ids is not None and self .pre_num_decode_token_pre_seq > 1 :
215+ if self .draft_token_ids is not None and self .pre_num_decode_token_per_seq > 1 :
216216 alive_prev = self .prev_token_ids [alive_seq_indices ]
217217 alive_draft = self .draft_token_ids [alive_seq_indices ]
218218 combined = torch .cat ([
@@ -236,7 +236,7 @@ def prepare_input_ids(
236236 # self.input_ids_loc.gpu[:num_deferred_tokens],
237237 # out=self.input_ids.gpu[:num_deferred_tokens],
238238 # )
239- if self .draft_token_ids is not None and self .pre_num_decode_token_pre_seq > 1 :
239+ if self .draft_token_ids is not None and self .pre_num_decode_token_per_seq > 1 :
240240 alive_prev = self .prev_token_ids [alive_seq_indices ] # (num_alive_seqs,)
241241 alive_draft = self .draft_token_ids [alive_seq_indices ] # (num_alive_seqs, mtp_n_grams-1)
242242 combined = torch .cat ([
@@ -1084,7 +1084,7 @@ def propose_draft_token_ids(
10841084 self .forward_vars ["draft_tokens" ].gpu [:bs , :self .drafter .mtp_k ] = draft_token
10851085 self .forward_vars ["draft_tokens" ].copy_to_cpu ()
10861086 self .tokenID_processor .draft_token_ids = draft_token
1087- self .tokenID_processor .pre_num_decode_token_pre_seq = 2
1087+ self .tokenID_processor .pre_num_decode_token_per_seq = 2
10881088
10891089 return None
10901090
0 commit comments