|
22 | 22 | #include <cuda_fp16.h> |
23 | 23 | #include <cuda_runtime.h> |
24 | 24 |
|
| 25 | +#include "src/turbomind/core/context.h" |
25 | 26 | #include "src/turbomind/kernels/apply_token_bitmask_inplace_cuda.h" |
26 | 27 | // clang-format on |
27 | 28 |
|
@@ -140,27 +141,28 @@ void ApplyTokenBitmaskInplaceDispatchToBitsPerThread(T* __restrict__ logits, |
140 | 141 | const int32_t num_blocks_per_row = CeilDiv(2048 / THREADS_PER_THREAD_BLOCK * 128, num_rows); |
141 | 142 | const int32_t num_bits_per_thread = CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * num_blocks_per_row); |
142 | 143 |
|
143 | | - const dim3 block(THREADS_PER_THREAD_BLOCK); |
| 144 | + const dim3 block(THREADS_PER_THREAD_BLOCK); |
| 145 | + const auto& stream = turbomind::core::Context::stream(); |
144 | 146 |
|
145 | 147 | if (num_bits_per_thread <= 4 && kAlignment <= 4) { |
146 | 148 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 4), num_rows); |
147 | 149 | LogitsBitmaskKernel<T, PackedT, 4> |
148 | | - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 150 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
149 | 151 | } |
150 | 152 | else if (num_bits_per_thread <= 8 && kAlignment <= 8) { |
151 | 153 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 8), num_rows); |
152 | 154 | LogitsBitmaskKernel<T, PackedT, 8> |
153 | | - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 155 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
154 | 156 | } |
155 | 157 | else if (num_bits_per_thread <= 16 && kAlignment <= 16) { |
156 | 158 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 16), num_rows); |
157 | 159 | LogitsBitmaskKernel<T, PackedT, 16> |
158 | | - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 160 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
159 | 161 | } |
160 | 162 | else { |
161 | 163 | const dim3 grid(CeilDiv(vocab_size, THREADS_PER_THREAD_BLOCK * 32), num_rows); |
162 | 164 | LogitsBitmaskKernel<T, PackedT, 32> |
163 | | - <<<grid, block, 0>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
| 165 | + <<<grid, block, 0, stream.handle()>>>(logits, bitmask, indices, vocab_size, logits_stride, bitmask_stride); |
164 | 166 | } |
165 | 167 | } |
166 | 168 |
|
|
0 commit comments