diff --git a/tests/grtestutils/googletest/check_allclose.cpp b/tests/grtestutils/googletest/check_allclose.cpp index ae6ca702b..b0f9d5602 100644 --- a/tests/grtestutils/googletest/check_allclose.cpp +++ b/tests/grtestutils/googletest/check_allclose.cpp @@ -12,81 +12,294 @@ #include "./check_allclose.hpp" #include "../view.hpp" +#include "gtest/gtest.h" +#include "status_reporting.h" #include +#include #include +#include #include #include -testing::AssertionResult check_allclose(const std::vector& actual, - const std::vector& desired, - double rtol, double atol) { - if (actual.size() != desired.size()) { - return testing::AssertionFailure() - << "the compared arrays have different lengths"; +namespace grtest::arraycmp_detail { + +namespace { // stuff inside an anonymous namespace is local to this file + +/// calls the unary function once for each element +/// +/// if selected is provided, we skip locations where selected[i] is false +template +void flat_for_each(Fn fn, int n_elements, const bool* selected) { + if (selected == nullptr) { + for (int flat_idx = 0; flat_idx < n_elements; flat_idx++) { + fn(flat_idx); + } + } else { + for (int flat_idx = 0; flat_idx < n_elements; flat_idx++) { + if (selected[flat_idx]) { + fn(flat_idx); + } + } } +} + +/// Specifies a noteworthy detail about a mismatching pairs of pointers +/// +/// There are 2 kinds of details: +/// - a detail describing a particular pair of mismatching elements (e.g. the +/// very first mismatch, the place where the size of the mismatch is most +/// significant). In these cases, the flattened index is tracked so that +/// the location of the mismatch and the values of the elements can be +/// printed +/// - a generic string that doesn't have any single associated location +struct MismatchDetail { + std::string description; + std::optional flat_idx; + + // the following constructor is defined in order to make this work with + // std::vector::emplace_back. Delete it, once we require C++20 or newer +#if __cpp_aggregate_paren_init < 201902L + MismatchDetail(const std::string& description, + const std::optional& flat_idx) + : description(description), flat_idx(flat_idx) {} +#endif +}; + +/// collects interesting details about mismatched elements in a pair of pointers +/// +/// This is called by @ref compare_ptrs_, if we determine that the compared +/// pointers contain at least one pair of mismatching elements. This loops back +/// over all pairs of elements and collects a vector of noteworthy mismatches. +/// +/// @note +/// It's ok if this is a little slow, as long as it provides useful messages. +/// (After all, this logic only gets invoked when comparisons fail). +template +std::vector collect_details_( + const T* actual, const T* desired, const IdxMapping& idx_mapping, + const bool* selection_mask, Cmp cmp_fn) { + // define some variables that we will fill as we loop over the array + int first_mismatch_idx = -1; + + int first_nan_mismatch_idx = -1; + int nan_mismatch_count = 0; + bool any_nan = false; - std::size_t num_mismatches = 0; double max_absDiff = 0.0; - std::size_t max_absDiff_ind = 0; + int max_absDiff_idx = -1; + double max_relDiff = 0.0; - std::size_t max_relDiff_ind = 0; - bool has_nan_mismatch = false; - - for (std::size_t i = 0; i < actual.size(); i++) { - double cur_absDiff = std::fabs(actual[i] - desired[i]); - - bool isnan_actual = std::isnan(actual[i]); - bool isnan_desired = std::isnan(desired[i]); - - if ((cur_absDiff > (atol + rtol * std::fabs(desired[i]))) || - (isnan_actual != isnan_desired)) { - num_mismatches++; - if (isnan_actual != isnan_desired) { - has_nan_mismatch = true; - max_absDiff = NAN; - max_absDiff_ind = i; - max_relDiff = NAN; - max_relDiff_ind = i; - } else if (!has_nan_mismatch) { + int max_relDiff_idx = -1; + + auto fn = [&](int flat_idx) { + bool either_nan = + std::isnan(actual[flat_idx]) || std::isnan(desired[flat_idx]); + any_nan = any_nan || either_nan; // <- record whether we have seen a NaN + + // record properties if there is a mismatch + if (!cmp_fn(actual[flat_idx], desired[flat_idx])) { + if (first_mismatch_idx == -1) { + first_mismatch_idx = flat_idx; + } + nan_mismatch_count += either_nan; + if (either_nan && first_nan_mismatch_idx == -1) { + first_nan_mismatch_idx = flat_idx; + } else if (!either_nan) { + double cur_absDiff = std::fabs(actual[flat_idx] - desired[flat_idx]); if (cur_absDiff > max_absDiff) { max_absDiff = cur_absDiff; - max_absDiff_ind = i; + max_absDiff_idx = flat_idx; } - if (cur_absDiff > (max_relDiff * std::fabs(desired[i]))) { - max_relDiff = cur_absDiff / std::fabs(desired[i]); - max_relDiff_ind = i; + if (cur_absDiff > (max_relDiff * std::fabs(desired[flat_idx]))) { + max_relDiff = cur_absDiff / std::fabs(desired[flat_idx]); + max_relDiff_idx = flat_idx; } } } + }; + flat_for_each(fn, idx_mapping.n_elements(), selection_mask); + + // now, let's construct the vector of details + std::vector details; + + if (first_mismatch_idx == -1) { + return details; // <- this is probably indicative of an error + } else { + details.emplace_back("first mismatch", + std::optional{first_mismatch_idx}); + } + + if (max_absDiff_idx == -1) { + details.emplace_back("Max abs diff: NaN (i.e. each mismatch involves NaN)", + std::nullopt); + } else { + details.emplace_back("Max abs diff: " + to_pretty_string(max_absDiff), + std::optional{max_absDiff_idx}); + } + + if (max_relDiff_idx == -1) { + details.emplace_back( + "Max rel diff: NaN (i.e. each mismatch involves NaN or has actual=0.0)", + std::nullopt); + } else { + details.emplace_back("Max rel diff: " + to_pretty_string(max_relDiff), + std::optional{max_relDiff_idx}); + } + + if (first_nan_mismatch_idx == -1) { + details.emplace_back(any_nan ? "all NaNs match" : "there are no NaNs", + std::nullopt); + } else { + details.emplace_back( + "First (of " + std::to_string(nan_mismatch_count) + ") NaN mismatch", + first_nan_mismatch_idx); } - if (num_mismatches == 0) { + return details; +} + +/// Returns a `testing::AssertionResult` instance specifying whether all pairs +/// of values from @p actual and @p desired pointers satisfy the comparison +/// operation specified by @p cmp_fn +/// +/// @tparam T is either `float` or `double` +/// @tparam Layout specifies the data-layout +/// @tparam Cmp Function-like type that does the underlying comparison. See the +/// description of the @p cmp_fn function for more details +/// +/// @param actual,desired The pointers being compared +/// @param idx_mapping Specifies information for treating the pointers as +/// contiguous multi-dimensional arrays. It maps between multi-dimensional +/// indices & pointer 1d offsets, and specifies all relevant information +/// for this mapping (i.e. extents and data layout) +/// @param selection_mask When specified, only the locations holding `true` +/// values are compared +/// @param cmp_fn "Callable" object that implements a function signature +/// equivalent to `bool fun(T actual, T desired)`. This signature is called +/// by passing pairs of values from the @p actual and @p desired pointers. +/// This should implement a member function called `describe_false` that +/// returns a `std::string` +template +testing::AssertionResult compare_ptrs_(const T* actual, const T* desired, + const IdxMapping& idx_mapping, + const bool* selection_mask, Cmp cmp_fn) { + GR_INTERNAL_REQUIRE(actual != nullptr && desired != nullptr, + "it's illegal to compare nullptr"); + // Part 1: perform the comparison (this is as fast as possible) + const int n_elements = idx_mapping.n_elements(); + int mismatch_num = 0; + int n_comparisons = 0; + auto loop_callback = [=, &mismatch_num, &n_comparisons](int flat_idx) { + n_comparisons++; + mismatch_num += !cmp_fn(actual[flat_idx], desired[flat_idx]); + }; + flat_for_each(loop_callback, n_elements, selection_mask); + + if (mismatch_num == 0) { return testing::AssertionSuccess(); } - std::string actual_vec_str = - grtest::ptr_to_string(actual.data(), actual.size()); - std::string ref_vec_str = - grtest::ptr_to_string(desired.data(), desired.size()); - - using grtest::to_pretty_string; - - return testing::AssertionFailure() - << "\narrays are unequal for the tolerance: " - << "rtol = " << to_pretty_string(rtol) << ", " - << "atol = " << to_pretty_string(atol) << '\n' - << "Mismatched elements: " << num_mismatches << " / " << actual.size() - << '\n' - << "Max absolute difference: " << to_pretty_string(max_absDiff) << ", " - << "ind = " << max_absDiff_ind << ", " - << "actual = " << to_pretty_string(actual[max_absDiff_ind]) << ", " - << "reference = " << to_pretty_string(desired[max_absDiff_ind]) << '\n' - << "Max relative difference: " << to_pretty_string(max_relDiff) << ", " - << "ind = " << max_absDiff_ind << ", " - << "actual = " << to_pretty_string(actual[max_relDiff_ind]) << ", " - << "desired = " << to_pretty_string(desired[max_relDiff_ind]) << '\n' - << "actual: " << actual_vec_str << '\n' - << "desired: " << ref_vec_str << '\n'; -} \ No newline at end of file + // Part 2: build the failure result and construct the detailed error message + // -> it's ok if this isn't extremely optimized. This logic shouldn't come up + // very frequently + testing::AssertionResult out = testing::AssertionFailure(); + + out << '\n' + << "arrays are " << cmp_fn.describe_false() << '\n' + << "index mapping: " << testing::PrintToString(idx_mapping) << '\n'; + out << "Mismatched elements: " << mismatch_num << " (" << n_comparisons + << " were compared"; + if (n_comparisons != n_elements) { + out << ", " << n_elements - n_comparisons << "ignored from masking"; + } + out << ")\n"; + + std::vector detail_vec = + collect_details_(actual, desired, idx_mapping, selection_mask, cmp_fn); + if (detail_vec.empty()) { + GR_INTERNAL_ERROR("something went wrong with finding mismatch details"); + } + + // now let's append the interesting mismatch details + for (const MismatchDetail& detail : detail_vec) { + if (!detail.flat_idx.has_value()) { + out << detail.description << '\n'; + continue; + } + int flat_idx = detail.flat_idx.value(); + + out << detail.description << ", "; + // write the index + int idx_components[IdxMapping::MAX_RANK]; + idx_mapping.offset_to_md_idx(flat_idx, idx_components); + out << "idx: {"; + for (int i = 0; i < idx_mapping.rank(); i++) { + out << idx_components[i]; + out << ((i + 1) < idx_mapping.rank() ? ',' : '}'); + } + out << ", "; + + // write the actual and description value + out << "actual = " << to_pretty_string(actual[flat_idx]) << ", " + << "desired = " << to_pretty_string(desired[flat_idx]) << '\n'; + } + + // print out final summary details + bool has_mask = selection_mask != nullptr; + out << "Flattened Ptr Details" + << (has_mask ? " (selection mask is ignored):\n" : ":\n") + << " actual: " << ptr_to_string(actual, idx_mapping) << '\n' + << " desired: " << ptr_to_string(desired, idx_mapping); + return out; +} + +} // anonymous namespace + +testing::AssertionResult compare_(CmpPack pack) { + // this function launches the appropriate specialization of compare_ptrs_ + // -> there are 3 template parameters to consider + // -> (see the docstring in the header file for a little more context) + + // load either (f32_actual, f32_desired) OR (f64_actual, f64_desired) + const float *f32_actual = nullptr, *f32_desired = nullptr; + const double *f64_actual = nullptr, *f64_desired = nullptr; + bool use_f32 = + std::holds_alternative>(pack.actual_desired_pair); + if (use_f32) { + f32_actual = std::get>(pack.actual_desired_pair).first; + f32_desired = std::get>(pack.actual_desired_pair).second; + } else { + f64_actual = std::get>(pack.actual_desired_pair).first; + f64_desired = std::get>(pack.actual_desired_pair).second; + } + + // Either idx_map_L OR idx_map_R will not be a nullptr + const IdxMapping* idx_map_L = + std::get_if>(&pack.idx_mapping); + const IdxMapping* idx_map_R = + std::get_if>(&pack.idx_mapping); + + // dispatcher_ is a "generic lambda" + // -> it acts as if the type of cmp_fn is a template parameter. + // -> when we pass it to std::visit, the cmp_fn argument is a copy of the + // alternative currently held by the `CmpPack::cmp_fn` variant + auto dispatcher_ = [&](auto cmp_fn) -> testing::AssertionResult { + const bool* smask = pack.selection_mask; + if (use_f32 && idx_map_L != nullptr) { + return compare_ptrs_(f32_actual, f32_desired, *idx_map_L, smask, cmp_fn); + } else if (use_f32 && idx_map_R != nullptr) { + return compare_ptrs_(f32_actual, f32_desired, *idx_map_R, smask, cmp_fn); + } else if (idx_map_L != nullptr) { + return compare_ptrs_(f64_actual, f64_desired, *idx_map_L, smask, cmp_fn); + } else if (idx_map_R != nullptr) { + return compare_ptrs_(f64_actual, f64_desired, *idx_map_R, smask, cmp_fn); + } else { + GR_INTERNAL_ERROR("should be unreachable"); + } + }; + return std::visit(dispatcher_, pack.cmp_fn); +} + +} // namespace grtest::arraycmp_detail \ No newline at end of file diff --git a/tests/grtestutils/googletest/check_allclose.hpp b/tests/grtestutils/googletest/check_allclose.hpp index dedb5049a..6b2686f59 100644 --- a/tests/grtestutils/googletest/check_allclose.hpp +++ b/tests/grtestutils/googletest/check_allclose.hpp @@ -12,18 +12,71 @@ #ifndef GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_HPP #define GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_HPP -#include #include + #include -/// this compares 2 std::vectors +#include "../view.hpp" +#include "./check_allclose_detail.hpp" + +#define COMPARE_(cmp_fn, ptr_pair, selection_mask, idx_mapping) \ + ::grtest::arraycmp_detail::compare_(::grtest::arraycmp_detail::CmpPack{ \ + {cmp_fn}, {ptr_pair}, selection_mask, {idx_mapping}}) + +/// Returns whether 2 pointers are exactly equal +/// +/// This draws a lot of inspiration from numpy.testing.assert_array_equal +template +testing::AssertionResult check_array_equal( + const float* actual, const float* desired, + grtest::IdxMapping idx_mapping, + const bool* selection_mask = nullptr) { + grtest::arraycmp_detail::FltIsEqual cmp_fn; + grtest::arraycmp_detail::PtrPair ptr_pair{actual, desired}; + return COMPARE_(cmp_fn, ptr_pair, selection_mask, idx_mapping); +} + +/// Returns whether 2 pointers are exactly equal +/// +/// This draws a lot of inspiration from numpy.testing.assert_array_equal +template +testing::AssertionResult check_array_equal( + const double* actual, const double* desired, + grtest::IdxMapping idx_mapping, + const bool* selection_mask = nullptr) { + grtest::arraycmp_detail::FltIsEqual cmp_fn; + grtest::arraycmp_detail::PtrPair ptr_pair{actual, desired}; + return COMPARE_(cmp_fn, ptr_pair, selection_mask, idx_mapping); +} + +/// compares 2 pointers /// /// This draws a lot of inspiration from numpy.testing.assert_allclose +template +testing::AssertionResult check_allclose(const float* actual, + const float* desired, + grtest::IdxMapping idx_mapping, + double rtol, double atol = 0.0, + const bool* selection_mask = nullptr) { + grtest::arraycmp_detail::FltIsClose cmp_fn(rtol, atol); + grtest::arraycmp_detail::PtrPair ptr_pair{actual, desired}; + return COMPARE_(cmp_fn, ptr_pair, selection_mask, idx_mapping); +} + +/// compares 2 pointers /// -/// Parts of this are fairly inefficient, partially because it is adapted from -/// code written from before we adopted googletest -testing::AssertionResult check_allclose(const std::vector& actual, - const std::vector& desired, - double rtol = 0.0, double atol = 0.0); +/// This draws a lot of inspiration from numpy.testing.assert_allclose +template +testing::AssertionResult check_allclose(const double* actual, + const double* desired, + grtest::IdxMapping idx_mapping, + double rtol, double atol = 0.0, + const bool* selection_mask = nullptr) { + grtest::arraycmp_detail::FltIsClose cmp_fn(rtol, atol); + grtest::arraycmp_detail::PtrPair ptr_pair{actual, desired}; + return COMPARE_(cmp_fn, ptr_pair, selection_mask, idx_mapping); +} + +#undef COMPARE_ #endif // GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_HPP diff --git a/tests/grtestutils/googletest/check_allclose_detail.hpp b/tests/grtestutils/googletest/check_allclose_detail.hpp new file mode 100644 index 000000000..cdeff25b4 --- /dev/null +++ b/tests/grtestutils/googletest/check_allclose_detail.hpp @@ -0,0 +1,149 @@ +//===----------------------------------------------------------------------===// +// +// See the LICENSE file for license and copyright information +// SPDX-License-Identifier: NCSA AND BSD-3-Clause +// +//===----------------------------------------------------------------------===// +/// +/// @file +/// Declares/implements a bunch of helper code to assist with implementing +/// logic in check_allclose.hpp +/// +//===----------------------------------------------------------------------===// +#ifndef GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_DETAIL_HPP +#define GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_DETAIL_HPP + +#include +#include +#include +#include +#include + +#include "../view.hpp" +#include "gtest/gtest.h" + +namespace grtest::arraycmp_detail { + +/// the goal is to avoid short-circuiting +template +[[gnu::always_inline]] inline bool both_nan_(T actual, T desired) { + bool actual_isnan = std::isnan(actual); + bool desired_isnan = std::isnan(desired); + return actual_isnan & desired_isnan; // use & to avoid &&'s short-circuiting +} + +template +[[gnu::always_inline]] inline bool isclose_(T actual, T desired, double rtol, + double atol) { + static_assert(std::is_floating_point_v); + + T abs_diff = std::fabs(actual - desired); + // the following variable is false if actual or desired (or both) is NaN + bool isclose = abs_diff <= atol + rtol * std::fabs(desired); + + if constexpr (EqualNan) { + return both_nan_(actual, desired) || isclose; + } else { + return isclose; + } +} + +/// "functor" to check if floating point values are equal within a tolerance +/// +/// This effectively implements numpy's isclose function (see +/// https://numpy.org/doc/stable/reference/generated/numpy.isclose.html). As in +/// the original function, the max allowed variations from the relative +/// difference tolerance and the absolute difference tolerance are summed and +/// compared against the absolute difference. +/// +/// @note +/// For less experienced C++ developers, `operator()` overloads the "function +/// call operation" (it is analogous to python's `__call__` method) +class FltIsClose { + double rtol; + double atol; + +public: + FltIsClose() = delete; + FltIsClose(double rtol, double atol) : rtol{rtol}, atol{atol} {} + + /// determines whether arguments are equal within the tolerance + bool operator()(float actual, float desired) const noexcept { + return isclose_(actual, desired, this->rtol, this->atol); + } + + /// determines whether arguments are equal within the tolerance + bool operator()(double actual, double desired) const noexcept { + return isclose_(actual, desired, this->rtol, this->atol); + } + + /// describe relationship between values for which this functor returns false + std::string describe_false() const { + std::string rtol_str = to_pretty_string(rtol); + std::string atol_str = to_pretty_string(atol); + return ("unequal for the tolerance (rtol = " + rtol_str + + ", atol = " + atol_str + ")"); + } +}; + +/// "functor" to check if floating point values are exactly equal +/// +/// @note +/// For less experienced C++ developers, `operator()` overloads the "function +/// call operation" (it is analogous to python's `__call__` method) +struct FltIsEqual { + /// determines whether arguments are exactly equal + bool operator()(float actual, float desired) const noexcept { + return (actual == desired) || both_nan_(actual, desired); + } + + /// determines whether arguments are exactly equal + bool operator()(double actual, double desired) const noexcept { + return (actual == desired) || both_nan_(actual, desired); + } + + /// describe relationship between values for which this functor returns false + std::string describe_false() const { return "not exactly equal"; } +}; + +template +using PtrPair = std::pair; + +/// Packages up the information for a comparison of 2 pointers +/// +/// See the docstring of @ref compare_ for an extended discussion for why +/// this type actually exists. +struct CmpPack { + std::variant cmp_fn; + std::variant, PtrPair> actual_desired_pair; + const bool* selection_mask; + std::variant, IdxMapping> + idx_mapping; +}; + +/// this dispatches the appropriate logic to drive the comparison +/// +/// The most pragmatic approach for implementing the underlying comparisons in +/// an extendable manner (without extensive code duplication or sacrificing +/// performance) is to implement them using templates and to make the datatype +/// a template parameter. +/// +/// This function was designed in a misguided attempt to shift most of the +/// implementation into source files in order to reduce compile times. In order +/// to hide all calls to a set of templates into a source file, this must be +/// a totally ordinary function that dispatches to the proper templates: +/// - thus, @ref CmpPack as a well-defined type to package up all of the +/// possible type combinations. This is achieved through the use of +/// std::variant (i.e. type-safe unions). +/// - the idea is that callers package up `CmpPack`, call this function, and +/// then function unpacks the values from `CmpPack` and dispatches to the +/// appropriate template function. +/// +/// With the benefit of hindsight, this was probably all a mistake... Reducing +/// the compilation cost was probably **NOT** worth the added complexity (if +/// nothing else, we probably should have measured it first...) +testing::AssertionResult compare_(CmpPack pack); + +} // namespace grtest::arraycmp_detail + +#endif // GRTESTUTILS_GOOGLETEST_CHECK_ALLCLOSE_DETAIL_HPP \ No newline at end of file diff --git a/tests/grtestutils/view.hpp b/tests/grtestutils/view.hpp index 88d03ec12..102cee1fc 100644 --- a/tests/grtestutils/view.hpp +++ b/tests/grtestutils/view.hpp @@ -16,11 +16,275 @@ #ifndef GRTESTUTILS_VIEW_HPP #define GRTESTUTILS_VIEW_HPP +#include #include // std::size_t +#include +#include #include +#include // std::pair + +#include "status_reporting.h" namespace grtest { +/// To be used with @ref IdxMapping +enum struct DataLayout { + LEFT, ///< the leftmost dimension has a stride 1 + RIGHT ///< the rightmost dimension has a stride 1 +}; + +/// Maps multi-dimensional indices to a 1D pointer offset +/// +/// Broader Context +/// =============== +/// To best describe this type, it's insightful to draw comparisons with C++ +/// conventions. +/// +/// Background +/// ---------- +/// For some background, C++23 introduced `std::mdspan` to describe +/// multi-dimensional views. A `std::mdspan` is parameterized by +/// - the data's extents (aka the shape) +/// - the data's layout, which dictates how a multidimensional index is mapped +/// to a 1D pointer offset +/// +/// For views of contiguous data there are 2 obvious layouts: +/// 1. layout-right: where the stride is `1` along the rightmost extent. +/// - for extents `{a,b,c}`, an optimal nested for-loop will iterates from +/// `0` up to `a` in the outermost loop and from `0` up to `c` +/// in the innermost loop +/// - this is the "natural layout" for a multidimensional c-style array +/// `arr[a][b][c]` +/// 2. layout-left: where the stride is `1` along the leftmost extent +/// - for extents `{a,b,c}`, an optimal nested for-loop will iterates from +/// `0` up to `c` in the outermost loop and from `0` up to `a` +/// in the innermost loop +/// - this is the "natural layout" for a multidimensional fortran array +/// `arr[a][b][c]` +/// +/// > Aside: More sophisticated layouts are possible when strides along each +/// > axis aren't directly tied to extents (this comes up when making subviews). +/// +/// About this type +/// --------------- +/// This type specifies the data layout, extents, and provides the mapping. You +/// draw analogies with C++23 types: +/// - `IdxMapping` <---> `std::layout_left::mapping` +/// - `IdxMapping` <---> `std::layout_right::mapping` +/// +/// At the time of writing, the type **only** represents contiguous data layouts +/// +/// @note +/// The template specialization using @ref DataLayout::RIGHT mostly exists for +/// exposition purposes +/// - it's useful to talk about this scenario since it is the "natural" layout +/// for C and C++ +/// - in practice, Grackle was written assuming @ref DataLayout::LEFT. Thus, +/// most loops will initially get written assuming that layout (but we always +/// have the option to conditionally provide the better kind of loop) +template +struct IdxMapping { + static_assert(Layout == DataLayout::LEFT || Layout == DataLayout::RIGHT, + "A new layout type was introduced that we don't yet support"); + static constexpr DataLayout layout = Layout; + + /// max rank value + static constexpr int MAX_RANK = 3; + +private: + // attributes: + int rank_; ///< the number of dimensions + int extents_[MAX_RANK]; ///< dimensions of the multi-dimensional index-space + + // private functions: + + // default constructor is only invoked by factory methods + // -> this explicitly set each extent_ to a value of 0 + IdxMapping() : rank_(0), extents_{} {} + + // the object returned by the create_ factory method: + // - `out.second == nullptr` indicates the method succeeded + // - otherwise, `out.second` is a string-literal specifying an error message + using InnerMappingMsgPair_ = std::pair, const char*>; + + static InnerMappingMsgPair_ create_(int rank, const int* extents) noexcept { + // arg checking + if ((rank < 1) || (rank > MAX_RANK)) { + return {IdxMapping(), "rank is invalid"}; + } else if (extents == nullptr) { + return {IdxMapping(), "extents is a nullptr"}; + } + for (int i = 0; i < rank; i++) { + if (extents[i] < 1) { + return {IdxMapping(), "extents must hold positive vals"}; + } + } + + // build and return the mapping + IdxMapping mapping; + mapping.rank_ = rank; + for (int i = 0; i < rank; i++) { + mapping.extents_[i] = extents[i]; + } + return {mapping, nullptr}; + } + + /// factory method that aborts with an error message upon failure + static IdxMapping create_or_abort_(int rank, const int* extents) { + InnerMappingMsgPair_ pair = IdxMapping::create_(rank, extents); + if (pair.second != nullptr) { + GR_INTERNAL_ERROR("%s", pair.second); + } + return pair.first; + } + +public: + /// factory method that returns an empty optional upon failure + static std::optional> try_create(int rank, + const int* extents) { + InnerMappingMsgPair_ pair = IdxMapping::create_(rank, extents); + return (pair.second == nullptr) + ? std::optional>{pair.first} + : std::optional>{}; + } + + explicit IdxMapping(int extent0) { + // for less experienced C++ devs: the `explicit` kwarg prevents the use of + // this constructor for implicit casts + int extents[1] = {extent0}; + *this = create_or_abort_(1, extents); + } + + IdxMapping(int extent0, int extent1) { + int extents[2] = {extent0, extent1}; + *this = create_or_abort_(2, extents); + } + + IdxMapping(int extent0, int extent1, int extent2) { + int extents[3] = {extent0, extent1, extent2}; + *this = create_or_abort_(3, extents); + } + + IdxMapping(const IdxMapping&) = default; + IdxMapping(IdxMapping&&) = default; + IdxMapping& operator=(const IdxMapping&) = default; + IdxMapping& operator=(IdxMapping&&) = default; + ~IdxMapping() = default; + + /// access the rank + int rank() const noexcept { return rank_; } + + /// access the extents pointer (the length is given by @ref rank) + const int* extents() const noexcept { return extents_; } + + /// construct an equivalent 3d IdxMapping + IdxMapping to_3d_mapping() const noexcept { + int out_rank = 3; + int rank_diff = out_rank - this->rank_; + GR_INTERNAL_REQUIRE(rank_diff >= 0, "current rank exceeds new rank"); + + IdxMapping out; + out.rank_ = out_rank; + for (int i = 0; i < out_rank; i++) { + out.extents_[i] = 1; + } + int offset = (Layout == DataLayout::RIGHT) ? rank_diff : 0; + for (int i = out; i < this->rank_; i++) { + out.extents_[i + offset] = this->extents_[i]; + } + return out; + } + + static constexpr bool is_contiguous() { return true; } + + int n_elements() const noexcept { + int product = 1; + for (int i = 0; i < this->rank_; i++) { + product *= this->extents_[i]; + } + return product; + } + + /** @{ */ // <- open the group of member functions with a shared docstring + /// compute the 1D pointer offset associated with the multidimensional index + /// + /// @note + /// For less experienced C++ devs, these methods overloads the "function call + /// operator". In python they would be named `__call__(self, ...)` + /// + /// Behavior is undefined if the number of arguments doesn't match the value + /// returned by `this->rank()` + [[gnu::always_inline]] int operator()(int i) const noexcept { + assert(this->rank_ == 1); + return i; + } + + [[gnu::always_inline]] int operator()(int i, int j) const noexcept { + assert(this->rank_ == 2); + if constexpr (Layout == DataLayout::LEFT) { + return i + this->extents_[0] * j; + } else { // Layout == DataLayout::RIGHT + return j + this->extents_[1] * i; + } + } + + [[gnu::always_inline]] int operator()(int i, int j, int k) const noexcept { + assert(this->rank_ == 3); + if constexpr (Layout == DataLayout::LEFT) { + return i + this->extents_[0] * (j + k * this->extents_[1]); + } else { // Layout == DataLayout::RIGHT + return k + this->extents_[2] * (j + i * this->extents_[1]); + } + } + /** @} */ // <- close the group of member functions with a shared docstring + + /// Convert a 1D pointer offset to a multidimensional index + /// + /// The number of components for the output index is given by calling the + /// @ref rank method. Behavior is undefined if @p out doesn't have enough + /// space. + /// + /// @param[in] offset The 1D pointer offset (must be non-negative) + /// @param[out] out Buffer where result is stored + void offset_to_md_idx(int offset, int* out) const noexcept { + assert(out != nullptr); + + const int contig_ax = (Layout == DataLayout::LEFT) ? 0 : (this->rank_ - 1); + const int slowest_ax = (Layout == DataLayout::LEFT) ? (this->rank_ - 1) : 0; + + switch (this->rank_) { + case 1: + out[0] = offset; + return; + case 2: + out[slowest_ax] = offset / this->extents_[contig_ax]; + out[contig_ax] = offset % this->extents_[contig_ax]; + return; + case 3: { + int largest_stride = this->extents_[contig_ax] * this->extents_[1]; + out[slowest_ax] = offset / largest_stride; + int remainder = offset % largest_stride; + out[1] = remainder / this->extents_[contig_ax]; + out[contig_ax] = remainder % this->extents_[contig_ax]; + return; + } + default: + GR_INTERNAL_ERROR("should be unreachable"); + } + } + + // teach googletest how to print this type + friend void PrintTo(const IdxMapping& mapping, std::ostream* os) { + const char* layout = + (Layout == DataLayout::LEFT) ? "DataLayout::LEFT" : "DataLayout::RIGHT"; + *os << "IdxMapping<" << layout << ">("; + for (int i = 0; i < mapping.rank_; i++) { + *os << mapping.extents_[i]; + *os << ((i + 1) < mapping.rank_ ? ',' : ')'); + } + } +}; + /// equivalent of converting output of printf("%g", val) to std::string std::string to_pretty_string(float val); std::string to_pretty_string(double val); @@ -29,6 +293,20 @@ std::string to_pretty_string(double val); std::string ptr_to_string(const float* ptr, std::size_t len); std::string ptr_to_string(const double* ptr, std::size_t len); +/// formats a pointer as a string (current implementation prints a 1D array) +template +std::string ptr_to_string(const float* ptr, IdxMapping idx_mapping) { + static_assert(IdxMapping::is_contiguous()); + return ptr_to_string(ptr, idx_mapping.n_elements()); +} + +/// formats a pointer as a string (current implementation prints a 1D array) +template +std::string ptr_to_string(const double* ptr, IdxMapping idx_mapping) { + static_assert(IdxMapping::is_contiguous()); + return ptr_to_string(ptr, idx_mapping.n_elements()); +} + } // namespace grtest #endif // GRTESTUTILS_VIEW_HPP \ No newline at end of file diff --git a/tests/unit/test_linalg.cpp b/tests/unit/test_linalg.cpp index 234e5dade..cdb31c841 100644 --- a/tests/unit/test_linalg.cpp +++ b/tests/unit/test_linalg.cpp @@ -6,7 +6,6 @@ #include "grtestutils/googletest/check_allclose.hpp" #include "grtestutils/view.hpp" - /// Records the paramters for a linear algebra test-case struct LinAlgCase { // attributes @@ -72,7 +71,10 @@ TEST_P(LinAlgSolve, Check) { ASSERT_EQ(rslt, 0) << "expected a return-code of 0, which indicates " << "that the linear equations were successfully solved"; - EXPECT_TRUE(check_allclose(/* actual: */ vec, my_case.solution_vector, + grtest::IdxMapping idx_mapping(vec.size()); + EXPECT_TRUE(check_allclose(/* actual: */ vec.data(), + /* desired: */ my_case.solution_vector.data(), + /* idx_mapping: */ idx_mapping, /* rtol: */1e-15, /* atol: */ 0.0)); }