@@ -49,10 +49,8 @@ CUB_NAMESPACE_BEGIN
4949
5050namespace detail
5151{
52-
5352namespace three_way_partition
5453{
55-
5654enum class input_size
5755{
5856 _1,
@@ -92,246 +90,193 @@ template <class InputT,
9290 class OffsetT ,
9391 input_size InputSize = classify_input_size<InputT>(),
9492 offset_size OffsetSize = classify_offset_size<OffsetT>()>
95- struct sm90_tuning
96- {
97- static constexpr int threads = 256 ;
98- static constexpr int items = Nominal4BItemsToItems<InputT>(9 );
99-
100- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
101-
102- using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t <OffsetT>;
103- using AccumPackT = typename AccumPackHelperT::pack_t ;
104- using delay_constructor = detail::default_delay_constructor_t <AccumPackT>;
105- };
93+ struct sm80_tuning ;
10694
10795template <class Input , class OffsetT >
108- struct sm90_tuning <Input, OffsetT, input_size::_1, offset_size::_4>
109- {
110- static constexpr int threads = 256 ;
111- static constexpr int items = 12 ;
112-
113- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
114-
115- using delay_constructor = detail::no_delay_constructor_t <445 >;
116- };
117-
118- template <class Input , class OffsetT >
119- struct sm90_tuning <Input, OffsetT, input_size::_2, offset_size::_4>
96+ struct sm80_tuning <Input, OffsetT, input_size::_2, offset_size::_4>
12097{
121- static constexpr int threads = 256 ;
122- static constexpr int items = 12 ;
123-
124- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
125-
126- using delay_constructor = detail::fixed_delay_constructor_t <104 , 512 >;
98+ static constexpr int threads = 256 ;
99+ static constexpr int items = 12 ;
100+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
101+ using delay_constructor = no_delay_constructor_t <910 >;
127102};
128103
129104template <class Input , class OffsetT >
130- struct sm90_tuning <Input, OffsetT, input_size::_4, offset_size::_4>
105+ struct sm80_tuning <Input, OffsetT, input_size::_4, offset_size::_4>
131106{
132- static constexpr int threads = 320 ;
133- static constexpr int items = 12 ;
134-
135- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
136-
137- using delay_constructor = detail::no_delay_constructor_t <1105 >;
107+ static constexpr int threads = 256 ;
108+ static constexpr int items = 11 ;
109+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
110+ using delay_constructor = no_delay_constructor_t <1120 >;
138111};
139112
140113template <class Input , class OffsetT >
141- struct sm90_tuning <Input, OffsetT, input_size::_8, offset_size::_4>
114+ struct sm80_tuning <Input, OffsetT, input_size::_8, offset_size::_4>
142115{
143- static constexpr int threads = 384 ;
144- static constexpr int items = 7 ;
145-
116+ static constexpr int threads = 224 ;
117+ static constexpr int items = 11 ;
146118 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
147-
148- using delay_constructor = detail::fixed_delay_constructor_t <464 , 1165 >;
119+ using delay_constructor = fixed_delay_constructor_t <264 , 1080 >;
149120};
150121
151122template <class Input , class OffsetT >
152- struct sm90_tuning <Input, OffsetT, input_size::_16, offset_size::_4>
123+ struct sm80_tuning <Input, OffsetT, input_size::_16, offset_size::_4>
153124{
154- static constexpr int threads = 128 ;
155- static constexpr int items = 7 ;
156-
125+ static constexpr int threads = 128 ;
126+ static constexpr int items = 10 ;
157127 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
158-
159- using delay_constructor = detail::no_delay_constructor_t <1040 >;
128+ using delay_constructor = fixed_delay_constructor_t <672 , 1120 >;
160129};
161130
131+ template <class InputT ,
132+ class OffsetT ,
133+ input_size InputSize = classify_input_size<InputT>(),
134+ offset_size OffsetSize = classify_offset_size<OffsetT>()>
135+ struct sm90_tuning ;
136+
162137template <class Input , class OffsetT >
163- struct sm90_tuning <Input, OffsetT, input_size::_1, offset_size::_8 >
138+ struct sm90_tuning <Input, OffsetT, input_size::_1, offset_size::_4 >
164139{
165- static constexpr int threads = 256 ;
166- static constexpr int items = 24 ;
167-
140+ static constexpr int threads = 256 ;
141+ static constexpr int items = 12 ;
168142 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
169-
170- using delay_constructor = detail::fixed_delay_constructor_t <4 , 285 >;
143+ using delay_constructor = no_delay_constructor_t <445 >;
171144};
172145
173146template <class Input , class OffsetT >
174- struct sm90_tuning <Input, OffsetT, input_size::_2, offset_size::_8 >
147+ struct sm90_tuning <Input, OffsetT, input_size::_2, offset_size::_4 >
175148{
176- static constexpr int threads = 640 ;
177- static constexpr int items = 24 ;
178-
179- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
180-
181- using delay_constructor = detail::no_delay_constructor_t <245 >;
149+ static constexpr int threads = 256 ;
150+ static constexpr int items = 12 ;
151+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
152+ using delay_constructor = fixed_delay_constructor_t <104 , 512 >;
182153};
183154
184155template <class Input , class OffsetT >
185- struct sm90_tuning <Input, OffsetT, input_size::_4, offset_size::_8 >
156+ struct sm90_tuning <Input, OffsetT, input_size::_4, offset_size::_4 >
186157{
187- static constexpr int threads = 256 ;
188- static constexpr int items = 23 ;
189-
190- static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
191-
192- using delay_constructor = detail::no_delay_constructor_t <910 >;
158+ static constexpr int threads = 320 ;
159+ static constexpr int items = 12 ;
160+ static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
161+ using delay_constructor = no_delay_constructor_t <1105 >;
193162};
194163
195164template <class Input , class OffsetT >
196- struct sm90_tuning <Input, OffsetT, input_size::_8, offset_size::_8 >
165+ struct sm90_tuning <Input, OffsetT, input_size::_8, offset_size::_4 >
197166{
198- static constexpr int threads = 256 ;
199- static constexpr int items = 18 ;
200-
167+ static constexpr int threads = 384 ;
168+ static constexpr int items = 7 ;
201169 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
202-
203- using delay_constructor = detail::no_delay_constructor_t <1145 >;
170+ using delay_constructor = fixed_delay_constructor_t <464 , 1165 >;
204171};
205172
206173template <class Input , class OffsetT >
207- struct sm90_tuning <Input, OffsetT, input_size::_16, offset_size::_8 >
174+ struct sm90_tuning <Input, OffsetT, input_size::_16, offset_size::_4 >
208175{
209- static constexpr int threads = 256 ;
210- static constexpr int items = 11 ;
211-
176+ static constexpr int threads = 128 ;
177+ static constexpr int items = 7 ;
212178 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
213-
214- using delay_constructor = detail::no_delay_constructor_t <1050 >;
179+ using delay_constructor = no_delay_constructor_t <1040 >;
215180};
216181
217- template <class InputT ,
218- class OffsetT ,
219- input_size InputSize = classify_input_size<InputT>(),
220- offset_size OffsetSize = classify_offset_size<OffsetT>()>
221- struct sm80_tuning
182+ template <class Input , class OffsetT >
183+ struct sm90_tuning <Input, OffsetT, input_size::_1, offset_size::_8>
222184{
223- static constexpr int threads = 256 ;
224- static constexpr int items = Nominal4BItemsToItems<InputT>(9 );
225-
185+ static constexpr int threads = 256 ;
186+ static constexpr int items = 24 ;
226187 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
227-
228- using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t <OffsetT>;
229- using AccumPackT = typename AccumPackHelperT::pack_t ;
230- using delay_constructor = detail::default_delay_constructor_t <AccumPackT>;
188+ using delay_constructor = fixed_delay_constructor_t <4 , 285 >;
231189};
232190
233191template <class Input , class OffsetT >
234- struct sm80_tuning <Input, OffsetT, input_size::_2, offset_size::_4 >
192+ struct sm90_tuning <Input, OffsetT, input_size::_2, offset_size::_8 >
235193{
236- static constexpr int threads = 256 ;
237- static constexpr int items = 12 ;
238-
194+ static constexpr int threads = 640 ;
195+ static constexpr int items = 24 ;
239196 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
240-
241- using delay_constructor = detail::no_delay_constructor_t <910 >;
197+ using delay_constructor = no_delay_constructor_t <245 >;
242198};
243199
244200template <class Input , class OffsetT >
245- struct sm80_tuning <Input, OffsetT, input_size::_4, offset_size::_4 >
201+ struct sm90_tuning <Input, OffsetT, input_size::_4, offset_size::_8 >
246202{
247- static constexpr int threads = 256 ;
248- static constexpr int items = 11 ;
249-
203+ static constexpr int threads = 256 ;
204+ static constexpr int items = 23 ;
250205 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
251-
252- using delay_constructor = detail::no_delay_constructor_t <1120 >;
206+ using delay_constructor = no_delay_constructor_t <910 >;
253207};
254208
255209template <class Input , class OffsetT >
256- struct sm80_tuning <Input, OffsetT, input_size::_8, offset_size::_4 >
210+ struct sm90_tuning <Input, OffsetT, input_size::_8, offset_size::_8 >
257211{
258- static constexpr int threads = 224 ;
259- static constexpr int items = 11 ;
260-
212+ static constexpr int threads = 256 ;
213+ static constexpr int items = 18 ;
261214 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
262-
263- using delay_constructor = detail::fixed_delay_constructor_t <264 , 1080 >;
215+ using delay_constructor = no_delay_constructor_t <1145 >;
264216};
265217
266218template <class Input , class OffsetT >
267- struct sm80_tuning <Input, OffsetT, input_size::_16, offset_size::_4 >
219+ struct sm90_tuning <Input, OffsetT, input_size::_16, offset_size::_8 >
268220{
269- static constexpr int threads = 128 ;
270- static constexpr int items = 10 ;
271-
221+ static constexpr int threads = 256 ;
222+ static constexpr int items = 11 ;
272223 static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
273-
274- using delay_constructor = detail::fixed_delay_constructor_t <672 , 1120 >;
224+ using delay_constructor = no_delay_constructor_t <1050 >;
275225};
276226
277- } // namespace three_way_partition
278-
279227template <class InputT , class OffsetT >
280- struct device_three_way_partition_policy_hub
228+ struct policy_hub
281229{
282- struct DefaultTuning
230+ template <typename DelayConstructor>
231+ struct DefaultPolicy
283232 {
284- static constexpr int ITEMS_PER_THREAD = Nominal4BItemsToItems<InputT>(9 );
285-
286233 using ThreeWayPartitionPolicy =
287- cub::AgentThreeWayPartitionPolicy<256 ,
288- ITEMS_PER_THREAD,
289- cub::BLOCK_LOAD_DIRECT,
290- cub::LOAD_DEFAULT,
291- cub::BLOCK_SCAN_WARP_SCANS>;
234+ AgentThreeWayPartitionPolicy<256 ,
235+ Nominal4BItemsToItems<InputT>(9 ),
236+ BLOCK_LOAD_DIRECT,
237+ LOAD_DEFAULT,
238+ BLOCK_SCAN_WARP_SCANS,
239+ DelayConstructor>;
292240 };
293241
294- // / SM35
295242 struct Policy350
296- : DefaultTuning
243+ : DefaultPolicy< fixed_delay_constructor_t < 350 , 450 >>
297244 , ChainedPolicy<350 , Policy350, Policy350>
298245 {};
299246
247+ // Use values from tuning if a specialization exists, otherwise pick DefaultPolicy
248+ template <typename Tuning>
249+ static auto select_agent_policy (int )
250+ -> AgentThreeWayPartitionPolicy<Tuning::threads,
251+ Tuning::items,
252+ Tuning::load_algorithm,
253+ LOAD_DEFAULT,
254+ BLOCK_SCAN_WARP_SCANS,
255+ typename Tuning::delay_constructor>;
256+
257+ template <typename Tuning>
258+ static auto select_agent_policy (long ) ->
259+ typename DefaultPolicy<
260+ default_delay_constructor_t<typename accumulator_pack_t<OffsetT>::pack_t>>::ThreeWayPartitionPolicy;
261+
300262 struct Policy800 : ChainedPolicy<800 , Policy800, Policy350>
301263 {
302- using tuning = detail::three_way_partition::sm80_tuning<InputT, OffsetT>;
303-
304- using ThreeWayPartitionPolicy =
305- AgentThreeWayPartitionPolicy<tuning::threads,
306- tuning::items,
307- tuning::load_algorithm,
308- cub::LOAD_DEFAULT,
309- cub::BLOCK_SCAN_WARP_SCANS,
310- typename tuning::delay_constructor>;
264+ using ThreeWayPartitionPolicy = decltype (select_agent_policy<sm80_tuning<InputT, OffsetT>>(0 ));
311265 };
312266
313267 struct Policy860
314- : DefaultTuning
268+ : DefaultPolicy< fixed_delay_constructor_t < 350 , 450 >>
315269 , ChainedPolicy<860 , Policy860, Policy800>
316270 {};
317271
318- // / SM90
319272 struct Policy900 : ChainedPolicy<900 , Policy900, Policy860>
320273 {
321- using tuning = detail::three_way_partition::sm90_tuning<InputT, OffsetT>;
322-
323- using ThreeWayPartitionPolicy =
324- AgentThreeWayPartitionPolicy<tuning::threads,
325- tuning::items,
326- tuning::load_algorithm,
327- cub::LOAD_DEFAULT,
328- cub::BLOCK_SCAN_WARP_SCANS,
329- typename tuning::delay_constructor>;
274+ using ThreeWayPartitionPolicy = decltype (select_agent_policy<sm90_tuning<InputT, OffsetT>>(0 ));
330275 };
331276
332277 using MaxPolicy = Policy900;
333278};
334-
279+ } // namespace three_way_partition
335280} // namespace detail
336281
337282CUB_NAMESPACE_END
0 commit comments