@@ -289,36 +289,67 @@ __device__ {1} operator{2}(const {3} & lhs, const {3} & rhs)
289289 }
290290 }
291291
292- template <typename , typename ... Args>
293- static cuda::std::optional<specialization> special (cccl_op_t operation, cccl_type_info ret, Args... arguments)
292+ template <typename Tag, typename RetStorageT, typename ... ArgStorageTs>
293+ static cuda::std::optional<specialization>
294+ special (cccl_op_t operation,
295+ tagged_arg<RetStorageT, cccl_type_info> ret,
296+ tagged_arg<ArgStorageTs, cccl_type_info>... arguments)
294297 {
298+ // We cannot use well-known operations with storage types as there
299+ // is currently no way to tell whether multiple storage types,
300+ // e.g., input and output, are the same type. This is necessary in
301+ // order for operations like `negate` to work, as it requires both
302+ // input and output to be the same type. For now, this check short
303+ // circuits the specialization process for well-known operations
304+ // with storage types. The code below that checks whether any of
305+ // the arguments or the return type are storage types will
306+ // currently not run, but is left here as it will be needed in the
307+ // future.
308+ if (ret.value .type == cccl_type_enum::CCCL_STORAGE
309+ || ((arguments.value .type == cccl_type_enum::CCCL_STORAGE) || ...))
310+ {
311+ return cuda::std::nullopt ;
312+ }
313+
295314 auto entry = well_known_operation_description (operation.type );
296315 if (!entry)
297316 {
298317 return cuda::std::nullopt ;
299318 }
300319
301- cccl_type_enum type_info_table[] = {ret.type , arguments.type ...};
320+ cccl_type_enum type_info_table[] = {ret.value . type , arguments. value .type ...};
302321 auto builder = entry->check (cuda::std::span (type_info_table), std::format (entry->name , +" " ));
303322
304323 std::string aux = " #include <cuda/std/functional>\n " ;
305324 if (entry->symbol )
306325 {
307- if (((arguments.type == cccl_type_enum::CCCL_STORAGE) || ...))
326+ if (((arguments.value . type == cccl_type_enum::CCCL_STORAGE) || ...))
308327 {
309- std::string type_names[] = {cccl_type_enum_to_name (ret.type ), cccl_type_enum_to_name (arguments.type )...};
328+ std::string type_names[] = {cccl_type_enum_to_name<RetStorageT>(ret.value .type ),
329+ cccl_type_enum_to_name<ArgStorageTs>(arguments.value .type )...};
310330 auto type_name_views = [&]<auto ... Is>(std::index_sequence<Is...>) {
311331 return std::array<std::string_view, sizeof ...(Is)>{{type_names[Is]...}};
312332 }(std::make_index_sequence<1 + sizeof ...(arguments)>());
313333 aux += builder (cuda::std::span (type_name_views), *entry->symbol , operation.name );
314334 }
315- else if (ret.type == cccl_type_enum::CCCL_STORAGE)
335+ else if (ret.value . type == cccl_type_enum::CCCL_STORAGE)
316336 {
317337 return cuda::std::nullopt ;
318338 }
319339 }
320340
321- return specialization{std::format (entry->name , cccl_type_enum_to_name (type_info_table[1 ]).c_str ()), aux};
341+ // Format the specialization name using the appropriate storage type
342+ // type_info_table[1] corresponds to the first argument
343+ using FirstArgStorageT = typename cuda::std::tuple_element<0 , cuda::std::tuple<ArgStorageTs...>>::type;
344+ std::string type_name = cccl_type_enum_to_name<FirstArgStorageT>(type_info_table[1 ]);
345+
346+ return specialization{std::format (entry->name , type_name.c_str ()), aux};
347+ }
348+
349+ template <typename Tag, typename ... Args>
350+ static cuda::std::optional<specialization> special (cccl_op_t operation, Args... args)
351+ {
352+ return special<Tag>(operation, arg_traits<cuda::std::decay_t <Args>>::wrap (args)...);
322353 }
323354#endif
324355};
@@ -330,10 +361,11 @@ struct binary_user_operation_traits
330361 using type = user_operation_traits::type<Tag, Operation, ValueT, ValueT, ValueT>;
331362
332363#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS
333- template <typename Tag, typename ... Args >
334- static cuda::std::optional<specialization> special (cccl_op_t operation, cccl_type_info arg_t )
364+ template <typename Tag, typename Arg >
365+ static cuda::std::optional<specialization> special (cccl_op_t operation, Arg arg )
335366 {
336- return user_operation_traits::special<Tag>(operation, arg_t , arg_t , arg_t );
367+ auto wrapped = arg_traits<cuda::std::decay_t <Arg>>::wrap (arg);
368+ return user_operation_traits::special<Tag>(operation, wrapped, wrapped, wrapped);
337369 }
338370#endif
339371};
@@ -345,10 +377,11 @@ struct binary_user_predicate_traits
345377 using type = user_operation_traits::type<Tag, Operation, cccl_type_info_mapping<bool >{}, ValueT, ValueT>;
346378
347379#ifndef _CCCL_C_PARALLEL_JIT_TEMPLATES_PREPROCESS
348- template <typename Tag, typename ... Args >
349- static cuda::std::optional<specialization> special (cccl_op_t operation, cccl_type_info arg_t )
380+ template <typename Tag, typename Arg >
381+ static cuda::std::optional<specialization> special (cccl_op_t operation, Arg arg )
350382 {
351- return user_operation_traits::special<Tag>(operation, arg_t , arg_t , arg_t );
383+ auto wrapped = arg_traits<cuda::std::decay_t <Arg>>::wrap (arg);
384+ return user_operation_traits::special<Tag>(operation, wrapped, wrapped, wrapped);
352385 }
353386#endif
354387};
0 commit comments