Skip to content

Commit 709304b

Browse files
committed
[python][UHI] Update the tests with the expected slicing logic
1 parent fc120a0 commit 709304b

File tree

1 file changed

+67
-30
lines changed

1 file changed

+67
-30
lines changed

bindings/pyroot/pythonizations/test/uhi_indexing.py

Lines changed: 67 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@
44

55
import pytest
66
import ROOT
7-
from ROOT._pythonization._uhi import (
8-
_get_axis,
9-
_get_processed_slices,
10-
_get_slice_indices,
11-
_shape,
12-
)
7+
from ROOT._pythonization._uhi import _get_axis, _get_processed_slices, _overflow, _shape, _underflow
138
from ROOT.uhi import loc, overflow, rebin, sum, underflow
149

1510

@@ -32,6 +27,14 @@ def _iterate_bins(hist):
3227
yield tuple(filter(None, (i, j, k)))
3328

3429

30+
def _get_slice_indices(slices):
31+
import numpy as np
32+
33+
ranges = [range(start, stop) for start, stop in slices]
34+
grids = np.meshgrid(*ranges, indexing="ij")
35+
return np.array(grids).reshape(len(slices), -1).T
36+
37+
3538
class TestTH1Indexing:
3639
def test_access_with_bin_number(self, hist_setup):
3740
for index in [0, 8]:
@@ -52,7 +55,7 @@ def test_access_flow_bins(self, hist_setup):
5255

5356
def test_access_with_len(self, hist_setup):
5457
len_indices = (len,) * hist_setup.GetDimension()
55-
bin_counts = (_get_axis(hist_setup, i).GetNbins() for i in range(hist_setup.GetDimension()))
58+
bin_counts = (_get_axis(hist_setup, i).GetNbins() + 1 for i in range(hist_setup.GetDimension()))
5659
assert hist_setup[len_indices] == hist_setup.GetBinContent(*bin_counts)
5760

5861
def test_access_with_ellipsis(self, hist_setup):
@@ -104,23 +107,48 @@ def test_setting_with_scalar(self, hist_setup):
104107
if _special_setting(hist_setup):
105108
pytest.skip("Setting cannot be tested here")
106109

110+
hist_setup.Reset()
107111
hist_setup[...] = 3
108112
for bin_indices in _iterate_bins(hist_setup):
109113
assert hist_setup.GetBinContent(*bin_indices) == 3
110114

115+
# Check that flow bins are not set
116+
for flow_type in [underflow, overflow]:
117+
flow_indices = (flow_type,) * hist_setup.GetDimension()
118+
assert hist_setup[flow_indices] == 0, (
119+
f"{hist_setup.values()}, {hist_setup[underflow]}, {hist_setup[overflow]}"
120+
)
121+
111122
def _test_slices_match(self, hist_setup, slice_ranges, processed_slices):
112123
dim = hist_setup.GetDimension()
113-
slices, _, _ = _get_processed_slices(hist_setup, processed_slices[dim])
124+
slices, _ = _get_processed_slices(hist_setup, processed_slices[dim])
114125
expected_indices = _get_slice_indices(slices)
115126
sliced_hist = hist_setup[tuple(slice_ranges[dim])]
116127

117128
for bin_indices in expected_indices:
118129
bin_indices = tuple(map(int, bin_indices))
119-
assert sliced_hist.GetBinContent(*bin_indices) == hist_setup.GetBinContent(*bin_indices)
130+
shifted_indices = []
131+
is_flow_bin = False
132+
for i, idx in enumerate(bin_indices):
133+
shift = slice_ranges[dim][i].start
134+
if callable(shift):
135+
shift = shift(hist_setup, i)
136+
elif shift is None:
137+
shift = 1
138+
else:
139+
shift += 1
120140

121-
for bin_indices in _iterate_bins(hist_setup):
122-
if list(bin_indices) not in expected_indices.tolist():
123-
assert sliced_hist.GetBinContent(*bin_indices) == 0
141+
shifted_idx = idx - shift + 1
142+
if shifted_idx <= 0 or shifted_idx == _overflow(hist_setup, i):
143+
is_flow_bin = True
144+
break
145+
146+
shifted_indices.append(shifted_idx)
147+
148+
if is_flow_bin:
149+
continue
150+
151+
assert sliced_hist.GetBinContent(*tuple(shifted_indices)) == hist_setup.GetBinContent(*bin_indices)
124152

125153
def test_slicing_with_endpoints(self, hist_setup):
126154
if _special_setting(hist_setup):
@@ -144,13 +172,13 @@ def test_slicing_without_endpoints(self, hist_setup):
144172

145173
processed_slices = {
146174
1: [slice(0, 8)],
147-
2: [slice(0, 8), slice(4, 11)],
148-
3: [slice(0, 8), slice(4, 11), slice(3, 6)],
175+
2: [slice(0, 8), slice(0, 8)],
176+
3: [slice(0, 8), slice(0, 8), slice(3, 6)],
149177
}
150178
slice_ranges = {
151179
1: [slice(None, 7)],
152-
2: [slice(None, 7), slice(3, None)],
153-
3: [slice(None, 7), slice(3, None), slice(2, 5)],
180+
2: [slice(None, 7), slice(None, 7)],
181+
3: [slice(None, 7), slice(None, 7), slice(2, 5)],
154182
}
155183
self._test_slices_match(hist_setup, slice_ranges, processed_slices)
156184

@@ -160,17 +188,17 @@ def test_slicing_with_data_coordinates(self, hist_setup):
160188

161189
processed_slices = {
162190
1: [slice(hist_setup.FindBin(2), 11)],
163-
2: [slice(hist_setup.FindBin(2), 11), slice(hist_setup.FindBin(3), 11)],
191+
2: [slice(hist_setup.FindBin(2) - 1, 11), slice(2, 11)],
164192
3: [
165193
slice(hist_setup.FindBin(2), 11),
166-
slice(hist_setup.FindBin(3), 11),
167-
slice(hist_setup.FindBin(1.5), 11),
194+
slice(2, 11),
195+
slice(2, 11),
168196
],
169197
}
170198
slice_ranges = {
171199
1: [slice(loc(2), None)],
172-
2: [slice(loc(2), None), slice(loc(3), None)],
173-
3: [slice(loc(2), None), slice(loc(3), None), slice(loc(1.5), None)],
200+
2: [slice(loc(2), None), slice(3, None)],
201+
3: [slice(loc(2), None), slice(3, None), slice(3, None)],
174202
}
175203
self._test_slices_match(hist_setup, slice_ranges, processed_slices)
176204

@@ -191,7 +219,10 @@ def test_slicing_over_everything_with_action_sum(self, hist_setup):
191219
dim = hist_setup.GetDimension()
192220

193221
if dim == 1:
194-
integral = hist_setup[::sum]
222+
full_integral = hist_setup[::sum]
223+
assert full_integral == hist_setup.Integral(_underflow(hist_setup, 0), _overflow(hist_setup, 0))
224+
225+
integral = hist_setup[0:len:sum]
195226
assert integral == hist_setup.Integral()
196227

197228
if dim == 2:
@@ -225,7 +256,7 @@ def test_slicing_with_action_rebin_and_sum(self, hist_setup):
225256
if dim == 1:
226257
sliced_hist_rebin = hist_setup[5 : 9 : rebin(2)]
227258
assert isinstance(sliced_hist_rebin, ROOT.TH1)
228-
assert sliced_hist_rebin.GetNbinsX() == hist_setup.GetNbinsX() // 2
259+
assert sliced_hist_rebin.GetNbinsX() == 2
229260

230261
sliced_hist_sum = hist_setup[5:9:sum]
231262
assert isinstance(sliced_hist_sum, float)
@@ -237,10 +268,10 @@ def test_slicing_with_action_rebin_and_sum(self, hist_setup):
237268
assert sliced_hist.GetNbinsX() == hist_setup.GetNbinsX() // 2
238269

239270
if dim == 3:
240-
sliced_hist = hist_setup[:: rebin(2), ::sum, 5 : 9 : rebin(3)]
271+
sliced_hist = hist_setup[:: rebin(2), ::sum, 3 : 9 : rebin(3)]
241272
assert isinstance(sliced_hist, ROOT.TH2)
242273
assert sliced_hist.GetNbinsX() == hist_setup.GetNbinsX() // 2
243-
assert sliced_hist.GetNbinsY() == hist_setup.GetNbinsZ() // 3
274+
assert sliced_hist.GetNbinsY() == 2
244275

245276
def test_slicing_with_dict_syntax(self, hist_setup):
246277
if _special_setting(hist_setup):
@@ -262,19 +293,25 @@ def test_integral_full_slice(self, hist_setup):
262293
assert hist_setup.Integral() == pytest.approx(sliced_hist.Integral(), rel=10e-6)
263294

264295
def test_statistics_slice(self, hist_setup):
265-
if _special_setting(hist_setup):
296+
if _special_setting(hist_setup) or isinstance(hist_setup, (ROOT.TH1C, ROOT.TH2C, ROOT.TH3C)):
266297
pytest.skip("Setting cannot be tested here")
267298

299+
# Check if slicing over everything preserves the statistics
300+
sliced_hist_full = hist_setup[...]
301+
302+
assert hist_setup.GetEffectiveEntries() == sliced_hist_full.GetEffectiveEntries()
303+
assert hist_setup.Integral() == sliced_hist_full.Integral()
304+
268305
# Check if slicing over a range updates the statistics
269306
dim = hist_setup.GetDimension()
270-
[_get_axis(hist_setup, i).SetRange(3, 5) for i in range(dim)]
307+
[_get_axis(hist_setup, i).SetRange(3, 7) for i in range(dim)]
271308
slice_indices = tuple(slice(2, 7) for _ in range(dim))
272309
sliced_hist = hist_setup[slice_indices]
273310

274-
assert hist_setup.Integral() == pytest.approx(sliced_hist.Integral(), rel=1e-6)
275-
assert hist_setup.GetMean() == pytest.approx(sliced_hist.GetMean(), abs=1e-3)
276-
assert hist_setup.GetStdDev() == pytest.approx(sliced_hist.GetStdDev(), abs=1e-3)
311+
assert hist_setup.Integral() == sliced_hist.Integral()
277312
assert hist_setup.GetEffectiveEntries() == sliced_hist.GetEffectiveEntries()
313+
assert hist_setup.GetStdDev() == pytest.approx(sliced_hist.GetStdDev(), rel=10e-5)
314+
assert hist_setup.GetMean() == pytest.approx(sliced_hist.GetMean(), rel=10e-5)
278315

279316

280317
if __name__ == "__main__":

0 commit comments

Comments
 (0)