Skip to content

Commit 4091d01

Browse files
committed
fix comments
1 parent a509fad commit 4091d01

File tree

3 files changed

+15
-16
lines changed

3 files changed

+15
-16
lines changed

atom/model_engine/model_runner.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def clean(self):
101101

102102
self.prev_batch: Optional[ScheduledBatch] = None
103103

104-
self.pre_num_decode_token_pre_seq = 1
104+
self.pre_num_decode_token_per_seq = 1
105105
self.draft_token_ids: Optional[torch.Tensor] = None
106106

107107
def prepare_sampled_ids(
@@ -152,7 +152,7 @@ def get_prev_alive_locations(self, batch: ScheduledBatch) -> tuple[list[int], in
152152
i for i, seq_id in enumerate(self.prev_batch.req_ids)
153153
if seq_id in batch.req_ids
154154
]
155-
num_deferred_tokens = len(alive_seq_indices) * self.pre_num_decode_token_pre_seq
155+
num_deferred_tokens = len(alive_seq_indices) * self.pre_num_decode_token_per_seq
156156
is_all_alive = len(alive_seq_indices) == len(self.prev_batch.req_ids)
157157
return alive_seq_indices, num_deferred_tokens, is_all_alive
158158

@@ -212,7 +212,7 @@ def prepare_input_ids(
212212
self.input_ids.np[:num_norm_tokens] = token_ids
213213
self.input_ids.copy_to_gpu(num_norm_tokens)
214214
# no new requests added and old requests finished
215-
if self.draft_token_ids is not None and self.pre_num_decode_token_pre_seq > 1:
215+
if self.draft_token_ids is not None and self.pre_num_decode_token_per_seq > 1:
216216
alive_prev = self.prev_token_ids[alive_seq_indices]
217217
alive_draft = self.draft_token_ids[alive_seq_indices]
218218
combined = torch.cat([
@@ -236,7 +236,7 @@ def prepare_input_ids(
236236
# self.input_ids_loc.gpu[:num_deferred_tokens],
237237
# out=self.input_ids.gpu[:num_deferred_tokens],
238238
# )
239-
if self.draft_token_ids is not None and self.pre_num_decode_token_pre_seq > 1:
239+
if self.draft_token_ids is not None and self.pre_num_decode_token_per_seq > 1:
240240
alive_prev = self.prev_token_ids[alive_seq_indices] # (num_alive_seqs,)
241241
alive_draft = self.draft_token_ids[alive_seq_indices] # (num_alive_seqs, mtp_n_grams-1)
242242
combined = torch.cat([
@@ -1084,7 +1084,7 @@ def propose_draft_token_ids(
10841084
self.forward_vars["draft_tokens"].gpu[:bs, :self.drafter.mtp_k] = draft_token
10851085
self.forward_vars["draft_tokens"].copy_to_cpu()
10861086
self.tokenID_processor.draft_token_ids = draft_token
1087-
self.tokenID_processor.pre_num_decode_token_pre_seq = 2
1087+
self.tokenID_processor.pre_num_decode_token_per_seq = 2
10881088

10891089
return None
10901090

atom/model_engine/scheduler.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,19 +194,23 @@ def postprocess(
194194
# update token_ids with the actual sampled token ids
195195
finished_seqs = []
196196
stream_outputs = []
197+
num_placeholder = (
198+
2 * self.mtp_k if is_deferred_out and self.use_spec else
199+
1 if is_deferred_out else
200+
self.mtp_k if self.use_spec else
201+
0
202+
)
197203

198204
for seq in self.running:
199205
if seq.id not in prev_token_ids:
200206
continue
201207
token_ids = prev_token_ids[seq.id]
202208
new_tokens = []
203209
if is_deferred_out:
204-
idx = seq.token_ids.index(self.eos_token_id)
205-
seq.token_ids[idx:] = token_ids
210+
seq.token_ids[-num_placeholder:] = token_ids
206211

207212
if seq.output_tokens:
208-
idx = seq.output_tokens.index(self.eos_token_id)
209-
seq.output_tokens[idx:] = token_ids
213+
seq.output_tokens[-num_placeholder:] = token_ids
210214

211215
else:
212216
seq.output_tokens.extend(token_ids)
@@ -256,12 +260,7 @@ def postprocess(
256260
if stream_output_queue is not None and stream_outputs:
257261
stream_output_queue.put_nowait(stream_outputs)
258262

259-
num_placeholder = (
260-
2 * self.mtp_k if is_deferred_out and self.use_spec else
261-
1 if is_deferred_out else
262-
self.mtp_k if self.use_spec else
263-
0
264-
)
263+
265264
if num_placeholder > 0:
266265
# placeholder for the each decode step
267266
for seq in seqs:

atom/model_ops/attentions/aiter_mla.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def prepare_kv_indices():
263263
def build_for_cudagraph_capture(self, bs: int) -> AttentionMetaData:
264264
var = self.model_runner.forward_vars
265265
sparse_kv_indptr = var["sparse_kv_indptr"].gpu if self.is_sparse else None
266-
max_q_len= 1 if not hasattr(self, "drafter") else var["mtp_k"] + 1
266+
max_q_len = 1 if not hasattr(self, "drafter") else var["mtp_k"] + 1
267267
ctx_mla_ps = self.set_mla_persistent_worker_buffers(bs, max_q_len)
268268
attn_matadata = AttentionMetaData(
269269
slot_mapping=var["slot_mapping"].gpu[: bs * max_q_len],

0 commit comments

Comments
 (0)