Skip to content

Commit 346a618

Browse files
Refactor selecting default tuning (#3124)
1 parent 53f69a4 commit 346a618

File tree

1 file changed

+40
-49
lines changed

1 file changed

+40
-49
lines changed

cub/cub/device/dispatch/tuning/tuning_select_if.cuh

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -82,16 +82,7 @@ enum class input_size
8282
};
8383

8484
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
85-
struct sm80_tuning
86-
{
87-
static constexpr int threads = 128;
88-
static constexpr int nominal_4b_items_per_thread = 10;
89-
// TODO(bgruber): use cuda::std::clamp() in C++14
90-
static constexpr int items =
91-
CUB_MIN(nominal_4b_items_per_thread, CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT))));
92-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
93-
using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
94-
};
85+
struct sm80_tuning;
9586

9687
// select::if
9788
template <class Input>
@@ -306,16 +297,7 @@ struct sm80_tuning<__uint128_t, flagged::yes, keep_rejects::yes, offset_size::_4
306297
#endif
307298

308299
template <class InputT, flagged, keep_rejects, offset_size OffsetSize, primitive, input_size InputSize>
309-
struct sm90_tuning
310-
{
311-
static constexpr int threads = 128;
312-
static constexpr int nominal_4b_items_per_thread = 10;
313-
// TODO(bgruber): use cuda::std::clamp() in C++14
314-
static constexpr int items =
315-
CUB_MIN(nominal_4b_items_per_thread, CUB_MAX(1, (nominal_4b_items_per_thread * 4 / sizeof(InputT))));
316-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
317-
using delay_constructor = detail::fixed_delay_constructor_t<350, 450>;
318-
};
300+
struct sm90_tuning;
319301

320302
// select::if
321303
template <class Input>
@@ -568,6 +550,7 @@ constexpr offset_size classify_offset_size()
568550
template <class InputT, class FlagT, class OffsetT, bool MayAlias, bool KeepRejects>
569551
struct policy_hub
570552
{
553+
template <CacheLoadModifier LoadModifier>
571554
struct DefaultPolicy
572555
{
573556
static constexpr int nominal_4B_items_per_thread = 10;
@@ -578,57 +561,65 @@ struct policy_hub
578561
AgentSelectIfPolicy<128,
579562
items_per_thread,
580563
BLOCK_LOAD_DIRECT,
581-
MayAlias ? LOAD_CA : LOAD_LDG,
564+
LoadModifier,
582565
BLOCK_SCAN_WARP_SCANS,
583566
detail::fixed_delay_constructor_t<350, 450>>;
584567
};
585568

586569
struct Policy350
587-
: DefaultPolicy
570+
: DefaultPolicy<MayAlias ? LOAD_CA : LOAD_LDG>
588571
, ChainedPolicy<350, Policy350, Policy350>
589572
{};
590573

591574
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
592575
{
593-
using tuning =
594-
sm80_tuning<InputT,
595-
is_flagged<FlagT>(),
596-
are_rejects_kept<KeepRejects>(),
597-
classify_offset_size<OffsetT>(),
598-
is_primitive<InputT>(),
599-
classify_input_size<InputT>()>;
576+
// Use values from tuning if a specialization exists, otherwise pick the default
577+
template <typename Tuning>
578+
static auto select_agent_policy(int)
579+
-> AgentSelectIfPolicy<Tuning::threads,
580+
Tuning::items,
581+
Tuning::load_algorithm,
582+
LOAD_DEFAULT,
583+
BLOCK_SCAN_WARP_SCANS,
584+
typename Tuning::delay_constructor>;
585+
template <typename Tuning>
586+
static auto select_agent_policy(long) -> typename DefaultPolicy<LOAD_DEFAULT>::SelectIfPolicyT;
600587

601588
using SelectIfPolicyT =
602-
AgentSelectIfPolicy<tuning::threads,
603-
tuning::items,
604-
tuning::load_algorithm,
605-
LOAD_DEFAULT,
606-
BLOCK_SCAN_WARP_SCANS,
607-
typename tuning::delay_constructor>;
589+
decltype(select_agent_policy<sm80_tuning<InputT,
590+
is_flagged<FlagT>(),
591+
are_rejects_kept<KeepRejects>(),
592+
classify_offset_size<OffsetT>(),
593+
is_primitive<InputT>(),
594+
classify_input_size<InputT>()>>(0));
608595
};
609596

610597
struct Policy860
611-
: DefaultPolicy
598+
: DefaultPolicy<MayAlias ? LOAD_CA : LOAD_LDG>
612599
, ChainedPolicy<860, Policy860, Policy800>
613600
{};
614601

615602
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
616603
{
617-
using tuning =
618-
sm90_tuning<InputT,
619-
is_flagged<FlagT>(),
620-
are_rejects_kept<KeepRejects>(),
621-
classify_offset_size<OffsetT>(),
622-
is_primitive<InputT>(),
623-
classify_input_size<InputT>()>;
604+
// Use values from tuning if a specialization exists, otherwise pick the default
605+
template <typename Tuning>
606+
static auto select_agent_policy(int)
607+
-> AgentSelectIfPolicy<Tuning::threads,
608+
Tuning::items,
609+
Tuning::load_algorithm,
610+
LOAD_DEFAULT,
611+
BLOCK_SCAN_WARP_SCANS,
612+
typename Tuning::delay_constructor>;
613+
template <typename Tuning>
614+
static auto select_agent_policy(long) -> typename DefaultPolicy<LOAD_DEFAULT>::SelectIfPolicyT;
624615

625616
using SelectIfPolicyT =
626-
AgentSelectIfPolicy<tuning::threads,
627-
tuning::items,
628-
tuning::load_algorithm,
629-
LOAD_DEFAULT,
630-
BLOCK_SCAN_WARP_SCANS,
631-
typename tuning::delay_constructor>;
617+
decltype(select_agent_policy<sm90_tuning<InputT,
618+
is_flagged<FlagT>(),
619+
are_rejects_kept<KeepRejects>(),
620+
classify_offset_size<OffsetT>(),
621+
is_primitive<InputT>(),
622+
classify_input_size<InputT>()>>(0));
632623
};
633624

634625
using MaxPolicy = Policy900;

0 commit comments

Comments
 (0)