Skip to content

Commit 2f57562

Browse files
committed
Implement PyOpenCL kernel framework.
* Add helper functions for PyOpenCL kernel loading, memory allocation and event tracking. * Add kernels to match the current OpenMP and JAX ones. * Expand tests to include this infrastructure. This work might not be merged, but it has served as a useful test to see what work is involved in the develop and debug cycle when using PyOpenCL as the backend.
1 parent defcf34 commit 2f57562

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+4530
-324
lines changed

src/toast/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,7 @@ add_subdirectory(io)
169169
add_subdirectory(accelerator)
170170
add_subdirectory(tests)
171171
add_subdirectory(jax)
172+
add_subdirectory(opencl)
172173
add_subdirectory(ops)
173174
add_subdirectory(templates)
174175
add_subdirectory(scripts)

src/toast/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@
3838
* Values "0", "false", or "no" will disable runtime support for hybrid GPU pipelines.
3939
* Requires TOAST_GPU_OPENMP or TOAST_GPU_JAX to be enabled.
4040
41+
TOAST_OPENCL=<value>
42+
* Values "1", "true", or "yes" will enable runtime support for pyopencl.
43+
* Requires pyopencl to be available / importable.
44+
45+
TOAST_OPENCL_DEFAULT=<value>
46+
* Default OpenCL device type, where supported values are "CPU", "GPU",
47+
and "OCLGRIND".
48+
4149
OMP_NUM_THREADS=<integer>
4250
* Toast uses OpenMP threading in several places and the concurrency is set by the
4351
usual environment variable.

src/toast/_libtoast/ops_pixels_healpix.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,8 @@ void pixels_healpix_nest_inner(
598598
int64_t n_samp,
599599
int64_t idet,
600600
uint8_t mask,
601-
bool use_flags
601+
bool use_flags,
602+
bool compute_submaps
602603
) {
603604
const double zaxis[3] = {0.0, 0.0, 1.0};
604605
int32_t p_indx = pixel_index[idet];
@@ -618,8 +619,10 @@ void pixels_healpix_nest_inner(
618619
if (use_flags && ((flags[isamp] & mask) != 0)) {
619620
pixels[poff] = -1;
620621
} else {
621-
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
622-
hsub[sub_map] = 1;
622+
if (compute_submaps) {
623+
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
624+
hsub[sub_map] = 1;
625+
}
623626
}
624627

625628
return;
@@ -639,7 +642,9 @@ void pixels_healpix_ring_inner(
639642
int64_t n_samp,
640643
int64_t idet,
641644
uint8_t mask,
642-
bool use_flags) {
645+
bool use_flags,
646+
bool compute_submaps
647+
) {
643648
const double zaxis[3] = {0.0, 0.0, 1.0};
644649
int32_t p_indx = pixel_index[idet];
645650
int32_t q_indx = quat_index[idet];
@@ -658,8 +663,10 @@ void pixels_healpix_ring_inner(
658663
if (use_flags && ((flags[isamp] & mask) != 0)) {
659664
pixels[poff] = -1;
660665
} else {
661-
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
662-
hsub[sub_map] = 1;
666+
if (compute_submaps) {
667+
sub_map = (int64_t)(pixels[poff] / n_pix_submap);
668+
hsub[sub_map] = 1;
669+
}
663670
}
664671

665672
return;
@@ -1163,6 +1170,7 @@ void init_ops_pixels_healpix(py::module & m) {
11631170
int64_t n_pix_submap,
11641171
int64_t nside,
11651172
bool nest,
1173+
bool compute_submaps,
11661174
bool use_accel
11671175
) {
11681176
auto & omgr = OmpManager::get();
@@ -1195,10 +1203,14 @@ void init_ops_pixels_healpix(py::module & m) {
11951203
);
11961204
int64_t n_view = temp_shape[0];
11971205

1206+
// Optionally compute the hit submaps
11981207
uint8_t * raw_hsub = extract_buffer <uint8_t> (
11991208
hit_submaps, "hit_submaps", 1, temp_shape, {-1}
12001209
);
12011210
int64_t n_submap = temp_shape[0];
1211+
if (! compute_submaps) {
1212+
raw_hsub = omgr.null_ptr <uint8_t> ();
1213+
}
12021214

12031215
// Optionally use flags
12041216
bool use_flags = true;
@@ -1225,6 +1237,7 @@ void init_ops_pixels_healpix(py::module & m) {
12251237
int64_t * dev_pixels = omgr.device_ptr(raw_pixels);
12261238
Interval * dev_intervals = omgr.device_ptr(raw_intervals);
12271239
uint8_t * dev_flags = omgr.device_ptr(raw_flags);
1240+
uint8_t * dev_hsub = omgr.device_ptr(raw_hsub);
12281241

12291242
// Make sure the lookup table exists on device
12301243
size_t utab_bytes = 0x100 * sizeof(int64_t);
@@ -1258,9 +1271,9 @@ void init_ops_pixels_healpix(py::module & m) {
12581271
n_det, \
12591272
n_samp, \
12601273
shared_flag_mask, \
1274+
compute_submaps, \
12611275
use_flags \
1262-
) \
1263-
map(tofrom : raw_hsub[0 : n_submap])
1276+
)
12641277
{
12651278
if (nest) {
12661279
# pragma omp target teams distribute parallel for collapse(3) \
@@ -1269,6 +1282,7 @@ void init_ops_pixels_healpix(py::module & m) {
12691282
dev_pixels, \
12701283
dev_quats, \
12711284
dev_flags, \
1285+
dev_hsub, \
12721286
dev_intervals, \
12731287
dev_utab \
12741288
)
@@ -1293,14 +1307,15 @@ void init_ops_pixels_healpix(py::module & m) {
12931307
raw_pixel_index,
12941308
dev_quats,
12951309
dev_flags,
1296-
raw_hsub,
1310+
dev_hsub,
12971311
dev_pixels,
12981312
n_pix_submap,
12991313
adjusted_isamp,
13001314
n_samp,
13011315
idet,
13021316
shared_flag_mask,
1303-
use_flags
1317+
use_flags,
1318+
compute_submaps
13041319
);
13051320
}
13061321
}
@@ -1312,6 +1327,7 @@ void init_ops_pixels_healpix(py::module & m) {
13121327
dev_pixels, \
13131328
dev_quats, \
13141329
dev_flags, \
1330+
dev_hsub, \
13151331
dev_intervals, \
13161332
dev_utab \
13171333
)
@@ -1335,14 +1351,15 @@ void init_ops_pixels_healpix(py::module & m) {
13351351
raw_pixel_index,
13361352
dev_quats,
13371353
dev_flags,
1338-
raw_hsub,
1354+
dev_hsub,
13391355
dev_pixels,
13401356
n_pix_submap,
13411357
adjusted_isamp,
13421358
n_samp,
13431359
idet,
13441360
shared_flag_mask,
1345-
use_flags
1361+
use_flags,
1362+
compute_submaps
13461363
);
13471364
}
13481365
}
@@ -1376,7 +1393,8 @@ void init_ops_pixels_healpix(py::module & m) {
13761393
n_samp,
13771394
idet,
13781395
shared_flag_mask,
1379-
use_flags
1396+
use_flags,
1397+
compute_submaps
13801398
);
13811399
}
13821400
}
@@ -1404,7 +1422,8 @@ void init_ops_pixels_healpix(py::module & m) {
14041422
n_samp,
14051423
idet,
14061424
shared_flag_mask,
1407-
use_flags
1425+
use_flags,
1426+
compute_submaps
14081427
);
14091428
}
14101429
}

src/toast/accelerator/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
accel_data_update_device,
1616
accel_data_update_host,
1717
accel_enabled,
18+
accel_wait,
1819
accel_get_device,
1920
use_accel_jax,
2021
use_accel_omp,
22+
use_accel_opencl,
2123
use_hybrid_pipelines,
2224
)
2325
from .kernel_registry import ImplementationType, kernel

0 commit comments

Comments
 (0)