Skip to content

Commit 29ba731

Browse files
Refactor SM90 radix_sort tuning (#3125)
1 parent 346a618 commit 29ba731

File tree

1 file changed

+15
-49
lines changed

1 file changed

+15
-49
lines changed

cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -761,20 +761,19 @@ struct policy_hub
761761
/// SM90
762762
struct Policy900 : ChainedPolicy<900, Policy900, Policy800>
763763
{
764-
enum
765-
{
766-
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5,
767-
SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
768-
SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5,
769-
ONESWEEP = true,
770-
ONESWEEP_RADIX_BITS = 8,
771-
OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0,
772-
FLOAT_KEYS = std::is_same<KeyT, float>::value ? 1 : 0,
773-
};
764+
static constexpr bool ONESWEEP = true;
765+
static constexpr int ONESWEEP_RADIX_BITS = 8;
774766

775767
using HistogramPolicy = AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>;
776768
using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>;
777769

770+
private:
771+
static constexpr int PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5;
772+
static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
773+
static constexpr int SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5;
774+
static constexpr int OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0;
775+
static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float>::value ? 1 : 0;
776+
778777
using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
779778
384,
780779
KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
@@ -796,11 +795,11 @@ struct policy_hub
796795
RADIX_SORT_STORE_DIRECT,
797796
ONESWEEP_RADIX_BITS>;
798797

799-
using OnesweepLargeKeyPolicy = //
800-
::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;
798+
using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof(KeyT) == 4, OnesweepPolicyKey32, OnesweepPolicyKey64>;
799+
800+
using OnesweepSmallKeyPolicySizes =
801+
sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;
801802

802-
using OnesweepSmallKeyPolicySizes = //
803-
detail::radix::sm90_small_key_tuning<sizeof(KeyT), KEYS_ONLY ? 0 : sizeof(ValueT), sizeof(OffsetT)>;
804803
using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
805804
OnesweepSmallKeyPolicySizes::threads,
806805
OnesweepSmallKeyPolicySizes::items,
@@ -810,42 +809,9 @@ struct policy_hub
810809
BLOCK_SCAN_RAKING_MEMOIZE,
811810
RADIX_SORT_STORE_DIRECT,
812811
8>;
813-
using OnesweepPolicy = //
814-
::cuda::std::_If<sizeof(KeyT) < 4, //
815-
OnesweepSmallKeyPolicy, //
816-
OnesweepLargeKeyPolicy>;
817-
818-
using ScanPolicy =
819-
AgentScanPolicy<512,
820-
23,
821-
OffsetT,
822-
BLOCK_LOAD_WARP_TRANSPOSE,
823-
LOAD_DEFAULT,
824-
BLOCK_STORE_WARP_TRANSPOSE,
825-
BLOCK_SCAN_RAKING_MEMOIZE>;
826-
827-
using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
828-
512,
829-
23,
830-
DominantT,
831-
BLOCK_LOAD_TRANSPOSE,
832-
LOAD_DEFAULT,
833-
RADIX_RANK_MATCH,
834-
BLOCK_SCAN_WARP_SCANS,
835-
PRIMARY_RADIX_BITS>;
836-
837-
using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
838-
(sizeof(KeyT) > 1) ? 256 : 128,
839-
47,
840-
DominantT,
841-
BLOCK_LOAD_TRANSPOSE,
842-
LOAD_DEFAULT,
843-
RADIX_RANK_MEMOIZE,
844-
BLOCK_SCAN_WARP_SCANS,
845-
PRIMARY_RADIX_BITS - 1>;
846812

847-
using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
848-
using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>;
813+
public:
814+
using OnesweepPolicy = ::cuda::std::_If<sizeof(KeyT) < 4, OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy>;
849815

850816
using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
851817
256,

0 commit comments

Comments
 (0)