-
Notifications
You must be signed in to change notification settings - Fork 1k
Expand file tree
/
Copy pathdecode_fixed.cu
More file actions
1376 lines (1207 loc) · 59.5 KB
/
decode_fixed.cu
File metadata and controls
1376 lines (1207 loc) · 59.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/*
* SPDX-FileCopyrightText: Copyright (c) 2024-2025, NVIDIA CORPORATION.
* SPDX-License-Identifier: Apache-2.0
*/
#include "page_data.cuh"
#include "page_decode.cuh"
#include "page_string_utils.cuh"
#include "parquet_gpu.hpp"
#include "rle_stream.cuh"
#include <cudf/detail/utilities/cuda.cuh>
#include <cooperative_groups.h>
#include <cuda/std/bit>
#include <cuda/std/iterator>
namespace cudf::io::parquet::detail {
namespace {
// Unlike cub's algorithm, this provides warp-wide and block-wide results simultaneously.
// Also, this provides the ability to compute warp_bits & lane_mask manually, which we need for
// lists.
struct block_scan_results {
uint32_t warp_bits;
int thread_count_within_warp;
int warp_count;
int thread_count_within_block;
int block_count;
};
template <int decode_block_size>
using block_scan_temp_storage = int[decode_block_size / cudf::detail::warp_size];
// Similar to CUB, must __syncthreads() after calling if reusing temp_storage
template <int decode_block_size>
__device__ inline static void scan_block_exclusive_sum(
int thread_bit,
block_scan_results& results,
block_scan_temp_storage<decode_block_size>& temp_storage)
{
int const t = threadIdx.x;
int const warp_index = t / cudf::detail::warp_size;
int const warp_lane = t % cudf::detail::warp_size;
uint32_t const lane_mask = (uint32_t(1) << warp_lane) - 1;
uint32_t warp_bits = ballot(thread_bit);
scan_block_exclusive_sum<decode_block_size>(
warp_bits, warp_lane, warp_index, lane_mask, results, temp_storage);
}
// Similar to CUB, must __syncthreads() after calling if reusing temp_storage
template <int decode_block_size>
__device__ static void scan_block_exclusive_sum(
uint32_t warp_bits,
int warp_lane,
int warp_index,
uint32_t lane_mask,
block_scan_results& results,
block_scan_temp_storage<decode_block_size>& temp_storage)
{
// Compute # warps
constexpr int num_warps = decode_block_size / cudf::detail::warp_size;
// Compute the warp-wide results
results.warp_bits = warp_bits;
results.warp_count = __popc(results.warp_bits);
results.thread_count_within_warp = __popc(results.warp_bits & lane_mask);
// Share the warp counts amongst the block threads
if (warp_lane == 0) { temp_storage[warp_index] = results.warp_count; }
__syncthreads(); // Sync to share counts between threads/warps
// Compute block-wide results
results.block_count = 0;
results.thread_count_within_block = results.thread_count_within_warp;
for (int warp_idx = 0; warp_idx < num_warps; ++warp_idx) {
results.block_count += temp_storage[warp_idx];
if (warp_idx < warp_index) { results.thread_count_within_block += temp_storage[warp_idx]; }
}
}
template <int block_size, bool has_lists_t, copy_mode copy_mode_t, typename state_buf>
__device__ void decode_fixed_width_values(
page_state_s* s, state_buf* const sb, int start, int end, int t)
{
constexpr int num_warps = block_size / cudf::detail::warp_size;
constexpr int max_batch_size = num_warps * cudf::detail::warp_size;
// nesting level that is storing actual leaf values
int const leaf_level_index = s->col.max_nesting_depth - 1;
auto const data_out = s->nesting_info[leaf_level_index].data_out;
Type const dtype = s->col.physical_type;
uint32_t const dtype_len = s->dtype_len;
int const skipped_leaf_values = s->page.skipped_leaf_values;
// decode values
int pos = start;
while (pos < end) {
int const batch_size = min(max_batch_size, end - pos);
int const target_pos = pos + batch_size;
int const thread_pos = pos + t;
// Index from value buffer (doesn't include nulls) to final array (has gaps for nulls)
int const dst_pos = [&]() {
if constexpr (copy_mode_t == copy_mode::DIRECT) {
return thread_pos - s->first_row;
} else {
int dst_pos = sb->nz_idx[rolling_index<state_buf::nz_buf_size>(thread_pos)];
if constexpr (!has_lists_t) { dst_pos -= s->first_row; }
return dst_pos;
}
}();
// target_pos will always be properly bounded by num_rows, but dst_pos may be negative (values
// before first_row) in the flat hierarchy case.
if (thread_pos < target_pos && dst_pos >= 0) {
// nesting level that is storing actual leaf values
// src_pos represents the logical row position we want to read from. But in the case of
// nested hierarchies (lists), there is no 1:1 mapping of rows to values. So src_pos
// has to take into account the # of values we have to skip in the page to get to the
// desired logical row. For flat hierarchies, skipped_leaf_values will always be 0.
int const src_pos = [&]() {
if constexpr (has_lists_t) { return thread_pos + skipped_leaf_values; }
return thread_pos;
}();
void* const dst = data_out + (static_cast<size_t>(dst_pos) * dtype_len);
if (s->col.logical_type.has_value() && s->col.logical_type->type == LogicalType::DECIMAL) {
switch (dtype) {
case Type::INT32:
read_fixed_width_value_fast(s, sb, src_pos, static_cast<uint32_t*>(dst));
break;
case Type::INT64:
read_fixed_width_value_fast(s, sb, src_pos, static_cast<uint2*>(dst));
break;
default:
if (s->dtype_len_in <= sizeof(int32_t)) {
read_fixed_width_byte_array_as_int(s, sb, src_pos, static_cast<int32_t*>(dst));
} else if (s->dtype_len_in <= sizeof(int64_t)) {
read_fixed_width_byte_array_as_int(s, sb, src_pos, static_cast<int64_t*>(dst));
} else {
read_fixed_width_byte_array_as_int(s, sb, src_pos, static_cast<__int128_t*>(dst));
}
break;
}
} else if (dtype == Type::BOOLEAN) {
read_boolean(sb, src_pos, static_cast<uint8_t*>(dst));
} else if (dtype == Type::INT96) {
read_int96_timestamp(s, sb, src_pos, static_cast<int64_t*>(dst));
} else if (dtype_len == 8) {
if (s->dtype_len_in == 4) {
// Reading INT32 TIME_MILLIS into 64-bit DURATION_MILLISECONDS
// TIME_MILLIS is the only duration type stored as int32:
// https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#deprecated-time-convertedtype
auto const dst_ptr = static_cast<uint32_t*>(dst);
read_fixed_width_value_fast(s, sb, src_pos, dst_ptr);
// zero out most significant bytes
cuda::std::memset(dst_ptr + 1, 0, sizeof(int32_t));
} else if (s->ts_scale) {
read_int64_timestamp(s, sb, src_pos, static_cast<int64_t*>(dst));
} else {
read_fixed_width_value_fast(s, sb, src_pos, static_cast<uint2*>(dst));
}
} else if (dtype_len == 4) {
read_fixed_width_value_fast(s, sb, src_pos, static_cast<uint32_t*>(dst));
} else {
read_nbyte_fixed_width_value(s, sb, src_pos, static_cast<uint8_t*>(dst), dtype_len);
}
}
pos += batch_size;
}
}
template <int block_size, bool has_lists_t, copy_mode copy_mode_t, typename state_buf>
__device__ inline void decode_fixed_width_split_values(
page_state_s* s, state_buf* const sb, int start, int end, int t)
{
using cudf::detail::warp_size;
constexpr int num_warps = block_size / warp_size;
constexpr int max_batch_size = num_warps * warp_size;
// nesting level that is storing actual leaf values
int const leaf_level_index = s->col.max_nesting_depth - 1;
auto const data_out = s->nesting_info[leaf_level_index].data_out;
Type const dtype = s->col.physical_type;
auto const data_len = cuda::std::distance(s->data_start, s->data_end);
auto const num_values = data_len / s->dtype_len_in;
int const skipped_leaf_values = s->page.skipped_leaf_values;
// decode values
int pos = start;
while (pos < end) {
int const batch_size = min(max_batch_size, end - pos);
int const target_pos = pos + batch_size;
int const thread_pos = pos + t;
// Index from value buffer (doesn't include nulls) to final array (has gaps for nulls)
int const dst_pos = [&]() {
if constexpr (copy_mode_t == copy_mode::DIRECT) {
return thread_pos - s->first_row;
} else {
int dst_pos = sb->nz_idx[rolling_index<state_buf::nz_buf_size>(thread_pos)];
if constexpr (!has_lists_t) { dst_pos -= s->first_row; }
return dst_pos;
}
}();
// target_pos will always be properly bounded by num_rows, but dst_pos may be negative (values
// before first_row) in the flat hierarchy case.
if (thread_pos < target_pos && dst_pos >= 0) {
// src_pos represents the logical row position we want to read from. But in the case of
// nested hierarchies (lists), there is no 1:1 mapping of rows to values. So src_pos
// has to take into account the # of values we have to skip in the page to get to the
// desired logical row. For flat hierarchies, skipped_leaf_values will always be 0.
int const src_pos = [&]() {
if constexpr (has_lists_t) {
return thread_pos + skipped_leaf_values;
} else {
return thread_pos;
}
}();
uint32_t const dtype_len = s->dtype_len;
uint8_t const* const src = s->data_start + src_pos;
uint8_t* const dst = data_out + static_cast<size_t>(dst_pos) * dtype_len;
auto const is_decimal =
s->col.logical_type.has_value() and s->col.logical_type->type == LogicalType::DECIMAL;
// Note: non-decimal FIXED_LEN_BYTE_ARRAY will be handled in the string reader
if (is_decimal) {
switch (dtype) {
case Type::INT32: gpuOutputByteStreamSplit<int32_t>(dst, src, num_values); break;
case Type::INT64: gpuOutputByteStreamSplit<int64_t>(dst, src, num_values); break;
case Type::FIXED_LEN_BYTE_ARRAY:
if (s->dtype_len_in <= sizeof(int32_t)) {
gpuOutputSplitFixedLenByteArrayAsInt(
reinterpret_cast<int32_t*>(dst), src, num_values, s->dtype_len_in);
break;
} else if (s->dtype_len_in <= sizeof(int64_t)) {
gpuOutputSplitFixedLenByteArrayAsInt(
reinterpret_cast<int64_t*>(dst), src, num_values, s->dtype_len_in);
break;
} else if (s->dtype_len_in <= sizeof(__int128_t)) {
gpuOutputSplitFixedLenByteArrayAsInt(
reinterpret_cast<__int128_t*>(dst), src, num_values, s->dtype_len_in);
break;
}
// unsupported decimal precision
[[fallthrough]];
default: s->set_error_code(decode_error::UNSUPPORTED_ENCODING);
}
} else if (dtype_len == 8) {
if (s->dtype_len_in == 4) {
// Reading INT32 TIME_MILLIS into 64-bit DURATION_MILLISECONDS
// TIME_MILLIS is the only duration type stored as int32:
// https://github.com/apache/parquet-format/blob/master/LogicalTypes.md#deprecated-time-convertedtype
gpuOutputByteStreamSplit<int32_t>(dst, src, num_values);
// zero out most significant bytes
cuda::std::memset(dst + sizeof(int32_t), 0, sizeof(int32_t));
} else if (s->ts_scale) {
gpuOutputSplitInt64Timestamp(
reinterpret_cast<int64_t*>(dst), src, num_values, s->ts_scale);
} else {
gpuOutputByteStreamSplit<int64_t>(dst, src, num_values);
}
} else if (dtype_len == 4) {
gpuOutputByteStreamSplit<int32_t>(dst, src, num_values);
} else {
s->set_error_code(decode_error::UNSUPPORTED_ENCODING);
}
}
pos += batch_size;
}
}
/**
* @brief Update validity and row indices for nested types
*
* @tparam decode_block_size Size of the thread block
* @tparam level_t Definition level type
* @tparam state_buf State buffer type
*
* @param target_value_count The target value count to process
* @param s Pointer to page state
* @param sb Pointer to state buffer
* @param def Pointer to the definition levels
* @param t Thread index
*
* @return Maximum depth valid count after processing
*/
template <int decode_block_size, typename level_t, typename state_buf>
__device__ int update_validity_and_row_indices_nested(
int32_t target_value_count, page_state_s* s, state_buf* sb, level_t const* const def, int t)
{
constexpr int num_warps = decode_block_size / cudf::detail::warp_size;
constexpr int max_batch_size = num_warps * cudf::detail::warp_size;
// how many (input) values we've processed in the page so far
int value_count = s->input_value_count;
// cap by last row so that we don't process any rows past what we want to output.
int const first_row = s->first_row;
int const last_row = first_row + s->num_rows;
int const capped_target_value_count = min(target_value_count, last_row);
int const max_depth = s->col.max_nesting_depth - 1;
auto& max_depth_ni = s->nesting_info[max_depth];
int max_depth_valid_count = max_depth_ni.valid_count;
__syncthreads();
while (value_count < capped_target_value_count) {
int const batch_size = min(max_batch_size, capped_target_value_count - value_count);
// definition level
int const d = [&]() {
if (t >= batch_size) {
return -1;
} else if (def) {
return static_cast<int>(def[rolling_index<state_buf::nz_buf_size>(value_count + t)]);
}
return 1;
}();
int const thread_value_count = t;
int const block_value_count = batch_size;
// compute our row index, whether we're in row bounds, and validity
// This ASSUMES that s->row_index_lower_bound is always -1!
// Its purpose is to handle rows than span page boundaries, which only happen for lists.
int const row_index = thread_value_count + value_count;
int const in_row_bounds = (row_index < last_row);
bool const in_write_row_bounds = in_row_bounds && (row_index >= first_row);
uint32_t const in_write_row_bounds_mask = ballot(in_write_row_bounds);
// NOTE: The below CANNOT be std::countr_zero(), because for zero start must be 0 not 32
int const write_start = __ffs(in_write_row_bounds_mask) - 1; // first bit in the warp to store
// iterate by depth
for (int d_idx = 0; d_idx <= max_depth; d_idx++) {
auto& ni = s->nesting_info[d_idx];
int const is_valid = ((d >= ni.max_def_level) && in_row_bounds) ? 1 : 0;
// thread and block validity count
using block_scan = cub::BlockScan<int, decode_block_size>;
__shared__ typename block_scan::TempStorage scan_storage;
int thread_valid_count, block_valid_count;
block_scan(scan_storage).ExclusiveSum(is_valid, thread_valid_count, block_valid_count);
// validity is processed per-warp
//
// nested schemas always read and write to the same bounds (that is, read and write
// positions are already pre-bounded by first_row/num_rows). flat schemas will start reading
// at the first value, even if that is before first_row, because we cannot trivially jump to
// the correct position to start reading. since we are about to write the validity vector
// here we need to adjust our computed mask to take into account the write row bounds.
int warp_null_count = 0;
if (ni.valid_map != nullptr) {
uint32_t const warp_validity_mask = ballot(is_valid);
// lane 0 from each warp writes out validity
if ((write_start >= 0) && ((t % cudf::detail::warp_size) == 0)) {
int const valid_map_offset = ni.valid_map_offset;
int const vindex = value_count + thread_value_count; // absolute input value index
int const bit_offset = (valid_map_offset + vindex + write_start) -
first_row; // absolute bit offset into the output validity map
int const write_end = cudf::detail::warp_size -
__clz(in_write_row_bounds_mask); // last bit in the warp to store
int const bit_count = write_end - write_start;
warp_null_count = bit_count - __popc(warp_validity_mask >> write_start);
store_validity(bit_offset, ni.valid_map, warp_validity_mask >> write_start, bit_count);
}
}
// sum null counts. we have to do it this way instead of just incrementing by (value_count -
// valid_count) because valid_count also includes rows that potentially start before our row
// bounds. if we could come up with a way to clean that up, we could remove this and just
// compute it directly at the end of the kernel.
size_type const block_null_count =
cudf::detail::single_lane_block_sum_reduce<decode_block_size, 0>(warp_null_count);
if (t == 0) { ni.null_count += block_null_count; }
// if this is valid and we're at the leaf, output dst_pos
if (d_idx == max_depth) {
if (is_valid) {
int const dst_pos = value_count + thread_value_count;
int const src_pos = max_depth_valid_count + thread_valid_count;
sb->nz_idx[rolling_index<state_buf::nz_buf_size>(src_pos)] = dst_pos;
}
// update stuff
max_depth_valid_count += block_valid_count;
}
} // end depth loop
value_count += block_value_count;
} // end loop
if (t == 0) {
// update valid value count for decoding and total # of values we've processed
max_depth_ni.valid_count = max_depth_valid_count;
max_depth_ni.value_count = value_count; // Needed AT LEAST for strings!
s->nz_count = max_depth_valid_count;
s->input_value_count = value_count;
s->input_row_count = value_count;
}
return max_depth_valid_count;
}
/**
* @brief Update validity and row indices for flat types
*
* @tparam decode_block_size Size of the thread block
* @tparam level_t Definition level type
* @tparam state_buf State buffer type
*
* @param target_value_count The target value count to process
* @param s Pointer to page state
* @param sb Pointer to state buffer
* @param def Pointer to the definition levels
* @param t Thread index
*
* @return Maximum depth valid count after processing
*/
template <int decode_block_size, typename level_t, typename state_buf>
__device__ int update_validity_and_row_indices_flat(
int32_t target_value_count, page_state_s* s, state_buf* sb, level_t const* const def, int t)
{
constexpr int num_warps = decode_block_size / cudf::detail::warp_size;
constexpr int max_batch_size = num_warps * cudf::detail::warp_size;
auto& ni = s->nesting_info[0];
// how many (input) values we've processed in the page so far
int value_count = s->input_value_count;
int valid_count = ni.valid_count;
// cap by last row so that we don't process any rows past what we want to output.
int const first_row = s->first_row;
int const last_row = first_row + s->num_rows;
int const capped_target_value_count = min(target_value_count, last_row);
int const valid_map_offset = ni.valid_map_offset;
__syncthreads();
while (value_count < capped_target_value_count) {
int const batch_size = min(max_batch_size, capped_target_value_count - value_count);
int const thread_value_count = t;
int const block_value_count = batch_size;
// compute our row index, whether we're in row bounds, and validity
// This ASSUMES that s->row_index_lower_bound is always -1!
// Its purpose is to handle rows than span page boundaries, which only happen for lists.
int const row_index = thread_value_count + value_count;
int const in_row_bounds = (row_index < last_row);
// use definition level & row bounds to determine if is valid
int const is_valid = [&]() {
if (t >= batch_size) {
return 0;
} else if (def) {
int const def_level =
static_cast<int>(def[rolling_index<state_buf::nz_buf_size>(value_count + t)]);
return ((def_level > 0) && in_row_bounds) ? 1 : 0;
}
return in_row_bounds;
}();
// thread and block validity count
using block_scan = cub::BlockScan<int, decode_block_size>;
__shared__ typename block_scan::TempStorage scan_storage;
int thread_valid_count, block_valid_count;
block_scan(scan_storage).ExclusiveSum(is_valid, thread_valid_count, block_valid_count);
uint32_t const warp_validity_mask = ballot(is_valid);
// validity is processed per-warp
//
// nested schemas always read and write to the same bounds (that is, read and write
// positions are already pre-bounded by first_row/num_rows). flat schemas will start reading
// at the first value, even if that is before first_row, because we cannot trivially jump to
// the correct position to start reading. since we are about to write the validity vector
// here we need to adjust our computed mask to take into account the write row bounds.
bool const in_write_row_bounds = in_row_bounds && (row_index >= first_row);
int const in_write_row_bounds_mask = ballot(in_write_row_bounds);
// NOTE: The below CANNOT be std::countr_zero(), because for zero start must be 0 not 32
int const write_start = __ffs(in_write_row_bounds_mask) - 1; // first bit in the warp to store
int warp_null_count = 0;
// lane 0 from each warp writes out validity
if ((write_start >= 0) && ((t % cudf::detail::warp_size) == 0)) {
int const vindex = value_count + thread_value_count; // absolute input value index
int const bit_offset = (valid_map_offset + vindex + write_start) -
first_row; // absolute bit offset into the output validity map
int const write_end =
cudf::detail::warp_size - __clz(in_write_row_bounds_mask); // last bit in the warp to store
int const bit_count = write_end - write_start;
warp_null_count = bit_count - __popc(warp_validity_mask >> write_start);
store_validity(bit_offset, ni.valid_map, warp_validity_mask >> write_start, bit_count);
}
// sum null counts. we have to do it this way instead of just incrementing by (value_count -
// valid_count) because valid_count also includes rows that potentially start before our row
// bounds. if we could come up with a way to clean that up, we could remove this and just
// compute it directly at the end of the kernel.
size_type const block_null_count =
cudf::detail::single_lane_block_sum_reduce<decode_block_size, 0>(warp_null_count);
if (t == 0) { ni.null_count += block_null_count; }
// output offset
if (is_valid) {
int const dst_pos = value_count + thread_value_count;
int const src_pos = valid_count + thread_valid_count;
sb->nz_idx[rolling_index<state_buf::nz_buf_size>(src_pos)] = dst_pos;
}
// update stuff
value_count += block_value_count;
valid_count += block_valid_count;
}
if (t == 0) {
// update valid value count for decoding and total # of values we've processed
ni.valid_count = valid_count;
ni.value_count = value_count;
s->nz_count = valid_count;
s->input_value_count = value_count;
s->input_row_count = value_count;
}
return valid_count;
}
/**
* @brief Update validity and row indices for list types
*
* @tparam decode_block_size Size of the thread block
* @tparam level_t Definition level type
* @tparam state_buf State buffer type
*
* @param target_value_count The target value count to process
* @param s Pointer to page state
* @param sb Pointer to state buffer
* @param def Pointer to the definition levels
* @param t Thread index
*
* @return Maximum depth valid count after processing
*/
template <int decode_block_size, bool nullable, typename level_t, typename state_buf>
__device__ int update_validity_and_row_indices_lists(int32_t target_value_count,
page_state_s* s,
state_buf* sb,
level_t const* const def,
level_t const* const rep,
int t)
{
constexpr int num_warps = decode_block_size / cudf::detail::warp_size;
constexpr int max_batch_size = num_warps * cudf::detail::warp_size;
// how many (input) values we've processed in the page so far, prior to this loop iteration
int value_count = s->input_value_count;
// how many rows we've processed in the page so far
int input_row_count = s->input_row_count;
// cap by last row so that we don't process any rows past what we want to output.
int const first_row = s->first_row;
int const last_row = first_row + s->num_rows;
int const row_index_lower_bound = s->row_index_lower_bound;
int const max_depth = s->col.max_nesting_depth - 1;
int max_depth_valid_count = s->nesting_info[max_depth].valid_count;
int const warp_index = t / cudf::detail::warp_size;
int const warp_lane = t % cudf::detail::warp_size;
bool const is_first_lane = (warp_lane == 0);
__syncthreads();
__shared__ block_scan_temp_storage<decode_block_size> temp_storage;
while (value_count < target_value_count) {
bool const within_batch = value_count + t < target_value_count;
// get definition level, use repetition level to get start/end depth
// different for each thread, as each thread has a different r/d
auto const [def_level, start_depth, end_depth] = [&]() {
if (!within_batch) { return cuda::std::make_tuple(-1, -1, -1); }
int const level_index = rolling_index<state_buf::nz_buf_size>(value_count + t);
int const rep_level = static_cast<int>(rep[level_index]);
int const start_depth = s->nesting_info[rep_level].start_depth;
if constexpr (!nullable) {
return cuda::std::make_tuple(-1, start_depth, max_depth);
} else {
if (def != nullptr) {
int const def_level = static_cast<int>(def[level_index]);
return cuda::std::make_tuple(
def_level, start_depth, s->nesting_info[def_level].end_depth);
} else {
return cuda::std::make_tuple(1, start_depth, max_depth);
}
}
}();
// Determine value count & row index
// track (page-relative) row index for the thread so we can compare against input bounds
// keep track of overall # of rows we've read.
int const is_new_row = start_depth == 0 ? 1 : 0;
int num_prior_new_rows, total_num_new_rows;
{
block_scan_results new_row_scan_results;
scan_block_exclusive_sum<decode_block_size>(is_new_row, new_row_scan_results, temp_storage);
__syncthreads();
num_prior_new_rows = new_row_scan_results.thread_count_within_block;
total_num_new_rows = new_row_scan_results.block_count;
}
int const row_index = input_row_count + ((num_prior_new_rows + is_new_row) - 1);
input_row_count += total_num_new_rows;
int const in_row_bounds = (row_index >= row_index_lower_bound) && (row_index < last_row);
// VALUE COUNT:
// in_nesting_bounds: if at a nesting level where we need to add value indices
// the bounds: from current rep to the rep AT the def depth
int in_nesting_bounds = ((0 >= start_depth && 0 <= end_depth) && in_row_bounds) ? 1 : 0;
int thread_value_count_within_warp, warp_value_count, thread_value_count, block_value_count;
{
block_scan_results value_count_scan_results;
scan_block_exclusive_sum<decode_block_size>(
in_nesting_bounds, value_count_scan_results, temp_storage);
__syncthreads();
thread_value_count_within_warp = value_count_scan_results.thread_count_within_warp;
warp_value_count = value_count_scan_results.warp_count;
thread_value_count = value_count_scan_results.thread_count_within_block;
block_value_count = value_count_scan_results.block_count;
}
// iterate by depth
for (int d_idx = 0; d_idx <= max_depth; d_idx++) {
auto& ni = s->nesting_info[d_idx];
// everything up to the max_def_level is a non-null value
int const is_valid = [&](int input_def_level) {
if constexpr (nullable) {
return ((input_def_level >= ni.max_def_level) && in_nesting_bounds) ? 1 : 0;
} else {
return in_nesting_bounds;
}
}(def_level);
// VALID COUNT:
// Not all values visited by this block will represent a value at this nesting level.
// the validity bit for thread t might actually represent output value t-6.
// the correct position for thread t's bit is thread_value_count.
uint32_t const warp_valid_mask = warp_reduce_or<cudf::detail::warp_size>(
static_cast<uint32_t>(is_valid) << thread_value_count_within_warp);
int thread_valid_count, block_valid_count;
{
auto thread_mask = (uint32_t(1) << thread_value_count_within_warp) - 1;
block_scan_results valid_count_scan_results;
scan_block_exclusive_sum<decode_block_size>(warp_valid_mask,
warp_lane,
warp_index,
thread_mask,
valid_count_scan_results,
temp_storage);
__syncthreads();
thread_valid_count = valid_count_scan_results.thread_count_within_block;
block_valid_count = valid_count_scan_results.block_count;
}
// compute warp and thread value counts for the -next- nesting level. we need to
// do this for lists so that we can emit an offset for the -current- nesting level.
// the offset for the current nesting level == current length of the next nesting level
int next_thread_value_count_within_warp = 0, next_warp_value_count = 0;
int next_thread_value_count = 0, next_block_value_count = 0;
int next_in_nesting_bounds = 0;
if (d_idx < max_depth) {
// NEXT DEPTH VALUE COUNT:
next_in_nesting_bounds =
((d_idx + 1 >= start_depth) && (d_idx + 1 <= end_depth) && in_row_bounds) ? 1 : 0;
{
block_scan_results next_value_count_scan_results;
scan_block_exclusive_sum<decode_block_size>(
next_in_nesting_bounds, next_value_count_scan_results, temp_storage);
__syncthreads();
next_thread_value_count_within_warp =
next_value_count_scan_results.thread_count_within_warp;
next_warp_value_count = next_value_count_scan_results.warp_count;
next_thread_value_count = next_value_count_scan_results.thread_count_within_block;
next_block_value_count = next_value_count_scan_results.block_count;
}
// STORE OFFSET TO THE LIST LOCATION
// if we're -not- at a leaf column and we're within nesting/row bounds
// and we have a valid data_out pointer, it implies this is a list column, so
// emit an offset.
if (in_nesting_bounds && ni.data_out != nullptr) {
const auto& next_ni = s->nesting_info[d_idx + 1];
int const idx = ni.value_count + thread_value_count;
cudf::size_type const ofs =
next_ni.value_count + next_thread_value_count + next_ni.page_start_value;
(reinterpret_cast<cudf::size_type*>(ni.data_out))[idx] = ofs;
}
}
// validity is processed per-warp (on lane 0's)
// thi is because when atomic writes are needed, they are 32-bit operations
//
// lists always read and write to the same bounds
// (that is, read and write positions are already pre-bounded by first_row/num_rows).
// since we are about to write the validity vector
// here we need to adjust our computed mask to take into account the write row bounds.
if constexpr (nullable) {
if (is_first_lane && (ni.valid_map != nullptr) && (warp_value_count > 0)) {
// absolute bit offset into the output validity map
// is cumulative sum of warp_value_count at the given nesting depth
// DON'T subtract by first_row: since it's lists it's not 1-row-per-value
int const bit_offset = ni.valid_map_offset + thread_value_count;
store_validity(bit_offset, ni.valid_map, warp_valid_mask, warp_value_count);
}
if (t == 0) { ni.null_count += block_value_count - block_valid_count; }
}
// if this is valid and we're at the leaf, output dst_pos
// Read value_count before the sync, so that when thread 0 modifies it we've already read its
// value
int const current_value_count = ni.value_count;
__syncthreads(); // guard against modification of ni.value_count below
if (d_idx == max_depth) {
if (is_valid) {
int const dst_pos = current_value_count + thread_value_count;
int const src_pos = max_depth_valid_count + thread_valid_count;
int const output_index = rolling_index<state_buf::nz_buf_size>(src_pos);
// Index from rolling buffer of values (which doesn't include nulls) to final array (which
// includes gaps for nulls)
sb->nz_idx[output_index] = dst_pos;
}
max_depth_valid_count += block_valid_count;
}
// update stuff
if (t == 0) {
ni.value_count += block_value_count;
ni.valid_map_offset += block_value_count;
}
__syncthreads(); // sync modification of ni.value_count
// propagate value counts for the next depth level
block_value_count = next_block_value_count;
thread_value_count = next_thread_value_count;
in_nesting_bounds = next_in_nesting_bounds;
warp_value_count = next_warp_value_count;
thread_value_count_within_warp = next_thread_value_count_within_warp;
} // END OF DEPTH LOOP
int const batch_size = min(max_batch_size, target_value_count - value_count);
value_count += batch_size;
}
if (t == 0) {
// update valid value count for decoding and total # of values we've processed
s->nesting_info[max_depth].valid_count = max_depth_valid_count;
s->nz_count = max_depth_valid_count;
s->input_value_count = value_count;
// If we have lists # rows != # values
s->input_row_count = input_row_count;
}
return max_depth_valid_count;
}
// is the page marked nullable or not
__device__ inline bool is_nullable(page_state_s* s)
{
auto const lvl = level_type::DEFINITION;
auto const max_def_level = s->col.max_level[lvl];
return max_def_level > 0;
}
// for a nullable page, check to see if it could have nulls
__device__ inline bool maybe_has_nulls(page_state_s* s)
{
auto const lvl = level_type::DEFINITION;
auto const init_run = s->initial_rle_run[lvl];
// literal runs, lets assume they could hold nulls
if (is_literal_run(init_run)) { return true; }
// repeated run with number of items in the run not equal
// to the rows in the page, assume that means we could have nulls
if (s->page.num_input_values != (init_run >> 1)) { return true; }
auto const lvl_bits = s->col.level_bits[lvl];
auto const run_val = lvl_bits == 0 ? 0 : s->initial_rle_value[lvl];
// the encoded repeated value isn't valid, we have (all) nulls
return run_val != s->col.max_level[lvl];
}
template <typename state_buf, typename thread_group>
inline __device__ void bool_plain_decode(page_state_s* s,
state_buf* sb,
int target_pos,
thread_group const& group)
{
int const pos = s->dict_pos;
int const t = group.thread_rank();
// Ensure all threads have the dict_pos
group.sync();
for (auto bit_pos = pos + t; bit_pos < target_pos; bit_pos += group.size()) {
int const byte_offset = bit_pos >> 3;
int const bit_in_byte_index = bit_pos & 7;
uint8_t const* const read_from = s->data_start + byte_offset;
bool const read_bit = (*read_from) & (1 << bit_in_byte_index);
int const write_to_index = rolling_index<state_buf::dict_buf_size>(bit_pos);
sb->dict_idx[write_to_index] = read_bit;
}
}
template <int rolling_buf_size, typename stream_type>
__device__ int skip_decode(stream_type& parquet_stream, int num_to_skip, int t)
{
// it could be that (e.g.) we skip 5000 but starting at row 4000 we have a run of length 2000:
// in that case skip_decode() only skips 4000, and we have to process the remaining 1000 up front
// modulo 2 * block_size of course, since that's as many as we process at once
int num_skipped = parquet_stream.skip_decode(t, num_to_skip);
while (num_skipped < num_to_skip) {
// TODO: Instead of decoding, skip within the run to the appropriate location
auto const to_decode = min(rolling_buf_size, num_to_skip - num_skipped);
num_skipped += parquet_stream.decode_next(t, to_decode);
__syncthreads();
}
return num_skipped;
}
template <decode_kernel_mask kernel_mask_t>
constexpr bool has_dict()
{
return (kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_DICT) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_DICT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_DICT_LIST) ||
(kernel_mask_t == decode_kernel_mask::STRING_DICT) ||
(kernel_mask_t == decode_kernel_mask::STRING_DICT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::STRING_DICT_LIST);
}
template <decode_kernel_mask kernel_mask_t>
constexpr bool has_bools()
{
return (kernel_mask_t == decode_kernel_mask::BOOLEAN) ||
(kernel_mask_t == decode_kernel_mask::BOOLEAN_NESTED) ||
(kernel_mask_t == decode_kernel_mask::BOOLEAN_LIST);
}
template <decode_kernel_mask kernel_mask_t>
constexpr bool has_nesting()
{
return (kernel_mask_t == decode_kernel_mask::BOOLEAN_NESTED) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_DICT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_NO_DICT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::BYTE_STREAM_SPLIT_FIXED_WIDTH_NESTED) ||
(kernel_mask_t == decode_kernel_mask::STRING_NESTED) ||
(kernel_mask_t == decode_kernel_mask::STRING_DICT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::STRING_STREAM_SPLIT_NESTED);
}
template <decode_kernel_mask kernel_mask_t>
constexpr bool has_lists()
{
return (kernel_mask_t == decode_kernel_mask::BOOLEAN_LIST) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_DICT_LIST) ||
(kernel_mask_t == decode_kernel_mask::FIXED_WIDTH_NO_DICT_LIST) ||
(kernel_mask_t == decode_kernel_mask::BYTE_STREAM_SPLIT_FIXED_WIDTH_LIST) ||
(kernel_mask_t == decode_kernel_mask::STRING_LIST) ||
(kernel_mask_t == decode_kernel_mask::STRING_DICT_LIST) ||
(kernel_mask_t == decode_kernel_mask::STRING_STREAM_SPLIT_LIST);
}
template <decode_kernel_mask kernel_mask_t>
constexpr bool is_split_decode()
{
return (kernel_mask_t == decode_kernel_mask::BYTE_STREAM_SPLIT_FIXED_WIDTH_FLAT) ||
(kernel_mask_t == decode_kernel_mask::BYTE_STREAM_SPLIT_FIXED_WIDTH_NESTED) ||
(kernel_mask_t == decode_kernel_mask::BYTE_STREAM_SPLIT_FIXED_WIDTH_LIST) ||
(kernel_mask_t == decode_kernel_mask::STRING_STREAM_SPLIT) ||
(kernel_mask_t == decode_kernel_mask::STRING_STREAM_SPLIT_NESTED) ||
(kernel_mask_t == decode_kernel_mask::STRING_STREAM_SPLIT_LIST);
}
/**
* @brief Kernel for computing fixed width non dictionary column data stored in the pages
*
* This function will write the page data and the page data's validity to the
* output specified in the page's column chunk. If necessary, additional
* conversion will be performed to translate from the Parquet datatype to
* desired output datatype.
*
* @param pages List of pages
* @param chunks List of column chunks
* @param min_row Row index to start reading at
* @param num_rows Maximum number of rows to read
* @param page_mask Boolean vector indicating which pages need to be decoded
* @param initial_str_offsets Vector to store the initial offsets for large nested string cols
* @param page_string_offset_indices Device span of offsets, indexed per-page, into the column's
* string offset buffer
* @param error_code Error code to set if an error is encountered
*/
template <typename level_t, int decode_block_size_t, decode_kernel_mask kernel_mask_t>
CUDF_KERNEL void __launch_bounds__(decode_block_size_t, 8)
decode_page_data_generic(PageInfo* pages,
device_span<ColumnChunkDesc const> chunks,
size_t min_row,
size_t num_rows,
cudf::device_span<bool const> page_mask,
cudf::device_span<size_t> initial_str_offsets,
cudf::device_span<size_t const> page_string_offset_indices,
kernel_error::pointer error_code)
{
constexpr bool has_dict_t = has_dict<kernel_mask_t>();
constexpr bool has_bools_t = has_bools<kernel_mask_t>();
constexpr bool has_nesting_t = has_nesting<kernel_mask_t>();
constexpr bool has_lists_t = has_lists<kernel_mask_t>();
constexpr bool split_decode_t = is_split_decode<kernel_mask_t>();
constexpr bool has_strings_t =
(static_cast<uint32_t>(kernel_mask_t) & STRINGS_MASK_NON_DELTA) != 0;
constexpr int rolling_buf_size = decode_block_size_t * 2;
constexpr int rle_run_buffer_size = rle_stream_required_run_buffer_size<decode_block_size_t>();
__shared__ __align__(16) page_state_s state_g;
constexpr bool use_dict_buffers = has_dict_t || has_bools_t;
using state_buf_t = page_state_buffers_s<rolling_buf_size, // size of nz_idx buffer
use_dict_buffers ? rolling_buf_size : 1,
1>;
__shared__ __align__(16) state_buf_t state_buffers;
auto const block = cg::this_thread_block();
page_state_s* const s = &state_g;
auto* const sb = &state_buffers;
int const page_idx = cg::this_grid().block_rank();
int const t = block.thread_rank();
PageInfo* pp = &pages[page_idx];
if (!(BitAnd(pages[page_idx].kernel_mask, kernel_mask_t))) { return; }
// must come after the kernel mask check
[[maybe_unused]] null_count_back_copier _{s, t};
// Exit super early for simple types if the page does not need to be decoded
if constexpr (not has_lists_t and not has_strings_t and not has_nesting_t) {
if (not page_mask[page_idx]) {
pp->num_nulls = pp->nesting[0].batch_size;
pp->num_valids = 0;
// Set s->nesting info = nullptr to bypass `null_count_back_copier` at return
s->nesting_info = nullptr;
return;
}
}
// Setup local page info
if (!setup_local_page_info(s,
pp,
chunks,
min_row,
num_rows,
mask_filter{kernel_mask_t},
page_processing_stage::DECODE)) {
return;
}
// Write list and/or string offsets and exit if the page does not need to be decoded