Skip to content

Commit 456dfce

Browse files
fix(gpu): indexes in gemm ks
1 parent fbe2f86 commit 456dfce

File tree

9 files changed

+258
-98
lines changed

9 files changed

+258
-98
lines changed

backends/tfhe-cuda-backend/cuda/include/keyswitch/keyswitch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
1717
void const *lwe_output_indexes, void const *lwe_array_in,
1818
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
1919
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
20-
uint32_t num_samples, int8_t *ksk_tmp_buffer);
20+
uint32_t num_samples, int8_t *ksk_tmp_buffer, bool uses_trivial_indexes);
2121

2222
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(
2323
void *stream, uint32_t gpu_index, int8_t **fp_ks_buffer,

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cu

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,15 @@ void cuda_keyswitch_lwe_ciphertext_vector_32(
1010
void *lwe_output_indexes, void *lwe_array_in, void *lwe_input_indexes,
1111
void *ksk, uint32_t lwe_dimension_in, uint32_t lwe_dimension_out,
1212
uint32_t base_log, uint32_t level_count, uint32_t num_samples,
13-
void *ksk_tmp_buffer) {
13+
void *ksk_tmp_buffer, bool uses_trivial_indices) {
1414
host_gemm_keyswitch_lwe_ciphertext_vector<uint32_t>(
1515
static_cast<cudaStream_t>(stream), gpu_index,
1616
static_cast<uint32_t *>(lwe_array_out),
1717
static_cast<uint32_t *>(lwe_output_indexes),
1818
static_cast<uint32_t *>(lwe_array_in),
1919
static_cast<uint32_t *>(lwe_input_indexes), static_cast<uint32_t *>(ksk),
2020
lwe_dimension_in, lwe_dimension_out, base_log, level_count, num_samples,
21-
static_cast<uint32_t *>(ksk_tmp_buffer));
21+
static_cast<uint32_t *>(ksk_tmp_buffer), uses_trivial_indices);
2222
}
2323

2424
/* Perform keyswitch on a batch of 64 bits input LWE ciphertexts.
@@ -42,7 +42,7 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
4242
void const *lwe_output_indexes, void const *lwe_array_in,
4343
void const *lwe_input_indexes, void const *ksk, uint32_t lwe_dimension_in,
4444
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
45-
uint32_t num_samples, int8_t *ksk_tmp_buffer) {
45+
uint32_t num_samples, int8_t *ksk_tmp_buffer, bool uses_trivial_indices) {
4646
host_gemm_keyswitch_lwe_ciphertext_vector<uint64_t>(
4747
static_cast<cudaStream_t>(stream), gpu_index,
4848
static_cast<uint64_t *>(lwe_array_out),
@@ -51,7 +51,7 @@ void cuda_keyswitch_lwe_ciphertext_vector_64(
5151
static_cast<const uint64_t *>(lwe_input_indexes),
5252
static_cast<const uint64_t *>(ksk), lwe_dimension_in, lwe_dimension_out,
5353
base_log, level_count, num_samples,
54-
(uint64_t *)((ks_mem *)ksk_tmp_buffer)->buffer);
54+
(uint64_t *)((ks_mem *)ksk_tmp_buffer)->buffer, uses_trivial_indices);
5555
}
5656

5757
uint64_t scratch_packing_keyswitch_lwe_list_to_glwe_64(

backends/tfhe-cuda-backend/cuda/src/crypto/keyswitch.cuh

Lines changed: 68 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,26 @@ __global__ void keyswitch_gemm_copy_message(const Torus *lwe_in, Torus *lwe_out,
113113
-lwe_in[lwe_id * (lwe_dimension_in + 1) + lwe_dimension_in];
114114
}
115115

116+
template <typename Torus>
117+
__global__ void keyswitch_gemm_copy_message_with_indices(
118+
const Torus *__restrict__ lwe_in,
119+
const Torus *__restrict__ lwe_input_indices, Torus *__restrict__ lwe_out,
120+
const Torus *__restrict__ lwe_output_indices,
121+
122+
uint32_t lwe_dimension_in, uint32_t num_lwes, uint32_t lwe_dimension_out) {
123+
124+
uint32_t lwe_id = blockIdx.x * blockDim.x + threadIdx.x;
125+
126+
if (lwe_id >= num_lwes)
127+
return;
128+
129+
uint32_t lwe_in_idx = lwe_input_indices[lwe_id];
130+
uint32_t lwe_out_idx = lwe_output_indices[lwe_id];
131+
132+
lwe_out[lwe_in_idx * (lwe_dimension_out + 1) + lwe_dimension_out] =
133+
-lwe_in[lwe_out_idx * (lwe_dimension_in + 1) + lwe_dimension_in];
134+
}
135+
116136
// Continue decomposition of an array of Torus elements in place. Supposes
117137
// that the array contains already decomposed elements and
118138
// computes the new decomposed level in place.
@@ -256,10 +276,10 @@ __host__ void host_keyswitch_lwe_ciphertext_vector(
256276
template <typename Torus>
257277
__host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
258278
cudaStream_t stream, uint32_t gpu_index, Torus *lwe_array_out,
259-
Torus const *lwe_output_indexes, Torus const *lwe_array_in,
260-
Torus const *lwe_input_indexes, Torus const *ksk, uint32_t lwe_dimension_in,
279+
Torus const *lwe_output_indices, Torus const *lwe_array_in,
280+
Torus const *lwe_input_indices, Torus const *ksk, uint32_t lwe_dimension_in,
261281
uint32_t lwe_dimension_out, uint32_t base_log, uint32_t level_count,
262-
uint32_t num_samples, Torus *fp_tmp_buffer) {
282+
uint32_t num_samples, Torus *fp_tmp_buffer, bool uses_trivial_indices) {
263283
cuda_set_device(gpu_index);
264284
check_cuda_error(cudaGetLastError());
265285

@@ -280,10 +300,18 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
280300
// lwe_array_out is num_samples x (lwe_dimension_out + 1). copy the bodies
281301
// lwe_array_in[:,lwe_dimension_in] to lwe_array_out[:,lwe_dimension_out]
282302
// and negate
283-
keyswitch_gemm_copy_message<Torus><<<grid_copy, threads_copy, 0, stream>>>(
284-
lwe_array_in, lwe_array_out, lwe_dimension_in, num_samples,
285-
lwe_dimension_out);
286-
check_cuda_error(cudaGetLastError());
303+
if (uses_trivial_indices) {
304+
keyswitch_gemm_copy_message<Torus><<<grid_copy, threads_copy, 0, stream>>>(
305+
lwe_array_in, lwe_array_out, lwe_dimension_in, num_samples,
306+
lwe_dimension_out);
307+
check_cuda_error(cudaGetLastError());
308+
} else {
309+
keyswitch_gemm_copy_message_with_indices<Torus>
310+
<<<grid_copy, threads_copy, 0, stream>>>(
311+
lwe_array_in, lwe_input_indices, lwe_array_out, lwe_output_indices,
312+
lwe_dimension_in, num_samples, lwe_dimension_out);
313+
check_cuda_error(cudaGetLastError());
314+
}
287315

288316
// dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
289317
// "lwe_out_only_body", prefix, stream, gpu_index);
@@ -322,10 +350,19 @@ __host__ int host_gemm_keyswitch_lwe_ciphertext_vector(
322350
lwe_dimension_in, "state_init", prefix, stream,
323351
gpu_index);*/
324352

325-
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
326-
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk,
327-
stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
328-
check_cuda_error(cudaGetLastError());
353+
if (uses_trivial_indices) {
354+
tgemm<Torus><<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
355+
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0, ksk,
356+
stride_KSK_buffer, lwe_array_out, lwe_dimension_out + 1);
357+
check_cuda_error(cudaGetLastError());
358+
} else {
359+
tgemm_with_indices<Torus>
360+
<<<grid_gemm, threads_gemm, shared_mem_size, stream>>>(
361+
num_samples, (lwe_dimension_out + 1), lwe_dimension_in, d_mem_0,
362+
lwe_input_indices, ksk, stride_KSK_buffer, lwe_array_out,
363+
lwe_dimension_out + 1, lwe_output_indices);
364+
check_cuda_error(cudaGetLastError());
365+
}
329366

330367
/* dump_2d_gpu_to_file(lwe_array_out, num_samples, lwe_dimension_out + 1,
331368
"tgemm0", prefix, stream, gpu_index);*/
@@ -400,35 +437,35 @@ void execute_keyswitch_async(CudaStreams streams,
400437
Torus *current_lwe_input_indexes =
401438
get_variant_element(lwe_input_indexes, i);
402439

403-
if (uses_trivial_indices && num_samples >= 19) {
440+
if (num_samples >= 144) {
404441
// Compute Keyswitch
405-
/* Torus *dup_out = (Torus *)cuda_malloc_async(
406-
num_samples_on_gpu * (lwe_dimension_out + 1) * sizeof(Torus),
407-
streams.stream(i), streams.gpu_index(i));
408-
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
409-
lwe_dimension_in, lwe_dimension_out, num_samples_on_gpu);
410-
Torus *tmp_buf = (Torus *)cuda_malloc_async(
411-
buffer_size, streams.stream(i), streams.gpu_index(i));*/
442+
/* Torus *dup_out = (Torus *)cuda_malloc_async(
443+
num_samples_on_gpu * (lwe_dimension_out + 1) * sizeof(Torus),
444+
streams.stream(i), streams.gpu_index(i));
445+
uint64_t buffer_size = scratch_cuda_keyswitch_size<Torus>(
446+
lwe_dimension_in, lwe_dimension_out, num_samples_on_gpu);
447+
Torus *tmp_buf = (Torus *)cuda_malloc_async(
448+
buffer_size, streams.stream(i), streams.gpu_index(i));*/
412449

413450
host_gemm_keyswitch_lwe_ciphertext_vector<Torus>(
414451
streams.stream(i), streams.gpu_index(i), current_lwe_array_out,
415452
current_lwe_output_indexes, current_lwe_array_in,
416453
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
417454
lwe_dimension_out, base_log, level_count, num_samples_on_gpu,
418-
fp_tmp_buffer[i]);
455+
fp_tmp_buffer[i], uses_trivial_indices);
419456

420457
// Compute Keyswitch
421-
/* host_keyswitch_lwe_ciphertext_vector<Torus>(
422-
streams.stream(i), streams.gpu_index(i), dup_out,
423-
current_lwe_output_indexes, current_lwe_array_in,
424-
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
425-
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);*/
426-
427-
/* compare_2d_arrays(dup_out, current_lwe_array_out, num_samples_on_gpu,
428-
lwe_dimension_out + 1, streams.stream(i),
429-
streams.gpu_index(i));*/
430-
/* cuda_drop_async(dup_out, streams.stream(i), streams.gpu_index(i));
431-
cuda_drop_async(tmp_buf, streams.stream(i), streams.gpu_index(i));*/
458+
/* host_keyswitch_lwe_ciphertext_vector<Torus>(
459+
streams.stream(i), streams.gpu_index(i), dup_out,
460+
current_lwe_output_indexes, current_lwe_array_in,
461+
current_lwe_input_indexes, ksks[i], lwe_dimension_in,
462+
lwe_dimension_out, base_log, level_count, num_samples_on_gpu);*/
463+
464+
/* compare_2d_arrays(dup_out, current_lwe_array_out,
465+
num_samples_on_gpu, lwe_dimension_out + 1, streams.stream(i),
466+
streams.gpu_index(i));*/
467+
/* cuda_drop_async(dup_out, streams.stream(i), streams.gpu_index(i));
468+
cuda_drop_async(tmp_buf, streams.stream(i), streams.gpu_index(i));*/
432469
;
433470
} else {
434471
// Compute Keyswitch

backends/tfhe-cuda-backend/cuda/src/linearalgebra/multiplication.cuh

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,4 +192,106 @@ __global__ void tgemm(uint M, uint N, uint K, const Torus *A, const Torus *B,
192192
}
193193
}
194194

195+
// Multiply matrices A, B of size (M, K), (K, N) respectively
196+
// with K as the inner dimension.
197+
//
198+
// A block of threads processeds blocks of size (BLOCK_SIZE_GEMM,
199+
// BLOCK_SIZE_GEMM) splitting them in multiple tiles: (BLOCK_SIZE_GEMM,
200+
// THREADS_GEMM)-shaped tiles of values from A, and a (THREADS_GEMM,
201+
// BLOCK_SIZE_GEMM)-shaped tiles of values from B.
202+
//
203+
// This code is adapted by generalizing the 1d block-tiling
204+
// kernel from https://github.com/siboehm/SGEMM_CUDA
205+
// to any matrix dimension
206+
template <typename Torus>
207+
__global__ void tgemm_with_indices(uint M, uint N, uint K, const Torus *A,
208+
const Torus *__restrict__ A_indices,
209+
const Torus *B, uint stride_B, Torus *C,
210+
uint stride_C,
211+
const Torus *__restrict__ C_indices) {
212+
213+
const int BM = BLOCK_SIZE_GEMM;
214+
const int BN = BLOCK_SIZE_GEMM;
215+
const int BK = THREADS_GEMM;
216+
const int TM = THREADS_GEMM;
217+
218+
const uint cRow = blockIdx.y;
219+
const uint cCol = blockIdx.x;
220+
221+
const int threadCol = threadIdx.x % BN;
222+
const int threadRow = threadIdx.x / BN;
223+
224+
// Allocate space for the current block tile in shared memory
225+
__shared__ Torus As[BM * BK];
226+
__shared__ Torus Bs[BK * BN];
227+
228+
// Initialize the pointers to the input blocks from A, B
229+
// Tiles from these blocks are loaded to shared memory
230+
B += cCol * BN;
231+
232+
// Each thread will handle multiple sub-blocks
233+
const uint innerColA = threadIdx.x % BK;
234+
const uint innerRowA = threadIdx.x / BK;
235+
const uint innerColB = threadIdx.x % BN;
236+
const uint innerRowB = threadIdx.x / BN;
237+
238+
// allocate thread-local cache for results in registerfile
239+
Torus threadResults[TM] = {0};
240+
241+
auto row_A = cRow * BM + innerRowA;
242+
auto col_B = cCol * BN + innerColB;
243+
244+
// For each thread, loop over block tiles
245+
for (uint bkIdx = 0; bkIdx < K; bkIdx += BK) {
246+
auto col_A = bkIdx + innerColA;
247+
auto row_B = bkIdx + innerRowB;
248+
249+
if (row_A < M && col_A < K) {
250+
As[innerRowA * BK + innerColA] = A[A_indices[row_A] * K + innerColA];
251+
} else {
252+
As[innerRowA * BK + innerColA] = 0;
253+
}
254+
255+
if (col_B < N && row_B < K) {
256+
Bs[innerRowB * BN + innerColB] = B[innerRowB * stride_B + innerColB];
257+
} else {
258+
Bs[innerRowB * BN + innerColB] = 0;
259+
}
260+
__syncthreads();
261+
262+
// Advance blocktile for the next iteration of this loop
263+
B += BK * stride_B;
264+
265+
// calculate per-thread results
266+
for (uint dotIdx = 0; dotIdx < BK; ++dotIdx) {
267+
// we make the dotproduct loop the outside loop, which facilitates
268+
// reuse of the Bs entry, which we can cache in a tmp var.
269+
Torus tmp = Bs[dotIdx * BN + threadCol];
270+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
271+
threadResults[resIdx] +=
272+
As[(threadRow * TM + resIdx) * BK + dotIdx] * tmp;
273+
}
274+
}
275+
__syncthreads();
276+
}
277+
278+
// Initialize the pointer to the output block of size (BLOCK_SIZE_GEMM,
279+
// BLOCK_SIZE_GEMM)
280+
// C += cRow * BM * stride_C + cCol * BN;
281+
282+
// write out the results
283+
for (uint resIdx = 0; resIdx < TM; ++resIdx) {
284+
int outRow = cRow * BM + threadRow * TM + resIdx;
285+
int outCol = cCol * BN + threadCol;
286+
287+
if (outRow >= M)
288+
continue;
289+
if (outCol >= N)
290+
continue;
291+
292+
C[C_indices[outRow] * stride_C + cCol * BN + threadCol] +=
293+
threadResults[resIdx];
294+
}
295+
}
296+
195297
#endif // CUDA_MULT_H

backends/tfhe-cuda-backend/src/bindings.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1933,6 +1933,7 @@ unsafe extern "C" {
19331933
level_count: u32,
19341934
num_samples: u32,
19351935
ksk_tmp_buffer: *mut i8,
1936+
uses_trivial_indexes: bool,
19361937
);
19371938
}
19381939
unsafe extern "C" {

0 commit comments

Comments
 (0)