Skip to content

Commit 0ea508f

Browse files
Refactor three_way_parition tuning (#3140)
* Drop needless comments * Move and rename policy_hub * Drop unneeded namespace qualifications * Rename DefaultTuning * Eliminate redundancy * Swap sm80 and sm90 tuning
1 parent 7321a51 commit 0ea508f

File tree

2 files changed

+107
-162
lines changed

2 files changed

+107
-162
lines changed

cub/cub/device/dispatch/dispatch_three_way_partition.cuh

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -153,16 +153,16 @@ DeviceThreeWayPartitionInitKernel(ScanTileStateT tile_state, int num_tiles, NumS
153153
* Dispatch
154154
******************************************************************************/
155155

156-
template <typename InputIteratorT,
157-
typename FirstOutputIteratorT,
158-
typename SecondOutputIteratorT,
159-
typename UnselectedOutputIteratorT,
160-
typename NumSelectedIteratorT,
161-
typename SelectFirstPartOp,
162-
typename SelectSecondPartOp,
163-
typename OffsetT,
164-
typename SelectedPolicy =
165-
detail::device_three_way_partition_policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
156+
template <
157+
typename InputIteratorT,
158+
typename FirstOutputIteratorT,
159+
typename SecondOutputIteratorT,
160+
typename UnselectedOutputIteratorT,
161+
typename NumSelectedIteratorT,
162+
typename SelectFirstPartOp,
163+
typename SelectSecondPartOp,
164+
typename OffsetT,
165+
typename SelectedPolicy = detail::three_way_partition::policy_hub<cub::detail::value_t<InputIteratorT>, OffsetT>>
166166
struct DispatchThreeWayPartitionIf
167167
{
168168
/*****************************************************************************

cub/cub/device/dispatch/tuning/tuning_three_way_partition.cuh

Lines changed: 97 additions & 152 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ CUB_NAMESPACE_BEGIN
4949

5050
namespace detail
5151
{
52-
5352
namespace three_way_partition
5453
{
55-
5654
enum class input_size
5755
{
5856
_1,
@@ -92,246 +90,193 @@ template <class InputT,
9290
class OffsetT,
9391
input_size InputSize = classify_input_size<InputT>(),
9492
offset_size OffsetSize = classify_offset_size<OffsetT>()>
95-
struct sm90_tuning
96-
{
97-
static constexpr int threads = 256;
98-
static constexpr int items = Nominal4BItemsToItems<InputT>(9);
99-
100-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
101-
102-
using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
103-
using AccumPackT = typename AccumPackHelperT::pack_t;
104-
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
105-
};
93+
struct sm80_tuning;
10694

10795
template <class Input, class OffsetT>
108-
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
109-
{
110-
static constexpr int threads = 256;
111-
static constexpr int items = 12;
112-
113-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
114-
115-
using delay_constructor = detail::no_delay_constructor_t<445>;
116-
};
117-
118-
template <class Input, class OffsetT>
119-
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
96+
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
12097
{
121-
static constexpr int threads = 256;
122-
static constexpr int items = 12;
123-
124-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
125-
126-
using delay_constructor = detail::fixed_delay_constructor_t<104, 512>;
98+
static constexpr int threads = 256;
99+
static constexpr int items = 12;
100+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
101+
using delay_constructor = no_delay_constructor_t<910>;
127102
};
128103

129104
template <class Input, class OffsetT>
130-
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
105+
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
131106
{
132-
static constexpr int threads = 320;
133-
static constexpr int items = 12;
134-
135-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
136-
137-
using delay_constructor = detail::no_delay_constructor_t<1105>;
107+
static constexpr int threads = 256;
108+
static constexpr int items = 11;
109+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
110+
using delay_constructor = no_delay_constructor_t<1120>;
138111
};
139112

140113
template <class Input, class OffsetT>
141-
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
114+
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
142115
{
143-
static constexpr int threads = 384;
144-
static constexpr int items = 7;
145-
116+
static constexpr int threads = 224;
117+
static constexpr int items = 11;
146118
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
147-
148-
using delay_constructor = detail::fixed_delay_constructor_t<464, 1165>;
119+
using delay_constructor = fixed_delay_constructor_t<264, 1080>;
149120
};
150121

151122
template <class Input, class OffsetT>
152-
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
123+
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
153124
{
154-
static constexpr int threads = 128;
155-
static constexpr int items = 7;
156-
125+
static constexpr int threads = 128;
126+
static constexpr int items = 10;
157127
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
158-
159-
using delay_constructor = detail::no_delay_constructor_t<1040>;
128+
using delay_constructor = fixed_delay_constructor_t<672, 1120>;
160129
};
161130

131+
template <class InputT,
132+
class OffsetT,
133+
input_size InputSize = classify_input_size<InputT>(),
134+
offset_size OffsetSize = classify_offset_size<OffsetT>()>
135+
struct sm90_tuning;
136+
162137
template <class Input, class OffsetT>
163-
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
138+
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_4>
164139
{
165-
static constexpr int threads = 256;
166-
static constexpr int items = 24;
167-
140+
static constexpr int threads = 256;
141+
static constexpr int items = 12;
168142
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
169-
170-
using delay_constructor = detail::fixed_delay_constructor_t<4, 285>;
143+
using delay_constructor = no_delay_constructor_t<445>;
171144
};
172145

173146
template <class Input, class OffsetT>
174-
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
147+
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
175148
{
176-
static constexpr int threads = 640;
177-
static constexpr int items = 24;
178-
179-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
180-
181-
using delay_constructor = detail::no_delay_constructor_t<245>;
149+
static constexpr int threads = 256;
150+
static constexpr int items = 12;
151+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
152+
using delay_constructor = fixed_delay_constructor_t<104, 512>;
182153
};
183154

184155
template <class Input, class OffsetT>
185-
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
156+
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
186157
{
187-
static constexpr int threads = 256;
188-
static constexpr int items = 23;
189-
190-
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
191-
192-
using delay_constructor = detail::no_delay_constructor_t<910>;
158+
static constexpr int threads = 320;
159+
static constexpr int items = 12;
160+
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
161+
using delay_constructor = no_delay_constructor_t<1105>;
193162
};
194163

195164
template <class Input, class OffsetT>
196-
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
165+
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
197166
{
198-
static constexpr int threads = 256;
199-
static constexpr int items = 18;
200-
167+
static constexpr int threads = 384;
168+
static constexpr int items = 7;
201169
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
202-
203-
using delay_constructor = detail::no_delay_constructor_t<1145>;
170+
using delay_constructor = fixed_delay_constructor_t<464, 1165>;
204171
};
205172

206173
template <class Input, class OffsetT>
207-
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
174+
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
208175
{
209-
static constexpr int threads = 256;
210-
static constexpr int items = 11;
211-
176+
static constexpr int threads = 128;
177+
static constexpr int items = 7;
212178
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
213-
214-
using delay_constructor = detail::no_delay_constructor_t<1050>;
179+
using delay_constructor = no_delay_constructor_t<1040>;
215180
};
216181

217-
template <class InputT,
218-
class OffsetT,
219-
input_size InputSize = classify_input_size<InputT>(),
220-
offset_size OffsetSize = classify_offset_size<OffsetT>()>
221-
struct sm80_tuning
182+
template <class Input, class OffsetT>
183+
struct sm90_tuning<Input, OffsetT, input_size::_1, offset_size::_8>
222184
{
223-
static constexpr int threads = 256;
224-
static constexpr int items = Nominal4BItemsToItems<InputT>(9);
225-
185+
static constexpr int threads = 256;
186+
static constexpr int items = 24;
226187
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_DIRECT;
227-
228-
using AccumPackHelperT = detail::three_way_partition::accumulator_pack_t<OffsetT>;
229-
using AccumPackT = typename AccumPackHelperT::pack_t;
230-
using delay_constructor = detail::default_delay_constructor_t<AccumPackT>;
188+
using delay_constructor = fixed_delay_constructor_t<4, 285>;
231189
};
232190

233191
template <class Input, class OffsetT>
234-
struct sm80_tuning<Input, OffsetT, input_size::_2, offset_size::_4>
192+
struct sm90_tuning<Input, OffsetT, input_size::_2, offset_size::_8>
235193
{
236-
static constexpr int threads = 256;
237-
static constexpr int items = 12;
238-
194+
static constexpr int threads = 640;
195+
static constexpr int items = 24;
239196
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
240-
241-
using delay_constructor = detail::no_delay_constructor_t<910>;
197+
using delay_constructor = no_delay_constructor_t<245>;
242198
};
243199

244200
template <class Input, class OffsetT>
245-
struct sm80_tuning<Input, OffsetT, input_size::_4, offset_size::_4>
201+
struct sm90_tuning<Input, OffsetT, input_size::_4, offset_size::_8>
246202
{
247-
static constexpr int threads = 256;
248-
static constexpr int items = 11;
249-
203+
static constexpr int threads = 256;
204+
static constexpr int items = 23;
250205
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
251-
252-
using delay_constructor = detail::no_delay_constructor_t<1120>;
206+
using delay_constructor = no_delay_constructor_t<910>;
253207
};
254208

255209
template <class Input, class OffsetT>
256-
struct sm80_tuning<Input, OffsetT, input_size::_8, offset_size::_4>
210+
struct sm90_tuning<Input, OffsetT, input_size::_8, offset_size::_8>
257211
{
258-
static constexpr int threads = 224;
259-
static constexpr int items = 11;
260-
212+
static constexpr int threads = 256;
213+
static constexpr int items = 18;
261214
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
262-
263-
using delay_constructor = detail::fixed_delay_constructor_t<264, 1080>;
215+
using delay_constructor = no_delay_constructor_t<1145>;
264216
};
265217

266218
template <class Input, class OffsetT>
267-
struct sm80_tuning<Input, OffsetT, input_size::_16, offset_size::_4>
219+
struct sm90_tuning<Input, OffsetT, input_size::_16, offset_size::_8>
268220
{
269-
static constexpr int threads = 128;
270-
static constexpr int items = 10;
271-
221+
static constexpr int threads = 256;
222+
static constexpr int items = 11;
272223
static constexpr BlockLoadAlgorithm load_algorithm = BLOCK_LOAD_WARP_TRANSPOSE;
273-
274-
using delay_constructor = detail::fixed_delay_constructor_t<672, 1120>;
224+
using delay_constructor = no_delay_constructor_t<1050>;
275225
};
276226

277-
} // namespace three_way_partition
278-
279227
template <class InputT, class OffsetT>
280-
struct device_three_way_partition_policy_hub
228+
struct policy_hub
281229
{
282-
struct DefaultTuning
230+
template <typename DelayConstructor>
231+
struct DefaultPolicy
283232
{
284-
static constexpr int ITEMS_PER_THREAD = Nominal4BItemsToItems<InputT>(9);
285-
286233
using ThreeWayPartitionPolicy =
287-
cub::AgentThreeWayPartitionPolicy<256,
288-
ITEMS_PER_THREAD,
289-
cub::BLOCK_LOAD_DIRECT,
290-
cub::LOAD_DEFAULT,
291-
cub::BLOCK_SCAN_WARP_SCANS>;
234+
AgentThreeWayPartitionPolicy<256,
235+
Nominal4BItemsToItems<InputT>(9),
236+
BLOCK_LOAD_DIRECT,
237+
LOAD_DEFAULT,
238+
BLOCK_SCAN_WARP_SCANS,
239+
DelayConstructor>;
292240
};
293241

294-
/// SM35
295242
struct Policy350
296-
: DefaultTuning
243+
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
297244
, ChainedPolicy<350, Policy350, Policy350>
298245
{};
299246

247+
// Use values from tuning if a specialization exists, otherwise pick DefaultPolicy
248+
template <typename Tuning>
249+
static auto select_agent_policy(int)
250+
-> AgentThreeWayPartitionPolicy<Tuning::threads,
251+
Tuning::items,
252+
Tuning::load_algorithm,
253+
LOAD_DEFAULT,
254+
BLOCK_SCAN_WARP_SCANS,
255+
typename Tuning::delay_constructor>;
256+
257+
template <typename Tuning>
258+
static auto select_agent_policy(long) ->
259+
typename DefaultPolicy<
260+
default_delay_constructor_t<typename accumulator_pack_t<OffsetT>::pack_t>>::ThreeWayPartitionPolicy;
261+
300262
struct Policy800 : ChainedPolicy<800, Policy800, Policy350>
301263
{
302-
using tuning = detail::three_way_partition::sm80_tuning<InputT, OffsetT>;
303-
304-
using ThreeWayPartitionPolicy =
305-
AgentThreeWayPartitionPolicy<tuning::threads,
306-
tuning::items,
307-
tuning::load_algorithm,
308-
cub::LOAD_DEFAULT,
309-
cub::BLOCK_SCAN_WARP_SCANS,
310-
typename tuning::delay_constructor>;
264+
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm80_tuning<InputT, OffsetT>>(0));
311265
};
312266

313267
struct Policy860
314-
: DefaultTuning
268+
: DefaultPolicy<fixed_delay_constructor_t<350, 450>>
315269
, ChainedPolicy<860, Policy860, Policy800>
316270
{};
317271

318-
/// SM90
319272
struct Policy900 : ChainedPolicy<900, Policy900, Policy860>
320273
{
321-
using tuning = detail::three_way_partition::sm90_tuning<InputT, OffsetT>;
322-
323-
using ThreeWayPartitionPolicy =
324-
AgentThreeWayPartitionPolicy<tuning::threads,
325-
tuning::items,
326-
tuning::load_algorithm,
327-
cub::LOAD_DEFAULT,
328-
cub::BLOCK_SCAN_WARP_SCANS,
329-
typename tuning::delay_constructor>;
274+
using ThreeWayPartitionPolicy = decltype(select_agent_policy<sm90_tuning<InputT, OffsetT>>(0));
330275
};
331276

332277
using MaxPolicy = Policy900;
333278
};
334-
279+
} // namespace three_way_partition
335280
} // namespace detail
336281

337282
CUB_NAMESPACE_END

0 commit comments

Comments
 (0)