@@ -52,27 +52,33 @@ def get_total_slots():
5252
5353 kv_start_indices , attention_mask = [], []
5454 block_num , block_size , _ , _ = step_context .kv_caches [0 ][1 ].shape
55- device = step_context .block_offsets .device
5655
5756 is_unpaged_prefill = False
5857 if not step_context .is_decoding :
5958 is_unpaged_prefill = \
6059 all ((step_context .q_seqlens ==
6160 step_context .kv_seqlens ).tolist ())
62- q_start_loc = torch .cat ((torch .tensor ([0 ], device = device ), step_context .q_seqlens .cumsum (0 ))).int ()
61+ q_start_loc = step_context .q_start_loc
62+ cu_seqlens = torch .cat ((q_start_loc , step_context .q_seqlens .sum ().unsqueeze (0 ))).int ()
63+
6364 q_seqlens = step_context .q_seqlens .int ()
6465 kv_seqlens = step_context .kv_seqlens .int ()
65- max_q_seq_len = torch .max (q_seqlens ).item ()
66- max_kv_seq_len = torch .max (kv_seqlens ).item ()
6766
6867 if step_context .is_decoding :
68+ # max_q_seq_len, max_kv_seq_len is not used in decoding stage
69+ max_q_seq_len = - 1
70+ max_kv_seq_len = - 1
71+
6972 # collect kv_start_indices without using a for-loop,
7073 # (fill kv-cache for just ONE token during the decoding phase)
7174 idx = (step_context .kv_seqlens - 1 ) % block_size
7275 b_num = (step_context .kv_seqlens - 1 ) // block_size
7376 last_block = step_context .block_offsets .gather (1 , b_num .view (- 1 , 1 )).view (- 1 )
7477 kv_start_indices = (last_block * block_size + idx ).reshape ((- 1 , 1 ))
7578 else :
79+ max_q_seq_len = torch .max (q_seqlens ).cpu ().item ()
80+ max_kv_seq_len = torch .max (kv_seqlens ).cpu ().item ()
81+
7682 for i in range (step_context .q_start_loc .size (0 )):
7783 q_seq_len = int (step_context .q_seqlens [i ])
7884 kv_seq_len = int (step_context .kv_seqlens [i ])
@@ -88,7 +94,7 @@ def get_total_slots():
8894 attn_metadata = attn_meta_cls (
8995 step_context .is_decoding ,
9096 step_context .block_offsets .int (),
91- q_start_loc = q_start_loc ,
97+ q_start_loc = cu_seqlens ,
9298 q_seqlens = q_seqlens ,
9399 kv_seqlens = kv_seqlens ,
94100 kv_start_indices = kv_start_indices ,
0 commit comments