Skip to content

Commit d573f87

Browse files
Add cub::DeviceTransform N->M API entrypoint (#7473)
1 parent d63e46e commit d573f87

File tree

3 files changed

+163
-9
lines changed

3 files changed

+163
-9
lines changed

cub/cub/device/device_transform.cuh

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,110 @@ private:
108108
}
109109
}
110110

111+
// TODO(bgruber): we want to eventually forward the output tuple to the kernel and optimize writing multiple streams
112+
template <detail::transform::requires_stable_address StableAddress = detail::transform::requires_stable_address::no,
113+
typename... RandomAccessIteratorsIn,
114+
typename... RandomAccessIteratorsOut,
115+
typename NumItemsT,
116+
typename Predicate,
117+
typename TransformOp,
118+
typename Env>
119+
CUB_RUNTIME_FUNCTION static cudaError_t TransformInternal(
120+
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
121+
::cuda::std::tuple<RandomAccessIteratorsOut...> outputs,
122+
NumItemsT num_items,
123+
Predicate predicate,
124+
TransformOp transform_op,
125+
Env env)
126+
{
127+
return TransformInternal<StableAddress>(
128+
::cuda::std::move(inputs),
129+
::cuda::make_zip_iterator(::cuda::std::move(outputs)),
130+
num_items,
131+
::cuda::std::move(predicate),
132+
::cuda::std::move(transform_op),
133+
::cuda::std::move(env));
134+
}
135+
111136
public:
137+
//! @rst
138+
//! Overview
139+
//! +++++++++++++++++++++++++++++++++++++++++++++
140+
//! Transforms many input sequences into many output sequence, by applying a transformation operation on corresponding
141+
//! input elements and writing the tuple result to the corresponding output elements. No guarantee is given on the
142+
//! identity (i.e. address) of the objects passed to the call operator of the transformation operation.
143+
//!
144+
//! A Simple Example
145+
//! +++++++++++++++++++++++++++++++++++++++++++++
146+
//!
147+
//! .. literalinclude:: ../../../cub/test/catch2_test_device_transform_api.cu
148+
//! :language: c++
149+
//! :dedent:
150+
//! :start-after: example-begin transform-many-many
151+
//! :end-before: example-end transform-many-many
152+
//!
153+
//! @endrst
154+
//!
155+
//! @param inputs A tuple of iterators to the input sequences where num_items elements are read from each. The
156+
//! iterators' value types must be trivially relocatable.
157+
//! @param outputs A tuple of iterators to the output sequences where num_items results are written to each. Each
158+
//! sequence may point to the beginning of one of the input sequences, performing the transformation inplace. Any
159+
//! output sequence must not overlap with any of the input sequence in any other way.
160+
//! @param num_items The number of elements in each input and output sequence.
161+
//! @param transform_op An n-ary function object, where n is the number of input sequences. The input iterators' value
162+
//! types must be convertible to the parameters of the function object's call operator. The return type of the call
163+
//! operator must be a tuple where each tuple element is assignable to the corresponding dereferenced output
164+
//! iterators.
165+
//! @param env Execution environment, or cudaStream_t. Default is ``cuda::std::execution::env{}``, which will run on
166+
//! stream\ :sub:`0`
167+
template <typename... RandomAccessIteratorsIn,
168+
typename... RandomAccessIteratorsOut,
169+
typename NumItemsT,
170+
typename TransformOp,
171+
typename Env = ::cuda::std::execution::env<>>
172+
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
173+
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
174+
::cuda::std::tuple<RandomAccessIteratorsOut...> outputs,
175+
NumItemsT num_items,
176+
TransformOp transform_op,
177+
Env env = {})
178+
{
179+
_CCCL_NVTX_RANGE_SCOPE("cub::DeviceTransform::Transform");
180+
return TransformInternal(
181+
::cuda::std::move(inputs),
182+
::cuda::std::move(outputs),
183+
num_items,
184+
detail::transform::always_true_predicate{},
185+
::cuda::std::move(transform_op),
186+
::cuda::std::move(env));
187+
}
188+
189+
#ifndef _CCCL_DOXYGEN_INVOKED // Do not document
190+
// Overload with additional parameters to specify temporary storage. Provided for compatibility with other CUB APIs.
191+
template <typename... RandomAccessIteratorsIn,
192+
typename... RandomAccessIteratorsOut,
193+
typename NumItemsT,
194+
typename TransformOp>
195+
CUB_RUNTIME_FUNCTION static cudaError_t Transform(
196+
void* d_temp_storage,
197+
size_t& temp_storage_bytes,
198+
::cuda::std::tuple<RandomAccessIteratorsIn...> inputs,
199+
::cuda::std::tuple<RandomAccessIteratorsOut...> outputs,
200+
NumItemsT num_items,
201+
TransformOp transform_op,
202+
cudaStream_t stream = nullptr)
203+
{
204+
if (d_temp_storage == nullptr)
205+
{
206+
temp_storage_bytes = 1;
207+
return cudaSuccess;
208+
}
209+
210+
return Transform(
211+
::cuda::std::move(inputs), ::cuda::std::move(outputs), num_items, ::cuda::std::move(transform_op), stream);
212+
}
213+
#endif // _CCCL_DOXYGEN_INVOKED
214+
112215
//! @rst
113216
//! Overview
114217
//! +++++++++++++++++++++++++++++++++++++++++++++

cub/test/catch2_test_device_transform.cu

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -435,22 +435,42 @@ C2H_TEST("DeviceTransform::Transform fancy output iterator type with void value
435435
REQUIRE(result == c2h::device_vector<type>(num_items, 3));
436436
}
437437

438-
C2H_TEST("DeviceTransform::Transform mixed input iterator types", "[device][transform]")
438+
struct plus_mul_neg
439439
{
440-
using type = int;
440+
template <typename T>
441+
__host__ __device__ auto operator()(T a, T b) const
442+
{
443+
return cuda::std::tuple{a + b, a * b, -a};
444+
}
445+
};
446+
447+
C2H_TEST("DeviceTransform::Transform mixed iterator types 2 -> 3", "[device][transform]")
448+
{
449+
using type = unsigned; // overflow is defined
441450
const int num_items = GENERATE(100, 100'000); // try to hit the small and full tile code paths
442451
cuda::counting_iterator<type> a{0};
443452
c2h::device_vector<type> b(num_items, thrust::no_init);
444453
c2h::gen(C2H_SEED(1), b);
445454

446-
c2h::device_vector<type> result(num_items, thrust::no_init);
447-
transform_many(cuda::std::make_tuple(a, b.begin()), result.begin(), num_items, cuda::std::plus<type>{});
455+
c2h::device_vector<type> result_a(num_items, thrust::no_init);
456+
c2h::device_vector<type> result_b(num_items, thrust::no_init);
457+
c2h::device_vector<type> result_c(num_items, thrust::no_init);
458+
transform_many(
459+
cuda::std::make_tuple(a, b.begin()),
460+
cuda::std::make_tuple(
461+
result_a.begin(), result_b.begin(), thrust::make_transform_output_iterator(result_c.begin(), cuda::std::negate{})),
462+
num_items,
463+
plus_mul_neg{});
448464

449465
// compute reference and verify
450466
c2h::host_vector<type> b_h = b;
451-
c2h::host_vector<type> reference_h(num_items);
452-
std::transform(a, a + num_items, b_h.begin(), reference_h.begin(), std::plus<type>{});
453-
REQUIRE(reference_h == result);
467+
c2h::host_vector<type> reference_a_h(num_items, thrust::no_init);
468+
std::transform(a, a + num_items, b_h.begin(), reference_a_h.begin(), cuda::std::plus<type>{});
469+
c2h::host_vector<type> reference_b_h(num_items, thrust::no_init);
470+
std::transform(a, a + num_items, b_h.begin(), reference_b_h.begin(), cuda::std::multiplies<type>{});
471+
CHECK(reference_a_h == result_a);
472+
CHECK(reference_b_h == result_b);
473+
CHECK(thrust::equal(a, a + num_items, result_c.begin()));
454474
}
455475

456476
struct plus_needs_stable_address

cub/test/catch2_test_device_transform_api.cu

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,37 @@
1010

1111
#include <c2h/catch2_test_helper.h>
1212

13+
// need a separate function because the ext. lambda needs to be enclosed by a function with external linkage on Windows
14+
void test_transform_many_many_api()
15+
{
16+
// example-begin transform-many-many
17+
auto input1 = thrust::device_vector<int>{0, -1, 2, -3, 4, -5};
18+
auto input2 = thrust::device_vector<double>{5.2, 3.1, -1.1, 3.0, 3.2, 0.0};
19+
auto op = [] __device__(int a, double b) -> cuda::std::tuple<double, bool> {
20+
const double product = a * b;
21+
return {product, product < 0};
22+
};
23+
24+
auto result1 = thrust::device_vector<double>(input1.size(), thrust::no_init);
25+
auto result2 = thrust::device_vector<bool>(input1.size(), thrust::no_init);
26+
cub::DeviceTransform::Transform(
27+
cuda::std::tuple{input1.begin(), input2.begin()},
28+
cuda::std::tuple{result1.begin(), result2.begin()},
29+
input1.size(),
30+
op);
31+
32+
const auto expected1 = thrust::host_vector<double>{0, -3.1, -2.2, -9, 12.8, -0};
33+
const auto expected2 = thrust::host_vector<bool>{false, true, true, true, false, false};
34+
// example-end transform-many-many
35+
CHECK(result1 == expected1);
36+
CHECK(result2 == expected2);
37+
}
38+
39+
C2H_TEST("DeviceTransform::Transform many->many API example", "[device][device_transform]")
40+
{
41+
test_transform_many_many_api();
42+
}
43+
1344
// need a separate function because the ext. lambda needs to be enclosed by a function with external linkage on Windows
1445
void test_transform_api()
1546
{
@@ -21,7 +52,7 @@ void test_transform_api()
2152
return (a + b) * c;
2253
};
2354

24-
auto result = thrust::device_vector<int>(input1.size());
55+
auto result = thrust::device_vector<int>(input1.size(), thrust::no_init);
2556
cub::DeviceTransform::Transform(
2657
cuda::std::tuple{input1.begin(), input2.begin(), input3}, result.begin(), input1.size(), op);
2758

@@ -74,7 +105,7 @@ void test_transform_stable_api()
74105
return a + input2_ptr[i];
75106
};
76107

77-
auto result = thrust::device_vector<int>(input1.size());
108+
auto result = thrust::device_vector<int>(input1.size(), thrust::no_init);
78109
cub::DeviceTransform::TransformStableArgumentAddresses(
79110
cuda::std::tuple{input1_ptr}, result.begin(), input1.size(), op);
80111

0 commit comments

Comments
 (0)