@@ -761,20 +761,19 @@ struct policy_hub
761761 // / SM90
762762 struct Policy900 : ChainedPolicy<900 , Policy900, Policy800>
763763 {
764- enum
765- {
766- PRIMARY_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 7 : 5 ,
767- SINGLE_TILE_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ,
768- SEGMENTED_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ,
769- ONESWEEP = true ,
770- ONESWEEP_RADIX_BITS = 8 ,
771- OFFSET_64BIT = sizeof (OffsetT) == 8 ? 1 : 0 ,
772- FLOAT_KEYS = std::is_same<KeyT, float >::value ? 1 : 0 ,
773- };
764+ static constexpr bool ONESWEEP = true ;
765+ static constexpr int ONESWEEP_RADIX_BITS = 8 ;
774766
775767 using HistogramPolicy = AgentRadixSortHistogramPolicy<128 , 16 , 1 , KeyT, ONESWEEP_RADIX_BITS>;
776768 using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256 , ONESWEEP_RADIX_BITS>;
777769
770+ private:
771+ static constexpr int PRIMARY_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 7 : 5 ;
772+ static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ;
773+ static constexpr int SEGMENTED_RADIX_BITS = (sizeof (KeyT) > 1 ) ? 6 : 5 ;
774+ static constexpr int OFFSET_64BIT = sizeof (OffsetT) == 8 ? 1 : 0 ;
775+ static constexpr int FLOAT_KEYS = ::cuda::std::is_same<KeyT, float >::value ? 1 : 0 ;
776+
778777 using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy<
779778 384 ,
780779 KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS
@@ -796,11 +795,11 @@ struct policy_hub
796795 RADIX_SORT_STORE_DIRECT,
797796 ONESWEEP_RADIX_BITS>;
798797
799- using OnesweepLargeKeyPolicy = //
800- ::cuda::std::_If<sizeof (KeyT) == 4 , OnesweepPolicyKey32, OnesweepPolicyKey64>;
798+ using OnesweepLargeKeyPolicy = ::cuda::std::_If<sizeof (KeyT) == 4 , OnesweepPolicyKey32, OnesweepPolicyKey64>;
799+
800+ using OnesweepSmallKeyPolicySizes =
801+ sm90_small_key_tuning<sizeof (KeyT), KEYS_ONLY ? 0 : sizeof (ValueT), sizeof (OffsetT)>;
801802
802- using OnesweepSmallKeyPolicySizes = //
803- detail::radix::sm90_small_key_tuning<sizeof (KeyT), KEYS_ONLY ? 0 : sizeof (ValueT), sizeof (OffsetT)>;
804803 using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy<
805804 OnesweepSmallKeyPolicySizes::threads,
806805 OnesweepSmallKeyPolicySizes::items,
@@ -810,42 +809,9 @@ struct policy_hub
810809 BLOCK_SCAN_RAKING_MEMOIZE,
811810 RADIX_SORT_STORE_DIRECT,
812811 8 >;
813- using OnesweepPolicy = //
814- ::cuda::std::_If<sizeof (KeyT) < 4 , //
815- OnesweepSmallKeyPolicy, //
816- OnesweepLargeKeyPolicy>;
817-
818- using ScanPolicy =
819- AgentScanPolicy<512 ,
820- 23 ,
821- OffsetT,
822- BLOCK_LOAD_WARP_TRANSPOSE,
823- LOAD_DEFAULT,
824- BLOCK_STORE_WARP_TRANSPOSE,
825- BLOCK_SCAN_RAKING_MEMOIZE>;
826-
827- using DownsweepPolicy = AgentRadixSortDownsweepPolicy<
828- 512 ,
829- 23 ,
830- DominantT,
831- BLOCK_LOAD_TRANSPOSE,
832- LOAD_DEFAULT,
833- RADIX_RANK_MATCH,
834- BLOCK_SCAN_WARP_SCANS,
835- PRIMARY_RADIX_BITS>;
836-
837- using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy<
838- (sizeof (KeyT) > 1 ) ? 256 : 128 ,
839- 47 ,
840- DominantT,
841- BLOCK_LOAD_TRANSPOSE,
842- LOAD_DEFAULT,
843- RADIX_RANK_MEMOIZE,
844- BLOCK_SCAN_WARP_SCANS,
845- PRIMARY_RADIX_BITS - 1 >;
846812
847- using UpsweepPolicy = AgentRadixSortUpsweepPolicy< 256 , 23 , DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>;
848- using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy< 256 , 47 , DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1 >;
813+ public:
814+ using OnesweepPolicy = ::cuda::std::_If< sizeof (KeyT) < 4 , OnesweepSmallKeyPolicy, OnesweepLargeKeyPolicy >;
849815
850816 using SingleTilePolicy = AgentRadixSortDownsweepPolicy<
851817 256 ,
0 commit comments