@@ -82,16 +82,7 @@ enum class input_size
8282};
8383
8484template <class InputT , flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
85- struct sm80_tuning
86- {
87- static constexpr int threads = 128 ;
88- static constexpr int nominal_4b_items_per_thread = 10 ;
89- // TODO(bgruber): use cuda::std::clamp() in C++14
90- static constexpr int items =
91- CUB_MIN (nominal_4b_items_per_thread, CUB_MAX(1 , (nominal_4b_items_per_thread * 4 / sizeof (InputT))));
92- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
93- using delay_constructor = detail::fixed_delay_constructor_t <350 , 450 >;
94- };
85+ struct sm80_tuning ;
9586
9687// select::if
9788template <class Input >
@@ -306,16 +297,7 @@ struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
306297#endif
307298
308299template <class InputT , flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
309- struct sm90_tuning
310- {
311- static constexpr int threads = 128 ;
312- static constexpr int nominal_4b_items_per_thread = 10 ;
313- // TODO(bgruber): use cuda::std::clamp() in C++14
314- static constexpr int items =
315- CUB_MIN (nominal_4b_items_per_thread, CUB_MAX(1 , (nominal_4b_items_per_thread * 4 / sizeof (InputT))));
316- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
317- using delay_constructor = detail::fixed_delay_constructor_t <350 , 450 >;
318- };
300+ struct sm90_tuning ;
319301
320302// select::if
321303template <class Input >
@@ -568,6 +550,7 @@ constexpr offset_size classify_offset_size()
568550template <class InputT , class FlagT , class OffsetT , bool MayAlias, bool KeepRejects>
569551struct policy_hub
570552{
553+ template <CacheLoadModifier LoadModifier>
571554 struct DefaultPolicy
572555 {
573556 static constexpr int nominal_4B_items_per_thread = 10 ;
@@ -578,57 +561,65 @@ struct policy_hub
578561 AgentSelectIfPolicy<128 ,
579562 items_per_thread,
580563 BLOCK_LOAD_DIRECT,
581- MayAlias ? LOAD_CA : LOAD_LDG ,
564+ LoadModifier ,
582565 BLOCK_SCAN_WARP_SCANS,
583566 detail::fixed_delay_constructor_t <350 , 450 >>;
584567 };
585568
586569 struct Policy350
587- : DefaultPolicy
570+ : DefaultPolicy<MayAlias ? LOAD_CA : LOAD_LDG>
588571 , ChainedPolicy<350 , Policy350, Policy350>
589572 {};
590573
591574 struct Policy800 : ChainedPolicy<800 , Policy800, Policy350>
592575 {
593- using tuning =
594- sm80_tuning<InputT,
595- is_flagged<FlagT>(),
596- are_rejects_kept<KeepRejects>(),
597- classify_offset_size<OffsetT>(),
598- is_primitive<InputT>(),
599- classify_input_size<InputT>()>;
576+ // Use values from tuning if a specialization exists, otherwise pick the default
577+ template <typename Tuning>
578+ static auto select_agent_policy (int )
579+ -> AgentSelectIfPolicy<Tuning::threads,
580+ Tuning::items,
581+ Tuning::load_algorithm,
582+ LOAD_DEFAULT,
583+ BLOCK_SCAN_WARP_SCANS,
584+ typename Tuning::delay_constructor>;
585+ template <typename Tuning>
586+ static auto select_agent_policy (long ) -> typename DefaultPolicy<LOAD_DEFAULT>::SelectIfPolicyT;
600587
601588 using SelectIfPolicyT =
602- AgentSelectIfPolicy<tuning::threads ,
603- tuning::items ,
604- tuning::load_algorithm ,
605- LOAD_DEFAULT ,
606- BLOCK_SCAN_WARP_SCANS ,
607- typename tuning::delay_constructor> ;
589+ decltype (select_agent_policy<sm80_tuning<InputT ,
590+ is_flagged<FlagT>() ,
591+ are_rejects_kept<KeepRejects>() ,
592+ classify_offset_size<OffsetT>() ,
593+ is_primitive<InputT>() ,
594+ classify_input_size<InputT>()>>( 0 )) ;
608595 };
609596
610597 struct Policy860
611- : DefaultPolicy
598+ : DefaultPolicy<MayAlias ? LOAD_CA : LOAD_LDG>
612599 , ChainedPolicy<860 , Policy860, Policy800>
613600 {};
614601
615602 struct Policy900 : ChainedPolicy<900 , Policy900, Policy860>
616603 {
617- using tuning =
618- sm90_tuning<InputT,
619- is_flagged<FlagT>(),
620- are_rejects_kept<KeepRejects>(),
621- classify_offset_size<OffsetT>(),
622- is_primitive<InputT>(),
623- classify_input_size<InputT>()>;
604+ // Use values from tuning if a specialization exists, otherwise pick the default
605+ template <typename Tuning>
606+ static auto select_agent_policy (int )
607+ -> AgentSelectIfPolicy<Tuning::threads,
608+ Tuning::items,
609+ Tuning::load_algorithm,
610+ LOAD_DEFAULT,
611+ BLOCK_SCAN_WARP_SCANS,
612+ typename Tuning::delay_constructor>;
613+ template <typename Tuning>
614+ static auto select_agent_policy (long ) -> typename DefaultPolicy<LOAD_DEFAULT>::SelectIfPolicyT;
624615
625616 using SelectIfPolicyT =
626- AgentSelectIfPolicy<tuning::threads ,
627- tuning::items ,
628- tuning::load_algorithm ,
629- LOAD_DEFAULT ,
630- BLOCK_SCAN_WARP_SCANS ,
631- typename tuning::delay_constructor> ;
617+ decltype (select_agent_policy<sm90_tuning<InputT ,
618+ is_flagged<FlagT>() ,
619+ are_rejects_kept<KeepRejects>() ,
620+ classify_offset_size<OffsetT>() ,
621+ is_primitive<InputT>() ,
622+ classify_input_size<InputT>()>>( 0 )) ;
632623 };
633624
634625 using MaxPolicy = Policy900;
0 commit comments