diff --git a/CMakeLists.txt b/CMakeLists.txt index fd5577c95..508a70783 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -183,45 +183,6 @@ find_package(AATM) find_package(SuiteSparse) -# We require libFLAC >= 1.4.0 for 32bit integer support -set(USE_FLAC FALSE) -if(DISABLE_FLAC) - message(STATUS "FLAC support disabled") -else() - find_package(FLAC) - if(FLAC_FOUND) - if(DEFINED FLAC_VERSION) - if(FLAC_VERSION STREQUAL "") - message(STATUS "Cannot determine FLAC version- assuming it is >= 1.4.0") - set(USE_FLAC TRUE) - else() - string(REGEX REPLACE "^([0-9]+)\\.[0-9]+\\..*" "\\1" - FLAC_MAJ_VERSION "${FLAC_VERSION}") - string(REGEX REPLACE "^[0-9]+\\.([0-9]+)\\..*" "\\1" - FLAC_MIN_VERSION "${FLAC_VERSION}") - if(FLAC_MAJ_VERSION GREATER 1) - # Future proofing - message(STATUS - "Found FLAC version ${FLAC_VERSION}, enabling support") - set(USE_FLAC TRUE) - else() - if(FLAC_MIN_VERSION GREATER_EQUAL 4) - message(STATUS - "Found FLAC version ${FLAC_VERSION}, enabling support") - set(USE_FLAC TRUE) - endif() - endif() - endif() - else() - message(STATUS "Cannot determine FLAC version- assuming it is >= 1.4.0") - set(USE_FLAC TRUE) - endif() - endif() - if(NOT USE_FLAC) - message(STATUS "Did not find FLAC >= 1.4.0") - endif() -endif() - find_package(Python3 3.9 COMPONENTS Interpreter Development.Module REQUIRED) # Internal products diff --git a/setup.py b/setup.py index 108daa4b0..1c5929738 100644 --- a/setup.py +++ b/setup.py @@ -301,6 +301,7 @@ def readme(): "psutil", "h5py", "pshmem>=1.3.0", + "flacarray>=0.3.4", "ruamel.yaml", "astropy", "healpy", diff --git a/src/toast/CMakeLists.txt b/src/toast/CMakeLists.txt index b434667d6..e055866b9 100644 --- a/src/toast/CMakeLists.txt +++ b/src/toast/CMakeLists.txt @@ -49,7 +49,6 @@ pybind11_add_module(_libtoast MODULE _libtoast/template_offset.cpp _libtoast/accelerator.cpp _libtoast/qarray_core.cpp - _libtoast/io_compression_flac.cpp _libtoast/ops_pointing_detector.cpp _libtoast/ops_stokes_weights.cpp _libtoast/ops_pixels_healpix.cpp @@ -86,13 +85,6 @@ if(CHOLMOD_FOUND) target_include_directories(_libtoast PUBLIC "${CHOLMOD_INCLUDE_DIR}") endif(CHOLMOD_FOUND) -if(USE_FLAC) - target_compile_definitions(_libtoast PRIVATE HAVE_FLAC=1) - target_compile_options(_libtoast PRIVATE "${FLAC_DEFINITIONS}") - target_include_directories(_libtoast PUBLIC "${FLAC_INCLUDE_DIRS}") - target_link_libraries(_libtoast PRIVATE "${FLAC_LIBRARIES}") -endif(USE_FLAC) - if(CUDAToolkit_FOUND AND NOT CUDA_DISABLED) target_compile_definitions(_libtoast PRIVATE HAVE_CUDALIBS=1) target_include_directories(_libtoast PRIVATE "${CUDAToolkit_INCLUDE_DIRS}") diff --git a/src/toast/_libtoast/io_compression_flac.cpp b/src/toast/_libtoast/io_compression_flac.cpp deleted file mode 100644 index 4b521e404..000000000 --- a/src/toast/_libtoast/io_compression_flac.cpp +++ /dev/null @@ -1,576 +0,0 @@ - -// Copyright (c) 2015-2023 by the parties listed in the AUTHORS file. -// All rights reserved. Use of this source code is governed by -// a BSD-style license that can be found in the LICENSE file. - -#include - - -#ifdef HAVE_FLAC -# include -# include - - -typedef struct { - toast::AlignedU8 * compressed; -} enc_write_callback_data; - - -FLAC__StreamEncoderWriteStatus enc_write_callback( - const FLAC__StreamEncoder * encoder, - const FLAC__byte buffer[], - size_t bytes, - uint32_t samples, - uint32_t current_frame, - void * client_data -) { - enc_write_callback_data * data = (enc_write_callback_data *)client_data; - data->compressed->insert( - data->compressed->end(), - buffer, - buffer + bytes - ); - return FLAC__STREAM_ENCODER_WRITE_STATUS_OK; -} - -void encode_flac( - int32_t * const data, - size_t n_data, - toast::AlignedU8 & bytes, - toast::AlignedI64 & offsets, - uint32_t level, - int64_t stride -) { - // If stride is specified, check consistency. - int64_t n_sub; - if (stride > 0) { - if (n_data % stride != 0) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Stride " << stride << " does not evenly divide into " << n_data; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - n_sub = (int64_t)(n_data / stride); - } else { - n_sub = 1; - stride = n_data; - } - offsets.resize(n_sub); - bytes.clear(); - - enc_write_callback_data write_callback_data; - write_callback_data.compressed = &bytes; - - bool success; - FLAC__StreamEncoderInitStatus status; - FLAC__StreamEncoder * encoder; - - for (int64_t sub = 0; sub < n_sub; ++sub) { - offsets[sub] = bytes.size(); - - // std::cerr << "Encoding " << stride << " samples at byte offset " << - // offsets[sub] << " starting at data element " << (sub * stride) << std::endl; - - encoder = FLAC__stream_encoder_new(); - - success = FLAC__stream_encoder_set_compression_level(encoder, level); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to set compression level to " << level; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_encoder_set_blocksize(encoder, 0); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to set encoder blocksize to " << 0; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_encoder_set_channels(encoder, 1); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to set encoder channels to " << 1; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_encoder_set_bits_per_sample(encoder, 32); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to set encoder bits per sample to " << 32; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - status = FLAC__stream_encoder_init_stream( - encoder, - enc_write_callback, - NULL, - NULL, - NULL, - (void *)&write_callback_data - ); - if (status != FLAC__STREAM_ENCODER_INIT_STATUS_OK) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to initialize stream encoder, status = " << status; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_encoder_process_interleaved( - encoder, - &(data[sub * stride]), - stride - ); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed on encoder_process_interleaved for chunk " << sub; - o << ", elements " << sub * stride << " - " << (sub + 1) * stride; - o << ", at byte offset " << offsets[sub]; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - success = FLAC__stream_encoder_finish(encoder); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed on encoder_finish"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - FLAC__stream_encoder_delete(encoder); - } - - return; -} - -typedef struct { - uint8_t const * input; - size_t in_nbytes; - size_t in_offset; - size_t in_end; - toast::AlignedI32 * output; -} dec_callback_data; - - -FLAC__StreamDecoderReadStatus dec_read_callback( - const FLAC__StreamDecoder * decoder, - FLAC__byte buffer[], - size_t * bytes, - void * client_data -) { - dec_callback_data * callback_data = (dec_callback_data *)client_data; - uint8_t const * input = callback_data->input; - size_t offset = callback_data->in_offset; - size_t remaining = callback_data->in_end - offset; - - // std::cerr << "Decode read: " << remaining << " bytes remaining" << std::endl; - - // The bytes requested by the decoder - size_t n_buffer = (*bytes); - - if (remaining == 0) { - // No data left - (*bytes) = 0; - - // std::cerr << "Decode read: 0 bytes remaining, END_OF_STREAM" << std::endl; - return FLAC__STREAM_DECODER_READ_STATUS_END_OF_STREAM; - } else { - // We have some data left - if (n_buffer == 0) { - // ... but there is no place to put it! - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Stream decoder gave us zero length buffer, but we have "; - o << remaining << " bytes left"; - log.error(o.str().c_str()); - - // std::cerr << "Decode read: 0 bytes in buffer, ABORT" << std::endl; - return FLAC__STREAM_DECODER_READ_STATUS_ABORT; - } else { - if (remaining > n_buffer) { - // Only copy in what there is space for - // std::cerr << "Decode read: putting " << n_buffer << " bytes at - // offset " << offset << " into buffer, CONTINUE" << std::endl; - for (size_t i = 0; i < n_buffer; ++i) { - buffer[i] = input[offset + i]; - } - callback_data->in_offset += n_buffer; - return FLAC__STREAM_DECODER_READ_STATUS_CONTINUE; - } else { - // Copy in the rest of the buffer and reset the number of bytes - // std::cerr << "Decode read: putting remainder of " << remaining << " - // bytes at offset " << offset << " into buffer, CONTINUE" << std::endl; - for (size_t i = 0; i < remaining; ++i) { - buffer[i] = input[offset + i]; - } - callback_data->in_offset += remaining; - (*bytes) = remaining; - return FLAC__STREAM_DECODER_READ_STATUS_CONTINUE; - } - } - } - - // Should never get here... - return FLAC__STREAM_DECODER_READ_STATUS_ABORT; -} - -FLAC__StreamDecoderWriteStatus dec_write_callback( - const FLAC__StreamDecoder * decoder, - const FLAC__Frame * frame, - const FLAC__int32 * const buffer[], - void * client_data -) { - dec_callback_data * data = (dec_callback_data *)client_data; - size_t offset = data->output->size(); - uint32_t blocksize = frame->header.blocksize; - data->output->resize(offset + blocksize); - for (size_t i = 0; i < blocksize; ++i) { - (*data->output)[offset + i] = buffer[0][i]; - } - return FLAC__STREAM_DECODER_WRITE_STATUS_CONTINUE; -} - -void dec_err_callback( - const FLAC__StreamDecoder * decoder, - FLAC__StreamDecoderErrorStatus status, - void * client_data -) { - dec_callback_data * data = (dec_callback_data *)client_data; - - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Stream decode error (" << status << ") at input byte range "; - o << data->in_offset << " - " << data->in_end << ", output size = "; - o << data->output->size(); - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - return; -} - -void decode_flac( - uint8_t * const bytes, - size_t n_bytes, - int64_t * const offsets, - size_t n_offset, - toast::AlignedI32 & data -) { - dec_callback_data callback_data; - callback_data.input = bytes; - callback_data.in_nbytes = n_bytes; - callback_data.output = &data; - - FLAC__StreamDecoder * decoder; - bool success; - FLAC__StreamDecoderInitStatus status; - - size_t n_sub = n_offset; - - for (size_t sub = 0; sub < n_sub; ++sub) { - callback_data.in_offset = offsets[sub]; - if (sub == n_sub - 1) { - callback_data.in_end = n_bytes; - } else { - callback_data.in_end = offsets[sub + 1]; - } - - // std::cerr << "Decoding chunk " << sub << " at byte offset " << - // callback_data.in_offset << " with " << (callback_data.in_end - - // callback_data.in_offset) << " bytes" << std::endl; - - decoder = FLAC__stream_decoder_new(); - - status = FLAC__stream_decoder_init_stream( - decoder, - dec_read_callback, - NULL, - NULL, - NULL, - NULL, - dec_write_callback, - NULL, - dec_err_callback, - (void *)&callback_data - ); - if (status != FLAC__STREAM_DECODER_INIT_STATUS_OK) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed to initialize decoder, status = " << status; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_decoder_process_until_end_of_stream(decoder); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed on decoder_process_until_end_of_stream for chunk " << sub; - o << ", byte range " << callback_data.in_offset << " - "; - o << callback_data.in_end; - o << ", output size = " << callback_data.output->size(); - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - success = FLAC__stream_decoder_finish(decoder); - if (!success) { - auto log = toast::Logger::get(); - std::ostringstream o; - o << "Failed on decoder_finish"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - } - - FLAC__stream_decoder_delete(decoder); - } - - return; -} - -#endif // ifdef HAVE_FLAC - - -void init_io_compression_flac(py::module & m) { - // FLAC compression - - m.def( - "have_flac_support", []() { - #ifdef HAVE_FLAC - return true; - - #else // ifdef HAVE_FLAC - return false; - - #endif // ifdef HAVE_FLAC - }, R"( - Return True if TOAST is compiled with FLAC support. - )"); - - m.def( - "compress_flac_2D", []( - py::buffer data, - uint32_t level - ) { - // This is used to return the actual shape of each buffer - std::vector temp_shape(3); - - int32_t * raw_data = extract_buffer ( - data, "int32 data", 2, temp_shape, {-1, -1} - ); - int64_t n_chunk = temp_shape[0]; - int64_t n_chunk_elem = temp_shape[1]; - - toast::AlignedU8 bytes; - toast::AlignedI64 offsets; - - #ifdef HAVE_FLAC - encode_flac( - raw_data, - n_chunk * n_chunk_elem, - bytes, - offsets, - level, - n_chunk_elem - ); - #else // ifdef HAVE_FLAC - auto log = toast::Logger::get(); - std::ostringstream o; - o << "TOAST was not built with libFLAC support"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - #endif // ifdef HAVE_FLAC - - // std::cout << "compress_flac_2D returning buffer @ " << - // (int64_t)bytes.data() << std::endl; - - return py::make_tuple(py::cast(bytes), py::cast(offsets)); - }, py::arg("data"), py::arg( - "level"), - R"( - Compress 2D 32bit integer data with FLAC. - - Each row of the input is compressed separately, and the byte offset - into the output stream is returned. - - Args: - data (array, int32): The 2D array of integer data. - level (uint32): The compression level (0-8). - - Returns: - (tuple): The (byte array, offsets). - - )"); - - m.def( - "decompress_flac_2D", []( - py::buffer data, - py::buffer offsets - ) { - // This is used to return the actual shape of each buffer - std::vector temp_shape(3); - - uint8_t * raw_data = extract_buffer ( - data, "FLAC bytes", 1, temp_shape, {-1} - ); - int64_t n_bytes = temp_shape[0]; - - int64_t * raw_offsets = extract_buffer ( - offsets, "FLAC offsets", 1, temp_shape, {-1} - ); - int64_t n_offset = temp_shape[0]; - - toast::AlignedI32 output; - - #ifdef HAVE_FLAC - decode_flac( - raw_data, - n_bytes, - raw_offsets, - n_offset, - output - ); - #else // ifdef HAVE_FLAC - auto log = toast::Logger::get(); - std::ostringstream o; - o << "TOAST was not built with libFLAC support"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - #endif // ifdef HAVE_FLAC - - // std::cout << "decompress_flac_2D returning buffer @ " << - // (int64_t)output.data() << std::endl; - - return py::cast(output); - }, py::arg("data"), py::arg( - "offsets"), - R"( - Decompress FLAC bytes into 2D 32bit integer data. - - The array of bytes is decompressed and returned. - - Args: - data (array, uint8): The 1D array of bytes. - offsets (array, int64): The array of offsets into the byte array. - - Returns: - (array): The array of 32bit integers. - - )"); - - m.def( - "compress_flac", []( - py::buffer data, - uint32_t level - ) { - // This is used to return the actual shape of each buffer - std::vector temp_shape(3); - - int32_t * raw_data = extract_buffer ( - data, "int32 data", 1, temp_shape, {-1} - ); - int64_t n = temp_shape[0]; - - toast::AlignedU8 bytes; - toast::AlignedI64 offsets; - - #ifdef HAVE_FLAC - encode_flac( - raw_data, - n, - bytes, - offsets, - level, - 0 - ); - #else // ifdef HAVE_FLAC - auto log = toast::Logger::get(); - std::ostringstream o; - o << "TOAST was not built with libFLAC support"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - #endif // ifdef HAVE_FLAC - - // std::cout << "compress_flac returning buffer @ " << (int64_t)bytes.data() - // << std::endl; - - return py::cast(bytes); - }, py::arg("data"), py::arg( - "level"), - R"( - Compress 1D 32bit integer data with FLAC. - - The 1D array is compressed and the byte array is returned. - - Args: - data (array, int32): The 1D array of integer data. - level (uint32): The compression level (0-8). - - Returns: - (array): The byte array. - - )"); - - m.def( - "decompress_flac", []( - py::buffer data - ) { - // This is used to return the actual shape of each buffer - std::vector temp_shape(3); - - uint8_t * raw_data = extract_buffer ( - data, "FLAC bytes", 1, temp_shape, {-1} - ); - int64_t n_bytes = temp_shape[0]; - - toast::AlignedI32 output; - - int64_t offset = 0; - - #ifdef HAVE_FLAC - decode_flac( - raw_data, - n_bytes, - &offset, - 1, - output - ); - #else // ifdef HAVE_FLAC - auto log = toast::Logger::get(); - std::ostringstream o; - o << "TOAST was not built with libFLAC support"; - log.error(o.str().c_str()); - throw std::runtime_error(o.str().c_str()); - #endif // ifdef HAVE_FLAC - - // std::cout << "decompress_flac returning buffer @ " << - // (int64_t)output.data() << std::endl; - - return py::cast(output); - }, py::arg( - "data"), - R"( - Decompress FLAC bytes into 1D 32bit integer data. - - The array of bytes is decompressed and returned. - - Args: - data (array, uint8): The 1D array of bytes. - - Returns: - (array): The array of 32bit integers. - - )"); - - return; -} diff --git a/src/toast/_libtoast/module.cpp b/src/toast/_libtoast/module.cpp index 8fb08d53e..eb7cdc81b 100644 --- a/src/toast/_libtoast/module.cpp +++ b/src/toast/_libtoast/module.cpp @@ -49,7 +49,6 @@ PYBIND11_MODULE(_libtoast, m) { init_ops_mapmaker_utils(m); init_ops_noise_weight(m); init_ops_scan_map(m); - init_io_compression_flac(m); // Internal unit test runner m.def( diff --git a/src/toast/_libtoast/module.hpp b/src/toast/_libtoast/module.hpp index f7f8506ed..ad9fa69ce 100644 --- a/src/toast/_libtoast/module.hpp +++ b/src/toast/_libtoast/module.hpp @@ -422,6 +422,5 @@ void init_ops_pixels_healpix(py::module & m); void init_ops_mapmaker_utils(py::module & m); void init_ops_noise_weight(py::module & m); void init_ops_scan_map(py::module & m); -void init_io_compression_flac(py::module & m); #endif // ifndef LIBTOAST_HPP diff --git a/src/toast/io/CMakeLists.txt b/src/toast/io/CMakeLists.txt index c714a5373..454c8345d 100644 --- a/src/toast/io/CMakeLists.txt +++ b/src/toast/io/CMakeLists.txt @@ -4,10 +4,11 @@ install(FILES __init__.py observation_hdf_save.py + observation_hdf_save_v1.py observation_hdf_load.py observation_hdf_load_v0.py + observation_hdf_load_v1.py hdf_utils.py - compression.py - compression_flac.py + deprecated_compression.py DESTINATION ${PYTHON_SITE}/toast/io ) diff --git a/src/toast/io/__init__.py b/src/toast/io/__init__.py index 8ccec08a0..0d202a35c 100644 --- a/src/toast/io/__init__.py +++ b/src/toast/io/__init__.py @@ -4,7 +4,6 @@ # Namespace imports -from .compression import compress_detdata, decompress_detdata from .hdf_utils import H5File, have_hdf5_parallel, hdf5_config, hdf5_open from .observation_hdf_load import ( load_hdf5, diff --git a/src/toast/io/compression_flac.py b/src/toast/io/compression_flac.py deleted file mode 100644 index d54a0a28c..000000000 --- a/src/toast/io/compression_flac.py +++ /dev/null @@ -1,294 +0,0 @@ -# Copyright (c) 2023-2023 by the parties listed in the AUTHORS file. -# All rights reserved. Use of this source code is governed by -# a BSD-style license that can be found in the LICENSE file. - -import time - -import numpy as np - -from .._libtoast import ( - compress_flac, - compress_flac_2D, - decompress_flac, - decompress_flac_2D, - have_flac_support, -) -from ..utils import AlignedU8, Logger, dtype_to_aligned - - -def float2int(data, quanta=None, precision=None): - """Convert floating point data to integers. - - This function subtracts the mean and rescales data before rounding to 32bit - integer values. - - Args: - data (array): The floating point data. - quanta (float): The floating point quantity corresponding to one integer - resolution amount in the output. If `None`, quanta will be - based on the full dynamic range of the data. - precision (int): Number of significant digits to preserve. If - provided, `quanta` will be estimated accordingly. - - Returns: - (tuple): The (integer data, offset, gain) - - """ - if np.any(np.isnan(data)): - raise RuntimeError("Cannot convert data with NaNs to integers") - dmin = np.amin(data) - dmax = np.amax(data) - offset = 0.5 * (dmin + dmax) - amp = 1.01 * max(np.abs(dmin - offset), np.abs(dmax - offset)) - # Use the full bit range of int32 FLAC. Actually we lose one bit - # Due to internal FLAC implementation. - max_flac = np.iinfo(np.int32).max // 2 - min_quanta = amp / max_flac - if precision is not None: - rms = np.std(data) - quanta = rms / 10**precision - if quanta < min_quanta: - msg = f"Set precision (={precision}) cannot be supported with 32bit FLAC: " - msg += f"{quanta} < {min_quanta}" - raise RuntimeError(msg) - if quanta is None: - quanta = min_quanta - if quanta == 0: - # This can happen if fed a vector of all zeros - quanta = 1.0 - gain = 1.0 / quanta - - return ( - np.array(np.around(gain * (data - offset)), dtype=np.int32), - offset, - gain, - ) - - -def int64to32(data): - """Convert 64bit integers to 32bit. - - This finds the 64bit integer mean and subtracts it. It then checks that the value - will fit in 32bit integers. If you want to treat the integer values as floating - point data, use float2int instead. - - Args: - data (array): The 64bit integer data. - - Returns: - (tuple): The (integer data, offset) - - """ - if data.dtype != np.dtype(np.int64): - raise ValueError("Only int64 data is supported") - - dmin = np.amin(data) - dmax = np.amax(data) - offset = int(round(0.5 * (dmin + dmax))) - - temp = np.array(data - offset, dtype=np.int64) - - # FLAC uses an extra bit... - flac_max = np.iinfo(np.int32).max // 2 - - bad = np.logical_or(temp > flac_max, temp < -flac_max) - n_bad = np.count_nonzero(bad) - if n_bad > 0: - msg = f"64bit integers minus {offset} has {n_bad} values outside 32bit range" - raise RuntimeError(msg) - - return ( - temp.astype(np.int32), - offset, - ) - - -def int2float(idata, offset, gain): - """Restore floating point data from integers. - - The gain and offset are applied and the resulting float32 data is returned. - - Args: - idata (array): The 32bit integer data. - offset (float): The offset used in the original conversion. - gain (float): The gain used in the original conversion. - - Returns: - (array): The restored float32 data. - - """ - if len(idata.shape) > 1: - raise ValueError("Only works with flat packed arrays") - coeff = 1.0 / gain - return np.array((idata * coeff) + offset, dtype=np.float32) - - -def compress_detdata_flac(detdata, level=5, quanta=None, precision=None): - """Compress a 2D DetectorData array into FLAC bytes. - - The input data is converted to 32bit integers. The "quanta" value is used - for floating point data conversion and represents the floating point increment - for a single integer value. If quanta is None, each detector timestream is - scaled independently based on its data range. If quanta is a scalar, all - detectors are scaled with the same value. If quanta is an array, it specifies - the scaling independently for each detector. - - Alternatively, if "precision" is provided, each data vector is scaled to retain - the prescribed number of significant digits. - - The following rules specify the data conversion that is performed depending on - the input type: - - int32: No conversion. - - int64: Subtract the integer closest to the mean, then truncate to lower - 32 bits, and check that the higher bits were zero. - - float32: Subtract the mean and scale data based on the quanta value (see - above). Then round to nearest 32bit integer. - - float64: Subtract the mean and scale data based on the quanta value (see - above). Then round to nearest 32bit integer. - - After conversion to 32bit integers, each detector's data is separately compressed - into a sequence of FLAC bytes, which is appended to the total. The offset in - bytes for each detector is recorded. - - Args: - detdata (DetectorData): The input detector data. - level (int): Compression level - quanta (array): For floating point data, the increment of each integer. - precision (int): Number of significant digits to retain in float-to-int - conversion. Alternative to `quanta`. - - Returns: - (tuple): The (compressed bytes, byte ranges, - [detector value offset, detector value gain]) - - """ - if not have_flac_support(): - raise RuntimeError("TOAST was not compiled with libFLAC support") - - if quanta is not None: - if precision is not None: - raise RuntimeError("Cannot set both quanta and precision") - try: - nq = len(quanta) - # This is a sequence - if nq != len(detdata.detectors): - raise ValueError( - "If not a scalar, quanta must have a value for each detector" - ) - dquanta = quanta - except TypeError: - # This is a scalar, applied to all detectors - dquanta = quanta * np.ones(len(detdata.detectors), dtype=np.float64) - else: - dquanta = [None for x in detdata.detectors] - - comp_bytes = AlignedU8() - - if detdata.dtype == np.dtype(np.int32): - pass - elif detdata.dtype == np.dtype(np.int64): - data_ioffsets = np.zeros(len(detdata.detectors), dtype=np.int64) - elif detdata.dtype == np.dtype(np.float32) or detdata.dtype == np.dtype(np.float64): - data_offsets = np.zeros(len(detdata.detectors), dtype=np.float64) - data_gains = np.ones(len(detdata.detectors), dtype=np.float64) - else: - raise ValueError(f"Unsupported data type '{detdata.dtype}'") - - start_bytes = list() - for idet, det in enumerate(detdata.detectors): - cur = comp_bytes.size() - start_bytes.append(cur) - - if detdata.dtype == np.dtype(np.int32): - intdata = detdata[idet].reshape(-1) - elif detdata.dtype == np.dtype(np.int64): - intdata, ioff = int64to32(detdata[idet, :].reshape(-1)) - data_ioffsets[idet] = ioff - else: - intdata, foff, fgain = float2int( - detdata[idet, :].reshape(-1), - quanta=dquanta[idet], - precision=precision, - ) - data_offsets[idet] = foff - data_gains[idet] = fgain - - dbytes = compress_flac(intdata, level) - ext = len(dbytes) - comp_bytes.resize(cur + ext) - comp_bytes[cur : cur + ext] = dbytes - - comp_ranges = list() - for idet in range(len(detdata.detectors)): - if idet == len(detdata.detectors) - 1: - comp_ranges.append((start_bytes[idet], comp_bytes.size())) - else: - comp_ranges.append((start_bytes[idet], start_bytes[idet + 1])) - - if detdata.dtype == np.dtype(np.int32): - return (comp_bytes.array(), comp_ranges) - elif detdata.dtype == np.dtype(np.int64): - return (comp_bytes.array(), comp_ranges, data_ioffsets) - else: - return (comp_bytes.array(), comp_ranges, data_offsets, data_gains) - - -def decompress_detdata_flac( - detdata, flacbytes, byte_ranges, det_offsets=None, det_gains=None -): - """Decompress FLAC bytes into a DetectorData array. - - Given an existing DetectorData object, decompress individual detector data from - the FLAC byte stream given the starting byte for each detector and optionally - the offset and gain factor to apply to convert the native 32bit integers into - the output type. - - Args: - detdata (DetectorData): The object to fill with decompressed data. - flacbytes (array): Compressed FLAC bytes - byte_ranges (list): The byte range for each detector. - det_offsets (array): The offset to apply to each detector during type - conversion. - det_gains (array): The scale factor to apply to each detector during - type conversion. - - Returns: - None - - """ - if not have_flac_support(): - raise RuntimeError("TOAST was not compiled with libFLAC support") - - # Since we may have discontiguous slices into the bytestream, we decompress - # all data types one detector at a time. - - for idet, det in enumerate(detdata.detectors): - slc = slice(byte_ranges[idet][0], byte_ranges[idet][1], 1) - idata = decompress_flac(flacbytes[slc]) - if detdata.dtype == np.dtype(np.int32): - # Just copy it into place - detdata[idet] = idata.array().reshape(detdata.detector_shape) - elif detdata.dtype == np.dtype(np.int64): - detdata[idet] = idata.array().reshape(detdata.detector_shape) - if det_offsets is not None: - detdata[idet] += det_offsets[idet] - elif detdata.dtype == np.dtype(np.float32) or detdata.dtype == np.dtype( - np.float64 - ): - if det_offsets is None: - doff = 0.0 - else: - doff = det_offsets[idet] - if det_gains is None: - dgain = 1.0 - else: - dgain = det_gains[idet] - detdata[idet] = int2float(idata.array(), doff, dgain).reshape( - detdata.detector_shape - ) - else: - raise ValueError(f"Unsupported data type '{detdata.dtype}'") diff --git a/src/toast/io/compression.py b/src/toast/io/deprecated_compression.py similarity index 78% rename from src/toast/io/compression.py rename to src/toast/io/deprecated_compression.py index 99f3e4781..e90621ec7 100644 --- a/src/toast/io/compression.py +++ b/src/toast/io/deprecated_compression.py @@ -5,12 +5,12 @@ import gzip import numpy as np -from astropy import units as u +from flacarray.compress import array_compress +from flacarray.decompress import array_decompress from ..observation_data import DetectorData from ..timing import function_timer from ..utils import AlignedU8, Logger -from .compression_flac import compress_detdata_flac, decompress_detdata_flac @function_timer @@ -135,33 +135,29 @@ def compress_detdata(detdata, comp_params=None): if quanta is not None and precision is not None: raise RuntimeError("Cannot set both quanta and precision") - if detdata.dtype in ftypes: - ( - comp_bytes, - comp_ranges, - comp_params["data_offsets"], - comp_params["data_gains"], - ) = compress_detdata_flac( - detdata, level=comp_level, quanta=quanta, precision=precision + if detdata.dtype in ftypes or detdata.dtype in itypes: + (comp_bytes, stream_starts, stream_nbytes, stream_offsets, stream_gains) = ( + array_compress( + detdata.data.reshape((n_det, -1)), + level=comp_level, + quanta=quanta, + precision=precision, + ) ) - elif detdata.dtype == np.dtype(np.int64): - ( - comp_bytes, - comp_ranges, - comp_params["data_offsets"], - ) = compress_detdata_flac(detdata, level=comp_level) - elif detdata.dtype == np.dtype(np.int32): - ( - comp_bytes, - comp_ranges, - ) = compress_detdata_flac(detdata, level=comp_level) + + comp_ranges = [(x, x + y) for x, y in zip(stream_starts, stream_nbytes)] + if detdata.dtype in ftypes: + comp_params["data_offsets"] = stream_offsets + comp_params["data_gains"] = stream_gains + elif detdata.dtype == np.dtype(np.int64): + comp_params["data_offsets"] = stream_offsets else: - msg = f"FLAC Compression of type '{detdata.dtype}' is not supported" + msg = f"Legacy FLAC Compression of type '{detdata.dtype}' is not supported" raise RuntimeError(msg) return (comp_bytes, comp_ranges, comp_params) else: - msg = f"Compression type \"{comp_params['type']}\" is not supported" + msg = f'Compression type "{comp_params["type"]}" is not supported' raise NotImplementedError(msg) @@ -218,22 +214,22 @@ def decompress_detdata(comp_bytes, comp_ranges, comp_params, detdata=None): # We are populating an existing data object. Verify consistent # properties. if detectors != detdata.detectors: - msg = f"Input detdata container has different detectors " + msg = "Input detdata container has different detectors " msg += f"({detdata.detectors}) than compressed data " msg += f"({detectors})" raise RuntimeError(msg) if detector_shape != detdata.detector_shape: - msg = f"Input detdata container has different det shape " + msg = "Input detdata container has different det shape " msg += f"({detdata.detector_shape}) than compressed data " msg += f"({detector_shape})" raise RuntimeError(msg) if dtype != detdata.dtype: - msg = f"Input detdata container has different dtype " + msg = "Input detdata container has different dtype " msg += f"({detdata.dtype}) than compressed data " msg += f"({dtype})" raise RuntimeError(msg) if units != detdata.units: - msg = f"Input detdata container has different units " + msg = "Input detdata container has different units " msg += f"({detdata.units}) than compressed data " msg += f"({units})" raise RuntimeError(msg) @@ -264,23 +260,35 @@ def decompress_detdata(comp_bytes, comp_ranges, comp_params, detdata=None): elif comp_params["type"] == "flac": data_offsets = None - if "data_offsets" in comp_params: - data_offsets = comp_params["data_offsets"] - dslices = list() + if "data_offsets" in comp_params and comp_params["data_offsets"] is not None: + data_offsets = comp_params["data_offsets"].astype(np.float32) data_gains = None - if "data_gains" in comp_params: - data_gains = comp_params["data_gains"] - decompress_detdata_flac( - detdata, + if "data_gains" in comp_params and comp_params["data_gains"] is not None: + data_gains = comp_params["data_gains"].astype(np.float32) + + stream_starts = np.array([x[0] for x in comp_ranges]) + stream_stops = np.array([x[1] for x in comp_ranges]) + stream_nbytes = stream_stops - stream_starts + + temp_array = array_decompress( comp_bytes, - comp_ranges, - det_offsets=data_offsets, - det_gains=data_gains, + det_stride, + stream_starts, + stream_nbytes, + stream_offsets=data_offsets, + stream_gains=data_gains, + first_stream_sample=None, + last_stream_sample=None, + is_int64=False, + use_threads=False, + no_flatten=True, ) + for idet, det in enumerate(detectors): + detdata[det, :] = temp_array[idet].reshape(detector_shape) else: - msg = f"Compression type \"{comp_params['type']}\" is not supported" + msg = f'Compression type "{comp_params["type"]}" is not supported' raise NotImplementedError(msg) return detdata diff --git a/src/toast/io/hdf_utils.py b/src/toast/io/hdf_utils.py index 25d1a0281..4e30118fb 100644 --- a/src/toast/io/hdf_utils.py +++ b/src/toast/io/hdf_utils.py @@ -140,3 +140,168 @@ def __enter__(self): def __exit__(self, *args): self.close() + + +def save_meta_object(parent, objname, obj): + """Recursive function to save python metadata objects. + + This function attempts to make intelligent choices based on the object types: + - Scalars are written as attributes to the parent group + - Arrays are written as a dataset in the parent group + - Dictionaries: a new group is created with objname and this function is + called for each child key / value. + - Lists / Tuples: a new group is created with objname and the original + data type is recorded to a group attribute. Next each item is passed + to this function with a name "item_XXXX". + + Args: + parent (h5py.Group): The parent group (if this process is participating) + else None. + objname (str): The name of the current object. + obj (object): A recognized object type (scalar, array, dict, list, + set, tuple) + + Returns: + None + + """ + log = Logger.get() + + def _type_to_str(obj): + if isinstance(obj, dict): + return "dict" + elif isinstance(obj, list): + return "list" + elif isinstance(obj, tuple): + return "tuple" + else: + msg = f"Unsupported container type '{type(obj)}'" + raise ValueError(msg) + + if parent is None: + # Not participating + return + + if "type" not in parent.attrs: + # This must be the root + parent.attrs["type"] = "dict" + + if isinstance(obj, dict): + child = parent.create_group(objname) + child.attrs["type"] = _type_to_str(obj) + for k, v in obj.items(): + save_meta_object(child, k, v) + elif isinstance(obj, (list, tuple)): + child = parent.create_group(objname) + child.attrs["type"] = _type_to_str(obj) + for indx, item in enumerate(obj): + k = f"item_{indx:04d}" + save_meta_object(child, k, item) + elif isinstance(obj, u.Quantity): + if isinstance(obj.value, np.ndarray): + # Array quantity + odata = parent.create_dataset(objname, data=obj.value) + odata.attrs["units"] = obj.unit.to_string() + del odata + else: + # Must be a scalar quantity + parent.attrs[f"{objname}_value"] = obj.value + parent.attrs[f"{objname}_units"] = obj.unit.to_string() + elif isinstance(obj, np.ndarray): + # Array + arr = parent.create_dataset(objname, data=obj) + del arr + else: + # This is a scalar or some kind of unknown object. Try + # to store it as an attribute and warn if something failed. + try: + parent.attrs[objname] = obj + except (ValueError, TypeError) as e: + msg = f"Failed to store metadata '{objname}' = '{v}' as an attribute ({e})." + msg += " Ignoring." + log.warn(msg) + + +def load_meta_object(parent): + """Recursive function to load HDF5 metadata objects. + + This function recursively processes an HDF5 group and converts groups and + datasets into python objects. + + Args: + parent (h5py.Group): The parent group (if this process is participating) + else None. + + Returns: + (object): The populated python container + + """ + if "type" not in parent.attrs: + raise RuntimeError("metadata group does not contain 'type' attribute") + + parsed = dict() + parsed["type"] = parent.attrs["type"] + + # First process child groups / datasets + for child_name in list(sorted(parent.keys())): + if isinstance(parent[child_name], h5py.Group): + # Descend + child = parent[child_name] + parsed[child_name] = load_meta_object(child) + del child + elif isinstance(parent[child_name], h5py.Dataset): + child = parent[child_name] + if "units" in child.attrs: + # This is an array Quantity + arr = u.Quantity(child, u.Unit(child.attrs["units"]), copy=True) + else: + # Plain numpy array + arr = np.array(child, copy=True) + parsed[child_name] = arr + del child + + # Now process parent attributes + units_pat = re.compile(r"(.*)_units") + value_pat = re.compile(r"(.*)_value") + for k, v in parent.attrs.items(): + if k == "type": + continue + if value_pat.match(k) is not None: + # We will process this when matching units + continue + units_mat = units_pat.match(k) + if units_mat is not None: + # We have a quantity + kname = units_mat.group(1) + unit_str = v + kval = parent.attrs[f"{kname}_value"] + parsed[kname] = u.Quantity(kval, u.Unit(unit_str)) + else: + # Simple scalar + parsed[k] = v + + # If the parent container is a list or tuple, construct that now and sort the + # children into the original order. + ret = None + ctype = parsed["type"] + if ctype == "dict": + del parsed["type"] + return parsed + elif ctype == "list": + keys = list(sorted(parsed.keys())) + ret = list() + for k in keys: + if k == "type": + continue + ret.append(parsed[k]) + elif ctype == "tuple": + keys = list(sorted(parsed.keys())) + ret = tuple() + for k in keys: + if k == "type": + continue + ret = ret + (parsed[k],) + else: + msg = f"Unsupported container format '{ctype}'" + raise RuntimeError(msg) + return ret diff --git a/src/toast/io/observation_hdf_load.py b/src/toast/io/observation_hdf_load.py index f871ca798..9bf3a6244 100644 --- a/src/toast/io/observation_hdf_load.py +++ b/src/toast/io/observation_hdf_load.py @@ -12,16 +12,22 @@ import numpy as np from astropy import units as u from astropy.table import QTable +import flacarray -from ..instrument import Focalplane, GroundSite, SpaceSite, Telescope -from ..mpi import MPI +from ..instrument import Focalplane from ..observation import Observation -from ..timing import GlobalTimers, Timer, function_timer -from ..utils import Environment, Logger, dtype_to_aligned, import_from_name +from ..timing import Timer, function_timer +from ..utils import Environment, Logger, import_from_name from ..weather import SimWeather, Weather -from .compression import decompress_detdata -from .hdf_utils import check_dataset_buffer_size, hdf5_config, hdf5_open -from .observation_hdf_load_v0 import load_hdf5_detdata_v0 +from .hdf_utils import ( + check_dataset_buffer_size, + hdf5_config, + hdf5_open, + load_meta_object, +) + +from .observation_hdf_load_v1 import load_hdf5 as load_hdf5_v1 +from .observation_hdf_load_v1 import load_instrument as load_instrument_v1 @function_timer @@ -180,54 +186,40 @@ def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): units = None full_shape = None dtype = None + orig_dtype = None compressed = False cgrp = None - comp_params = dict() - comp_ranges = None - comp_offsets = None - comp_gains = None - comp_nbytes = None - if hgrp is not None: if isinstance(hgrp[field], h5py.Dataset): # This is uncompressed data ds = hgrp[field] full_shape = ds.shape dtype = ds.dtype + orig_dtype = dtype units = u.Unit(str(ds.attrs["units"])) else: - # This must be a group of datasets containing compressed data + # This must be a group of datasets containing compressed data. + # FIXME: we should have a flacarray helper function to get array + # properties without doing this manually. cgrp = hgrp[field] - ds = cgrp["compressed"] - units = u.Unit(str(ds.attrs["units"])) - comp_params["units"] = units - dtype = np.dtype(ds.attrs["dtype"]) - comp_params["dtype"] = dtype - det_shape = tuple(ast.literal_eval(ds.attrs["det_shape"])) - comp_params["det_shape"] = det_shape - comp_nbytes = len(ds) - - ds_ranges = cgrp["ranges"] - n_det = len(ds_ranges) - comp_ranges = np.zeros((n_det, 2), dtype=np.int64) - slc = (slice(0, n_det, 1), slice(0, 2, 1)) - ds_ranges.read_direct(comp_ranges, slc, slc) - - if "offsets" in cgrp: - ds_offsets = cgrp["offsets"] - comp_offsets = np.zeros(n_det, np.float64) - slc = (slice(0, n_det, 1),) - ds_offsets.read_direct(comp_offsets, slc, slc) - - if "gains" in cgrp: - ds_gains = cgrp["gains"] - comp_gains = np.zeros(n_det, np.float64) - slc = (slice(0, n_det, 1),) - ds_gains.read_direct(comp_gains, slc, slc) - - full_shape = (n_det,) + det_shape - comp_params["type"] = ds.attrs["comp_type"] + units = u.Unit(str(cgrp.attrs["units"])) + orig_dtype = np.dtype(cgrp.attrs["dtype"]) + detector_shape = tuple(ast.literal_eval(cgrp.attrs["detector_shape"])) + n_channel = int(cgrp.attrs["flac_channels"]) + starts = cgrp["stream_starts"] + n_det = starts.shape[0] + full_shape = (n_det,) + detector_shape + if "stream_offsets" in cgrp: + if n_channel == 2: + dtype = np.dtype(np.float64) + else: + dtype = np.dtype(np.float32) + else: + if n_channel == 2: + dtype = np.dtype(np.int64) + else: + dtype = np.dtype(np.int32) compressed = True if serial_load: @@ -235,12 +227,7 @@ def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): units = obs.comm.comm_group.bcast(units, root=0) full_shape = obs.comm.comm_group.bcast(full_shape, root=0) dtype = obs.comm.comm_group.bcast(dtype, root=0) - if compressed: - comp_params = obs.comm.comm_group.bcast(comp_params, root=0) - comp_nbytes = obs.comm.comm_group.bcast(comp_nbytes, root=0) - comp_ranges = obs.comm.comm_group.bcast(comp_ranges, root=0) - comp_offsets = obs.comm.comm_group.bcast(comp_offsets, root=0) - comp_gains = obs.comm.comm_group.bcast(comp_gains, root=0) + orig_dtype = obs.comm.comm_group.bcast(orig_dtype, root=0) sample_shape = None if len(full_shape) > 2: @@ -250,7 +237,7 @@ def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): obs.detdata.create( field, sample_shape=sample_shape, - dtype=dtype, + dtype=orig_dtype, detectors=obs.local_detectors, units=units, ) @@ -260,132 +247,70 @@ def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): # data. We can do this since we previously checked that for serial loads # the data is distributed by detector. - if serial_load: - if compressed: - # Each process has the information about all detectors' byte - # range and compression parameters. We just need to send - # the bytes. - for proc, detrange in enumerate(dist_dets): - first_det = detrange.offset - last_det = detrange.offset + detrange.n_elem - 1 - - first_byte = comp_ranges[first_det][0] - end_byte = comp_ranges[last_det][1] - local_ranges = np.zeros( - (detrange.n_elem, 2), - dtype=np.int64, - ) - for d in range(detrange.n_elem): - local_ranges[d, 0] = comp_ranges[first_det + d][0] - first_byte - local_ranges[d, 1] = comp_ranges[first_det + d][1] - first_byte - - if comp_offsets is not None: - comp_params["data_offsets"] = comp_offsets[ - first_det : last_det + 1 - ] - if comp_gains is not None: - comp_params["data_gains"] = comp_gains[first_det : last_det + 1] - - pbytes = end_byte - first_byte - pslice = (slice(0, pbytes, 1),) - hslice = (slice(first_byte, end_byte, 1),) - - if obs.comm.group_rank == 0: - buffer = np.zeros(pbytes, dtype=np.uint8) - ds.read_direct(buffer, hslice, pslice) - if proc == 0: - # Decompress our own data - decompress_detdata( - buffer, - local_ranges, - comp_params, - detdata=obs.detdata[field], - ) - else: - # Send - obs.comm.comm_group.Send(buffer, dest=proc, tag=proc) - elif obs.comm.group_rank == proc: - # Receive and decompress - buffer = np.zeros(pbytes, dtype=np.uint8) - obs.comm.comm_group.Recv(buffer, source=0, tag=proc) - decompress_detdata( - buffer, - local_ranges, - comp_params, - detdata=obs.detdata[field], - ) - else: - for proc, detrange in enumerate(dist_dets): - first_det = detrange.offset - end_det = detrange.offset + detrange.n_elem - n_local_det = detrange.n_elem - pslice = (slice(0, n_local_det, 1), slice(0, obs.n_all_samples, 1)) - hslice = ( - slice(first_det, end_det, 1), - slice(0, obs.n_all_samples, 1), - ) - if obs.comm.group_rank == 0: - buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) - ds.read_direct(buffer, hslice, pslice) - if proc == 0: - # Copy data into place - obs.detdata[field][:] = buffer - else: - # Send - obs.comm.comm_group.Send( - buffer.reshape(-1), dest=proc, tag=proc - ) - del buffer - elif obs.comm.group_rank == proc: - # Receive and store - buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) - obs.comm.comm_group.Recv(buffer, source=0, tag=proc) + if compressed: + # Load with flacarray + mpi_dist = [(x.offset, x.offset + x.n_elem) for x in dist_dets] + flcdata = ( + flacarray.hdf5.read_array( + cgrp, + keep=None, + stream_slice=None, + keep_indices=False, + mpi_comm=obs.comm.comm_group, + mpi_dist=mpi_dist, + use_threads=False, + ) + .astype(orig_dtype) + .reshape(obs.detdata[field].shape) + ) + for idet, det in enumerate(obs.detdata[field].detectors): + obs.detdata[field][idet] = flcdata[idet] + del flcdata + elif serial_load: + # Uncompressed read and distribute + for proc, detrange in enumerate(dist_dets): + first_det = detrange.offset + end_det = detrange.offset + detrange.n_elem + n_local_det = detrange.n_elem + pslice = (slice(0, n_local_det, 1), slice(0, obs.n_all_samples, 1)) + hslice = ( + slice(first_det, end_det, 1), + slice(0, obs.n_all_samples, 1), + ) + if obs.comm.group_rank == 0: + buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) + ds.read_direct(buffer, hslice, pslice) + if proc == 0: + # Copy data into place obs.detdata[field][:] = buffer + else: + # Send + obs.comm.comm_group.Send( + buffer.reshape(-1), dest=proc, tag=proc + ) del buffer + elif obs.comm.group_rank == proc: + # Receive and store + buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) + obs.comm.comm_group.Recv(buffer, source=0, tag=proc) + obs.detdata[field][:] = buffer + del buffer else: - if compressed: - last_det = det_off + det_nelem - 1 - first_byte = comp_ranges[det_off][0] - end_byte = comp_ranges[last_det][1] - local_ranges = np.zeros( - (det_nelem, 2), - dtype=np.int64, - ) - for d in range(det_nelem): - local_ranges[d, 0] = comp_ranges[det_off + d][0] - first_byte - local_ranges[d, 1] = comp_ranges[det_off + d][1] - first_byte - - if comp_offsets is not None: - comp_params["data_offsets"] = comp_offsets[det_off : last_det + 1] - if comp_gains is not None: - comp_params["data_gains"] = comp_gains[det_off : last_det + 1] - - pbytes = end_byte - first_byte - pslice = (slice(0, pbytes, 1),) - hslice = (slice(first_byte, end_byte, 1),) - buffer = np.zeros(pbytes, dtype=np.uint8) - ds.read_direct(buffer, hslice, pslice) - decompress_detdata( - buffer, - local_ranges, - comp_params, - detdata=obs.detdata[field], - ) - else: - detdata_slice = [slice(0, det_nelem, 1), slice(0, samp_nelem, 1)] - hf_slice = [ - slice(det_off, det_off + det_nelem, 1), - slice(samp_off, samp_off + samp_nelem, 1), - ] - if len(full_shape) > 2: - for dim in full_shape[2:]: - detdata_slice.append(slice(0, dim)) - hf_slice.append(slice(0, dim)) - detdata_slice = tuple(detdata_slice) - hf_slice = tuple(hf_slice) - msg = f"Detdata field {field} (group rank {obs.comm.group_rank})" - check_dataset_buffer_size(msg, hf_slice, dtype, parallel) - ds.read_direct(obs.detdata[field].data, hf_slice, detdata_slice) + # Uncompressed read in parallel + detdata_slice = [slice(0, det_nelem, 1), slice(0, samp_nelem, 1)] + hf_slice = [ + slice(det_off, det_off + det_nelem, 1), + slice(samp_off, samp_off + samp_nelem, 1), + ] + if len(full_shape) > 2: + for dim in full_shape[2:]: + detdata_slice.append(slice(0, dim)) + hf_slice.append(slice(0, dim)) + detdata_slice = tuple(detdata_slice) + hf_slice = tuple(hf_slice) + msg = f"Detdata field {field} (group rank {obs.comm.group_rank})" + check_dataset_buffer_size(msg, hf_slice, dtype, parallel) + ds.read_direct(obs.detdata[field].data, hf_slice, detdata_slice) if obs.comm.comm_group is not None: obs.comm.comm_group.barrier() @@ -395,6 +320,7 @@ def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): timer=timer, ) del ds + del cgrp @function_timer @@ -447,8 +373,23 @@ def load_instrument(parent_group, detectors=None, file_det_sets=None, comm=None) telescope = None session = None new_detsets = file_det_sets + toast_version = None if parent_group is not None: inst_group = parent_group["instrument"] + toast_version = int(inst_group.attrs["toast_format_version"]) + if comm is not None: + toast_version = comm.bcast(toast_version, root=0) + + if toast_version < 2: + return load_instrument_v1( + parent_group, detectors=detectors, file_det_sets=file_det_sets, comm=comm + ) + if toast_version > 2: + msg = "load_instrument() found invalid file format " + msg += f"version {toast_version}" + raise RuntimeError(msg) + + if parent_group is not None: telescope_name = str(inst_group.attrs["telescope_name"]) telescope_uid = int(inst_group.attrs["telescope_uid"]) telescope_class = import_from_name(str(inst_group.attrs["telescope_class"])) @@ -609,6 +550,7 @@ def load_hdf5_obs_meta( session = None obs_det_sets = None obs_sample_sets = None + all_det_flags = None if hgroup is not None: # Observation properties @@ -631,6 +573,9 @@ def load_hdf5_obs_meta( hgroup, detectors=detectors, file_det_sets=file_det_sets, comm=None ) + # Per detector flags. + all_det_flags = json.loads(hgroup.attrs["observation_detector_flags"]) + log.debug_rank( f"{log_prefix} Loaded instrument properties in", comm=comm.comm_group, @@ -646,6 +591,7 @@ def load_hdf5_obs_meta( session = comm.comm_group.bcast(session, root=0) obs_det_sets = comm.comm_group.bcast(obs_det_sets, root=0) obs_sample_sets = comm.comm_group.bcast(obs_sample_sets, root=0) + all_det_flags = comm.comm_group.bcast(all_det_flags, root=0) # Create the observation obs = Observation( @@ -660,6 +606,12 @@ def load_hdf5_obs_meta( process_rows=process_rows, ) + # Set per-detector flags + local_det_flags = dict() + for det in obs.local_detectors: + local_det_flags[det] = all_det_flags[det] + obs.set_local_detector_flags(local_det_flags) + # Load observation metadata. This is complicated because a subset of processes # may have the file open, but the object loader may need the whole communicator # to load the object. First we load all simple metadata and record the names @@ -668,10 +620,14 @@ def load_hdf5_obs_meta( # collectively. meta_load = dict() + attr_load = dict() if hgroup is not None: meta_group = hgroup["metadata"] for obj_name in meta_group.keys(): + if obj_name == "other": + # Simple python metadata- will process below + continue obj = meta_group[obj_name] if meta is not None and obj_name not in meta: # The user restricted the list of things to load, and this is @@ -689,47 +645,56 @@ def load_hdf5_obs_meta( else: msg = f"metadata object group '{obj_name}' has class " msg += f"{obj.attrs['class']}, but instantiated " - msg += f"object does not have a load_hdf5() method" + msg += "object does not have a load_hdf5() method" log.error(msg) - else: - # This must be some other custom user dataset. Ignore it. - pass - else: - # Array-like dataset that we can load - if "units" in obj.attrs: - # This array is a quantity - meta_load[obj_name] = u.Quantity( - np.array(obj), unit=u.Unit(obj.attrs["units"]) - ) - else: - meta_load[obj_name] = np.array(obj) - del obj - - # Now extract attributes (scalars) - units_pat = re.compile(r"(.*)_units") - for k, v in meta_group.attrs.items(): - if meta is not None and k not in meta: - continue - if units_pat.match(k) is not None: - # unit field, skip - continue - # Check for quantity - unit_name = f"{k}_units" - if unit_name in meta_group.attrs: - meta_load[k] = u.Quantity(v, unit=u.Unit(meta_group.attrs[unit_name])) - else: - meta_load[k] = v + continue + # Warn that we are not loading this object + msg = f"Found un-loadable metadata object '{obj_name}'. Skipping." + log.warn(msg) + # Now load regular metadata into a python dictionary + meta_other = meta_group["other"] + other = load_meta_object(meta_other) + meta_load.update(other) + del other + del meta_other del meta_group + # Now process observation attribute objects + attr_group = hgroup["attr"] + for obj_name in attr_group.keys(): + obj = attr_group[obj_name] + if isinstance(obj, h5py.Group): + # This might be an object to restore + if "class" in obj.attrs: + objclass = import_from_name(obj.attrs["class"]) + test_obj = objclass() + if hasattr(test_obj, "load_hdf5"): + # Record this in the dictionary of things to load in the + # next step. + attr_load[obj_name] = test_obj + else: + msg = f"attr object group '{obj_name}' has class " + msg += f"{obj.attrs['class']}, but instantiated " + msg += "object does not have a load_hdf5() method" + log.error(msg) + continue + # Warn that we are not loading this object + msg = f"Found un-loadable attribute object '{obj_name}'. Skipping." + log.warn(msg) + del attr_group + # Communicate the partial metadata if not parallel and nproc > 1: meta_load = comm.comm_group.bcast(meta_load, root=0) + attr_load = comm.comm_group.bcast(attr_load, root=0) # Now load any remaining metadata objects meta_group = None + attr_group = None if hgroup is not None: meta_group = hgroup["metadata"] + attr_group = hgroup["attr"] for meta_key in list(meta_load.keys()): if hasattr(meta_load[meta_key], "load_hdf5"): handle = None @@ -738,10 +703,22 @@ def load_hdf5_obs_meta( meta_load[meta_key].load_hdf5(handle, obs) del handle del meta_group + for attr_key in list(attr_load.keys()): + if hasattr(attr_load[attr_key], "load_hdf5"): + handle = None + if hgroup is not None: + handle = attr_group[attr_key] + attr_load[attr_key].load_hdf5(handle, obs) + del handle + del attr_group # Assign the internal observation dictionary obs._internal = meta_load + # Assign all class attributes + for k, v in attr_load.items(): + setattr(obs, k, v) + log.debug_rank( f"{log_prefix} Finished other metadata in", comm=comm.comm_group, @@ -834,25 +811,25 @@ def load_hdf5( if comm.comm_group is not None: file_version = comm.comm_group.bcast(file_version, root=0) - # As the file format evolves, we might close the file at this point and call - # an earlier version of the loader. However, v0 and v1 only differ in the - # detector data loading, so we can just branch at that point. - # - # Example for future: - # if file_version == 12345: - # # Close file and call older version - # del hgroup - # if hf is not None: - # hf.close() - # del hf - # return load_hdf5_v12345(...) - # - if file_version == 0 and detectors is not None: - msg = f"HDF5 file '{path}' uses format v0 which does not support loading" - msg = " a subset of detectors" - log.error(msg) - raise RuntimeError(msg) - if file_version > 1: + if file_version < 2: + # The v1 loader also deals with v0 data. + del hgroup + if hf is not None: + hf.close() + del hf + return load_hdf5_v1( + path, + comm, + process_rows=process_rows, + meta=meta, + detdata=detdata, + shared=shared, + intervals=intervals, + detectors=detectors, + force_serial=force_serial, + ) + + if file_version > 2: msg = f"HDF5 file '{path}' using unsupported data format {file_version}" log.error(msg) raise RuntimeError(msg) @@ -910,10 +887,7 @@ def load_hdf5( detdata_group = None if hgroup is not None: detdata_group = hgroup["detdata"] - if file_version == 0: - load_hdf5_detdata_v0(obs, detdata_group, detdata, log_prefix, parallel) - else: - load_hdf5_detdata(obs, detdata_group, detdata, log_prefix, parallel) + load_hdf5_detdata(obs, detdata_group, detdata, log_prefix, parallel) del detdata_group log.debug_rank( f"{log_prefix} Finished detector data in", diff --git a/src/toast/io/observation_hdf_load_v1.py b/src/toast/io/observation_hdf_load_v1.py new file mode 100644 index 000000000..00a30074b --- /dev/null +++ b/src/toast/io/observation_hdf_load_v1.py @@ -0,0 +1,965 @@ +# Copyright (c) 2021-2025 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. + +import ast +import json +import os +import re +from datetime import datetime, timezone + +import h5py +import numpy as np +from astropy import units as u +from astropy.table import QTable + +from ..instrument import Focalplane, GroundSite, SpaceSite, Telescope +from ..mpi import MPI +from ..observation import Observation +from ..timing import GlobalTimers, Timer, function_timer +from ..utils import Environment, Logger, dtype_to_aligned, import_from_name +from ..weather import SimWeather, Weather +from .deprecated_compression import decompress_detdata +from .hdf_utils import check_dataset_buffer_size, hdf5_config, hdf5_open +from .observation_hdf_load_v0 import load_hdf5_detdata_v0 + + +@function_timer +def load_hdf5_shared(obs, hgrp, fields, log_prefix, parallel): + log = Logger.get() + + timer = Timer() + timer.start() + + # Get references to the distribution of detectors and samples + proc_rows = obs.dist.process_rows + proc_cols = obs.dist.comm.group_size // proc_rows + dist_samps = obs.dist.samps + dist_dets = obs.dist.det_indices + + serial_load = False + if obs.comm.group_size > 1 and not parallel: + # We are doing a serial load, but we have multiple processes + # in the group. + serial_load = True + + field_list = None + if hgrp is not None: + field_list = list(hgrp.keys()) + if serial_load and obs.comm.comm_group is not None: + # Broadcast the field list + field_list = obs.comm.comm_group.bcast(field_list, root=0) + + for field in field_list: + if fields is not None and field not in fields: + continue + ds = None + comm_type = None + full_shape = None + dtype = None + if hgrp is not None: + ds = hgrp[field] + comm_type = ds.attrs["comm_type"] + full_shape = ds.shape + dtype = ds.dtype + if serial_load: + comm_type = obs.comm.comm_group.bcast(comm_type, root=0) + full_shape = obs.comm.comm_group.bcast(full_shape, root=0) + dtype = obs.comm.comm_group.bcast(dtype, root=0) + + slc = list() + shape = list() + if comm_type == "row": + off = dist_dets[obs.comm.group_rank].offset + nelem = dist_dets[obs.comm.group_rank].n_elem + slc.append(slice(off, off + nelem)) + shape.append(nelem) + elif comm_type == "column": + off = dist_samps[obs.comm.group_rank].offset + nelem = dist_samps[obs.comm.group_rank].n_elem + slc.append(slice(off, off + nelem)) + shape.append(nelem) + else: + slc.append(slice(0, full_shape[0])) + shape.append(full_shape[0]) + if len(full_shape) > 1: + for dim in full_shape[1:]: + slc.append(slice(0, dim)) + shape.append(dim) + slc = tuple(slc) + shape = tuple(shape) + + obs.shared.create_type(comm_type, field, shape, dtype) + shcomm = obs.shared[field].comm + + # Load the data on one process of the communicator if loading in parallel. + # If doing a serial load, the single reading process must communicate the + # data to the rank zero process on each object comm. + if (comm_type == "group") or (not serial_load): + # Load data on the rank zero process and set + data = None + if shcomm is None or shcomm.rank == 0: + msg = f"Shared field {field} ({comm_type})" + check_dataset_buffer_size(msg, slc, dtype, parallel) + data = np.array(ds[slc], copy=False).astype(obs.shared[field].dtype) + obs.shared[field].set(data, fromrank=0) + del data + else: + # More compilicated, since we have data distributed on along a process + # row or column, but are loading data on one process. First load full + # data on the reader. + full_data = None + if obs.comm.group_rank == 0: + full_data = np.array(ds[:], copy=False) + + # Note: we could use a scatterv here instead of broadcasting the whole + # thing, if this ever becomes worth the additional book-keeping. + data = None + if comm_type == "row" and obs.comm_row_rank == 0: + # Distribute to the other rank zeros of the process rows + if obs.comm_col is not None: + full_data = obs.comm_col.bcast(full_data, root=0) + data = np.array(full_data[slc], dtype=obs.shared[field].dtype) + elif comm_type == "column" and obs.comm_col_rank == 0: + # Distribute to the other rank zeros of the process columns + if obs.comm_row is not None: + full_data = obs.comm_row.bcast(full_data, root=0) + data = np.array(full_data[slc], dtype=obs.shared[field].dtype) + del full_data + + # Now set the data within each row / column + obs.shared[field].set(data, fromrank=0) + del data + del ds + + if obs.comm.comm_group is not None: + obs.comm.comm_group.barrier() + log.verbose_rank( + f"{log_prefix} Shared finished {field} read in", + comm=obs.comm.comm_group, + timer=timer, + ) + + return + + +@function_timer +def load_hdf5_detdata(obs, hgrp, fields, log_prefix, parallel): + log = Logger.get() + + timer = Timer() + timer.start() + + # Get references to the distribution of detectors and samples + dist_samps = obs.dist.samps + dist_dets = obs.dist.det_indices + + # Data ranges for this process + samp_off = dist_samps[obs.comm.group_rank].offset + samp_nelem = dist_samps[obs.comm.group_rank].n_elem + det_off = dist_dets[obs.comm.group_rank].offset + det_nelem = dist_dets[obs.comm.group_rank].n_elem + + serial_load = False + if obs.comm.group_size > 1 and not parallel: + # We are doing a serial load, but we have multiple processes + # in the group. + serial_load = True + + field_list = None + if hgrp is not None: + field_list = list(hgrp.keys()) + if serial_load and obs.comm.comm_group is not None: + # Broadcast the field list + field_list = obs.comm.comm_group.bcast(field_list, root=0) + + for field in field_list: + if fields is not None and field not in fields: + continue + ds = None + units = None + full_shape = None + dtype = None + + compressed = False + cgrp = None + comp_params = dict() + comp_ranges = None + comp_offsets = None + comp_gains = None + comp_nbytes = None + + if hgrp is not None: + if isinstance(hgrp[field], h5py.Dataset): + # This is uncompressed data + ds = hgrp[field] + full_shape = ds.shape + dtype = ds.dtype + units = u.Unit(str(ds.attrs["units"])) + else: + # This must be a group of datasets containing compressed data + cgrp = hgrp[field] + ds = cgrp["compressed"] + units = u.Unit(str(ds.attrs["units"])) + comp_params["units"] = units + dtype = np.dtype(ds.attrs["dtype"]) + comp_params["dtype"] = dtype + det_shape = tuple(ast.literal_eval(ds.attrs["det_shape"])) + comp_params["det_shape"] = det_shape + comp_nbytes = len(ds) + + ds_ranges = cgrp["ranges"] + n_det = len(ds_ranges) + comp_ranges = np.zeros((n_det, 2), dtype=np.int64) + slc = (slice(0, n_det, 1), slice(0, 2, 1)) + ds_ranges.read_direct(comp_ranges, slc, slc) + + if "offsets" in cgrp: + ds_offsets = cgrp["offsets"] + comp_offsets = np.zeros(n_det, np.float64) + slc = (slice(0, n_det, 1),) + ds_offsets.read_direct(comp_offsets, slc, slc) + + if "gains" in cgrp: + ds_gains = cgrp["gains"] + comp_gains = np.zeros(n_det, np.float64) + slc = (slice(0, n_det, 1),) + ds_gains.read_direct(comp_gains, slc, slc) + + full_shape = (n_det,) + det_shape + comp_params["type"] = ds.attrs["comp_type"] + compressed = True + + if serial_load: + compressed = obs.comm.comm_group.bcast(compressed, root=0) + units = obs.comm.comm_group.bcast(units, root=0) + full_shape = obs.comm.comm_group.bcast(full_shape, root=0) + dtype = obs.comm.comm_group.bcast(dtype, root=0) + if compressed: + comp_params = obs.comm.comm_group.bcast(comp_params, root=0) + comp_nbytes = obs.comm.comm_group.bcast(comp_nbytes, root=0) + comp_ranges = obs.comm.comm_group.bcast(comp_ranges, root=0) + comp_offsets = obs.comm.comm_group.bcast(comp_offsets, root=0) + comp_gains = obs.comm.comm_group.bcast(comp_gains, root=0) + + sample_shape = None + if len(full_shape) > 2: + sample_shape = full_shape[2:] + + # All processes create their local detector data + obs.detdata.create( + field, + sample_shape=sample_shape, + dtype=dtype, + detectors=obs.local_detectors, + units=units, + ) + + # All processes independently load their data if running in parallel. + # If loading serially, one process reads and sends blocks of detector + # data. We can do this since we previously checked that for serial loads + # the data is distributed by detector. + + if serial_load: + if compressed: + # Each process has the information about all detectors' byte + # range and compression parameters. We just need to send + # the bytes. + for proc, detrange in enumerate(dist_dets): + first_det = detrange.offset + last_det = detrange.offset + detrange.n_elem - 1 + + first_byte = comp_ranges[first_det][0] + end_byte = comp_ranges[last_det][1] + local_ranges = np.zeros( + (detrange.n_elem, 2), + dtype=np.int64, + ) + for d in range(detrange.n_elem): + local_ranges[d, 0] = comp_ranges[first_det + d][0] - first_byte + local_ranges[d, 1] = comp_ranges[first_det + d][1] - first_byte + + if comp_offsets is not None: + comp_params["data_offsets"] = comp_offsets[ + first_det : last_det + 1 + ] + if comp_gains is not None: + comp_params["data_gains"] = comp_gains[first_det : last_det + 1] + + pbytes = end_byte - first_byte + pslice = (slice(0, pbytes, 1),) + hslice = (slice(first_byte, end_byte, 1),) + + if obs.comm.group_rank == 0: + buffer = np.zeros(pbytes, dtype=np.uint8) + ds.read_direct(buffer, hslice, pslice) + if proc == 0: + # Decompress our own data + decompress_detdata( + buffer, + local_ranges, + comp_params, + detdata=obs.detdata[field], + ) + else: + # Send + obs.comm.comm_group.Send(buffer, dest=proc, tag=proc) + elif obs.comm.group_rank == proc: + # Receive and decompress + buffer = np.zeros(pbytes, dtype=np.uint8) + obs.comm.comm_group.Recv(buffer, source=0, tag=proc) + decompress_detdata( + buffer, + local_ranges, + comp_params, + detdata=obs.detdata[field], + ) + else: + for proc, detrange in enumerate(dist_dets): + first_det = detrange.offset + end_det = detrange.offset + detrange.n_elem + n_local_det = detrange.n_elem + pslice = (slice(0, n_local_det, 1), slice(0, obs.n_all_samples, 1)) + hslice = ( + slice(first_det, end_det, 1), + slice(0, obs.n_all_samples, 1), + ) + if obs.comm.group_rank == 0: + buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) + ds.read_direct(buffer, hslice, pslice) + if proc == 0: + # Copy data into place + obs.detdata[field][:] = buffer + else: + # Send + obs.comm.comm_group.Send( + buffer.reshape(-1), dest=proc, tag=proc + ) + del buffer + elif obs.comm.group_rank == proc: + # Receive and store + buffer = np.zeros((n_local_det, obs.n_all_samples), dtype=dtype) + obs.comm.comm_group.Recv(buffer, source=0, tag=proc) + obs.detdata[field][:] = buffer + del buffer + else: + if compressed: + last_det = det_off + det_nelem - 1 + first_byte = comp_ranges[det_off][0] + end_byte = comp_ranges[last_det][1] + local_ranges = np.zeros( + (det_nelem, 2), + dtype=np.int64, + ) + for d in range(det_nelem): + local_ranges[d, 0] = comp_ranges[det_off + d][0] - first_byte + local_ranges[d, 1] = comp_ranges[det_off + d][1] - first_byte + + if comp_offsets is not None: + comp_params["data_offsets"] = comp_offsets[det_off : last_det + 1] + if comp_gains is not None: + comp_params["data_gains"] = comp_gains[det_off : last_det + 1] + + pbytes = end_byte - first_byte + pslice = (slice(0, pbytes, 1),) + hslice = (slice(first_byte, end_byte, 1),) + buffer = np.zeros(pbytes, dtype=np.uint8) + ds.read_direct(buffer, hslice, pslice) + decompress_detdata( + buffer, + local_ranges, + comp_params, + detdata=obs.detdata[field], + ) + else: + detdata_slice = [slice(0, det_nelem, 1), slice(0, samp_nelem, 1)] + hf_slice = [ + slice(det_off, det_off + det_nelem, 1), + slice(samp_off, samp_off + samp_nelem, 1), + ] + if len(full_shape) > 2: + for dim in full_shape[2:]: + detdata_slice.append(slice(0, dim)) + hf_slice.append(slice(0, dim)) + detdata_slice = tuple(detdata_slice) + hf_slice = tuple(hf_slice) + msg = f"Detdata field {field} (group rank {obs.comm.group_rank})" + check_dataset_buffer_size(msg, hf_slice, dtype, parallel) + ds.read_direct(obs.detdata[field].data, hf_slice, detdata_slice) + + if obs.comm.comm_group is not None: + obs.comm.comm_group.barrier() + log.verbose_rank( + f"{log_prefix} Detdata finished {field} read in", + comm=obs.comm.comm_group, + timer=timer, + ) + del ds + + +@function_timer +def load_hdf5_intervals(obs, hgrp, times, fields, log_prefix, parallel): + log = Logger.get() + + timer = Timer() + timer.start() + + serial_load = False + if obs.comm.group_size > 1 and not parallel: + # We are doing a serial load, but we have multiple processes + # in the group. + serial_load = True + + field_list = None + if hgrp is not None: + field_list = list(hgrp.keys()) + if serial_load: + # Broadcast the field list + field_list = obs.comm.comm_group.bcast(field_list, root=0) + + if obs.comm.comm_group is not None: + obs.comm.comm_group.barrier() + + for field in field_list: + if fields is not None and field not in fields: + continue + # The dataset + ds = None + global_times = None + if obs.comm.group_rank == 0: + ds = hgrp[field] + global_times = np.transpose(ds[:]) + + obs.intervals.create(field, global_times, times, fromrank=0) + del ds + + if obs.comm.comm_group is not None: + obs.comm.comm_group.barrier() + log.verbose_rank( + f"{log_prefix} Intervals finished {field} read in", + comm=obs.comm.comm_group, + timer=timer, + ) + + +def load_instrument(parent_group, detectors=None, file_det_sets=None, comm=None): + """Load instrument information from an HDF5 group.""" + telescope = None + session = None + new_detsets = file_det_sets + if parent_group is not None: + inst_group = parent_group["instrument"] + toast_version = int(inst_group.attrs["toast_format_version"]) + if toast_version != 1: + msg = "Version 1 of load_instrument() called on file format " + msg += f"version {toast_version}" + raise RuntimeError(msg) + telescope_name = str(inst_group.attrs["telescope_name"]) + telescope_uid = int(inst_group.attrs["telescope_uid"]) + telescope_class = import_from_name(str(inst_group.attrs["telescope_class"])) + + site_name = str(inst_group.attrs["site_name"]) + site_uid = int(inst_group.attrs["site_uid"]) + site_class = import_from_name(str(inst_group.attrs["site_class"])) + + site = None + if "site_alt_m" in inst_group.attrs: + # This is a ground based site + site_alt_m = float(inst_group.attrs["site_alt_m"]) + site_lat_deg = float(inst_group.attrs["site_lat_deg"]) + site_lon_deg = float(inst_group.attrs["site_lon_deg"]) + + weather = None + if "site_weather_name" in inst_group.attrs: + weather_name = str(inst_group.attrs["site_weather_name"]) + weather_realization = int(inst_group.attrs["site_weather_realization"]) + weather_max_pwv = None + if inst_group.attrs["site_weather_max_pwv"] != "NONE": + weather_max_pwv = u.Quantity( + float(inst_group.attrs["site_weather_max_pwv"]), u.mm + ) + weather_time = datetime.fromtimestamp( + float(inst_group.attrs["site_weather_time"]), tz=timezone.utc + ) + weather_median = bool(inst_group.attrs["site_weather_median"]) + weather = SimWeather( + time=weather_time, + name=weather_name, + site_uid=site_uid, + realization=weather_realization, + max_pwv=weather_max_pwv, + median_weather=weather_median, + ) + elif "site_weather_time" in inst_group.attrs: + # This is a generic weather object + props = dict() + props["time"] = datetime.fromtimestamp( + float(inst_group.attrs["site_weather_time"]), tz=timezone.utc + ) + for attr_name in [ + "ice_water", + "liquid_water", + "pwv", + "humidity", + "surface_pressure", + "surface_temperature", + "air_temperature", + "west_wind", + "south_wind", + ]: + file_attr = f"site_weather_{attr_name}" + props[attr_name] = u.Quantity(inst_group.attrs[file_attr]) + weather = Weather(**props) + site = site_class( + site_name, + site_lat_deg * u.degree, + site_lon_deg * u.degree, + site_alt_m * u.meter, + uid=site_uid, + weather=weather, + ) + else: + site = site_class(site_name, uid=site_uid) + + session = None + if "session_name" in inst_group.attrs: + session_name = str(inst_group.attrs["session_name"]) + session_uid = int(inst_group.attrs["session_uid"]) + session_start = inst_group.attrs["session_start"] + if str(session_start) == "NONE": + session_start = None + else: + session_start = datetime.fromtimestamp( + float(inst_group.attrs["session_start"]), + tz=timezone.utc, + ) + session_end = inst_group.attrs["session_end"] + if str(session_end) == "NONE": + session_end = None + else: + session_end = datetime.fromtimestamp( + float(inst_group.attrs["session_end"]), + tz=timezone.utc, + ) + session_class = import_from_name(str(inst_group.attrs["session_class"])) + session = session_class( + session_name, uid=session_uid, start=session_start, end=session_end + ) + + raw_focalplane = Focalplane() + raw_focalplane.load_hdf5(inst_group, comm=None) + + if detectors is None: + focalplane = raw_focalplane + else: + # Slice focalplane to include only these detectors. Also modify + # detector sets to include only these detectors. + keep = set(detectors) + fp_rows = [x["name"] in keep for x in raw_focalplane.detector_data] + fp_data = QTable(raw_focalplane.detector_data[fp_rows]) + + focalplane = Focalplane( + detector_data=fp_data, + sample_rate=raw_focalplane.sample_rate, + field_of_view=raw_focalplane.field_of_view, + ) + new_detsets = list() + if isinstance(file_det_sets, list): + # List of lists + for ds in file_det_sets: + new_ds = list() + for d in ds: + if d in keep: + new_ds.append(d) + if len(new_ds) > 0: + new_detsets.append(new_ds) + else: + # Must be a dictionary + for dskey, ds in file_det_sets.items(): + new_ds = list() + for d in ds: + if d in keep: + new_ds.append(d) + if len(new_ds) > 0: + new_detsets.append(new_ds) + + telescope = telescope_class( + telescope_name, uid=telescope_uid, focalplane=focalplane, site=site + ) + del inst_group + return telescope, session, new_detsets + + +@function_timer +def load_hdf5_obs_meta( + comm, + hgroup, + parallel=False, + log_prefix="", + meta=None, + detectors=None, + process_rows=None, +): + log = Logger.get() + rank = comm.group_rank + nproc = comm.group_size + + timer = Timer() + timer.start() + + telescope = None + obs_samples = None + obs_name = None + obs_uid = None + session = None + obs_det_sets = None + obs_sample_sets = None + + if hgroup is not None: + # Observation properties + obs_name = str(hgroup.attrs["observation_name"]) + obs_uid = int(hgroup.attrs["observation_uid"]) + obs_dets = json.loads(hgroup.attrs["observation_detectors"]) + file_det_sets = None + if hgroup.attrs["observation_detector_sets"] != "NONE": + file_det_sets = json.loads(hgroup.attrs["observation_detector_sets"]) + obs_samples = int(hgroup.attrs["observation_samples"]) + obs_sample_sets = None + if hgroup.attrs["observation_sample_sets"] != "NONE": + obs_sample_sets = [ + [int(x) for x in y] + for y in json.loads(hgroup.attrs["observation_sample_sets"]) + ] + + # Instrument properties + telescope, session, obs_det_sets = load_instrument( + hgroup, detectors=detectors, file_det_sets=file_det_sets, comm=None + ) + + log.debug_rank( + f"{log_prefix} Loaded instrument properties in", + comm=comm.comm_group, + timer=timer, + ) + + # Broadcast the observation properties if needed + if not parallel and nproc > 1: + telescope = comm.comm_group.bcast(telescope, root=0) + obs_samples = comm.comm_group.bcast(obs_samples, root=0) + obs_name = comm.comm_group.bcast(obs_name, root=0) + obs_uid = comm.comm_group.bcast(obs_uid, root=0) + session = comm.comm_group.bcast(session, root=0) + obs_det_sets = comm.comm_group.bcast(obs_det_sets, root=0) + obs_sample_sets = comm.comm_group.bcast(obs_sample_sets, root=0) + + # Create the observation + obs = Observation( + comm, + telescope, + obs_samples, + name=obs_name, + uid=obs_uid, + session=session, + detector_sets=obs_det_sets, + sample_sets=obs_sample_sets, + process_rows=process_rows, + ) + + # Load observation metadata. This is complicated because a subset of processes + # may have the file open, but the object loader may need the whole communicator + # to load the object. First we load all simple metadata and record the names + # of more complicated objects to load. Then we ensure that all processes have + # this information. Then all processes load more complicated objects + # collectively. + + meta_load = dict() + + if hgroup is not None: + meta_group = hgroup["metadata"] + for obj_name in meta_group.keys(): + obj = meta_group[obj_name] + if meta is not None and obj_name not in meta: + # The user restricted the list of things to load, and this is + # not in the list. + continue + if isinstance(obj, h5py.Group): + # This might be an object to restore + if "class" in obj.attrs: + objclass = import_from_name(obj.attrs["class"]) + test_obj = objclass() + if hasattr(test_obj, "load_hdf5"): + # Record this in the dictionary of things to load in the + # next step. + meta_load[obj_name] = test_obj + else: + msg = f"metadata object group '{obj_name}' has class " + msg += f"{obj.attrs['class']}, but instantiated " + msg += "object does not have a load_hdf5() method" + log.error(msg) + else: + # This must be some other custom user dataset. Ignore it. + pass + else: + # Array-like dataset that we can load + if "units" in obj.attrs: + # This array is a quantity + meta_load[obj_name] = u.Quantity( + np.array(obj), unit=u.Unit(obj.attrs["units"]) + ) + else: + meta_load[obj_name] = np.array(obj) + del obj + + # Now extract attributes (scalars) + units_pat = re.compile(r"(.*)_units") + for k, v in meta_group.attrs.items(): + if meta is not None and k not in meta: + continue + if units_pat.match(k) is not None: + # unit field, skip + continue + # Check for quantity + unit_name = f"{k}_units" + if unit_name in meta_group.attrs: + meta_load[k] = u.Quantity(v, unit=u.Unit(meta_group.attrs[unit_name])) + else: + meta_load[k] = v + del meta_group + + # Communicate the partial metadata + if not parallel and nproc > 1: + meta_load = comm.comm_group.bcast(meta_load, root=0) + + # Now load any remaining metadata objects + + meta_group = None + if hgroup is not None: + meta_group = hgroup["metadata"] + for meta_key in list(meta_load.keys()): + if hasattr(meta_load[meta_key], "load_hdf5"): + handle = None + if hgroup is not None: + handle = meta_group[meta_key] + meta_load[meta_key].load_hdf5(handle, obs) + del handle + del meta_group + + # Assign the internal observation dictionary + obs._internal = meta_load + + log.debug_rank( + f"{log_prefix} Finished other metadata in", + comm=comm.comm_group, + timer=timer, + ) + return obs + + +@function_timer +def load_hdf5( + path, + comm, + process_rows=None, + meta=None, + detdata=None, + shared=None, + intervals=None, + detectors=None, + force_serial=False, +): + """Load an HDF5 observation. + + By default, all detdata, shared, intervals, and noise models are loaded into + memory. A subset of objects may be specified with a list of names + passed to the corresponding function arguments. + + Args: + path (str): The path to the file on disk. + comm (toast.Comm): The toast communicator to use. + process_rows (int): (Optional) The size of the rectangular process grid + in the detector direction. This number must evenly divide into the size of + comm. If not specified, defaults to the size of the communicator. + meta (list): Only load this list of metadata objects. + detdata (list): Only load this list of detdata objects. + shared (list): Only load this list of shared objects. + intervals (list): Only load this list of intervals objects. + detectors (list): Only load this list of detectors from all detector data + objects. + force_serial (bool): If True, do not use HDF5 parallel support, + even if it is available. + + Returns: + (Observation): The constructed observation. + + """ + log = Logger.get() + env = Environment.get() + + rank = comm.group_rank + nproc = comm.group_size + if nproc == 1: + # Force serial usage in this case, to avoid any MPI overhead + force_serial = True + + timer = Timer() + timer.start() + log_prefix = f"HDF5 load {os.path.basename(path)}: " + + # Open the file and get the root group. + hf = None + hfgroup = None + + parallel, _, _ = hdf5_config(comm=comm.comm_group, force_serial=force_serial) + if ( + (not parallel) + and (process_rows is not None) + and (process_rows != comm.group_size) + ): + msg = "When loading observations with serial HDF5, process_rows must equal " + msg += "the group size" + log.error(msg) + raise RuntimeError(msg) + + hf = hdf5_open(path, "r", comm=comm.comm_group, force_serial=force_serial) + hgroup = hf + + log.debug_rank( + f"{log_prefix} Opened file {path} in", + comm=comm.comm_group, + timer=timer, + ) + + # The rank zero process gets the file format version and communicates to all + # processes in the group, regardless of whether they are participating in + # the load. + file_version = None + if rank == 0: + # Data format version check + file_version = int(hgroup.attrs["toast_format_version"]) + if comm.comm_group is not None: + file_version = comm.comm_group.bcast(file_version, root=0) + + # As the file format evolves, we might close the file at this point and call + # an earlier version of the loader. However, v0 and v1 only differ in the + # detector data loading, so we can just branch at that point. + # + # Example for future: + # if file_version == 12345: + # # Close file and call older version + # del hgroup + # if hf is not None: + # hf.close() + # del hf + # return load_hdf5_v12345(...) + # + if file_version == 0 and detectors is not None: + msg = f"HDF5 file '{path}' uses format v0 which does not support loading" + msg = " a subset of detectors" + log.error(msg) + raise RuntimeError(msg) + if file_version > 1: + msg = f"HDF5 file '{path}' using unsupported data format {file_version}" + log.error(msg) + raise RuntimeError(msg) + + # Load all metadata into an empty Observation + obs = load_hdf5_obs_meta( + comm, + hgroup, + parallel=parallel, + log_prefix="", + meta=meta, + detectors=detectors, + process_rows=process_rows, + ) + + # Load shared data + + shared_group = None + if hgroup is not None: + shared_group = hgroup["shared"] + load_hdf5_shared(obs, shared_group, shared, log_prefix, parallel) + del shared_group + log.debug_rank( + f"{log_prefix} Finished shared data in", + comm=comm.comm_group, + timer=timer, + ) + + # Load intervals + + intervals_group = None + intervals_times = None + if hgroup is not None: + intervals_group = hgroup["intervals"] + intervals_times = intervals_group.attrs["times"] + if not parallel and nproc > 1: + intervals_times = comm.comm_group.bcast(intervals_times, root=0) + load_hdf5_intervals( + obs, + intervals_group, + obs.shared[intervals_times], + intervals, + log_prefix, + parallel, + ) + del intervals_group + log.debug_rank( + f"{log_prefix} Finished intervals in", + comm=comm.comm_group, + timer=timer, + ) + + # Load detector data + + detdata_group = None + if hgroup is not None: + detdata_group = hgroup["detdata"] + if file_version == 0: + load_hdf5_detdata_v0(obs, detdata_group, detdata, log_prefix, parallel) + else: + load_hdf5_detdata(obs, detdata_group, detdata, log_prefix, parallel) + del detdata_group + log.debug_rank( + f"{log_prefix} Finished detector data in", + comm=comm.comm_group, + timer=timer, + ) + + # Clean up + del hgroup + if hf is not None: + hf.close() + del hf + + return obs + + +def load_instrument_file(path, detectors=None, obs_det_sets=None, comm=None): + """Load instrument information from an HDF5 file. + + This function loads the telescope and session serially on one process. + It supports including a relative internal path inside the HDF5 file by separating + the filesystem path from the internal path with a colon. For example: + + path="/path/to/file.h5:/obs1 + + The internal path should be to the *parent* group of the "instrument" group. + + """ + parts = path.split(":") + if len(parts) == 1: + file = parts[0] + internal = "/" + else: + file = parts[0] + internal = parts[1] + grouptree = internal.split(os.path.sep) + with h5py.File(file, "r") as hf: + parent = hf + for grp in grouptree: + if grp == "": + continue + parent = parent[grp] + telescope, session, _ = load_instrument(parent) + return telescope, session diff --git a/src/toast/io/observation_hdf_save.py b/src/toast/io/observation_hdf_save.py index 611615037..f0e452e07 100644 --- a/src/toast/io/observation_hdf_save.py +++ b/src/toast/io/observation_hdf_save.py @@ -9,14 +9,13 @@ import h5py import numpy as np from astropy import units as u -from astropy.table import Table + +import flacarray from ..instrument import GroundSite -from ..mpi import MPI from ..observation import default_values as defaults -from ..observation_data import DetectorData from ..observation_dist import global_interval_times -from ..timing import GlobalTimers, Timer, function_timer +from ..timing import Timer, function_timer from ..utils import ( Environment, Logger, @@ -24,8 +23,7 @@ hdf5_use_serial, object_fullname, ) -from .compression import compress_detdata, decompress_detdata -from .hdf_utils import check_dataset_buffer_size, hdf5_open +from .hdf_utils import check_dataset_buffer_size, hdf5_open, save_meta_object @function_timer @@ -188,7 +186,7 @@ def save_hdf5_shared(obs, hgrp, fields, log_prefix): @function_timer -def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place=False): +def save_hdf5_detdata(obs, hgrp, fields, log_prefix, in_place=False): log = Logger.get() timer = Timer() @@ -208,6 +206,14 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place # Are we doing serial I/O? use_serial = hdf5_use_serial(hgrp, comm) + # Valid flacarray dtypes + flac_dtypes = [ + np.dtype(np.float64), + np.dtype(np.float32), + np.dtype(np.int64), + np.dtype(np.int32), + ] + for ifield, (field, fieldcomp) in enumerate(fields): tag_offset = (obs.comm.group * 1000 + ifield) * obs.comm.group_size if field not in obs.detdata: @@ -226,27 +232,26 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place # by detector, since we will compress each detector independently. if fieldcomp is not None and proc_cols != 1: msg = f"Detector data '{field}' compression requested, but data for " - msg += f"individual channels is split between processes." + msg += "individual channels is split between processes." raise RuntimeError(msg) # Compute properties of the full set of data across the observation ddtype = local_data.dtype + if ddtype in flac_dtypes: + hdtype = ddtype + else: + # Cast flag bytes to int for compression + hdtype = np.int32 dshape = (len(obs.all_detectors), obs.n_all_samples) dvalshape = None if len(local_data.detector_shape) > 1: dvalshape = local_data.detector_shape[1:] dshape += dvalshape + local_n_det = local_data.shape[0] - fdtype = ddtype - if ddtype.char == "d" and use_float32: - # We are truncating to single precision - fdtype = np.dtype(np.float32) - - # If we are using our own internal compression, each process compresses their - # local data and sends it to one process for insertion into the overall blob - # of bytes. + if fieldcomp is None: + # We are not using compression. - if fieldcomp is None or "type_hdf5" in fieldcomp: # The buffer class to use for allocating receive buffers bufclass, _ = dtype_to_aligned(ddtype) @@ -254,14 +259,7 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place hdata = None if hgrp is not None: # This process is participating. - # - # Future NOTE: Here is where we could extract the "type_hdf5" parameter - # from the dictionary and create the dataset with appropriate compression - # and chunking settings. Detector data would then be written to the - # dataset in the usual way, with compression "under the hood" done by - # HDF5. - # - hdata = hgrp.create_dataset(field, dshape, dtype=fdtype) + hdata = hgrp.create_dataset(field, dshape, dtype=ddtype) hdata.attrs["units"] = local_data.units.to_string() if use_serial: @@ -290,7 +288,7 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place # Root process writes local data if rank == 0: hdata.write_direct( - local_data.data.astype(fdtype), detdata_slice, hf_slice + local_data.data.astype(ddtype), detdata_slice, hf_slice ) elif proc == rank: # We are sending @@ -300,7 +298,7 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place recv = bufclass(nflat) comm.Recv(recv, source=proc, tag=tag_offset + proc) hdata.write_direct( - recv.array().astype(fdtype).reshape(shp), + recv.array().astype(ddtype).reshape(shp), detdata_slice, hf_slice, ) @@ -328,7 +326,7 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place with hdata.collective: hdata.write_direct( - local_data.data.astype(fdtype), detdata_slice, hf_slice + local_data.data.astype(ddtype), detdata_slice, hf_slice ) del hdata log.verbose_rank( @@ -337,205 +335,68 @@ def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place timer=timer, ) else: - # Compress our local detector data. The starting dictionary of properties - # is passed in and additional metadata is appended. - if ddtype.char == "d" and use_float32: - temp_detdata = DetectorData( - obs.detdata[field].detectors, - obs.detdata[field].detector_shape, - np.float32, - units=obs.detdata[field].units, - ) - temp_detdata.data[:] = obs.detdata[field].data.astype(np.float32) - comp_bytes, comp_ranges, comp_props = compress_detdata( - temp_detdata, fieldcomp + fgrp = None + if hgrp is not None: + # This process is participating. Create a subgroup for this + # field. + fgrp = hgrp.create_group(field) + # Add attributes for the original data properties + fgrp.attrs["units"] = local_data.units.to_string() + fgrp.attrs["dtype"] = ddtype.char + fgrp.attrs["detector_shape"] = str( + [int(x) for x in local_data.detector_shape] ) - if in_place: - # Decompress - decompress_detdata( - comp_bytes, comp_ranges, comp_props, detdata=temp_detdata - ) - # upcast back to float64 and and overwrite the original detector data - obs.detdata[field].data[:] = temp_detdata.data[:] - del temp_detdata + if "level" in fieldcomp: + level = int(fieldcomp["level"]) else: - temp_detdata = None - comp_bytes, comp_ranges, comp_props = compress_detdata( - obs.detdata[field], fieldcomp - ) - if in_place: - # Decompress and overwrite the original detector data - decompress_detdata( - comp_bytes, comp_ranges, comp_props, detdata=obs.detdata[field] + level = 5 + quanta = None + precision = None + + if ddtype.char == "d" or ddtype.char == "f": + # Floating point type + if "quanta" in fieldcomp: + quanta = float(fieldcomp["quanta"]) + elif "precision" in fieldcomp: + precision = float(fieldcomp["precision"]) + if quanta is None and precision is None: + msg = "When compressing floating point data, you" + msg += " must specify the quanta or precision." + raise RuntimeError("You must specify the quanta") + + # We flatten all the per-sample data when compressing + flacarray.hdf5.write_array( + local_data.data.astype(hdtype).reshape((local_n_det, -1)), + fgrp, + level=level, + quanta=quanta, + precision=precision, + mpi_comm=comm, + use_threads=False, + ) + if in_place: + # Decompress data back into original location, to capture + # any truncation effects. + det_off = dist_dets[obs.comm.group_rank].offset + det_nelem = dist_dets[obs.comm.group_rank].n_elem + mpi_dist = [(x.offset, x.offset + x.n_elem) for x in dist_dets] + + flcdata = ( + flacarray.hdf5.read_array( + fgrp, + keep=None, + stream_slice=None, + keep_indices=False, + mpi_comm=comm, + mpi_dist=mpi_dist, + use_threads=False, ) - - # Extract per-detector quantities for communicating / writing later - comp_data_offsets = None - if "data_offsets" in comp_props: - comp_data_offsets = comp_props["data_offsets"] - comp_data_gains = None - if "data_gains" in comp_props: - comp_data_gains = comp_props["data_gains"] - - # Get the total number of bytes - n_local_bytes = len(comp_bytes) - if comm is None: - n_all_bytes = n_local_bytes - else: - n_all_bytes = comm.allreduce(n_local_bytes, op=MPI.SUM) - - # Create the datasets - hdata_bytes = None - hdata_ranges = None - hdata_offsets = None - hdata_gains = None - cgrp = None - if hgrp is not None: - # This process is participating. - cgrp = hgrp.create_group(field) - hdata_bytes = cgrp.create_dataset( - "compressed", n_all_bytes, dtype=np.uint8 - ) - hdata_bytes.attrs["units"] = local_data.units.to_string() - # Write common properties of many compression schemes - hdata_bytes.attrs["dtype"] = str(comp_props["dtype"]) - hdata_bytes.attrs["det_shape"] = str( - tuple([int(x) for x in comp_props["det_shape"]]) + .astype(ddtype) + .reshape(local_data.shape) ) - hdata_bytes.attrs["comp_type"] = comp_props["type"] - if "level" in comp_props: - hdata_bytes.attrs["comp_level"] = comp_props["level"] - hdata_ranges = cgrp.create_dataset( - "ranges", - (len(obs.all_detectors), 2), - dtype=np.int64, - ) - if comp_data_offsets is not None: - hdata_offsets = cgrp.create_dataset( - "offsets", - (len(obs.all_detectors),), - dtype=np.float64, - ) - if comp_data_gains is not None: - hdata_gains = cgrp.create_dataset( - "gains", - (len(obs.all_detectors),), - dtype=np.float64, - ) - - # Send data to rank zero of the group for writing. - hf_det = 0 - hf_bytes = 0 - for proc in range(nproc): - if rank == 0: - if proc == 0: - # Root process writes local data - det_ranges = np.array( - [(x[0] + hf_bytes, x[1] + hf_bytes) for x in comp_ranges], - dtype=np.int64, - ).reshape((-1, 2)) - dslc = ( - slice(0, len(det_ranges), 1), - slice(0, 2, 1), - ) - hslc = ( - slice(hf_det, hf_det + len(det_ranges), 1), - slice(0, 2, 1), - ) - hdata_ranges.write_direct(det_ranges, dslc, hslc) - - dslc = (slice(0, n_local_bytes, 1),) - hslc = (slice(hf_bytes, hf_bytes + n_local_bytes, 1),) - hdata_bytes.write_direct(comp_bytes, dslc, hslc) - - dslc = (slice(0, len(comp_ranges), 1),) - hslc = (slice(hf_det, hf_det + len(comp_ranges), 1),) - if comp_data_offsets is not None: - hdata_offsets.write_direct(comp_data_offsets, dslc, hslc) - if comp_data_gains is not None: - hdata_gains.write_direct(comp_data_gains, dslc, hslc) - - hf_bytes += n_local_bytes - hf_det += len(comp_ranges) - else: - # Receive data and write - n_recv_bytes = comm.recv( - source=proc, tag=tag_offset + 10 * proc - ) - n_recv_dets = comm.recv( - source=proc, tag=tag_offset + 10 * proc + 1 - ) - - recv_bytes = np.zeros(n_recv_bytes, dtype=np.uint8) - comm.Recv( - recv_bytes, source=proc, tag=tag_offset + 10 * proc + 2 - ) - dslc = (slice(0, n_recv_bytes, 1),) - hslc = (slice(hf_bytes, hf_bytes + n_recv_bytes, 1),) - hdata_bytes.write_direct(recv_bytes, dslc, hslc) - del recv_bytes - - recv_ranges = np.zeros(n_recv_dets * 2, dtype=np.int64) - comm.Recv( - recv_ranges, - source=proc, - tag=tag_offset + 10 * proc + 3, - ) - recv_ranges[:] += hf_bytes - dslc = ( - slice(0, n_recv_dets, 1), - slice(0, 2, 1), - ) - hslc = ( - slice(hf_det, hf_det + n_recv_dets, 1), - slice(0, 2, 1), - ) - hdata_ranges.write_direct( - recv_ranges.reshape((n_recv_dets, 2)), dslc, hslc - ) - del recv_ranges - - recv_buf = np.zeros(n_recv_dets, dtype=np.float64) - dslc = (slice(0, n_recv_dets, 1),) - hslc = (slice(hf_det, hf_det + n_recv_dets, 1),) - if comp_data_offsets is not None: - comm.Recv( - recv_buf, source=proc, tag=tag_offset + 10 * proc + 4 - ) - hdata_offsets.write_direct(recv_buf, dslc, hslc) - if comp_data_gains is not None: - comm.Recv( - recv_buf, source=proc, tag=tag_offset + 10 * proc + 5 - ) - hdata_gains.write_direct(recv_buf, dslc, hslc) - del recv_buf - - hf_bytes += n_recv_bytes - hf_det += n_recv_dets - - elif proc == rank: - # We are sending. First send the number of bytes and detectors - det_ranges = np.zeros( - (len(comp_ranges), 2), - dtype=np.int64, - ) - for d in range(len(comp_ranges)): - det_ranges[d, :] = comp_ranges[d] - comm.send(n_local_bytes, dest=0, tag=tag_offset + 10 * proc) - comm.send(len(det_ranges), dest=0, tag=tag_offset + 10 * proc + 1) - - comm.Send(comp_bytes, dest=0, tag=tag_offset + 10 * proc + 2) - comm.Send( - det_ranges.flatten(), dest=0, tag=tag_offset + 10 * proc + 3 - ) - if comp_data_offsets is not None: - comm.Send( - comp_data_offsets, dest=0, tag=tag_offset + 10 * proc + 4 - ) - if comp_data_gains is not None: - comm.Send( - comp_data_gains, dest=0, tag=tag_offset + 10 * proc + 5 - ) + for idet, det in enumerate(local_data.detectors): + local_data.data[idet] = flcdata[idet] + del flcdata @function_timer @@ -599,6 +460,7 @@ def save_instrument(parent_group, telescope, comm=None, session=None): if parent_group is not None: # Instrument properties inst_group = parent_group.create_group("instrument") + inst_group.attrs["toast_format_version"] = 2 inst_group.attrs["telescope_name"] = telescope.name inst_group.attrs["telescope_class"] = object_fullname(telescope.__class__) inst_group.attrs["telescope_uid"] = telescope.uid @@ -681,7 +543,6 @@ def save_hdf5( config=None, times=defaults.times, force_serial=False, - detdata_float32=False, detdata_in_place=False, ): """Save an observation to HDF5. @@ -719,8 +580,6 @@ def save_hdf5( times (str): The name of the shared timestamp field. force_serial (bool): If True, do not use HDF5 parallel support, even if it is available. - detdata_float32 (bool): If True, cast any float64 detector fields - to float32 on write. Integer detdata is not affected. detdata_in_place (bool): If True, input detdata will be replaced with a compressed and decompressed version that includes the digitization error. @@ -759,6 +618,18 @@ def save_hdf5( hf = hdf5_open(hfpath_temp, "w", comm=comm, force_serial=force_serial) hgroup = hf + # Gather the local detector flags to all writing processes + if comm is None: + all_det_flags = obs.local_detector_flags + else: + proc_det_flags = comm.gather(obs.local_detector_flags, root=0) + all_det_flags = None + if rank == 0: + all_det_flags = dict() + for pflags in proc_det_flags: + all_det_flags.update(pflags) + all_det_flags = comm.bcast(all_det_flags, root=0) + shared_group = None detdata_group = None intervals_group = None @@ -768,7 +639,7 @@ def save_hdf5( hgroup.attrs["toast_version"] = env.version() if config is not None: hgroup.attrs["job_config"] = json.dumps(config) - hgroup.attrs["toast_format_version"] = 1 + hgroup.attrs["toast_format_version"] = 2 # Observation properties hgroup.attrs["observation_name"] = obs.name @@ -788,6 +659,9 @@ def save_hdf5( hgroup.attrs["observation_samples"] = obs.n_all_samples hgroup.attrs["observation_sample_sets"] = obs_all_samp_sets + # Per detector flags. + hgroup.attrs["observation_detector_flags"] = json.dumps(all_det_flags) + log.verbose_rank( f"{log_prefix} Wrote observation attributes in", comm=comm, @@ -809,9 +683,12 @@ def save_hdf5( ) meta_group = None + meta_other = None if hgroup is not None: meta_group = hgroup.create_group("metadata") + meta_other = meta_group.create_group("other") + # Process all metadata in the observation dictionary for k, v in obs.items(): if meta is not None and k not in meta: continue @@ -822,36 +699,27 @@ def save_hdf5( kgroup.attrs["class"] = object_fullname(v.__class__) v.save_hdf5(kgroup, obs) del kgroup - elif isinstance(v, u.Quantity): - if isinstance(v.value, np.ndarray): - # Array quantity - if meta_group is not None: - qdata = meta_group.create_dataset(k, data=v.value) - qdata.attrs["units"] = v.unit.to_string() - del qdata - else: - # Must be a scalar - if meta_group is not None: - meta_group.attrs[f"{k}"] = v.value - meta_group.attrs[f"{k}_units"] = v.unit.to_string() - elif isinstance(v, np.ndarray): - if meta_group is not None: - marr = meta_group.create_dataset(k, data=v) - del marr - elif meta_group is not None: - try: - if isinstance(v, u.Quantity): - meta_group.attrs[k] = v.value - else: - meta_group.attrs[k] = v - except (ValueError, TypeError) as e: - msg = f"Failed to store obs key '{k}' = '{v}' as an attribute ({e})." - msg += " Try casting it to a supported type when storing in the " - msg += "observation dictionary or implement save_hdf5() and " - msg += "load_hdf5() methods." - log.verbose(msg) + else: + # Process this object recursively + save_meta_object(meta_other, k, v) + del meta_other del meta_group + # Now pass through observation attributes and look for things to save + attr_group = None + if hgroup is not None: + attr_group = hgroup.create_group("attr") + for k, v in vars(obs).items(): + if k.startswith("_"): + continue + if hasattr(v, "save_hdf5"): + kgroup = None + if attr_group is not None: + kgroup = attr_group.create_group(k) + kgroup.attrs["class"] = object_fullname(v.__class__) + v.save_hdf5(kgroup, obs) + del kgroup + log.verbose_rank( f"{log_prefix} Wrote other metadata in", comm=comm, @@ -893,6 +761,7 @@ def save_hdf5( ) if detdata is None: + # We are writing all detector data without, compression fields = [(x, None) for x in obs.detdata.keys()] else: fields = list() @@ -909,7 +778,6 @@ def save_hdf5( detdata_group, fields, log_prefix, - use_float32=detdata_float32, in_place=detdata_in_place, ) del detdata_group diff --git a/src/toast/io/observation_hdf_save_v1.py b/src/toast/io/observation_hdf_save_v1.py new file mode 100644 index 000000000..e85104976 --- /dev/null +++ b/src/toast/io/observation_hdf_save_v1.py @@ -0,0 +1,984 @@ +# Copyright (c) 2021-2025 by the parties listed in the AUTHORS file. +# All rights reserved. Use of this source code is governed by +# a BSD-style license that can be found in the LICENSE file. +"""This deprecated v1 format save functionality is kept here for +unit tests that check the ability to load v1 data. +""" + +import json +import os +from datetime import timezone + +import h5py +import numpy as np +from astropy import units as u + +from ..instrument import GroundSite +from ..mpi import MPI +from ..observation import default_values as defaults +from ..observation_data import DetectorData +from ..observation_dist import global_interval_times +from ..timing import Timer, function_timer +from ..utils import ( + Environment, + Logger, + dtype_to_aligned, + hdf5_use_serial, + object_fullname, +) +from .deprecated_compression import compress_detdata, decompress_detdata +from .hdf_utils import check_dataset_buffer_size, hdf5_open + + +@function_timer +def save_hdf5_shared(obs, hgrp, fields, log_prefix): + log = Logger.get() + + timer = Timer() + timer.start() + + # Get references to the distribution of detectors and samples + proc_rows = obs.dist.process_rows + proc_cols = obs.dist.comm.group_size // proc_rows + dist_samps = obs.dist.samps + dist_dets = obs.dist.det_indices + + # Are we doing serial I/O? + use_serial = hdf5_use_serial(hgrp, obs.comm.comm_group) + + for ifield, field in enumerate(fields): + tag_offset = (obs.comm.group * 1000 + ifield) * obs.comm.group_size + + if field not in obs.shared: + msg = f"Shared data '{field}' does not exist in observation " + msg += f"{obs.name}. Skipping." + log.warning_rank(msg, comm=obs.comm.comm_group) + continue + + # Compute properties of the full set of data across the observation + + scomm = obs.shared.comm_type(field) + sdata = obs.shared[field] + sdtype = sdata.dtype + + if scomm == "group": + sshape = sdata.shape + elif scomm == "column": + sshape = (obs.n_all_samples,) + sdata.shape[1:] + else: + sshape = (len(obs.all_detectors),) + sdata.shape[1:] + + # The buffer class to use for allocating receive buffers + bufclass, _ = dtype_to_aligned(sdtype) + + hdata = None + if hgrp is not None: + # This process is participating + hdata = hgrp.create_dataset(field, sshape, dtype=sdtype) + hdata.attrs["comm_type"] = scomm + + if use_serial: + # Send data to rank zero of the group for writing. + if scomm == "group": + # Easy... + if obs.comm.group_rank == 0: + hdata[:] = sdata.data + elif scomm == "column": + # Send data to root process + for proc in range(proc_cols): + # Process grid is indexed row-major, so the rank-zero process + # of each column is just the first row of the grid. + send_rank = proc + + # Leading data range for this process + off = dist_samps[send_rank].offset + nelem = dist_samps[send_rank].n_elem + nflat = nelem * np.prod(sshape[1:]) + shp = (nelem,) + sshape[1:] + if send_rank == 0: + # Root process writes local data + if obs.comm.group_rank == 0: + hdata[off : off + nelem] = sdata.data + elif send_rank == obs.comm.group_rank: + # We are sending + obs.comm.comm_group.Send( + sdata.data.flatten(), dest=0, tag=tag_offset + send_rank + ) + elif obs.comm.group_rank == 0: + # We are receiving and writing + recv = bufclass(nflat) + obs.comm.comm_group.Recv( + recv, source=send_rank, tag=tag_offset + send_rank + ) + hdata[off : off + nelem] = recv.array().reshape(shp) + recv.clear() + del recv + else: + # Send data to root process + for proc in range(proc_rows): + # Process grid is indexed row-major, so the rank-zero process + # of each row is strided by the number of columns. + send_rank = proc * proc_cols + # Leading data range for this process + off = dist_dets[send_rank].offset + nelem = dist_dets[send_rank].n_elem + nflat = nelem * np.prod(sshape[1:]) + shp = (nelem,) + sshape[1:] + if send_rank == 0: + # Root process writes local data + if obs.comm.group_rank == 0: + hdata[off : off + nelem] = sdata.data + elif send_rank == obs.comm.group_rank: + # We are sending + obs.comm.comm_group.Send( + sdata.data.flatten(), dest=0, tag=tag_offset + send_rank + ) + elif obs.comm.group_rank == 0: + # We are receiving and writing + recv = bufclass(nflat) + obs.comm.comm_group.Recv( + recv, source=send_rank, tag=tag_offset + send_rank + ) + hdata[off : off + nelem] = recv.array().reshape(shp) + recv.clear() + del recv + else: + # If we have parallel support, the rank zero of each comm can write + # independently. + if scomm == "group": + # Easy... + if obs.comm.group_rank == 0: + msg = f"Shared field {field} ({scomm})" + slices = tuple([slice(0, x) for x in sshape]) + check_dataset_buffer_size(msg, slices, sdtype, True) + hdata.write_direct(sdata.data, slices, slices) + elif scomm == "column": + # Rank zero of each column writes + if sdata.comm is None or sdata.comm.rank == 0: + sh_slices = tuple([slice(0, x) for x in sshape]) + offset = dist_samps[obs.comm.group_rank].offset + nelem = dist_samps[obs.comm.group_rank].n_elem + hf_slices = [ + slice(offset, offset + nelem), + ] + hf_slices.extend([slice(0, x) for x in sdata.shape[1:]]) + hf_slices = tuple(hf_slices) + msg = f"Shared field {field} ({scomm})" + check_dataset_buffer_size(msg, hf_slices, sdtype, True) + hdata.write_direct(sdata.data, sh_slices, hf_slices) + else: + # Rank zero of each row writes + if sdata.comm is None or sdata.comm.rank == 0: + sh_slices = tuple([slice(0, x) for x in sshape]) + offset = dist_dets[obs.comm.group_rank].offset + nelem = dist_dets[obs.comm.group_rank].n_elem + hf_slices = [ + slice(offset, offset + nelem), + ] + hf_slices.extend([slice(0, x) for x in sdata.shape[1:]]) + hf_slices = tuple(hf_slices) + msg = f"Shared field {field} ({scomm})" + check_dataset_buffer_size(msg, hf_slices, sdtype, True) + hdata.write_direct(sdata.data, sh_slices, hf_slices) + + log.verbose_rank( + f"{log_prefix} Shared finished {field} write in", + comm=obs.comm.comm_group, + timer=timer, + ) + del hdata + + +@function_timer +def save_hdf5_detdata(obs, hgrp, fields, log_prefix, use_float32=False, in_place=False): + log = Logger.get() + + timer = Timer() + timer.start() + + # Get references to the distribution of detectors and samples + proc_rows = obs.dist.process_rows + proc_cols = obs.dist.comm.group_size // proc_rows + dist_samps = obs.dist.samps + dist_dets = obs.dist.det_indices + + # We are using the group communicator + comm = obs.comm.comm_group + nproc = obs.comm.group_size + rank = obs.comm.group_rank + + # Are we doing serial I/O? + use_serial = hdf5_use_serial(hgrp, comm) + + for ifield, (field, fieldcomp) in enumerate(fields): + tag_offset = (obs.comm.group * 1000 + ifield) * obs.comm.group_size + if field not in obs.detdata: + msg = f"Detector data '{field}' does not exist in observation " + msg += f"{obs.name}. Skipping." + log.warning_rank(msg, comm=comm) + continue + + local_data = obs.detdata[field] + if local_data.detectors != obs.local_detectors: + msg = f"Detector data '{field}' does not contain all local detectors." + log.error(msg) + raise RuntimeError(msg) + + # If we are using compression, we require the data to be distributed by + # by detector, since we will compress each detector independently. + if fieldcomp is not None and proc_cols != 1: + msg = f"Detector data '{field}' compression requested, but data for " + msg += "individual channels is split between processes." + raise RuntimeError(msg) + + # Compute properties of the full set of data across the observation + ddtype = local_data.dtype + dshape = (len(obs.all_detectors), obs.n_all_samples) + dvalshape = None + if len(local_data.detector_shape) > 1: + dvalshape = local_data.detector_shape[1:] + dshape += dvalshape + + fdtype = ddtype + if ddtype.char == "d" and use_float32: + # We are truncating to single precision + fdtype = np.dtype(np.float32) + + # If we are using our own internal compression, each process compresses their + # local data and sends it to one process for insertion into the overall blob + # of bytes. + + if fieldcomp is None or "type_hdf5" in fieldcomp: + # The buffer class to use for allocating receive buffers + bufclass, _ = dtype_to_aligned(ddtype) + + # Group handle + hdata = None + if hgrp is not None: + # This process is participating. + # + # Future NOTE: Here is where we could extract the "type_hdf5" parameter + # from the dictionary and create the dataset with appropriate compression + # and chunking settings. Detector data would then be written to the + # dataset in the usual way, with compression "under the hood" done by + # HDF5. + # + hdata = hgrp.create_dataset(field, dshape, dtype=fdtype) + hdata.attrs["units"] = local_data.units.to_string() + + if use_serial: + # Send data to rank zero of the group for writing. + for proc in range(nproc): + # Data ranges for this process + samp_off = dist_samps[proc].offset + samp_nelem = dist_samps[proc].n_elem + det_off = dist_dets[proc].offset + det_nelem = dist_dets[proc].n_elem + nflat = det_nelem * samp_nelem + shp = (det_nelem, samp_nelem) + detdata_slice = [slice(0, det_nelem, 1), slice(0, samp_nelem, 1)] + hf_slice = [ + slice(det_off, det_off + det_nelem, 1), + slice(samp_off, samp_off + samp_nelem, 1), + ] + if dvalshape is not None: + nflat *= np.prod(dvalshape) + shp += dvalshape + detdata_slice.extend([slice(0, x) for x in dvalshape]) + hf_slice.extend([slice(0, x) for x in dvalshape]) + detdata_slice = tuple(detdata_slice) + hf_slice = tuple(hf_slice) + if proc == 0: + # Root process writes local data + if rank == 0: + hdata.write_direct( + local_data.data.astype(fdtype), detdata_slice, hf_slice + ) + elif proc == rank: + # We are sending + comm.Send(local_data.flatdata, dest=0, tag=tag_offset + proc) + elif rank == 0: + # We are receiving and writing + recv = bufclass(nflat) + comm.Recv(recv, source=proc, tag=tag_offset + proc) + hdata.write_direct( + recv.array().astype(fdtype).reshape(shp), + detdata_slice, + hf_slice, + ) + recv.clear() + del recv + else: + # If we have parallel support, every process can write independently. + samp_off = dist_samps[obs.comm.group_rank].offset + samp_nelem = dist_samps[obs.comm.group_rank].n_elem + det_off = dist_dets[obs.comm.group_rank].offset + det_nelem = dist_dets[obs.comm.group_rank].n_elem + + detdata_slice = [slice(0, det_nelem, 1), slice(0, samp_nelem, 1)] + hf_slice = [ + slice(det_off, det_off + det_nelem, 1), + slice(samp_off, samp_off + samp_nelem, 1), + ] + if dvalshape is not None: + detdata_slice.extend([slice(0, x) for x in dvalshape]) + hf_slice.extend([slice(0, x) for x in dvalshape]) + detdata_slice = tuple(detdata_slice) + hf_slice = tuple(hf_slice) + msg = f"Detector data field {field} (group rank {obs.comm.group_rank})" + check_dataset_buffer_size(msg, hf_slice, ddtype, True) + + with hdata.collective: + hdata.write_direct( + local_data.data.astype(fdtype), detdata_slice, hf_slice + ) + del hdata + log.verbose_rank( + f"{log_prefix} Detdata finished {field} serial write in", + comm=comm, + timer=timer, + ) + else: + # Compress our local detector data. The starting dictionary of properties + # is passed in and additional metadata is appended. + if ddtype.char == "d" and use_float32: + temp_detdata = DetectorData( + obs.detdata[field].detectors, + obs.detdata[field].detector_shape, + np.float32, + units=obs.detdata[field].units, + ) + temp_detdata.data[:] = obs.detdata[field].data.astype(np.float32) + comp_bytes, comp_ranges, comp_props = compress_detdata( + temp_detdata, fieldcomp + ) + if in_place: + # Decompress + decompress_detdata( + comp_bytes, comp_ranges, comp_props, detdata=temp_detdata + ) + # upcast back to float64 and and overwrite the original detector data + obs.detdata[field].data[:] = temp_detdata.data[:] + del temp_detdata + else: + temp_detdata = None + comp_bytes, comp_ranges, comp_props = compress_detdata( + obs.detdata[field], fieldcomp + ) + if in_place: + # Decompress and overwrite the original detector data + decompress_detdata( + comp_bytes, comp_ranges, comp_props, detdata=obs.detdata[field] + ) + + # Extract per-detector quantities for communicating / writing later + comp_data_offsets = None + if "data_offsets" in comp_props: + comp_data_offsets = comp_props["data_offsets"] + comp_data_gains = None + if "data_gains" in comp_props: + comp_data_gains = comp_props["data_gains"] + + # Get the total number of bytes + n_local_bytes = len(comp_bytes) + if comm is None: + n_all_bytes = n_local_bytes + else: + n_all_bytes = comm.allreduce(n_local_bytes, op=MPI.SUM) + + # Create the datasets + hdata_bytes = None + hdata_ranges = None + hdata_offsets = None + hdata_gains = None + cgrp = None + if hgrp is not None: + # This process is participating. + cgrp = hgrp.create_group(field) + hdata_bytes = cgrp.create_dataset( + "compressed", n_all_bytes, dtype=np.uint8 + ) + hdata_bytes.attrs["units"] = local_data.units.to_string() + # Write common properties of many compression schemes + hdata_bytes.attrs["dtype"] = str(comp_props["dtype"]) + hdata_bytes.attrs["det_shape"] = str( + tuple([int(x) for x in comp_props["det_shape"]]) + ) + hdata_bytes.attrs["comp_type"] = comp_props["type"] + if "level" in comp_props: + hdata_bytes.attrs["comp_level"] = comp_props["level"] + hdata_ranges = cgrp.create_dataset( + "ranges", + (len(obs.all_detectors), 2), + dtype=np.int64, + ) + if comp_data_offsets is not None: + hdata_offsets = cgrp.create_dataset( + "offsets", + (len(obs.all_detectors),), + dtype=np.float64, + ) + if comp_data_gains is not None: + hdata_gains = cgrp.create_dataset( + "gains", + (len(obs.all_detectors),), + dtype=np.float64, + ) + + # Send data to rank zero of the group for writing. + hf_det = 0 + hf_bytes = 0 + for proc in range(nproc): + if rank == 0: + if proc == 0: + # Root process writes local data + det_ranges = np.array( + [(x[0] + hf_bytes, x[1] + hf_bytes) for x in comp_ranges], + dtype=np.int64, + ).reshape((-1, 2)) + dslc = ( + slice(0, len(det_ranges), 1), + slice(0, 2, 1), + ) + hslc = ( + slice(hf_det, hf_det + len(det_ranges), 1), + slice(0, 2, 1), + ) + hdata_ranges.write_direct(det_ranges, dslc, hslc) + + dslc = (slice(0, n_local_bytes, 1),) + hslc = (slice(hf_bytes, hf_bytes + n_local_bytes, 1),) + hdata_bytes.write_direct(comp_bytes, dslc, hslc) + + dslc = (slice(0, len(comp_ranges), 1),) + hslc = (slice(hf_det, hf_det + len(comp_ranges), 1),) + if comp_data_offsets is not None: + hdata_offsets.write_direct(comp_data_offsets, dslc, hslc) + if comp_data_gains is not None: + hdata_gains.write_direct(comp_data_gains, dslc, hslc) + + hf_bytes += n_local_bytes + hf_det += len(comp_ranges) + else: + # Receive data and write + n_recv_bytes = comm.recv( + source=proc, tag=tag_offset + 10 * proc + ) + n_recv_dets = comm.recv( + source=proc, tag=tag_offset + 10 * proc + 1 + ) + + recv_bytes = np.zeros(n_recv_bytes, dtype=np.uint8) + comm.Recv( + recv_bytes, source=proc, tag=tag_offset + 10 * proc + 2 + ) + dslc = (slice(0, n_recv_bytes, 1),) + hslc = (slice(hf_bytes, hf_bytes + n_recv_bytes, 1),) + hdata_bytes.write_direct(recv_bytes, dslc, hslc) + del recv_bytes + + recv_ranges = np.zeros(n_recv_dets * 2, dtype=np.int64) + comm.Recv( + recv_ranges, + source=proc, + tag=tag_offset + 10 * proc + 3, + ) + recv_ranges[:] += hf_bytes + dslc = ( + slice(0, n_recv_dets, 1), + slice(0, 2, 1), + ) + hslc = ( + slice(hf_det, hf_det + n_recv_dets, 1), + slice(0, 2, 1), + ) + hdata_ranges.write_direct( + recv_ranges.reshape((n_recv_dets, 2)), dslc, hslc + ) + del recv_ranges + + recv_buf = np.zeros(n_recv_dets, dtype=np.float64) + dslc = (slice(0, n_recv_dets, 1),) + hslc = (slice(hf_det, hf_det + n_recv_dets, 1),) + if comp_data_offsets is not None: + comm.Recv( + recv_buf, source=proc, tag=tag_offset + 10 * proc + 4 + ) + hdata_offsets.write_direct(recv_buf, dslc, hslc) + if comp_data_gains is not None: + comm.Recv( + recv_buf, source=proc, tag=tag_offset + 10 * proc + 5 + ) + hdata_gains.write_direct(recv_buf, dslc, hslc) + del recv_buf + + hf_bytes += n_recv_bytes + hf_det += n_recv_dets + + elif proc == rank: + # We are sending. First send the number of bytes and detectors + det_ranges = np.zeros( + (len(comp_ranges), 2), + dtype=np.int64, + ) + for d in range(len(comp_ranges)): + det_ranges[d, :] = comp_ranges[d] + comm.send(n_local_bytes, dest=0, tag=tag_offset + 10 * proc) + comm.send(len(det_ranges), dest=0, tag=tag_offset + 10 * proc + 1) + + comm.Send(comp_bytes, dest=0, tag=tag_offset + 10 * proc + 2) + comm.Send( + det_ranges.flatten(), dest=0, tag=tag_offset + 10 * proc + 3 + ) + if comp_data_offsets is not None: + comm.Send( + comp_data_offsets, dest=0, tag=tag_offset + 10 * proc + 4 + ) + if comp_data_gains is not None: + comm.Send( + comp_data_gains, dest=0, tag=tag_offset + 10 * proc + 5 + ) + + +@function_timer +def save_hdf5_intervals(obs, hgrp, fields, log_prefix): + log = Logger.get() + + timer = Timer() + timer.start() + + # We are using the group communicator + comm = obs.comm.comm_group + nproc = obs.comm.group_size + rank = obs.comm.group_rank + + for field in fields: + if field not in obs.intervals: + msg = f"Intervals '{field}' does not exist in observation " + msg += f"{obs.name}. Skipping." + log.warning_rank(msg, comm=comm) + continue + + if field == obs.intervals.all_name: + # This is the internal fake interval for all samples. We don't + # save this because it is re-created on demand. + continue + + # Get the list of start / stop tuples on the rank zero process + ilist = global_interval_times(obs.dist, obs.intervals, field, join=False) + + n_list = None + if rank == 0: + n_list = len(ilist) + if comm is not None: + n_list = comm.bcast(n_list, root=0) + + # Participating processes create the dataset + hdata = None + if hgrp is not None: + hdata = hgrp.create_dataset(field, (2, n_list), dtype=np.float64) + # Only the root process writes + if rank == 0: + hdata[:, :] = np.transpose(np.array(ilist)) + del hdata + + log.verbose_rank( + f"{log_prefix} Intervals finished {field} write in", + comm=comm, + timer=timer, + ) + + +def save_instrument(parent_group, telescope, comm=None, session=None): + """Save instrument information to an HDF5 group. + + Given the parent group (which might exist on multiple processes in the case of + MPI use), create an instrument sub group and write telescope and optionally + session information to that group. + + """ + inst_group = None + if parent_group is not None: + # Instrument properties + inst_group = parent_group.create_group("instrument") + inst_group.attrs["toast_format_version"] = 1 + inst_group.attrs["telescope_name"] = telescope.name + inst_group.attrs["telescope_class"] = object_fullname(telescope.__class__) + inst_group.attrs["telescope_uid"] = telescope.uid + site = telescope.site + inst_group.attrs["site_name"] = site.name + inst_group.attrs["site_class"] = object_fullname(site.__class__) + inst_group.attrs["site_uid"] = site.uid + if isinstance(site, GroundSite): + inst_group.attrs["site_lat_deg"] = float( + site.earthloc.lat.to_value(u.degree) + ) + inst_group.attrs["site_lon_deg"] = float( + site.earthloc.lon.to_value(u.degree) + ) + inst_group.attrs["site_alt_m"] = float( + site.earthloc.height.to_value(u.meter) + ) + if site.weather is not None: + if hasattr(site.weather, "name"): + # This is a simulated weather object, dump it. + inst_group.attrs["site_weather_name"] = str(site.weather.name) + inst_group.attrs["site_weather_realization"] = int( + site.weather.realization + ) + if site.weather.max_pwv is None: + inst_group.attrs["site_weather_max_pwv"] = "NONE" + else: + inst_group.attrs["site_weather_max_pwv"] = float( + site.weather.max_pwv.to_value(u.mm) + ) + inst_group.attrs["site_weather_time"] = ( + site.weather.time.timestamp() + ) + inst_group.attrs["site_weather_median"] = ( + site.weather.median_weather + ) + else: + # This is a generic weather object + inst_group.attrs["site_weather_time"] = ( + site.weather.time.astimezone(timezone.utc).timestamp() + ) + for attr_name in [ + "ice_water", + "liquid_water", + "pwv", + "humidity", + "surface_pressure", + "surface_temperature", + "air_temperature", + "west_wind", + "south_wind", + ]: + file_attr = f"site_weather_{attr_name}" + attr_val = getattr(site.weather, attr_name) + inst_group.attrs[file_attr] = str(attr_val) + if session is not None: + inst_group.attrs["session_name"] = session.name + inst_group.attrs["session_class"] = object_fullname(session.__class__) + inst_group.attrs["session_uid"] = session.uid + if session.start is None: + inst_group.attrs["session_start"] = "NONE" + else: + inst_group.attrs["session_start"] = session.start.timestamp() + if session.end is None: + inst_group.attrs["session_end"] = "NONE" + else: + inst_group.attrs["session_end"] = session.end.timestamp() + telescope.focalplane.save_hdf5(inst_group, comm=comm) + del inst_group + + +@function_timer +def save_hdf5( + obs, + dir, + meta=None, + detdata=None, + shared=None, + intervals=None, + config=None, + times=defaults.times, + force_serial=False, + detdata_float32=False, + detdata_in_place=False, +): + """Save an observation to HDF5. + + This function writes an observation to a new file in the specified directory. The + name is built from the observation name and the observation UID. + + The telescope information is written to a sub-dataset. + + By default, all shared, intervals, and noise models are dumped as individual + datasets. A subset of objects may be specified with a list of names passed to + the corresponding function arguments. + + For detector data, by default, all objects will be dumped uncompressed into + individual datasets. If you wish to specify a subset you can provide a list of + names and only these will be dumped uncompressed. To enable compression, provide + a list of tuples, (detector data name, compression properties), where compression + properties is a dictionary accepted by the `compress_detdata()` function. + + When dumping arbitrary metadata, scalars are stored as attributes of the observation + "meta" group. Any objects in the metadata which have a `save_hdf5()` method are + passed a group and the name of the new dataset to create. Other objects are + attempted to be dumped by h5py and a warning is printed if it fails. The list of + metadata objects to dump can be given explicitly. + + Args: + obs (Observation): The observation to write. + dir (str): The parent directory containing the file. + meta (list): Only save this list of metadata objects. + detdata (list): Only save this list of detdata objects, optionally with + compression. + shared (list): Only save this list of shared objects. + intervals (list): Only save this list of intervals objects. + config (dict): The job config dictionary to save. + times (str): The name of the shared timestamp field. + force_serial (bool): If True, do not use HDF5 parallel support, + even if it is available. + detdata_float32 (bool): If True, cast any float64 detector fields + to float32 on write. Integer detdata is not affected. + detdata_in_place (bool): If True, input detdata will be replaced + with a compressed and decompressed version that includes the + digitization error. + + + Returns: + (str): The full path of the file that was written. + + """ + log = Logger.get() + env = Environment.get() + if obs.comm.group_size == 1: + # Force serial usage in this case, to avoid any MPI overhead + force_serial = True + + if obs.name is None: + raise RuntimeError("Cannot save observations that have no name") + + timer = Timer() + timer.start() + log_prefix = f"HDF5 save {obs.name}: " + + comm = obs.comm.comm_group + rank = obs.comm.group_rank + + namestr = f"{obs.name}_{obs.uid}" + hfpath = os.path.join(dir, f"obs_{namestr}.h5") + hfpath_temp = f"{hfpath}.tmp" + + # Create the file and get the root group + hf = None + hgroup = None + vtimer = Timer() + vtimer.start() + + hf = hdf5_open(hfpath_temp, "w", comm=comm, force_serial=force_serial) + hgroup = hf + + shared_group = None + detdata_group = None + intervals_group = None + if hgroup is not None: + # This process is participating + # Record the software versions and config + hgroup.attrs["toast_version"] = env.version() + if config is not None: + hgroup.attrs["job_config"] = json.dumps(config) + hgroup.attrs["toast_format_version"] = 1 + + # Observation properties + hgroup.attrs["observation_name"] = obs.name + hgroup.attrs["observation_uid"] = obs.uid + + obs_all_dets = json.dumps(obs.all_detectors) + obs_all_det_sets = "NONE" + if obs.all_detector_sets is not None: + obs_all_det_sets = json.dumps(obs.all_detector_sets) + obs_all_samp_sets = "NONE" + if obs.all_sample_sets is not None: + obs_all_samp_sets = json.dumps( + [[str(x) for x in y] for y in obs.all_sample_sets] + ) + hgroup.attrs["observation_detectors"] = obs_all_dets + hgroup.attrs["observation_detector_sets"] = obs_all_det_sets + hgroup.attrs["observation_samples"] = obs.n_all_samples + hgroup.attrs["observation_sample_sets"] = obs_all_samp_sets + + log.verbose_rank( + f"{log_prefix} Wrote observation attributes in", + comm=comm, + timer=vtimer, + ) + + save_instrument(hgroup, obs.telescope, comm=comm, session=obs.session) + + log.verbose_rank( + f"{log_prefix} Wrote instrument in", + comm=comm, + timer=vtimer, + ) + + log.debug_rank( + f"{log_prefix} Finished instrument model", + comm=comm, + timer=timer, + ) + + meta_group = None + if hgroup is not None: + meta_group = hgroup.create_group("metadata") + + for k, v in obs.items(): + if meta is not None and k not in meta: + continue + if hasattr(v, "save_hdf5"): + kgroup = None + if meta_group is not None: + kgroup = meta_group.create_group(k) + kgroup.attrs["class"] = object_fullname(v.__class__) + v.save_hdf5(kgroup, obs) + del kgroup + elif isinstance(v, u.Quantity): + if isinstance(v.value, np.ndarray): + # Array quantity + if meta_group is not None: + qdata = meta_group.create_dataset(k, data=v.value) + qdata.attrs["units"] = v.unit.to_string() + del qdata + else: + # Must be a scalar + if meta_group is not None: + meta_group.attrs[f"{k}"] = v.value + meta_group.attrs[f"{k}_units"] = v.unit.to_string() + elif isinstance(v, np.ndarray): + if meta_group is not None: + marr = meta_group.create_dataset(k, data=v) + del marr + elif meta_group is not None: + try: + if isinstance(v, u.Quantity): + meta_group.attrs[k] = v.value + else: + meta_group.attrs[k] = v + except (ValueError, TypeError) as e: + msg = f"Failed to store obs key '{k}' = '{v}' as an attribute ({e})." + msg += " Try casting it to a supported type when storing in the " + msg += "observation dictionary or implement save_hdf5() and " + msg += "load_hdf5() methods." + log.verbose(msg) + del meta_group + + log.verbose_rank( + f"{log_prefix} Wrote other metadata in", + comm=comm, + timer=vtimer, + ) + + log.debug_rank( + f"{log_prefix} Finished metadata", + comm=comm, + timer=timer, + ) + + # Dump data + + if shared is None: + fields = list(obs.shared.keys()) + else: + fields = list(shared) + + dump_intervals = True + if times not in obs.shared: + msg = f"Timestamp field '{times}' does not exist. Not saving intervals." + log.warning_rank(msg, comm=comm) + dump_intervals = False + else: + if times not in fields: + fields.append(times) + + shared_group = None + if hgroup is not None: + shared_group = hgroup.create_group("shared") + save_hdf5_shared(obs, shared_group, fields, log_prefix) + del shared_group + + log.debug_rank( + f"{log_prefix} Finished shared data", + comm=comm, + timer=timer, + ) + + if detdata is None: + fields = [(x, None) for x in obs.detdata.keys()] + else: + fields = list() + for df in detdata: + if isinstance(df, (tuple, list)): + fields.append(df) + else: + fields.append((df, None)) + detdata_group = None + if hgroup is not None: + detdata_group = hgroup.create_group("detdata") + save_hdf5_detdata( + obs, + detdata_group, + fields, + log_prefix, + use_float32=detdata_float32, + in_place=detdata_in_place, + ) + del detdata_group + log.debug_rank( + f"{log_prefix} Finished detector data", + comm=comm, + timer=timer, + ) + + if intervals is None: + fields = list(obs.intervals.keys()) + else: + fields = list(intervals) + if dump_intervals: + intervals_group = None + if hgroup is not None: + intervals_group = hgroup.create_group("intervals") + intervals_group.attrs["times"] = times + save_hdf5_intervals(obs, intervals_group, fields, log_prefix) + del intervals_group + log.debug_rank( + f"{log_prefix} Finished intervals data", + comm=comm, + timer=timer, + ) + + # Close file if we opened it + del hgroup + if hf is not None: + hf.close() + del hf + + if comm is not None: + comm.barrier() + + # Move file into place + if rank == 0: + os.rename(hfpath_temp, hfpath) + + return hfpath + + +def save_instrument_file(path, telescope, session): + """Save instrument data to an HDF5 group. + + This function loads the telescope and session serially on one process. + It supports including a relative internal path inside the HDF5 file by separating + the filesystem path from the internal path with a colon. For example: + + path="/path/to/file.h5:/obs1 + + The internal path should be to the *parent* group of the "instrument" group. + + """ + parts = path.split(":") + if len(parts) == 1: + file = parts[0] + internal = "/" + else: + file = parts[0] + internal = parts[1] + grouptree = internal.split(os.path.sep) + with h5py.File(file, "w") as hf: + parent = hf + for grp in grouptree: + if grp == "": + continue + parent = parent.create_group(grp) + save_instrument(parent, telescope, session=session) diff --git a/src/toast/observation.py b/src/toast/observation.py index a29ed4036..3d37484d0 100644 --- a/src/toast/observation.py +++ b/src/toast/observation.py @@ -1,4 +1,4 @@ -# Copyright (c) 2015-2024 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2025 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. @@ -197,6 +197,7 @@ class Observation(MutableMapping): """ view = ViewInterface() + _reserved = set(["dist", "detdata", "shared", "intervals"]) @function_timer def __init__( @@ -574,6 +575,76 @@ def __repr__(self): val += "\n>" return val + def meta_equal(self, other, prefix): + """Test if observation metadata is equal between instances. + + This compares the `_internal` dictionary of metadata between two + observations. + + Args: + other (Observation): The other instance to compare + prefix (str): The top level prefix string for logging + + Returns: + (bool): True if the metadata is equal, else False + + """ + log = Logger.get() + + def _compare_nodes(self_obj, other_obj, prefix): + if type(self_obj) is not type(other_obj): + if np.ndim(self_obj) == 0 and np.ndim(other_obj) == 0: + # Both objects are scalars, but one might be a native python + # type and the other a numpy type. Continue with testing + # these values. + pass + else: + msg = f"{prefix} meta_equal type {type(self_obj)} != " + msg += f"{type(other_obj)}" + log.verbose(msg) + return False + if isinstance(self_obj, dict): + if set(self_obj.keys()) != set(other_obj.keys()): + msg = f"{prefix} meta_equal dict keys mismatch" + log.verbose(msg) + return False + result = True + for k, v in self_obj.items(): + v_other = other_obj[k] + child_prefix = f"{prefix}_{k}" + check = _compare_nodes(v, v_other, child_prefix) + if not check: + result = False + return result + if isinstance(self_obj, (list, tuple)): + if len(self_obj) != len(other_obj): + msg = f"{prefix} meta_equal container length mismatch" + log.verbose(msg) + return False + result = True + for index, val in enumerate(self_obj): + other_val = other_obj[index] + child_prefix = f"{prefix}_{index:04d}" + check = _compare_nodes(val, other_val, child_prefix) + if not check: + result = False + return result + try: + is_eq = np.allclose(self_obj, other_obj) + if not is_eq: + msg = f"{prefix} meta_equal arrays are not close" + log.verbose(msg) + result = is_eq + except Exception: + # Not arrays + result = self_obj == other_obj + if not result: + msg = f"{prefix} meta_equal scalars are not equal" + log.verbose(msg) + return result + + return _compare_nodes(self._internal, other._internal, prefix) + def __eq__(self, other): # Note that testing for equality is quite expensive, since it means testing all # metadata and also all detector, shared, and interval data. This is mainly @@ -599,23 +670,14 @@ def __eq__(self, other): if self.dist != other.dist: fail = 1 log.verbose(f"Proc {self.comm.world_rank}: Obs distributions not equal") - if set(self._internal.keys()) != set(other._internal.keys()): + if self.local_detector_flags != other.local_detector_flags: fail = 1 - log.verbose(f"Proc {self.comm.world_rank}: Obs metadata keys not equal") - for k, v in self._internal.items(): - if v != other._internal[k]: - feq = True - try: - feq = np.allclose(v, other._internal[k]) - except Exception: - # Not floating point data - feq = False - if not feq: - fail = 1 - log.verbose( - f"Proc {self.comm.world_rank}: Obs metadata[{k}]: {v} != {other[k]}" - ) - break + log.verbose( + f"Proc {self.comm.world_rank}: Obs local_detector_flags not equal" + ) + + self.meta_equal(other, f"Proc {self.comm.world_rank}: Obs _internal") + if self.shared != other.shared: fail = 1 log.verbose(f"Proc {self.comm.world_rank}: Obs shared data not equal") @@ -625,6 +687,37 @@ def __eq__(self, other): if self.intervals != other.intervals: fail = 1 log.verbose(f"Proc {self.comm.world_rank}: Obs intervals not equal") + + # Handle other arbitrary attributes. + self_attrs = list() + for k, v in vars(self).items(): + if k.startswith("_"): + continue + if k in self._reserved: + continue + self_attrs.append(k) + other_attrs = list() + for k, v in vars(other).items(): + if k.startswith("_"): + continue + if k in self._reserved: + continue + other_attrs.append(k) + if other_attrs != self_attrs: + fail = 1 + msg = f"Proc {self.comm.world_rank}: Obs attr lists not equal " + msg += f"{other_attrs} != {self_attrs}" + log.verbose(msg) + else: + for attr in self_attrs: + self_obj = getattr(self, attr) + other_obj = getattr(other, attr) + if self_obj != other_obj: + fail = 1 + log.verbose( + f"Proc {self.comm.world_rank}: Obs attr {attr} not equal" + ) + if self.comm.comm_group is not None: fail = self.comm.comm_group.allreduce(fail, op=MPI.SUM) return fail == 0 @@ -633,7 +726,13 @@ def __ne__(self, other): return not self.__eq__(other) def duplicate( - self, times=None, meta=None, shared=None, detdata=None, intervals=None + self, + times=None, + meta=None, + attr=None, + shared=None, + detdata=None, + intervals=None, ): """Return a copy of the observation and all its data. @@ -650,6 +749,7 @@ def duplicate( Args: times (str): The name of the timestamps shared field. meta (list): List of metadata objects to copy, or None. + attr (list): List of other observation attributes to copy, or None. shared (list): List of shared objects to copy, or None. detdata (list): List of detdata objects to copy, or None. intervals (list): List of intervals objects to copy, or None. @@ -697,6 +797,17 @@ def duplicate( new_obs.intervals[name] = IntervalList( new_obs.shared[times], timespans=timespans ) + # Handle other arbitrary attributes + for k, v in vars(self).items(): + if k.startswith("_"): + # We skip internal objects + continue + if hasattr(new_obs, k): + # This is some object already instantiated above + continue + if attr is None or k in attr: + # Copy this object + setattr(new_obs, k, copy.deepcopy(v)) return new_obs def memory_use(self): diff --git a/src/toast/ops/load_hdf5.py b/src/toast/ops/load_hdf5.py index b21ef223f..a176bb8ba 100644 --- a/src/toast/ops/load_hdf5.py +++ b/src/toast/ops/load_hdf5.py @@ -1,14 +1,11 @@ -# Copyright (c) 2021-2023 by the parties listed in the AUTHORS file. +# Copyright (c) 2021-2025 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. -import glob import os import re import h5py -import numpy as np -import traitlets from ..dist import distribute_discrete from ..io import load_hdf5 @@ -35,7 +32,7 @@ class LoadHDF5(Operator): None, allow_none=True, help="Top-level directory containing the data volume" ) - pattern = Unicode("obs_.*_.*\.h5", help="Regexp pattern to match files against") + pattern = Unicode(r"obs_.*_.*\.h5", help="Regexp pattern to match files against") files = List([], help="Override `volume` and load a list of files") diff --git a/src/toast/ops/mapmaker_templates.py b/src/toast/ops/mapmaker_templates.py index 4094512f5..2d3769313 100644 --- a/src/toast/ops/mapmaker_templates.py +++ b/src/toast/ops/mapmaker_templates.py @@ -418,7 +418,7 @@ class SolveAmplitudes(Operator): that model the timestream contributions from noise, systematics, etc: .. math:: - \left[ M^T N^{-1} Z M + M_p \right] a = M^T N^{-1} Z d + \\left[ M^T N^{-1} Z M + M_p \\right] a = M^T N^{-1} Z d Where `a` are the solved amplitudes and `d` is the input data. `N` is the diagonal time domain noise covariance. `M` is a matrix of templates that diff --git a/src/toast/ops/save_hdf5.py b/src/toast/ops/save_hdf5.py index e0205ca87..b4f54728d 100644 --- a/src/toast/ops/save_hdf5.py +++ b/src/toast/ops/save_hdf5.py @@ -1,14 +1,14 @@ -# Copyright (c) 2021-2021 by the parties listed in the AUTHORS file. +# Copyright (c) 2021-2025 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. import os import numpy as np -import traitlets +import warnings from ..io import load_hdf5, save_hdf5 -from ..mpi import MPI, comm_equal +from ..mpi import MPI from ..observation import default_values as defaults from ..timing import function_timer from ..traits import Bool, Dict, Int, List, Unicode, trait_docs @@ -51,23 +51,40 @@ def obs_approx_equal(obs1, obs2): if obs1.dist != obs2.dist: fail = 1 log.verbose(f"Proc {obs1.comm.world_rank}: Obs distributions not equal") - if set(obs1._internal.keys()) != set(obs2._internal.keys()): + + if not obs1.meta_equal(obs2, f"Proc {obs1.comm.world_rank}: Obs _internal"): + fail = 1 + log.verbose(f"Proc {obs1.comm.world_rank}: Obs metadata not equal") + + # Compare any extra metadata class instances + extra_objs1 = list() + for k, v in vars(obs1).items(): + if k.startswith("_"): + continue + if hasattr(v, "save_hdf5"): + extra_objs1.append(k) + extra_objs2 = list() + for k, v in vars(obs2).items(): + if k.startswith("_"): + continue + if hasattr(v, "save_hdf5"): + extra_objs2.append(k) + + if extra_objs1 != extra_objs2: fail = 1 - log.verbose(f"Proc {obs1.comm.world_rank}: Obs metadata keys not equal") - for k, v in obs1._internal.items(): - if v != obs2._internal[k]: - feq = True - try: - feq = np.allclose(v, obs2._internal[k]) - except Exception: - # Not floating point data - feq = False - if not feq: + log.verbose( + f"Proc {obs1.comm.world_rank}: Obs extra metadata obj lists not equal" + ) + else: + for exobj in extra_objs1: + obj1 = getattr(obs1, exobj) + obj2 = getattr(obs2, exobj) + if obj1 != obj2: fail = 1 log.verbose( - f"Proc {obs1.comm.world_rank}: Obs metadata[{k}]: {v} != {obs2[k]}" + f"Proc {obs1.comm.world_rank}: Obs extra {exobj} not equal" ) - break + if obs1.shared != obs2.shared: fail = 1 log.verbose(f"Proc {obs1.comm.world_rank}: Obs shared data not equal") @@ -114,13 +131,8 @@ def obs_approx_equal(obs1, obs2): msg += f"{o1d[k].units} != {o2d[k].units}" log.verbose(msg) fail = 1 - if o1d[k].dtype == np.dtype(np.float64): - if not np.allclose(o1d[k].data, o2d[k].data, rtol=1.0e-5, atol=1.0e-8): - msg = f"Proc {obs1.comm.world_rank}: Obs detdata {k} array " - msg += f"{o1d[k].data} != {o2d[k].data}" - log.verbose(msg) - fail = 1 - elif o1d[k].dtype == np.dtype(np.float32): + if o1d[k].dtype == np.dtype(np.float64) or o1d[k].dtype == np.dtype(np.float32): + # Only compare to 32bit precision if not np.allclose(o1d[k].data, o2d[k].data, rtol=1.0e-3, atol=1.0e-5): msg = f"Proc {obs1.comm.world_rank}: Obs detdata {k} array " msg += f"{o1d[k].data} != {o2d[k].data}" @@ -140,7 +152,18 @@ def obs_approx_equal(obs1, obs2): class SaveHDF5(Operator): """Operator which saves observations to HDF5. - This creates a file for each observation. + This creates a file for each observation. Detector data compression can be enabled + by specifying a tuple for each item in the detdata list. The first item in the + tuple is the field name. The second item is either None, or a dictionary of FLAC + comppression properties. Allowed compression parameters are: + + "level": (int) the compression level + "quanta": (float) the quantization value, only for floating point data + "precision": (float) the fixed precision, only for floating point data + + For integer data, an empty dictionary may be passed, and FLAC compression + will use the default level (5). Floating point data *must* specify either the + quanta or precision parameters. """ @@ -177,7 +200,7 @@ class SaveHDF5(Operator): ) detdata_float32 = Bool( - False, help="If True, convert any float64 detector data to float32 on write." + False, help="(Deprecated) Specify the per-field compression parameters." ) detdata_in_place = Bool( @@ -186,12 +209,14 @@ class SaveHDF5(Operator): "over the input data.", ) - compress_detdata = Bool(False, help="If True, use FLAC to compress detector signal") + compress_detdata = Bool( + False, help="(Deprecated) Specify the per-field compression parameters" + ) compress_precision = Int( None, allow_none=True, - help="Number of significant digits to retain in detdata compression", + help="(Deprecated) Specify the per-field compression parameters", ) verify = Bool(False, help="If True, immediately load data back in and verify") @@ -208,6 +233,26 @@ def _exec(self, data, detectors=None, **kwargs): log.error(msg) raise RuntimeError(msg) + # Warn for deprecated traits that will be removed eventually. + + if self.detdata_float32: + msg = "The detdata_float32 option is deprecated. Instead, specify" + msg = " a compression quanta / precision that is appropriate for" + msg = " each detdata field." + warnings.warn(msg, DeprecationWarning) + + if self.compress_detdata: + msg = "The compress_detdata option is deprecated. Instead, specify" + msg = " a compression quanta / precision that is appropriate for" + msg = " each detdata field." + warnings.warn(msg, DeprecationWarning) + + if self.compress_precision is not None: + msg = "The compress_precision option is deprecated. Instead, specify" + msg = " a compression quanta / precision that is appropriate for" + msg = " each detdata field." + warnings.warn(msg, DeprecationWarning) + # One process creates the top directory if data.comm.world_rank == 0: os.makedirs(self.volume, exist_ok=True) @@ -226,6 +271,38 @@ def _exec(self, data, detectors=None, **kwargs): if len(self.intervals) > 0: intervals_fields = list(self.intervals) + if len(self.detdata) > 0: + detdata_fields = list(self.detdata) + else: + detdata_fields = list() + + # Handle parsing of deprecated global compression options. All + # new code should specify the FLAC compression parameters per + # field. + for ifield, field in enumerate(detdata_fields): + if not isinstance(field, str): + # User already specified compression parameters + continue + cprops = {"level": 5} + if self.compress_detdata: + # Try to guess what to do. + if "flag" not in field: + # Might be float data + if self.compress_precision is None: + # Compress to 32bit floats + cprops["quanta"] = np.finfo(np.float32).eps + else: + cprops["precision"] = self.compress_precision + detdata_fields[ifield] = (field, cprops) + elif self.detdata_float32: + # Implement this truncation as just compression to 32bit float + # precision + cprops["quanta"] = np.finfo(np.float32).eps + detdata_fields[ifield] = (field, cprops) + else: + # No compression + detdata_fields[ifield] = (field, None) + for ob in data.obs: # Observations must have a name for this to work if ob.name is None: @@ -235,36 +312,10 @@ def _exec(self, data, detectors=None, **kwargs): # Check to see if any detector data objects are temporary and have just # a partial list of detectors. Delete these. - for dd in list(ob.detdata.keys()): if ob.detdata[dd].detectors != ob.local_detectors: del ob.detdata[dd] - if len(self.detdata) > 0: - detdata_fields = list(self.detdata) - else: - detdata_fields = list() - - if self.compress_detdata: - # Add generic compression instructions to detdata fields - for ifield, field in enumerate(detdata_fields): - if not isinstance(field, str): - # Assume user already supplied instructions for this field - continue - if "flag" in field: - # Flags are ZIP-compressed - detdata_fields[ifield] = (field, {"type": "gzip"}) - else: - # Everything else is FLAC-compressed - detdata_fields[ifield] = ( - field, - { - "type": "flac", - "level": 5, - "precision": self.compress_precision, - }, - ) - outpath = save_hdf5( ob, self.volume, @@ -275,7 +326,6 @@ def _exec(self, data, detectors=None, **kwargs): config=self.config, times=str(self.times), force_serial=self.force_serial, - detdata_float32=self.detdata_float32, detdata_in_place=self.detdata_in_place, ) @@ -299,47 +349,13 @@ def _exec(self, data, detectors=None, **kwargs): # We saved nothing verify_fields = list() - if self.detdata_float32: - # We want to duplicate everything *except* float64 detdata - # fields. - dup_detdata = list() - conv_detdata = list() - for fld in verify_fields: - if ob.detdata[fld].dtype == np.dtype(np.float64): - conv_detdata.append(fld) - else: - dup_detdata.append(fld) - original = ob.duplicate( - times=str(self.times), - meta=meta_fields, - shared=shared_fields, - detdata=dup_detdata, - intervals=intervals_fields, - ) - for fld in conv_detdata: - if len(ob.detdata[fld].detector_shape) == 1: - sample_shape = None - else: - sample_shape = ob.detdata[fld].detector_shape[1:] - original.detdata.create( - fld, - sample_shape=sample_shape, - dtype=np.float32, - detectors=ob.detdata[fld].detectors, - units=ob.detdata[fld].units, - ) - original.detdata[fld].data[:] = ( - ob.detdata[fld].data[:].astype(np.float32) - ) - else: - # Duplicate detdata - original = ob.duplicate( - times=str(self.times), - meta=meta_fields, - shared=shared_fields, - detdata=verify_fields, - intervals=intervals_fields, - ) + original = ob.duplicate( + times=str(self.times), + meta=meta_fields, + shared=shared_fields, + detdata=verify_fields, + intervals=intervals_fields, + ) compare = load_hdf5( loadpath, @@ -355,7 +371,9 @@ def _exec(self, data, detectors=None, **kwargs): if not obs_approx_equal(compare, original): msg = "Observation HDF5 verify failed:\n" msg += f"Input = {original}\n" - msg += f"Loaded = {compare}" + msg += f"Loaded = {compare}\n" + msg += f"Input signal[0] = {original.detdata['signal'][0]}\n" + msg += f"Loaded signal[0] = {compare.detdata['signal'][0]}" log.error(msg) raise RuntimeError(msg) diff --git a/src/toast/ops/sim_cosmic_rays.py b/src/toast/ops/sim_cosmic_rays.py index 9096df3f0..09fc6dd75 100644 --- a/src/toast/ops/sim_cosmic_rays.py +++ b/src/toast/ops/sim_cosmic_rays.py @@ -36,7 +36,7 @@ class InjectCosmicRays(Operator): We assume the glitch to be described as .. math:: - \gamma (t) = C_1 +C_2 e^{-t/\tau } + \\gamma (t) = C_1 +C_2 e^{-t/\\tau } where :math:C_1 and :math:C_2 and the time constant :math:\tau are drawn from a distribution of estimated values from simulations. For each observation and each detector, we estimate the number of hits expected diff --git a/src/toast/tests/CMakeLists.txt b/src/toast/tests/CMakeLists.txt index a0b0326dc..8ad466010 100644 --- a/src/toast/tests/CMakeLists.txt +++ b/src/toast/tests/CMakeLists.txt @@ -79,7 +79,6 @@ install(FILES ops_perturbhwp.py ops_filterbin.py io_hdf5.py - io_compression.py ops_noise_estim.py ops_noise_filter.py ops_yield_cut.py diff --git a/src/toast/tests/io_compression.py b/src/toast/tests/io_compression.py deleted file mode 100644 index 8e6dd33ed..000000000 --- a/src/toast/tests/io_compression.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright (c) 2023-2023 by the parties listed in the AUTHORS file. -# All rights reserved. Use of this source code is governed by -# a BSD-style license that can be found in the LICENSE file. - -import os -import sys - -import numpy as np -from astropy import units as u - -from .. import ops as ops -from ..data import Data -from ..io import compress_detdata, decompress_detdata, load_hdf5, save_hdf5 -from ..io.compression_flac import ( - compress_detdata_flac, - compress_flac, - compress_flac_2D, - decompress_detdata_flac, - decompress_flac, - decompress_flac_2D, - float2int, - have_flac_support, - int2float, - int64to32, -) -from ..mpi import Comm -from ..observation import default_values as defaults -from ..observation_data import DetectorData -from ..ops import LoadHDF5, SaveHDF5 -from ..timing import Timer -from ..utils import AlignedI32, AlignedU8 -from .helpers import close_data, create_ground_data, create_outdir -from .mpi import MPITestCase - - -class IoCompressionTest(MPITestCase): - def setUp(self): - fixture_name = os.path.splitext(os.path.basename(__file__))[0] - self.outdir = create_outdir(self.comm, subdir=fixture_name) - self.types = { - "f64": np.float64, - "f32": np.float32, - "i64": np.int64, - "i32": np.int32, - } - self.fakedets = ["D00A", "D00B", "D01A", "D01B"] - - def test_type_conversion(self): - rng = np.random.default_rng(12345) - - n_test = 10000 - - off = 1.0e6 - scale = 100.0 - data = scale * rng.random(size=n_test, dtype=np.float64) + off - idata, doff, dgain = float2int(data) - check = int2float(idata, doff, dgain) - self.assertTrue(np.allclose(check, data, rtol=1.0e-6, atol=1.0e-5)) - - rng_max = np.iinfo(np.int32).max // 2 - i64data = rng.integers( - -rng_max, - rng_max, - size=n_test, - dtype=np.int64, - ) - idata, ioff = int64to32(i64data) - check = np.array(idata, dtype=np.int64) + ioff - self.assertTrue(np.array_equal(check, i64data)) - - def test_flac_lowlevel(self): - if not have_flac_support(): - print("FLAC disabled, skipping...") - return - - timer1 = Timer() - timer2 = Timer() - - n_det = 20 - n_samp = 100000 - - rng = np.random.default_rng(12345) - - rng_max = np.iinfo(np.int32).max // 2 - input = rng.integers( - -rng_max, - rng_max, - size=(n_det * n_samp), - dtype=np.int32, - ).reshape((n_det, n_samp)) - - # Compare results of 1D and 2D compression - - timer2.start() - fbytes2, foffs2 = compress_flac_2D(input, 5) - timer2.stop() - - # print(f"Compress 2D one shot in {timer2.seconds()} s") - timer2.clear() - - fbytes1 = AlignedU8() - foffs1 = np.zeros(n_det, dtype=np.int64) - timer1.start() - for d in range(n_det): - cur = fbytes1.size() - foffs1[d] = cur - dbytes = compress_flac(input[d], 5) - ext = len(dbytes) - fbytes1.resize(cur + ext) - fbytes1[cur : cur + ext] = dbytes - timer1.stop() - - # print(f"Compress {n_det} dets with 1D in {timer1.seconds()} s") - timer1.clear() - - self.assertTrue(len(fbytes1) == len(fbytes2)) - self.assertTrue(np.array_equal(fbytes1, fbytes2)) - self.assertTrue(np.array_equal(foffs1, foffs2)) - - timer2.start() - output2 = decompress_flac_2D(fbytes2, foffs2) - timer2.stop() - - # print(f"Decompress 2D one shot in {timer2.seconds()} s") - timer2.clear() - - output1 = AlignedI32() - timer1.start() - for d in range(n_det): - cur = output1.size() - if d == n_det - 1: - slc = slice(foffs1[d], len(fbytes1), 1) - else: - slc = slice(foffs1[d], foffs1[d + 1], 1) - dout = decompress_flac(fbytes1[slc]) - ext = len(dout) - output1.resize(cur + ext) - output1[cur : cur + ext] = dout - timer1.stop() - - # print(f"Decompress {n_det} dets with 1D in {timer1.seconds()} s") - timer1.clear() - - self.assertTrue(np.array_equal(output1, output2)) - - def test_roundtrip_detdata(self): - rank = 0 - if self.comm is not None: - rank = self.comm.rank - n_samp = 100000 - - rng = np.random.default_rng(12345) - - comp_types = ["none", "gzip"] - if have_flac_support(): - comp_types.append("flac") - else: - print("FLAC disabled, skipping 'flac' compression type") - - for comp_type in comp_types: - for dtname, dt in self.types.items(): - # Use fake data with multiple elements per sample, just to test - # correct handling. - detdata = DetectorData( - self.fakedets, - (n_samp, 4), - dtype=dt, - units=u.K, - ) - if dtname == "f32" or dtname == "f64": - detdata.flatdata[:] = rng.random( - size=(4 * n_samp * len(self.fakedets)), dtype=dt - ) - else: - rng_max = np.iinfo(np.int32).max // 2 - detdata.flatdata[:] = rng.integers( - 0, - rng_max, - size=(4 * n_samp * len(self.fakedets)), - dtype=dt, - ) - - # print( - # f"Uncompressed {comp_type}:{dtname} is {detdata.memory_use()} bytes" - # ) - comp_data = compress_detdata(detdata, {"type": comp_type}) - # print(f" Compressed {comp_type}:{dtname} is {len(comp_data[0])} bytes") - new_detdata = decompress_detdata( - comp_data[0], comp_data[1], comp_data[2] - ) - check = np.allclose(new_detdata[:], detdata[:], atol=1.0e-5) - if not check: - print(f"Orig: {detdata}") - print(f"New: {new_detdata}") - self.assertTrue(False) - del new_detdata - del detdata - - def create_data(self, pixel_per_process=1, single_group=False): - data = create_ground_data( - self.comm, - pixel_per_process=pixel_per_process, - single_group=single_group, - flagged_pixels=False, - sample_rate=10.0 * u.Hz, - hwp_rpm=1.0, - ) - - # Simple detector pointing - detpointing_azel = ops.PointingDetectorSimple( - boresight="boresight_azel", quats="quats_azel" - ) - - # Create a noise model from focalplane detector properties - default_model = ops.DefaultNoiseModel() - default_model.apply(data) - - # Make an elevation-dependent noise model - el_model = ops.ElevationNoise( - noise_model="noise_model", - out_model="el_weighted", - detector_pointing=detpointing_azel, - ) - el_model.apply(data) - - # Simulate noise and accumulate to signal - sim_noise = ops.SimNoise(noise_model=el_model.out_model) - sim_noise.apply(data) - - # Delete temporary object. - for ob in data.obs: - del ob.detdata["quats_azel"] - - return data - - def test_roundtrip_memory(self): - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - data = self.create_data() - - comp_types = ["none", "gzip"] - if have_flac_support(): - comp_types.append("flac") - else: - print("FLAC disabled, skipping 'flac' compression type") - - for comp_type in comp_types: - # Extract compressed versions of signal and flags - for ob in data.obs: - for key in [defaults.det_data]: - # msg = f"{ob.name} uncompressed {comp_type}:{key} is " - # msg += f"{ob.detdata[key].memory_use()} bytes" - # print(msg) - comp_data = compress_detdata(ob.detdata[key], {"type": comp_type}) - # msg = f"{ob.name} compressed {comp_type}:{key} is " - # msg += f"{len(comp_data[0])} bytes" - # print(msg) - new_detdata = decompress_detdata( - comp_data[0], comp_data[1], comp_data[2] - ) - check = np.allclose(new_detdata[:], ob.detdata[key][:], atol=1.0e-5) - if not check: - print(f"Orig: {ob.detdata[key]}") - print(f"New: {new_detdata}") - self.assertTrue(False) - - if comp_type == "flac": - continue - - comp_data = compress_detdata( - ob.detdata[defaults.det_flags], {"type": comp_type} - ) - new_detdata = decompress_detdata( - comp_data[0], comp_data[1], comp_data[2] - ) - if new_detdata != ob.detdata[defaults.det_flags]: - print(f"Orig: {ob.detdata[defaults.det_flags]}") - print(f"New: {new_detdata}") - self.assertTrue(False) - - close_data(data) - - def test_precision_vs_quanta(self): - if not have_flac_support(): - print("FLAC disabled, skipping...") - return - - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - data = self.create_data() - - # Compress data using either "quanta" or "precision" - for ob in data.obs: - key = defaults.det_data - rms_in = np.std(ob.detdata[key].data) - precision = 5 - quanta = rms_in / 10**precision - comp_data_precision = compress_detdata( - ob.detdata[key], {"type": "flac", "precision": precision} - ) - comp_data_quanta = compress_detdata( - ob.detdata[key], {"type": "flac", "quanta": quanta} - ) - new_detdata_float32 = ob.detdata[key].data.astype(np.float32) - new_detdata_precision = decompress_detdata(*comp_data_precision) - new_detdata_quanta = decompress_detdata(*comp_data_quanta) - - rms_float32 = np.std(ob.detdata[key].data - new_detdata_float32.data) - rms_precision = np.std( - new_detdata_float32.data - new_detdata_precision.data - ) - rms_quanta = np.std(new_detdata_float32.data - new_detdata_quanta.data) - - # print(f"RMS (in) = {rms_in}") - # print(f"RMS (float32) = {rms_float32} abs, {rms_float32 / rms_in} rel") - # print(f"RMS (precision) = {rms_precision} abs, {rms_precision / rms_in} rel") - # print(f"RMS (quanta) = {rms_quanta} abs, {rms_quanta / rms_in} rel") - - check = np.allclose( - new_detdata_precision[:], - new_detdata_quanta[:], - rtol=10 ** (-(precision - 1)), - atol=1.0e-5, - ) - if not check: - print(f"RMS (in) = {rms_in}") - print(f"RMS (precision) = {rms_precision}, RMS (quanta) = {rms_quanta}") - print(f"Precision = {precision}: {new_detdata_precision}") - print(f"Quanta = {quanta}: {new_detdata_quanta}") - self.assertTrue(False) - - if rms_in / rms_precision < 0.9 * 10**precision: - print( - f"RMS(in) / RMS(precision) = {rms_in / rms_precision} but precision = {precision}" - ) - self.assertTrue(False) - - close_data(data) - - def test_compression_in_place(self): - if not have_flac_support(): - print("FLAC disabled, skipping...") - return - - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - testdir = os.path.join(self.outdir, "test_in_place") - if rank == 0: - os.makedirs(testdir) - - # We use a single process group in this test to avoid having the - # data shuffled around between saving and loading - - data = self.create_data(pixel_per_process=4, single_group=True) - save_hdf5 = SaveHDF5( - volume=testdir, - detdata_float32=True, - compress_detdata=True, - detdata_in_place=True, - compress_precision=3, - ) - save_hdf5.apply(data) - - compressed_data = Data(Comm(self.comm)) - load_hdf5 = LoadHDF5(volume=testdir) - load_hdf5.apply(compressed_data) - - for ob_in, ob_out in zip(data.obs, compressed_data.obs): - key = defaults.det_data - signal_in = ob_in.detdata[key].data - signal_out = ob_out.detdata[key].data - check = np.allclose(signal_in, signal_out, rtol=1e-6) - if not check: - print(f"signal in = {signal_in}") - print(f"signal out = {signal_out}") - self.assertTrue(False) - - close_data(data) - close_data(compressed_data) - - def test_roundtrip_hdf5(self): - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - testdir = os.path.join(self.outdir, "test_hdf5") - if rank == 0: - os.makedirs(testdir) - nocompdir = os.path.join(self.outdir, "test_hdf5_nocomp") - if rank == 0: - os.makedirs(nocompdir) - - data = self.create_data() - - obfiles = list() - for obs in data.obs: - _ = save_hdf5( - obs, - nocompdir, - meta=None, - detdata=[ - defaults.det_data, - defaults.det_flags, - ], - shared=None, - intervals=None, - config=None, - times=defaults.times, - force_serial=False, - detdata_float32=True, - ) - if have_flac_support(): - dcomp = (defaults.det_data, {"type": "flac"}) - else: - print("FLAC disabled, default to detdata compression='none'") - dcomp = (defaults.det_data, {"type": "none"}) - obf = save_hdf5( - obs, - testdir, - meta=None, - detdata=[ - dcomp, - (defaults.det_flags, {"type": "gzip"}), - ], - shared=None, - intervals=None, - config=None, - times=defaults.times, - force_serial=False, - detdata_float32=False, - ) - obfiles.append(obf) - - if data.comm.comm_world is not None: - data.comm.comm_world.barrier() - - # Load the data and check - check_data = Data(comm=data.comm) - - for hfile in obfiles: - check_data.obs.append(load_hdf5(hfile, check_data.comm)) - - # Verify. The other unit tests will check general HDF5 I/O in the case without - # compression. Here we are testing the round trip of DetectorData objects. - for ob, orig in zip(check_data.obs, data.obs): - if ob.detdata[defaults.det_flags] != orig.detdata[defaults.det_flags]: - msg = f"---- Proc {data.comm.world_rank} flags not equal ---\n" - msg += f"{orig.detdata[defaults.det_flags]}\n" - msg += f"{ob.detdata[defaults.det_flags]}" - print(msg) - self.assertTrue(False) - if not np.allclose( - ob.detdata[defaults.det_data], - orig.detdata[defaults.det_data], - atol=1.0e-4, - rtol=1.0e-5, - ): - msg = f"---- Proc {data.comm.world_rank} signal not equal ---\n" - msg += f"{orig.detdata[defaults.det_data]}\n" - msg += f"{ob.detdata[defaults.det_data]}" - print(msg) - self.assertTrue(False) - - close_data(data) - - def test_hdf5_verify(self): - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - testdir = os.path.join(self.outdir, "verify_hdf5") - if rank == 0: - os.makedirs(testdir) - - data = self.create_data() - - if have_flac_support(): - dcomp = (defaults.det_data, {"type": "flac"}) - else: - print("FLAC disabled, default to detdata compression='none'") - dcomp = (defaults.det_data, {"type": "none"}) - - saver = ops.SaveHDF5( - volume=testdir, - detdata=[ - dcomp, - (defaults.det_flags, {"type": "gzip"}), - ], - detdata_float32=True, - verify=True, - ) - saver.apply(data) - - close_data(data) diff --git a/src/toast/tests/io_hdf5.py b/src/toast/tests/io_hdf5.py index 12a66facb..1b9147a74 100644 --- a/src/toast/tests/io_hdf5.py +++ b/src/toast/tests/io_hdf5.py @@ -1,10 +1,9 @@ -# Copyright (c) 2015-2021 by the parties listed in the AUTHORS file. +# Copyright (c) 2015-2025 by the parties listed in the AUTHORS file. # All rights reserved. Use of this source code is governed by # a BSD-style license that can be found in the LICENSE file. import glob import os -import sys import numpy as np from astropy import units as u @@ -13,19 +12,111 @@ from ..config import build_config from ..data import Data from ..io import load_hdf5, save_hdf5 -from ..mpi import MPI -from ..observation_data import DetectorData +from ..ops.save_hdf5 import obs_approx_equal from ..weather import Weather from .helpers import close_data, create_ground_data, create_outdir from .mpi import MPITestCase +class ExtraMeta(object): + """Class to test Observation attribute save / load.""" + + def __init__(self): + self._data = np.random.normal(size=100) + + def save_hdf5(self, group, obs): + if group is not None: + hdata = group.create_dataset( + "ExtraMeta", self._data.shape, dtype=self._data.dtype + ) + if obs.comm.group_rank == 0: + hdata.write_direct(self._data, (slice(0, 100, 1),), (slice(0, 100, 1),)) + + def load_hdf5(self, group, obs): + if group is not None: + ds = group["ExtraMeta"] + if obs.comm.group_rank == 0: + self._data = np.empty(ds.shape, dtype=ds.dtype) + hslc = tuple([slice(0, x, 1) for x in ds.shape]) + ds.read_direct(self._data, hslc, hslc) + if obs.comm.comm_group is not None: + self._data = obs.comm.comm_group.bcast(self._data, root=0) + + def __eq__(self, other): + if np.allclose(self._data, other._data): + return True + else: + return False + + +def create_other_meta(): + """Helper function to generate python containers of metadata for testing""" + # Create nested containers of all types, with scalars and arrays + scalar = 1.234 + qscalar = 1.234 * u.second + arr = np.arange(10, dtype=np.float64) + qarr = arr * u.meter + + def _leaf_dict(): + return { + "scalar": scalar, + "qscalar": qscalar, + "arr": arr, + "qarr": qarr, + } + + def _leaf_list(): + return [ + scalar, + qscalar, + arr, + qarr, + ] + + def _leaf_tuple(): + return ( + scalar, + qscalar, + arr, + qarr, + ) + + def _node_dict(): + return { + "dict": _leaf_dict(), + "list": _leaf_list(), + "tuple": _leaf_tuple(), + } + + def _node_list(): + return [ + _leaf_dict(), + _leaf_list(), + _leaf_tuple(), + ] + + def _node_tuple(): + return ( + _leaf_dict(), + _leaf_list(), + _leaf_tuple(), + ) + + root = { + "top_dict": _node_dict(), + "top_list": _node_list(), + "top_tuple": _node_tuple(), + } + + return root + + class IoHdf5Test(MPITestCase): def setUp(self): fixture_name = os.path.splitext(os.path.basename(__file__))[0] self.outdir = create_outdir(self.comm, subdir=fixture_name) - def create_data(self, split=False, base_weather=False): + def create_data(self, split=False, base_weather=False, no_meta=False): # Create fake observing of a small patch. Use a multifrequency # focalplane so we can test split sessions. @@ -38,6 +129,13 @@ def create_data(self, split=False, base_weather=False): split=split, ) + # Add extra metadata attribute + if not no_meta: + for ob in data.obs: + ob.extra = ExtraMeta() + other = create_other_meta() + ob.update(other) + if base_weather: # Replace the simulated weather with the base class for testing for ob in data.obs: @@ -138,66 +236,14 @@ def test_save_load(self): # Verify for ob, orig in zip(check_data.obs, original): - if ob != orig: + if not obs_approx_equal(ob, orig): print( f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}" ) - self.assertTrue(ob == orig) + self.assertTrue(False) close_data(data) - def test_save_load_float32(self): - rank = 0 - if self.comm is not None: - rank = self.comm.rank - - datadir = os.path.join(self.outdir, "save_load_float32") - if rank == 0: - os.makedirs(datadir) - if self.comm is not None: - self.comm.barrier() - - data, config = self.create_data() - det_data_fields = ["signal", "flags", "alt_signal"] - - # Make a copy for later comparison. Convert float64 detdata to - # float32. - original = dict() - for ob in data.obs: - original[ob.name] = ob.duplicate(times="times") - for field, ddata in original[ob.name].detdata.items(): - if ddata.dtype.char == "d": - # Hack in a replacement - new_dd = DetectorData( - ddata.detectors, - ddata.detector_shape, - np.float32, - units=ddata.units, - ) - new_dd[:] = original[ob.name].detdata[field][:] - original[ob.name].detdata._internal[field] = new_dd - - saver = ops.SaveHDF5( - volume=datadir, detdata=det_data_fields, config=config, detdata_float32=True - ) - saver.apply(data) - - if data.comm.comm_world is not None: - data.comm.comm_world.barrier() - - check_data = Data(data.comm) - loader = ops.LoadHDF5(volume=datadir, detdata=det_data_fields) - loader.apply(check_data) - - # Verify - for ob in check_data.obs: - orig = original[ob.name] - if ob != orig: - print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") - self.assertTrue(ob == orig) - - close_data(data) - def test_save_load_ops(self): rank = 0 if self.comm is not None: @@ -232,9 +278,9 @@ def test_save_load_ops(self): # Verify for ob in check_data.obs: orig = original[ob.name] - if ob != orig: + if not obs_approx_equal(ob, orig): print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") - self.assertTrue(ob == orig) + self.assertTrue(False) del check_data # Also test loading explicit files @@ -245,22 +291,23 @@ def test_save_load_ops(self): for ob in check_data.obs: orig = original[ob.name] - if ob != orig: + if not obs_approx_equal(ob, orig): print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") - self.assertTrue(ob == orig) + self.assertTrue(False) del check_data # Also check loading by regex, in this case only one frequency check_data = Data(data.comm) + loader.files = [] loader.volume = datadir loader.pattern = r".*100\.0-GHz.*\.h5" loader.apply(check_data) for ob in check_data.obs: orig = original[ob.name] - if ob != orig: + if not obs_approx_equal(ob, orig): print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") - self.assertTrue(ob == orig) + self.assertTrue(False) del check_data close_data(data) @@ -302,26 +349,31 @@ def test_save_load_empty_detdata(self): for ob in check_data.obs: orig = original[ob.name] orig.detdata.clear() - if ob != orig: + if not obs_approx_equal(ob, orig): print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") - self.assertTrue(ob == orig) + self.assertTrue(False) del check_data close_data(data) - def test_save_load_ops_f32(self): + def test_save_load_ops_compression(self): rank = 0 if self.comm is not None: rank = self.comm.rank - datadir = os.path.join(self.outdir, "save_load_ops_f32") + datadir = os.path.join(self.outdir, "save_load_ops_compression") if rank == 0: os.makedirs(datadir) if self.comm is not None: self.comm.barrier() data, config = self.create_data(split=True) - det_data_fields = ["signal", "flags", "alt_signal"] + det_data_fields = [ + ("signal", {"quanta": 1.0e-7}), + ("flags", {}), + ("alt_signal", {"quanta": 1.0e-12}), + ] + det_data_names = ["signal", "flags", "alt_signal"] # Make a copy for later comparison. original = dict() @@ -329,15 +381,80 @@ def test_save_load_ops_f32(self): original[ob.name] = ob.duplicate(times="times") saver = ops.SaveHDF5( - volume=datadir, - detdata=det_data_fields, - config=config, - detdata_float32=True, - verify=True, + volume=datadir, detdata=det_data_fields, config=config, verify=True ) saver.apply(data) if data.comm.comm_world is not None: data.comm.comm_world.barrier() + check_data = Data(data.comm) + loader = ops.LoadHDF5(volume=datadir, detdata=det_data_names) + loader.apply(check_data) + + # Verify + for ob in check_data.obs: + orig = original[ob.name] + if not obs_approx_equal(ob, orig): + print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") + self.assertTrue(False) + del check_data + + close_data(data) + + def test_save_load_version1(self): + # Here we test loading an old version 1 format file (the original + # version 1 saving code is kept around for this purpose). + from ..io.observation_hdf_save_v1 import save_hdf5 as save_v1 + + rank = 0 + if self.comm is not None: + rank = self.comm.rank + + datadir = os.path.join(self.outdir, "save_load_v1") + if rank == 0: + os.makedirs(datadir) + if self.comm is not None: + self.comm.barrier() + + data, config = self.create_data(split=True, no_meta=True) + det_data_names = ["signal", "flags", "alt_signal"] + det_data_fields = [ + ("signal", {"type": "flac", "quanta": 1.0e-7}), + ("flags", {"type": "gzip"}), + ("alt_signal", {"type": "flac", "quanta": 1.0e-7}), + ] + + # Export the data, and make a copy for later comparison. + original = list() + obfiles = list() + for ob in data.obs: + original.append(ob.duplicate(times="times")) + obf = save_v1( + ob, + datadir, + detdata=det_data_fields, + config=config, + force_serial=False, + detdata_float32=False, + ) + obfiles.append(obf) + + if self.comm is not None: + self.comm.barrier() + + # Import the data + check_data = Data(comm=data.comm) + + for hfile in obfiles: + check_data.obs.append( + load_hdf5(hfile, check_data.comm, detdata=det_data_names) + ) + + # Verify + for ob, orig in zip(check_data.obs, original): + if not obs_approx_equal(ob, orig): + print(f"-------- Proc {data.comm.world_rank} ---------\n{orig}\n{ob}") + self.assertTrue(False) + close_data(data) diff --git a/src/toast/tests/runner.py b/src/toast/tests/runner.py index c29f0b6cc..642d6aa3d 100644 --- a/src/toast/tests/runner.py +++ b/src/toast/tests/runner.py @@ -23,7 +23,6 @@ from . import instrument as test_instrument from . import intervals as test_intervals from . import ops_flag_intervals as test_ops_flag_intervals -from . import io_compression as test_io_compression from . import io_hdf5 as test_io_hdf5 from . import math_misc as test_math_misc from . import noise as test_noise @@ -265,7 +264,6 @@ def test(name=None, verbosity=2): suite.addTest(loader.loadTestsFromModule(test_template_gain)) suite.addTest(loader.loadTestsFromModule(test_io_hdf5)) - suite.addTest(loader.loadTestsFromModule(test_io_compression)) suite.addTest(loader.loadTestsFromModule(test_accelerator))