Skip to content

Commit aadc0c4

Browse files
authored
[c.parallel]: migrate transform to use jit templates instead of string based implementations (#7399)
* Plumb custom storage_t through jit templates to handle cases where operators can have different storage types for input and output * Set value_type of output iterator template to actual type instead of void to ensure proper policy selection in device code * Disable well known operations for storage types
1 parent c2b7a73 commit aadc0c4

File tree

10 files changed

+474
-153
lines changed

10 files changed

+474
-153
lines changed

c/parallel/src/jit_templates/README.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ void func(/*..., */ cccl_iterator_t input_it /*, ...*/) {
5555
// `get_specialization` performs rudimentary type checking; if you pass the wrong type of an argument here
5656
// (one that doesn't match the types expected by the template), this call to `get_specialization` will fail to
5757
// compile.
58+
//
59+
// If different arguments require different storage types (e.g. input vs output CCCL_STORAGE), wrap them in
60+
// tagged_arg<StorageT, T> so parameter_mapping can select the correct storage type per argument.
5861
input_t
5962
);
6063

@@ -74,6 +77,10 @@ void func(/*..., */ cccl_iterator_t input_it /*, ...*/) {
7477
An argument mapping is a type, usable as a template argument, which carries the `cccl_*` values that are the
7578
arguments to the invocation of `get_specialization` in a format that will work in device code compiled with NVRTC.
7679
80+
If different runtime arguments require different storage types, wrap the runtime argument in
81+
`tagged_arg<StorageT, T>` (defined in `traits.h`). This allows `parameter_mapping` to select the correct storage type
82+
per argument.
83+
7784
Below is the contents of the `cccl_type_info` mapping header with annotations:
7885
7986
```cpp
@@ -162,6 +169,8 @@ A template traits type must provide:
162169
given arguments. The poster child for the use of this feature is the iterator JIT templates, where, if the kind of
163170
the iterator is determined to be a pointer, a simple pointer name is returned (instead of generating a specialization
164171
name of the template).
172+
Note: if callers wrap runtime arguments in `tagged_arg`, traits that want to participate in `special` handling
173+
should provide matching overloads.
165174

166175
#### Archetypes
167176

@@ -180,6 +189,9 @@ information provided by template traits and archetypes; performs basic type chec
180189
plus any auxiliary code necessary to compile it. See the example at the beginning of this document for an explanation of
181190
its various parameters and the return values.
182191

192+
`get_specialization` will call `Traits::special(...)` when that call is well-formed for the provided runtime arguments,
193+
including tagged arguments.
194+
183195
### CMake
184196

185197
The last bit necessary to make all of this work is the preprocessing and embedding step. In it, the `jit_entry.h` header

c/parallel/src/jit_templates/mappings/iterator.h

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -37,52 +37,61 @@ struct parameter_mapping<cccl_iterator_t>
3737
{
3838
static const constexpr auto archetype = cccl_iterator_t_mapping<int>{};
3939

40-
template <typename Traits>
41-
static std::string map(template_id<Traits>, cccl_iterator_t arg)
40+
template <typename Traits, typename ArgT>
41+
static std::string map(template_id<Traits>, ArgT arg)
4242
{
43-
if (arg.advance.type != cccl_op_kind_t::CCCL_STATEFUL && arg.advance.type != cccl_op_kind_t::CCCL_STATELESS)
43+
using traits = arg_traits<cuda::std::decay_t<ArgT>>;
44+
using storage_type = typename traits::storage_type;
45+
const auto& value = traits::unwrap(arg);
46+
47+
if (value.advance.type != cccl_op_kind_t::CCCL_STATEFUL && value.advance.type != cccl_op_kind_t::CCCL_STATELESS)
4448
{
4549
throw std::runtime_error("c.parallel: well-known operations are not allowed as an iterator's advance operation");
4650
}
47-
if (arg.dereference.type != cccl_op_kind_t::CCCL_STATEFUL && arg.dereference.type != cccl_op_kind_t::CCCL_STATELESS)
51+
if (value.dereference.type != cccl_op_kind_t::CCCL_STATEFUL
52+
&& value.dereference.type != cccl_op_kind_t::CCCL_STATELESS)
4853
{
4954
throw std::runtime_error("c.parallel: well-known operations are not allowed as an iterator's dereference "
5055
"operation");
5156
}
5257

5358
return std::format(
5459
"cccl_iterator_t_mapping<{}>{{.is_pointer = {}, .size = {}, .alignment = {}, .advance = {}, .{} = {}}}",
55-
cccl_type_enum_to_name(arg.value_type.type),
56-
arg.type == cccl_iterator_kind_t::CCCL_POINTER,
57-
arg.size,
58-
arg.alignment,
59-
arg.advance.name,
60+
cccl_type_enum_to_name<storage_type>(value.value_type.type),
61+
value.type == cccl_iterator_kind_t::CCCL_POINTER,
62+
value.size,
63+
value.alignment,
64+
value.advance.name,
6065
std::is_same_v<Traits, output_iterator_traits> ? "assign" : "dereference",
61-
arg.dereference.name);
66+
value.dereference.name);
6267
}
6368

64-
template <typename Traits>
65-
static std::string aux(template_id<Traits>, cccl_iterator_t arg)
69+
template <typename Traits, typename ArgT>
70+
static std::string aux(template_id<Traits>, ArgT arg)
6671
{
72+
using traits = arg_traits<cuda::std::decay_t<ArgT>>;
73+
using storage_type = typename traits::storage_type;
74+
const auto& value = traits::unwrap(arg);
75+
6776
if constexpr (std::is_same_v<Traits, output_iterator_traits>)
6877
{
6978
return std::format(
7079
R"output(
7180
extern "C" __device__ void {0}(void *, const void*);
7281
extern "C" __device__ void {1}(const void *, const void*);
7382
)output",
74-
arg.advance.name,
75-
arg.dereference.name);
83+
value.advance.name,
84+
value.dereference.name);
7685
}
7786

7887
return std::format(
7988
R"input(
8089
extern "C" __device__ void {0}(void *, const void*);
8190
extern "C" __device__ void {1}(const void *, {2}*);
8291
)input",
83-
arg.advance.name,
84-
arg.dereference.name,
85-
cccl_type_enum_to_name(arg.value_type.type));
92+
value.advance.name,
93+
value.dereference.name,
94+
cccl_type_enum_to_name<storage_type>(value.value_type.type));
8695
}
8796
};
8897
#endif

c/parallel/src/jit_templates/mappings/operation.h

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,26 @@ struct parameter_mapping<cccl_op_t>
3232
{
3333
static const constexpr auto archetype = cccl_op_t_mapping{};
3434

35-
template <typename Traits>
36-
static std::string map(template_id<Traits>, cccl_op_t op)
35+
template <typename Traits, typename ArgT>
36+
static std::string map(template_id<Traits>, ArgT arg)
3737
{
38+
const auto& value = arg_traits<cuda::std::decay_t<ArgT>>::unwrap(arg);
3839
return std::format(
3940
"cccl_op_t_mapping{{.is_stateless = {}, .size = {}, .alignment = {}, .operation = {}}}",
40-
op.type != cccl_op_kind_t::CCCL_STATEFUL,
41-
op.size,
42-
op.alignment,
43-
op.name);
41+
value.type != cccl_op_kind_t::CCCL_STATEFUL,
42+
value.size,
43+
value.alignment,
44+
value.name);
4445
}
4546

46-
template <typename Traits>
47-
static std::string aux(template_id<Traits>, cccl_op_t op)
47+
template <typename Traits, typename ArgT>
48+
static std::string aux(template_id<Traits>, ArgT arg)
4849
{
50+
const auto& value = arg_traits<cuda::std::decay_t<ArgT>>::unwrap(arg);
4951
return std::format(R"(
5052
extern "C" __device__ void {}();
5153
)",
52-
op.name);
54+
value.name);
5355
}
5456
};
5557
#endif

c/parallel/src/jit_templates/mappings/type_info.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,17 @@ struct parameter_mapping<cccl_type_info>
2424
{
2525
static const constexpr auto archetype = cccl_type_info_mapping<int>{};
2626

27-
template <typename TplId>
28-
static std::string map(TplId, cccl_type_info arg)
27+
template <typename TplId, typename ArgT>
28+
static std::string map(TplId, ArgT arg)
2929
{
30-
return std::format("cccl_type_info_mapping<{}>{{}}", cccl_type_enum_to_name(arg.type));
30+
using traits = arg_traits<cuda::std::decay_t<ArgT>>;
31+
using storage_type = typename traits::storage_type;
32+
const auto& value = traits::unwrap(arg);
33+
return std::format("cccl_type_info_mapping<{}>{{}}", cccl_type_enum_to_name<storage_type>(value.type));
3134
}
3235

33-
template <typename TplId>
34-
static std::string aux(TplId, cccl_type_info)
36+
template <typename TplId, typename ArgT>
37+
static std::string aux(TplId, ArgT)
3538
{
3639
return {};
3740
}

c/parallel/src/jit_templates/templates/input_iterator.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,5 +84,17 @@ struct input_iterator_traits
8484

8585
return cuda::std::nullopt;
8686
}
87+
88+
template <typename Tag, typename StorageT>
89+
static cuda::std::optional<specialization> special(tagged_arg<StorageT, cccl_iterator_t> it)
90+
{
91+
if (it.value.type == cccl_iterator_kind_t::CCCL_POINTER)
92+
{
93+
return cuda::std::make_optional(
94+
specialization{cccl_type_enum_to_name<StorageT>(it.value.value_type.type, true), ""});
95+
}
96+
97+
return cuda::std::nullopt;
98+
}
8799
};
88100
#endif

c/parallel/src/jit_templates/templates/operation.h

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -289,36 +289,67 @@ __device__ {1} operator{2}(const {3} & lhs, const {3} & rhs)
289289
}
290290
}
291291

292-
template <typename, typename... Args>
293-
static cuda::std::optional<specialization> special(cccl_op_t operation, cccl_type_info ret, Args... arguments)
292+
template <typename Tag, typename RetStorageT, typename... ArgStorageTs>
293+
static cuda::std::optional<specialization>
294+
special(cccl_op_t operation,
295+
tagged_arg<RetStorageT, cccl_type_info> ret,
296+
tagged_arg<ArgStorageTs, cccl_type_info>... arguments)
294297
{
298+
// We cannot use well-known operations with storage types as there
299+
// is currently no way to tell whether multiple storage types,
300+
// e.g., input and output, are the same type. This is necessary in
301+
// order for operations like `negate` to work, as it requires both
302+
// input and output to be the same type. For now, this check short
303+
// circuits the specialization process for well-known operations
304+
// with storage types. The code below that checks whether any of
305+
// the arguments or the return type are storage types will
306+
// currently not run, but is left here as it will be needed in the
307+
// future.
308+
if (ret.value.type == cccl_type_enum::CCCL_STORAGE
309+
|| ((arguments.value.type == cccl_type_enum::CCCL_STORAGE) || ...))
310+
{
311+
return cuda::std::nullopt;
312+
}
313+
295314
auto entry = well_known_operation_description(operation.type);
296315
if (!entry)
297316
{
298317
return cuda::std::nullopt;
299318
}
300319

301-
cccl_type_enum type_info_table[] = {ret.type, arguments.type...};
320+
cccl_type_enum type_info_table[] = {ret.value.type, arguments.value.type...};
302321
auto builder = entry->check(cuda::std::span(type_info_table), std::format(entry->name, +""));
303322

304323
std::string aux = "#include <cuda/std/functional>\n";
305324
if (entry->symbol)
306325
{
307-
if (((arguments.type == cccl_type_enum::CCCL_STORAGE) || ...))
326+
if (((arguments.value.type == cccl_type_enum::CCCL_STORAGE) || ...))
308327
{
309-
std::string type_names[] = {cccl_type_enum_to_name(ret.type), cccl_type_enum_to_name(arguments.type)...};
328+
std::string type_names[] = {cccl_type_enum_to_name<RetStorageT>(ret.value.type),
329+
cccl_type_enum_to_name<ArgStorageTs>(arguments.value.type)...};
310330
auto type_name_views = [&]<auto... Is>(std::index_sequence<Is...>) {
311331
return std::array<std::string_view, sizeof...(Is)>{{type_names[Is]...}};
312332
}(std::make_index_sequence<1 + sizeof...(arguments)>());
313333
aux += builder(cuda::std::span(type_name_views), *entry->symbol, operation.name);
314334
}
315-
else if (ret.type == cccl_type_enum::CCCL_STORAGE)
335+
else if (ret.value.type == cccl_type_enum::CCCL_STORAGE)
316336
{
317337
return cuda::std::nullopt;
318338
}
319339
}
320340

321-
return specialization{std::format(entry->name, cccl_type_enum_to_name(type_info_table[1]).c_str()), aux};
341+
// Format the specialization name using the appropriate storage type
342+
// type_info_table[1] corresponds to the first argument
343+
using FirstArgStorageT = typename cuda::std::tuple_element<0, cuda::std::tuple<ArgStorageTs...>>::type;
344+
std::string type_name = cccl_type_enum_to_name<FirstArgStorageT>(type_info_table[1]);
345+
346+
return specialization{std::format(entry->name, type_name.c_str()), aux};
347+
}
348+
349+
template <typename Tag, typename... Args>
350+
static cuda::std::optional<specialization> special(cccl_op_t operation, Args... args)
351+
{
352+
return special<Tag>(operation, arg_traits<cuda::std::decay_t<Args>>::wrap(args)...);
322353
}
323354
#endif
324355
};
@@ -330,10 +361,11 @@ struct binary_user_operation_traits
330361
using type = user_operation_traits::type<Tag, Operation, ValueT, ValueT, ValueT>;
331362

332363
#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS
333-
template <typename Tag, typename... Args>
334-
static cuda::std::optional<specialization> special(cccl_op_t operation, cccl_type_info arg_t)
364+
template <typename Tag, typename Arg>
365+
static cuda::std::optional<specialization> special(cccl_op_t operation, Arg arg)
335366
{
336-
return user_operation_traits::special<Tag>(operation, arg_t, arg_t, arg_t);
367+
auto wrapped = arg_traits<cuda::std::decay_t<Arg>>::wrap(arg);
368+
return user_operation_traits::special<Tag>(operation, wrapped, wrapped, wrapped);
337369
}
338370
#endif
339371
};
@@ -345,10 +377,11 @@ struct binary_user_predicate_traits
345377
using type = user_operation_traits::type<Tag, Operation, cccl_type_info_mapping<bool>{}, ValueT, ValueT>;
346378

347379
#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS
348-
template <typename Tag, typename... Args>
349-
static cuda::std::optional<specialization> special(cccl_op_t operation, cccl_type_info arg_t)
380+
template <typename Tag, typename Arg>
381+
static cuda::std::optional<specialization> special(cccl_op_t operation, Arg arg)
350382
{
351-
return user_operation_traits::special<Tag>(operation, arg_t, arg_t, arg_t);
383+
auto wrapped = arg_traits<cuda::std::decay_t<Arg>>::wrap(arg);
384+
return user_operation_traits::special<Tag>(operation, wrapped, wrapped, wrapped);
352385
}
353386
#endif
354387
};

c/parallel/src/jit_templates/templates/output_iterator.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ struct output_iterator_t
4747
{
4848
using iterator_category = cuda::std::random_access_iterator_tag;
4949
using difference_type = cuda::std::size_t;
50-
using value_type = void;
50+
using value_type = typename decltype(AssignTV)::Type;
5151
using reference =
5252
output_iterator_proxy_t<Tag, Iterator.size, Iterator.alignment, typename decltype(AssignTV)::Type, Iterator.assign>;
5353
using pointer = reference*;
@@ -98,5 +98,18 @@ struct output_iterator_traits
9898

9999
return cuda::std::nullopt;
100100
}
101+
102+
template <typename Tag, typename StorageT, typename AssignStorageT>
103+
static cuda::std::optional<specialization>
104+
special(tagged_arg<StorageT, cccl_iterator_t> it, tagged_arg<AssignStorageT, cccl_type_info> assign_t)
105+
{
106+
if (it.value.type == cccl_iterator_kind_t::CCCL_POINTER)
107+
{
108+
return cuda::std::make_optional(
109+
specialization{cccl_type_enum_to_name<AssignStorageT>(assign_t.value.type, true), ""});
110+
}
111+
112+
return cuda::std::nullopt;
113+
}
101114
};
102115
#endif

0 commit comments

Comments
 (0)