Skip to content

Commit 8901444

Browse files
authored
Formatters for cuda::arch_id and cuda::compute_capability (#7335)
1 parent cd0538d commit 8901444

File tree

5 files changed

+186
-0
lines changed

5 files changed

+186
-0
lines changed

libcudacxx/include/cuda/__device/arch_id.h

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
#include <cuda/__device/compute_capability.h>
2525
#include <cuda/__fwd/devices.h>
26+
#include <cuda/std/__fwd/format.h>
2627
#include <cuda/std/__type_traits/always_false.h>
2728
#include <cuda/std/__utility/to_underlying.h>
2829
#include <cuda/std/array>
@@ -142,6 +143,40 @@ enum class arch_id : int
142143

143144
_CCCL_END_NAMESPACE_CUDA
144145

146+
#if __cpp_lib_format >= 201907L
147+
_CCCL_BEGIN_NAMESPACE_STD
148+
149+
template <class _CharT>
150+
struct formatter<::cuda::arch_id, _CharT> : private formatter<::cuda::compute_capability, _CharT>
151+
{
152+
template <class _ParseCtx>
153+
_CCCL_HOST_API constexpr auto parse(_ParseCtx& __ctx)
154+
{
155+
return __ctx.begin();
156+
}
157+
158+
template <class _FmtCtx>
159+
_CCCL_HOST_API auto format(const ::cuda::arch_id& __arch, _FmtCtx& __ctx) const
160+
{
161+
auto __it = __ctx.out();
162+
*__it++ = _CharT{'s'};
163+
*__it++ = _CharT{'m'};
164+
*__it++ = _CharT{'_'};
165+
__ctx.advance_to(__it);
166+
__it = formatter<::cuda::compute_capability, _CharT>::format(::cuda::compute_capability{__arch}, __ctx);
167+
if (::cuda::__is_specific_arch(__arch))
168+
{
169+
*__it++ = _CharT{'a'};
170+
}
171+
return __it;
172+
}
173+
};
174+
175+
_CCCL_END_NAMESPACE_STD
176+
#endif // __cpp_lib_format >= 201907L
177+
178+
// todo: specialize cuda::std::formatter for cuda::arch_id
179+
145180
#if _CCCL_CUDA_COMPILATION()
146181

147182
_CCCL_BEGIN_NAMESPACE_CUDA_DEVICE

libcudacxx/include/cuda/__device/compute_capability.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#endif // no system header
2323

2424
#include <cuda/__fwd/devices.h>
25+
#include <cuda/std/__fwd/format.h>
2526
#include <cuda/std/__utility/to_underlying.h>
2627

2728
#include <cuda/std/__cccl/prologue.h>
@@ -172,6 +173,30 @@ class compute_capability
172173

173174
_CCCL_END_NAMESPACE_CUDA
174175

176+
#if __cpp_lib_format >= 201907L
177+
_CCCL_BEGIN_NAMESPACE_STD
178+
179+
template <class _CharT>
180+
struct formatter<::cuda::compute_capability, _CharT> : private formatter<int, _CharT>
181+
{
182+
template <class _ParseCtx>
183+
_CCCL_HOST_API constexpr auto parse(_ParseCtx& __ctx)
184+
{
185+
return __ctx.begin();
186+
}
187+
188+
template <class _FmtCtx>
189+
_CCCL_HOST_API auto format(const ::cuda::compute_capability& __cc, _FmtCtx& __ctx) const
190+
{
191+
return formatter<int, _CharT>::format(__cc.get(), __ctx);
192+
}
193+
};
194+
195+
_CCCL_END_NAMESPACE_STD
196+
#endif // __cpp_lib_format >= 201907L
197+
198+
// todo: specialize cuda::std::formatter for cuda::compute_capability
199+
175200
#if _CCCL_CUDA_COMPILATION()
176201

177202
_CCCL_BEGIN_NAMESPACE_CUDA_DEVICE

libcudacxx/include/cuda/std/__fwd/format.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@
2424

2525
#include <cuda/std/__cccl/prologue.h>
2626

27+
#if __cpp_lib_format >= 201907L
28+
29+
_CCCL_BEGIN_NAMESPACE_STD
30+
31+
template <class, class>
32+
struct formatter;
33+
34+
_CCCL_END_NAMESPACE_STD
35+
36+
#endif // __cpp_lib_format >= 201907L
37+
2738
_CCCL_BEGIN_NAMESPACE_CUDA_STD
2839

2940
template <class _CharT>
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <cuda/devices>
12+
13+
#if __cpp_lib_format >= 201907L
14+
# include <format>
15+
#endif // __cpp_lib_format >= 201907L
16+
17+
#include "literal.h"
18+
19+
#if __cpp_lib_format >= 201907L
20+
template <class C>
21+
void test()
22+
{
23+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_60) == TEST_STRLIT(C, "sm_60"));
24+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_61) == TEST_STRLIT(C, "sm_61"));
25+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_62) == TEST_STRLIT(C, "sm_62"));
26+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_70) == TEST_STRLIT(C, "sm_70"));
27+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_75) == TEST_STRLIT(C, "sm_75"));
28+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_80) == TEST_STRLIT(C, "sm_80"));
29+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_86) == TEST_STRLIT(C, "sm_86"));
30+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_87) == TEST_STRLIT(C, "sm_87"));
31+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_88) == TEST_STRLIT(C, "sm_88"));
32+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_89) == TEST_STRLIT(C, "sm_89"));
33+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_90) == TEST_STRLIT(C, "sm_90"));
34+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_100) == TEST_STRLIT(C, "sm_100"));
35+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_103) == TEST_STRLIT(C, "sm_103"));
36+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_110) == TEST_STRLIT(C, "sm_110"));
37+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_120) == TEST_STRLIT(C, "sm_120"));
38+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_121) == TEST_STRLIT(C, "sm_121"));
39+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_90a) == TEST_STRLIT(C, "sm_90a"));
40+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_100a) == TEST_STRLIT(C, "sm_100a"));
41+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_103a) == TEST_STRLIT(C, "sm_103a"));
42+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_110a) == TEST_STRLIT(C, "sm_110a"));
43+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_120a) == TEST_STRLIT(C, "sm_120a"));
44+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::arch_id::sm_121a) == TEST_STRLIT(C, "sm_121a"));
45+
}
46+
47+
void test()
48+
{
49+
test<char>();
50+
test<wchar_t>();
51+
}
52+
#endif // __cpp_lib_format >= 201907L
53+
54+
int main(int, char**)
55+
{
56+
#if __cpp_lib_format >= 201907L
57+
NV_IF_TARGET(NV_IS_HOST, (test();))
58+
#endif // __cpp_lib_format >= 201907L
59+
return 0;
60+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of libcu++, the C++ Standard Library for your entire system,
4+
// under the Apache License v2.0 with LLVM Exceptions.
5+
// See https://llvm.org/LICENSE.txt for license information.
6+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
7+
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
8+
//
9+
//===----------------------------------------------------------------------===//
10+
11+
#include <cuda/devices>
12+
13+
#if __cpp_lib_format >= 201907L
14+
# include <format>
15+
#endif // __cpp_lib_format >= 201907L
16+
17+
#include "literal.h"
18+
19+
#if __cpp_lib_format >= 201907L
20+
template <class C>
21+
void test()
22+
{
23+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{0}) == TEST_STRLIT(C, "0"));
24+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{60}) == TEST_STRLIT(C, "60"));
25+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{61}) == TEST_STRLIT(C, "61"));
26+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{62}) == TEST_STRLIT(C, "62"));
27+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{70}) == TEST_STRLIT(C, "70"));
28+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{75}) == TEST_STRLIT(C, "75"));
29+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{80}) == TEST_STRLIT(C, "80"));
30+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{86}) == TEST_STRLIT(C, "86"));
31+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{87}) == TEST_STRLIT(C, "87"));
32+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{88}) == TEST_STRLIT(C, "88"));
33+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{89}) == TEST_STRLIT(C, "89"));
34+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{90}) == TEST_STRLIT(C, "90"));
35+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{100}) == TEST_STRLIT(C, "100"));
36+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{103}) == TEST_STRLIT(C, "103"));
37+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{110}) == TEST_STRLIT(C, "110"));
38+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{120}) == TEST_STRLIT(C, "120"));
39+
assert(std::format(TEST_STRLIT(C, "{}"), cuda::compute_capability{121}) == TEST_STRLIT(C, "121"));
40+
}
41+
42+
void test()
43+
{
44+
test<char>();
45+
test<wchar_t>();
46+
}
47+
#endif // __cpp_lib_format >= 201907L
48+
49+
int main(int, char**)
50+
{
51+
#if __cpp_lib_format >= 201907L
52+
NV_IF_TARGET(NV_IS_HOST, (test();))
53+
#endif // __cpp_lib_format >= 201907L
54+
return 0;
55+
}

0 commit comments

Comments
 (0)