From 4d2703a390fa5e64d443d7fb4494f74ff9fb75db Mon Sep 17 00:00:00 2001 From: Stef Sijben Date: Wed, 22 Nov 2023 14:03:53 +0100 Subject: [PATCH 1/3] Add more `xaxis_*_iterator` test cases Test not only `xarray` inputs, but also `xtensor` and `xtensor_fixed`. These are currently failing. Also test some more cases with `column_major` inputs. --- test/test_xaxis_iterator.cpp | 49 +++++++++++++-------- test/test_xaxis_slice_iterator.cpp | 68 ++++++++++++++++-------------- 2 files changed, 67 insertions(+), 50 deletions(-) diff --git a/test/test_xaxis_iterator.cpp b/test/test_xaxis_iterator.cpp index 1cc764d7d..e11dd3278 100644 --- a/test/test_xaxis_iterator.cpp +++ b/test/test_xaxis_iterator.cpp @@ -8,25 +8,38 @@ ****************************************************************************/ #include "xtensor/xarray.hpp" +#include "xtensor/xtensor.hpp" +#include "xtensor/xfixed.hpp" #include "xtensor/xaxis_iterator.hpp" #include "test_common_macros.hpp" +#define ROW_TYPES \ + xarray, \ + xtensor, \ + xtensor_fixed, layout_type::row_major> +#define COL_TYPES \ + xarray, \ + xtensor, \ + xtensor_fixed, layout_type::column_major> +#define ALL_TYPES ROW_TYPES, COL_TYPES + namespace xt { using std::size_t; - xarray get_test_array() + template> + T get_test_array() { - xarray res = { + T res = { {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}; return res; } - TEST(xaxis_iterator, begin) + TEST_CASE_TEMPLATE("xaxis_iterator.begin", T, ALL_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter_begin = axis_begin(a); EXPECT_EQ(size_t(2), iter_begin->dimension()); EXPECT_EQ(a.shape()[1], iter_begin->shape()[0]); @@ -37,9 +50,9 @@ namespace xt EXPECT_EQ(a(0, 2, 3), (*iter_begin)(2, 3)); } - TEST(xaxis_iterator, increment) + TEST_CASE_TEMPLATE("xaxis_iterator.increment", T, ROW_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter = axis_begin(a); ++iter; @@ -52,9 +65,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(2, 3)); } - TEST(xaxis_iterator, end) + TEST_CASE_TEMPLATE("xaxis_iterator.end", T, ALL_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter_begin = axis_begin(a, 1u); auto iter_end = axis_end(a, 1u); auto dist = std::distance(iter_begin, iter_end); @@ -80,9 +93,9 @@ namespace xt EXPECT_EQ(iter_begin_row, iter_end_row); } - TEST(xaxis_iterator, nested) + TEST_CASE_TEMPLATE("xaxis_iterator.nested", T, ROW_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter = axis_begin(a); ++iter; auto niter = axis_begin(*iter); @@ -95,9 +108,9 @@ namespace xt EXPECT_EQ(a(1, 1, 3), (*niter)(3)); } - TEST(xaxis_iterator, const_array) + TEST_CASE_TEMPLATE("xaxis_iterator.const_array", T, ROW_TYPES) { - const xarray a = get_test_array(); + const T a = get_test_array(); auto iter = axis_begin(a); ++iter; @@ -110,9 +123,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(2, 3)); } - TEST(xaxis_iterator, axis_0) + TEST_CASE_TEMPLATE("xaxis_iterator.axis_0", T, ROW_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter = axis_begin(a, 0); EXPECT_EQ(a(0, 0, 0), (*iter)(0, 0)); @@ -142,9 +155,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(2, 3)); } - TEST(xaxis_iterator, axis_1) + TEST_CASE_TEMPLATE("xaxis_iterator.axis_1", T, ROW_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter = axis_begin(a, 1u); EXPECT_EQ(a(0, 0, 0), (*iter)(0, 0)); @@ -175,9 +188,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(1, 3)); } - TEST(xaxis_iterator, axis_2) + TEST_CASE_TEMPLATE("xaxis_iterator.axis_2", T, ROW_TYPES) { - xarray a = get_test_array(); + T a = get_test_array(); auto iter = axis_begin(a, 2u); EXPECT_EQ(a(0, 0, 0), (*iter)(0, 0)); diff --git a/test/test_xaxis_slice_iterator.cpp b/test/test_xaxis_slice_iterator.cpp index 242822c50..2494899ed 100644 --- a/test/test_xaxis_slice_iterator.cpp +++ b/test/test_xaxis_slice_iterator.cpp @@ -8,25 +8,39 @@ ****************************************************************************/ #include "xtensor/xarray.hpp" +#include "xtensor/xfixed.hpp" +#include "xtensor/xtensor.hpp" #include "xtensor/xaxis_slice_iterator.hpp" #include "test_common_macros.hpp" +#define ROW_TYPES \ + xarray, \ + xtensor, \ + xtensor_fixed, layout_type::row_major> +#define COL_TYPES \ + xarray, \ + xtensor, \ + xtensor_fixed, layout_type::column_major> +#define ALL_TYPES ROW_TYPES, COL_TYPES + namespace xt { using std::size_t; + constexpr auto _col = layout_type::column_major; - xarray get_slice_test_array() + template> + T get_slice_test_array() { - xarray res = { + T res = { {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, {{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}; return res; } - TEST(xaxis_slice_iterator, begin) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.begin", T, ALL_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter_begin = axis_slice_begin(a, 0); EXPECT_EQ(size_t(1), iter_begin->dimension()); EXPECT_EQ(a.shape()[0], iter_begin->shape()[0]); @@ -34,33 +48,23 @@ namespace xt EXPECT_EQ(a(1, 0, 0), (*iter_begin)(1)); } - TEST(xaxis_slice_iterator, end) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.end", T, ALL_TYPES) { - xarray a = get_slice_test_array(); - xarray a_col = get_slice_test_array(); + T a = get_slice_test_array(); auto dist = std::distance(axis_slice_begin(a, 0), axis_slice_end(a, 0)); EXPECT_EQ(12, dist); - dist = std::distance(axis_slice_begin(a_col), axis_slice_end(a_col)); - EXPECT_EQ(12, dist); - dist = std::distance(axis_slice_begin(a, 1), axis_slice_end(a, 1)); EXPECT_EQ(8, dist); - dist = std::distance(axis_slice_begin(a_col, 1), axis_slice_end(a_col, 1)); - EXPECT_EQ(8, dist); - dist = std::distance(axis_slice_begin(a, 2), axis_slice_end(a, 2)); EXPECT_EQ(6, dist); - - dist = std::distance(axis_slice_begin(a_col, 2), axis_slice_end(a_col, 2)); - EXPECT_EQ(6, dist); } - TEST(xaxis_slice_iterator, increment) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.increment", T, ROW_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, 0); ++iter; @@ -71,9 +75,9 @@ namespace xt EXPECT_EQ(a(1, 0, 1), (*iter)(1)); } - TEST(xaxis_slice_iterator, const_array) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.const_array", T, ROW_TYPES) { - const xarray a = get_slice_test_array(); + const T a = get_slice_test_array(); auto iter = axis_slice_begin(a, 2); ++iter; @@ -86,9 +90,9 @@ namespace xt EXPECT_EQ(a(0, 1, 3), (*iter)(3)); } - TEST(xaxis_slice_iterator, axis_0) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_0", T, ROW_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(0)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); @@ -128,9 +132,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(1)); } - TEST(xaxis_slice_iterator, axis_0_col) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_0_col", T, COL_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(0)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); @@ -170,9 +174,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(1)); } - TEST(xaxis_slice_iterator, axis_1) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_1", T, ROW_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(1)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); @@ -208,9 +212,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(2)); } - TEST(xaxis_slice_iterator, axis_1_col) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_1_col", T, COL_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(1)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); @@ -246,9 +250,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(2)); } - TEST(xaxis_slice_iterator, axis_2) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_2", T, ROW_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(2)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); @@ -282,9 +286,9 @@ namespace xt EXPECT_EQ(a(1, 2, 3), (*iter)(3)); } - TEST(xaxis_slice_iterator, axis_2_col) + TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_2_col", T, COL_TYPES) { - xarray a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(2)); EXPECT_EQ(a(0, 0, 0), (*iter)(0)); From bb95ed0d522c7bc0e8c6d00b5328005f1dac397b Mon Sep 17 00:00:00 2001 From: Stef Sijben Date: Wed, 22 Nov 2023 14:03:53 +0100 Subject: [PATCH 2/3] Fix `xaxis_*_iterator` shape/stride (#2116) An `xaxis_slice_iterator` always refers to a 1d view, so just use an array of size 1 for the shape and stride. Another type would probably be more optimal in case of compile-time fixed size (e.g. `xtensor_fixed`), but at least this is correct. Always use runtime dimensionality `xaxis_iterator` shape and strides. Other types would probably be more optimal in case of compile-time fixed dimension and/or size (e.g. `xtensor`, `xtensor_fixed`), but at least this is correct. --- include/xtensor/xaxis_iterator.hpp | 6 +++--- include/xtensor/xaxis_slice_iterator.hpp | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/xtensor/xaxis_iterator.hpp b/include/xtensor/xaxis_iterator.hpp index 953206aaf..ad8f16fc6 100644 --- a/include/xtensor/xaxis_iterator.hpp +++ b/include/xtensor/xaxis_iterator.hpp @@ -39,7 +39,7 @@ namespace xt using xexpression_type = std::decay_t; using size_type = typename xexpression_type::size_type; using difference_type = typename xexpression_type::difference_type; - using shape_type = typename xexpression_type::shape_type; + using shape_type = std::vector; using value_type = xstrided_view; using reference = std::remove_reference_t>; using pointer = xtl::xclosure_pointer>>; @@ -106,8 +106,8 @@ namespace xt ) { using xexpression_type = std::decay_t; - using shape_type = typename xexpression_type::shape_type; - using strides_type = typename xexpression_type::strides_type; + using shape_type = std::vector; + using strides_type = std::vector; const auto& e_shape = e.shape(); shape_type shape(e_shape.size() - 1); diff --git a/include/xtensor/xaxis_slice_iterator.hpp b/include/xtensor/xaxis_slice_iterator.hpp index 8a9814df6..7354e83a7 100644 --- a/include/xtensor/xaxis_slice_iterator.hpp +++ b/include/xtensor/xaxis_slice_iterator.hpp @@ -34,8 +34,8 @@ namespace xt using xexpression_type = std::decay_t; using size_type = typename xexpression_type::size_type; using difference_type = typename xexpression_type::difference_type; - using shape_type = typename xexpression_type::shape_type; - using strides_type = typename xexpression_type::strides_type; + using shape_type = std::array; + using strides_type = std::array; using value_type = xstrided_view; using reference = std::remove_reference_t>; using pointer = xtl::xclosure_pointer>>; From ab8326dfb18c8d8f068c20b649e7fbea6a0bd71f Mon Sep 17 00:00:00 2001 From: Stef Sijben Date: Thu, 23 Nov 2023 08:55:57 +0100 Subject: [PATCH 3/3] Fix clang-format errors. --- test/test_xaxis_iterator.cpp | 20 +++++++++----------- test/test_xaxis_slice_iterator.cpp | 20 +++++++++----------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/test/test_xaxis_iterator.cpp b/test/test_xaxis_iterator.cpp index e11dd3278..19633bf10 100644 --- a/test/test_xaxis_iterator.cpp +++ b/test/test_xaxis_iterator.cpp @@ -8,27 +8,25 @@ ****************************************************************************/ #include "xtensor/xarray.hpp" -#include "xtensor/xtensor.hpp" -#include "xtensor/xfixed.hpp" #include "xtensor/xaxis_iterator.hpp" +#include "xtensor/xfixed.hpp" +#include "xtensor/xtensor.hpp" #include "test_common_macros.hpp" -#define ROW_TYPES \ - xarray, \ - xtensor, \ - xtensor_fixed, layout_type::row_major> -#define COL_TYPES \ - xarray, \ - xtensor, \ - xtensor_fixed, layout_type::column_major> +#define ROW_TYPES \ + xarray, xtensor, \ + xtensor_fixed, layout_type::row_major> +#define COL_TYPES \ + xarray, xtensor, \ + xtensor_fixed, layout_type::column_major> #define ALL_TYPES ROW_TYPES, COL_TYPES namespace xt { using std::size_t; - template> + template > T get_test_array() { T res = { diff --git a/test/test_xaxis_slice_iterator.cpp b/test/test_xaxis_slice_iterator.cpp index 2494899ed..b165b2cf9 100644 --- a/test/test_xaxis_slice_iterator.cpp +++ b/test/test_xaxis_slice_iterator.cpp @@ -8,20 +8,18 @@ ****************************************************************************/ #include "xtensor/xarray.hpp" +#include "xtensor/xaxis_slice_iterator.hpp" #include "xtensor/xfixed.hpp" #include "xtensor/xtensor.hpp" -#include "xtensor/xaxis_slice_iterator.hpp" #include "test_common_macros.hpp" -#define ROW_TYPES \ - xarray, \ - xtensor, \ - xtensor_fixed, layout_type::row_major> -#define COL_TYPES \ - xarray, \ - xtensor, \ - xtensor_fixed, layout_type::column_major> +#define ROW_TYPES \ + xarray, xtensor, \ + xtensor_fixed, layout_type::row_major> +#define COL_TYPES \ + xarray, xtensor, \ + xtensor_fixed, layout_type::column_major> #define ALL_TYPES ROW_TYPES, COL_TYPES namespace xt @@ -29,7 +27,7 @@ namespace xt using std::size_t; constexpr auto _col = layout_type::column_major; - template> + template > T get_slice_test_array() { T res = { @@ -134,7 +132,7 @@ namespace xt TEST_CASE_TEMPLATE("xaxis_slice_iterator.axis_0_col", T, COL_TYPES) { - T a = get_slice_test_array(); + T a = get_slice_test_array(); auto iter = axis_slice_begin(a, size_t(0)); EXPECT_EQ(a(0, 0, 0), (*iter)(0));