Skip to content

Commit f4bce2e

Browse files
committed
Update and fix subsetting logic
- Update regex to allow negative indices e.g [-5] - Fix logic with -1 used as the default for 'all', when it is omitted from the parser. - Change the constructor of SubsetInfo to use optionals - Update the types of m_start,m_stop,m_size to be unsigned when the object is created. Wrapping is handled in the constructor - Fix some tests in parse_slices_test.cpp - anddd finally, add a boat load of tests for subsetting
1 parent ed61209 commit f4bce2e

File tree

6 files changed

+194
-30
lines changed

6 files changed

+194
-30
lines changed

src/utils/subset.cpp

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,25 +23,41 @@ int64_t to_int(std::optional<std::string> value, int64_t default_value)
2323
}
2424

2525
constexpr auto token_re = ctll::fixed_string{R"(\[([^\[\]]*)\])"};
26-
constexpr auto index_re = ctll::fixed_string{R"((\d+))"};
27-
constexpr auto slice_re = ctll::fixed_string{R"((\d*)(:(-?\d*)(:(-?\d*))?)?)"};
26+
constexpr auto index_re = ctll::fixed_string{R"((-?\d+))"};
27+
constexpr auto slice_re = ctll::fixed_string{R"((-?\d*)(:(-?\d*)(:(-?\d*))?)?)"};
2828

2929
template <typename T> libtokamap::SubsetInfo parse_slice(const T& slice, size_t dimension)
3030
{
3131
const auto& index_match = ctre::match<index_re>(slice);
3232
if (index_match) {
3333
int64_t index = to_int(index_match.template get<1>().to_optional_string(), 0);
34-
auto subset = libtokamap::SubsetInfo{index, index + 1, 1, dimension};
34+
35+
// AP: do we want to normalise here instead of in the actual subset class??
36+
auto normalised_index = (index < 0) ? dimension + index : index;
37+
auto subset = libtokamap::SubsetInfo{normalised_index, normalised_index + 1, 1, dimension};
3538
if (!subset.validate()) {
3639
throw libtokamap::ProcessingError{"invalid subset: " + slice.to_string()};
3740
}
3841
return subset;
3942
}
4043
const auto& slice_match = ctre::match<slice_re>(slice);
4144
if (slice_match) {
42-
int64_t start = to_int(slice_match.template get<1>().to_optional_string(), 0);
43-
int64_t stop = to_int(slice_match.template get<3>().to_optional_string(), -1);
4445
int64_t stride = to_int(slice_match.template get<5>().to_optional_string(), 1);
46+
47+
// m_start: if omitted, pass std::nullopt
48+
auto start_str = slice_match.template get<1>().to_optional_string();
49+
std::optional<int64_t> start;
50+
if (start_str && !start_str.value().empty()) {
51+
start = std::stoi(start_str.value());
52+
}
53+
54+
// m_stop: if omitted, pass std::nullopt
55+
auto stop_str = slice_match.template get<3>().to_optional_string();
56+
std::optional<int64_t> stop;
57+
if (stop_str && !stop_str.value().empty()) {
58+
stop = std::stoi(stop_str.value());
59+
}
60+
4561
auto subset = libtokamap::SubsetInfo{start, stop, stride, dimension};
4662
if (!subset.validate()) {
4763
throw libtokamap::ProcessingError{"invalid subset: " + slice.to_string()};

src/utils/typed_data_array.cpp

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,22 +38,26 @@ IndicesList generate_indices(const std::vector<libtokamap::SubsetInfo>& subsets)
3838
for (int64_t i = static_cast<int64_t>(n_dims) - 1; i >= 0; --i) {
3939
current[i] += subsets[i].stride();
4040

41-
// Check if we're still within bounds (handles both positive and negative strides)
4241
bool within_bounds = false;
4342
if (subsets[i].stride() > 0) {
4443
within_bounds = (current[i] < subsets[i].stop());
4544
} else {
46-
// stops unsigned integer wraparound --> infinite loop
47-
within_bounds = (current[i] <= subsets[i].dim_size() && current[i] > subsets[i].stop());
45+
// For negative stride, stop can be UINT64_MAX - handle separately!
46+
if (subsets[i].stop() == std::numeric_limits<uint64_t>::max()) {
47+
// Go down to 0 inclusive
48+
within_bounds = (current[i] < subsets[i].dim_size());
49+
} else {
50+
within_bounds = (current[i] <= subsets[i].dim_size() && current[i] > subsets[i].stop());
51+
}
4852
}
4953

5054
if (within_bounds) {
51-
break; // no carry needed
55+
break;
5256
}
5357
if (i == 0) {
54-
done = true; // we're finished
58+
done = true;
5559
} else {
56-
current[i] = subsets[i].start(); // reset and carry to next dimension
60+
current[i] = subsets[i].start();
5761
}
5862
}
5963
}

src/utils/typed_data_array.hpp

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -115,18 +115,36 @@ inline DataType type_index_map(std::type_index type_index)
115115
class SubsetInfo
116116
{
117117
public:
118-
SubsetInfo(int64_t start, int64_t stop, int64_t stride, size_t size)
119-
: m_start{start}, m_stop{stop}, m_stride{stride}, m_dim_size{static_cast<int64_t>(size)}
118+
SubsetInfo(std::optional<int64_t> start, std::optional<int64_t> stop, int64_t stride, size_t size)
119+
: m_stride{stride}, m_dim_size{size}
120120
{
121-
if (size > std::numeric_limits<int64_t>::max()) {
122-
throw libtokamap::ProcessingError{"dimension size too large"};
121+
if (stride == 0) {
122+
throw libtokamap::ProcessingError{"stride of 0 is not allowed, apologies"};
123123
}
124-
// negative indexes mean that many elements from the end
125-
if (start < 0) {
126-
m_start = m_dim_size + start;
124+
125+
// m_start
126+
if (!start.has_value()) {
127+
// If start omitted, need to default to values based on stride
128+
m_start = (stride > 0) ? 0 : m_dim_size - 1;
129+
} else if (start.value() < 0) {
130+
m_start = m_dim_size + start.value();
131+
} else {
132+
m_start = start.value();
127133
}
128-
if (stop < 0) {
129-
m_stop = m_dim_size + stop + 1;
134+
135+
// m_stop
136+
if (!stop.has_value()) {
137+
// If stop omitted, need to default to values based on stride
138+
if (stride > 0) {
139+
m_stop = m_dim_size;
140+
} else {
141+
// Dummy flag value to know when to go all the way to INCLUDE zeroth index
142+
m_stop = std::numeric_limits<uint64_t>::max();
143+
}
144+
} else if (stop.value() < 0) {
145+
m_stop = m_dim_size + stop.value();
146+
} else {
147+
m_stop = stop.value();
130148
}
131149
}
132150

@@ -139,7 +157,10 @@ class SubsetInfo
139157
size = (m_stop - m_start + m_stride - 1) / m_stride;
140158
}
141159
} else if (m_stride < 0) {
142-
if (m_start > m_stop) {
160+
if (m_stop == std::numeric_limits<uint64_t>::max()) {
161+
// As above
162+
size = (m_start + (-m_stride)) / (-m_stride);
163+
} else if (m_start > m_stop) {
143164
size = (m_start - m_stop - m_stride - 1) / (-m_stride);
144165
}
145166
}
@@ -148,8 +169,10 @@ class SubsetInfo
148169

149170
[[nodiscard]] bool validate() const
150171
{
151-
bool valid_stride = m_stride >= 0 ? m_start <= m_stop : m_stop <= m_start;
152-
return m_start <= m_dim_size - 1 && m_stop <= m_dim_size && m_stride < m_dim_size && valid_stride;
172+
bool valid_stride = m_stride > 0
173+
? (m_start < m_dim_size && m_start <= m_stop && m_stop <= m_dim_size)
174+
: (m_stop == std::numeric_limits<uint64_t>::max() || (m_stop <= m_start && m_start < m_dim_size));
175+
return valid_stride;
153176
}
154177

155178
[[nodiscard]] uint64_t start() const { return m_start; }
@@ -161,10 +184,10 @@ class SubsetInfo
161184
[[nodiscard]] uint64_t dim_size() const { return m_dim_size; }
162185

163186
private:
164-
int64_t m_start;
165-
int64_t m_stop;
187+
uint64_t m_start;
188+
uint64_t m_stop;
166189
int64_t m_stride = 1;
167-
int64_t m_dim_size;
190+
uint64_t m_dim_size;
168191
};
169192

170193
std::vector<size_t> compute_offsets(const std::vector<size_t>& shape, const std::vector<SubsetInfo>& subsets);

test/src/parse_slices_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ TEST_CASE("Parsing of slice strings", "[slice]") {
5252

5353
REQUIRE(slices.size() == 1);
5454
REQUIRE(slices[0].start() == 3);
55-
REQUIRE(slices[0].stop() == 9);
55+
REQUIRE(slices[0].stop() == 8);
5656
REQUIRE(slices[0].stride() == 1);
5757
}
5858

test/src/subset_test.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,77 @@ TEST_CASE("Subset 2D to 1D array", "[subset]")
265265
std::vector<float> expected = {905.0, 705.0, 505.0, 305.0, 105.0};
266266
REQUIRE(result == expected);
267267
}
268+
269+
SECTION("Negative index for row [-1][:]")
270+
{
271+
TypedDataArray array{data, {rows, cols}};
272+
auto subsets = parse_slices("[-1][:]", array.shape());
273+
274+
array.slice<float>(subsets);
275+
276+
// Result: 1D array with last row (row 9)
277+
REQUIRE(array.rank() == 1);
278+
REQUIRE(array.size() == cols);
279+
REQUIRE(array.shape() == std::vector<size_t>{cols});
280+
281+
auto result = array.to_vector<float>();
282+
std::vector<float> expected = {900.0, 901.0, 902.0, 903.0, 904.0, 905.0, 906.0, 907.0, 908.0, 909.0, 910.0, 911.0, 912.0, 913.0, 914.0};
283+
REQUIRE(result == expected);
284+
}
285+
286+
SECTION("Negative index for column [:][-1]")
287+
{
288+
TypedDataArray array{data, {rows, cols}};
289+
auto subsets = parse_slices("[:][-1]", array.shape());
290+
291+
array.slice<float>(subsets);
292+
293+
// Result: 1D array with last column (column 14)
294+
REQUIRE(array.rank() == 1);
295+
REQUIRE(array.size() == rows);
296+
REQUIRE(array.shape() == std::vector<size_t>{rows});
297+
298+
auto result = array.to_vector<float>();
299+
std::vector<float> expected = {14.0, 114.0, 214.0, 314.0, 414.0, 514.0, 614.0, 714.0, 814.0, 914.0};
300+
REQUIRE(result == expected);
301+
}
302+
303+
SECTION("Negative indices in slice [-3:-1:1][:]")
304+
{
305+
TypedDataArray array{data, {rows, cols}};
306+
auto subsets = parse_slices("[-3:-1:1][:]", array.shape());
307+
308+
array.slice<float>(subsets);
309+
310+
// -3 = row 7, -1 = row 9, stop is exclusive so rows 7, 8
311+
REQUIRE(array.rank() == 2);
312+
REQUIRE(array.size() == 2 * cols);
313+
REQUIRE(array.shape() == std::vector<size_t>{2, cols});
314+
315+
auto result = array.to_vector<float>();
316+
std::vector<float> expected = {
317+
700.0, 701.0, 702.0, 703.0, 704.0, 705.0, 706.0, 707.0, 708.0, 709.0, 710.0, 711.0, 712.0, 713.0, 714.0,
318+
800.0, 801.0, 802.0, 803.0, 804.0, 805.0, 806.0, 807.0, 808.0, 809.0, 810.0, 811.0, 812.0, 813.0, 814.0
319+
};
320+
REQUIRE(result == expected);
321+
}
322+
323+
SECTION("Full reverse with negative stride [::-1][5]")
324+
{
325+
TypedDataArray array{data, {rows, cols}};
326+
auto subsets = parse_slices("[::-1][5]", array.shape());
327+
328+
array.slice<float>(subsets);
329+
330+
// Result: all rows in reverse order, column 5
331+
REQUIRE(array.rank() == 1);
332+
REQUIRE(array.size() == rows);
333+
REQUIRE(array.shape() == std::vector<size_t>{rows});
334+
335+
auto result = array.to_vector<float>();
336+
std::vector<float> expected = {905.0, 805.0, 705.0, 605.0, 505.0, 405.0, 305.0, 205.0, 105.0, 5.0};
337+
REQUIRE(result == expected);
338+
}
268339
}
269340

270341
TEST_CASE("Subset 3D to 2D array", "[subset]")
@@ -874,4 +945,54 @@ TEST_CASE("Subset validation", "[subset]")
874945
// [:][:][1:5:-1] - stop > start with negative stride should fail
875946
REQUIRE_THROWS(parse_slices("[:][:][1:5:-1]", array.shape()));
876947
}
948+
949+
SECTION("Direct SubsetInfo validation - boundary cases")
950+
{
951+
// Test m_start and m_stop boundary conditions
952+
constexpr size_t dim_size = 10;
953+
954+
// Valid: m_start=0, m_stop=10 (full range, positive stride)
955+
SubsetInfo valid1(0, 10, 1, dim_size);
956+
REQUIRE(valid1.validate());
957+
958+
// Valid: m_start=9, m_stop=0 (negative stride)
959+
SubsetInfo valid2(9, 0, -1, dim_size);
960+
REQUIRE(valid2.validate());
961+
962+
// Valid: m_start=9, omitted stop (negative stride, goes to beginning)
963+
SubsetInfo valid3(9, std::nullopt, -1, dim_size);
964+
REQUIRE(valid3.validate());
965+
966+
// Invalid: m_start=10 (out of bounds, equals dim_size)
967+
SubsetInfo invalid1(10, 10, 1, dim_size);
968+
REQUIRE_FALSE(invalid1.validate());
969+
970+
// Valid: m_start=-1 with positive stride (normalizes to 9, but 9>5 so actually invalid)
971+
SubsetInfo invalid2(-1, 5, 1, dim_size);
972+
REQUIRE_FALSE(invalid2.validate()); // After normalization: start=9, stop=5, 9>5 with positive stride = invalid
973+
974+
// Invalid: m_stop=11 (out of bounds)
975+
SubsetInfo invalid3(0, 11, 1, dim_size);
976+
REQUIRE_FALSE(invalid3.validate());
977+
978+
// Invalid: m_start > m_stop with positive stride
979+
SubsetInfo invalid4(7, 3, 1, dim_size);
980+
REQUIRE_FALSE(invalid4.validate());
981+
982+
// Invalid: m_stop > m_start with negative stride
983+
SubsetInfo invalid5(3, 7, -1, dim_size);
984+
REQUIRE_FALSE(invalid5.validate());
985+
}
986+
987+
SECTION("Test omitted stop for negative strides going to beginning")
988+
{
989+
constexpr size_t dim_size = 10;
990+
991+
// Test that omitted stop is correctly handled for negative strides (go to beginning)
992+
SubsetInfo subset1(9, std::nullopt, -1, dim_size);
993+
REQUIRE(subset1.validate());
994+
REQUIRE(subset1.start() == 9);
995+
REQUIRE(subset1.stop() == std::numeric_limits<uint64_t>::max()); // Omitted stop with negative stride uses UINT64_MAX sentinel
996+
REQUIRE(subset1.size() == 10); // Should include all elements from 9 down to 0
997+
}
877998
}

test/src/typed_data_array_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ TEST_CASE("Test array slice")
154154
constexpr size_t element1 = 5;
155155
std::vector<libtokamap::SubsetInfo> subsets = {
156156
libtokamap::SubsetInfo{element1, element1 + 1, 1, dim1},
157-
libtokamap::SubsetInfo{0, -1, 1, dim2},
157+
libtokamap::SubsetInfo{0, std::nullopt, 1, dim2},
158158
};
159159
array.slice<float>(subsets);
160160

@@ -176,7 +176,7 @@ TEST_CASE("Test array slice")
176176
constexpr size_t stride = 1;
177177
std::vector<libtokamap::SubsetInfo> subsets = {
178178
libtokamap::SubsetInfo{start, stop, stride, dim1},
179-
libtokamap::SubsetInfo{0, -1, 1, dim2},
179+
libtokamap::SubsetInfo{0, std::nullopt, 1, dim2},
180180
};
181181
array.slice<float>(subsets);
182182

@@ -197,7 +197,7 @@ TEST_CASE("Test array slice")
197197
constexpr size_t stop = start + range_len;
198198
constexpr size_t stride = 1;
199199
std::vector<libtokamap::SubsetInfo> subsets = {
200-
libtokamap::SubsetInfo{0, -1, 1, dim1},
200+
libtokamap::SubsetInfo{0, std::nullopt, 1, dim1},
201201
libtokamap::SubsetInfo{start, stop, stride, dim2},
202202
};
203203
array.slice<float>(subsets);

0 commit comments

Comments
 (0)