Skip to content

Commit 6581e2b

Browse files
author
pjanowsk
committed
Add remaining tests
1 parent a57c97f commit 6581e2b

4 files changed

Lines changed: 902 additions & 0 deletions

File tree

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
import pytest
2+
3+
import cosipy
4+
if not cosipy.with_ml:
5+
pytest.skip(reason="Optional [ml] dependencies not installed", allow_module_level=True)
6+
7+
import numpy as np
8+
import torch
9+
from unittest.mock import MagicMock
10+
11+
from astropy import units as u
12+
from astropy.time import Time
13+
14+
from cosipy.data_io.EmCDSUnbinnedData import TimeTagEmCDSEventInSCFrameInterface
15+
from cosipy.background_estimation.ml.NFBackground import NFBackground
16+
from cosipy.interfaces.data_interface import TimeTagEmCDSEventDataInSCFrameInterface
17+
from cosipy.interfaces.event import EventInterface
18+
19+
from cosipy.background_estimation.ml.nf_unbinned_background import FreeNormNFUnbinnedBackground
20+
21+
22+
@pytest.fixture
23+
def mock_sc_history():
24+
"""Provides a realistic SpacecraftHistory mock using real Astropy units/times."""
25+
sc_history = MagicMock()
26+
27+
obstime = Time([1000.0, 1010.0, 1020.0], format='unix')
28+
sc_history.obstime = obstime
29+
sc_history.tstart = obstime[0]
30+
sc_history.tstop = obstime[-1]
31+
32+
sc_history.livetime = [9.0, 8.5] * u.s
33+
sc_history.intervals_duration = [10.0, 10.0] * u.s
34+
35+
sc_history.cumulative_livetime.return_value = 17.5 * u.s
36+
37+
return sc_history
38+
39+
@pytest.fixture
40+
def mock_data():
41+
"""Provides a realistic Data mock using real Astropy times."""
42+
data = MagicMock(spec=TimeTagEmCDSEventDataInSCFrameInterface)
43+
44+
data.energy_keV = [500.0, 1000.0]
45+
data.scattering_angle_rad = [0.5, 1.0]
46+
data.scattered_lon_rad_sc = [0.1, 0.2]
47+
data.scattered_lat_rad_sc = [0.3, 0.4]
48+
49+
data.time = Time([1005.0, 1015.0], format='unix')
50+
51+
return data
52+
53+
@pytest.fixture
54+
def mock_model():
55+
"""Provides a mocked NFBackground model that returns predictable tensors."""
56+
model = MagicMock(spec=NFBackground)
57+
58+
model.evaluate_rate.return_value = torch.tensor([2.0, 3.0])
59+
model.evaluate_density.return_value = torch.tensor([0.5, 0.6])
60+
61+
model.active_pool = True
62+
return model
63+
64+
65+
@pytest.fixture
66+
def background_instance(mock_model, mock_data, mock_sc_history):
67+
"""Instantiates the background class with all dependencies."""
68+
return FreeNormNFUnbinnedBackground(
69+
model=mock_model,
70+
data=mock_data,
71+
sc_history=mock_sc_history,
72+
label="test_bkg_norm"
73+
)
74+
75+
76+
class TestFreeNormNFUnbinnedBackground:
77+
78+
def test_init_and_properties(self, background_instance):
79+
"""Test initial state, event type, and parameters property."""
80+
assert background_instance.event_type == TimeTagEmCDSEventInSCFrameInterface
81+
assert background_instance.offset == 1e-12
82+
assert background_instance._label == "test_bkg_norm"
83+
assert background_instance._accum_livetime == 17.5
84+
85+
params = background_instance.parameters
86+
assert "test_bkg_norm" in params
87+
assert isinstance(params["test_bkg_norm"], u.Quantity)
88+
89+
def test_offset_validation(self, background_instance):
90+
"""Test the offset setter logic, including negative value guardrails."""
91+
background_instance.offset = 0.0
92+
assert background_instance.offset == 0.0
93+
94+
background_instance.offset = None
95+
assert background_instance.offset is None
96+
97+
with pytest.raises(ValueError, match="The offset cannot be negative."):
98+
background_instance.offset = -1.0
99+
100+
def test_integrate_rate(self, background_instance, mock_model):
101+
"""Test that the integration correctly multiplies rates by livetime."""
102+
expected_counts = background_instance._integrate_rate()
103+
104+
assert expected_counts == 43.5
105+
mock_model.evaluate_rate.assert_called_once()
106+
107+
passed_times = mock_model.evaluate_rate.call_args[0][0]
108+
assert torch.allclose(passed_times, torch.tensor([[1005.0], [1015.0]], dtype=torch.float64))
109+
110+
def test_compute_density(self, background_instance, mock_model):
111+
"""Test the complex density computation and bin mapping logic."""
112+
densities = background_instance._compute_density()
113+
114+
assert isinstance(densities, np.ndarray)
115+
assert densities.dtype == np.float64
116+
np.testing.assert_allclose(densities, [0.9, 1.53])
117+
118+
passed_source = mock_model.evaluate_density.call_args[0][1]
119+
assert passed_source.shape == (2, 4)
120+
assert torch.allclose(passed_source[0, 3], torch.tensor(np.pi/2 - 0.3, dtype=torch.float32))
121+
122+
def test_compute_density_time_bounds_error(self, background_instance, mock_data):
123+
"""Ensure evaluating events outside the spacecraft history raises an error."""
124+
mock_data.time = Time([900.0, 1050.0], format='unix')
125+
126+
with pytest.raises(ValueError, match="Input times are outside the spacecraft history range"):
127+
background_instance._compute_density()
128+
129+
def test_norm_setter_and_getter(self, background_instance):
130+
"""Test scaling logic when norm is manipulated."""
131+
132+
initial_norm_qty = background_instance.norm
133+
assert initial_norm_qty.unit == u.Hz
134+
np.testing.assert_allclose(initial_norm_qty.value, 43.5 / 17.5)
135+
136+
background_instance.norm = 5.0 * u.Hz
137+
138+
np.testing.assert_allclose(background_instance._norm, 5.0 * (17.5 / 43.5))
139+
140+
background_instance.set_parameters(test_bkg_norm=10.0 * u.Hz)
141+
assert background_instance.norm.value == 10.0
142+
143+
def test_expected_counts_method(self, background_instance):
144+
"""Test that expected counts scale by the internal norm."""
145+
assert background_instance.expected_counts() == 43.5
146+
147+
background_instance._norm = 2.0
148+
assert background_instance.expected_counts() == 87.0
149+
150+
def test_expectation_density_method(self, background_instance):
151+
"""Test final density calculations, including norm scaling and offset addition."""
152+
153+
res1 = background_instance.expectation_density()
154+
np.testing.assert_allclose(res1, [0.9 + 1e-12, 1.53 + 1e-12])
155+
156+
background_instance._norm = 2.0
157+
background_instance.offset = None
158+
res2 = background_instance.expectation_density()
159+
np.testing.assert_allclose(res2, [1.8, 3.06])
160+
161+
def test_compute_pool_management(self, background_instance, mock_model):
162+
"""Ensure lazy evaluation accurately initializes and shutdowns compute pools."""
163+
mock_model.active_pool = False
164+
165+
background_instance.expectation_density()
166+
167+
mock_model.init_compute_pool.assert_called_once()
168+
mock_model.shutdown_compute_pool.assert_called_once()
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
import pytest
2+
3+
import cosipy
4+
if not cosipy.with_ml:
5+
pytest.skip(reason="Optional [ml] dependencies not installed", allow_module_level=True)
6+
7+
import numpy as np
8+
import torch
9+
from unittest.mock import MagicMock, patch
10+
11+
from cosipy.background_estimation.ml.NFBackgroundModels import (
12+
TotalDC4BackgroundRate,
13+
TotalBackgroundDensityCMLPDGaussianCARQSFlow
14+
)
15+
16+
@pytest.fixture
17+
def dummy_rate_input():
18+
return {
19+
"slew_duration": 10.0,
20+
"obs_duration": 50.0,
21+
"start_time": 1000.0,
22+
"offset": 5.0,
23+
"slope": 0.1,
24+
"buildup": ((1.0, 2.0), (10.0, 20.0)),
25+
"scale": 0.5,
26+
"cutoff": (90.0, (1.0, 1.0, 1.0), (0.1, 0.1, 0.1), (0.0, 30.0, 60.0)),
27+
"outlocs": torch.tensor([500.0, 900.0, 1500.0]),
28+
"saa_decay": ((2.0, 3.0), (15.0, 30.0))
29+
}
30+
31+
@pytest.fixture
32+
def dummy_density_input():
33+
return {
34+
"model_state_dict": {},
35+
"bins": 8,
36+
"hidden_units": 64,
37+
"residual_blocks": 2,
38+
"total_layers": 3,
39+
"context_size": 5,
40+
"mlp_hidden_units": 32,
41+
"mlp_hidden_layers": 2,
42+
"menergy_cuts": (100.0, 10000.0),
43+
"phi_cuts": (0.0, np.pi),
44+
"start_time": 1000.0,
45+
"total_time": 10000.0,
46+
"period": 5400.0,
47+
"slew_duration": 600.0,
48+
"obs_duration": 3000.0,
49+
"outlocs": torch.tensor([500.0, 900.0, 1500.0])
50+
}
51+
52+
class TestTotalDC4BackgroundRate:
53+
54+
def test_context_dim_property(self, dummy_rate_input):
55+
model = TotalDC4BackgroundRate(dummy_rate_input)
56+
assert model.context_dim == 1
57+
58+
def test_unpack_rate_input(self, dummy_rate_input):
59+
"""Ensure all elements from the dictionary are mapped to the correct instance variables."""
60+
model = TotalDC4BackgroundRate(dummy_rate_input)
61+
62+
assert model._slew_duration == 10.0
63+
assert model._offset == 5.0
64+
assert model._buildup_A == (1.0, 2.0)
65+
assert model._cutoff_mu == (0.0, 30.0, 60.0)
66+
assert torch.allclose(model._outlocs, torch.tensor([500.0, 900.0, 1500.0]))
67+
68+
def test_static_math_methods(self):
69+
"""Test the pure mathematical equations for buildup and decay."""
70+
t = torch.tensor([10.0])
71+
72+
buildup_res = TotalDC4BackgroundRate._buildup(t, A=4.0, T=10.0)
73+
assert torch.allclose(buildup_res, torch.tensor([2.0]))
74+
75+
decay_res = TotalDC4BackgroundRate._decay(t, A=4.0, T=10.0)
76+
assert torch.allclose(decay_res, torch.tensor([2.0]))
77+
78+
vm_res = TotalDC4BackgroundRate._von_mises(torch.tensor([0.0]), T=10.0, A=2.0, kappa=1.0, mu=0.0)
79+
assert torch.allclose(vm_res, torch.tensor([2.0 * np.exp(1.0)], dtype=torch.float32))
80+
81+
def test_pointing_scale(self, dummy_rate_input):
82+
"""Test the sigmoid boundary logic inside pointing scale."""
83+
model = TotalDC4BackgroundRate(dummy_rate_input)
84+
85+
res = model._pointing_scale(torch.tensor([0.0, 60.0, 120.0]), scale=0.5, k0=10.0)
86+
assert res.shape == (3,)
87+
assert np.isclose(res[0], res[2])
88+
89+
def test_saa_decay(self, dummy_rate_input):
90+
"""Test SAA decay correctly identifies the proper last exit time using searchsorted."""
91+
model = TotalDC4BackgroundRate(dummy_rate_input)
92+
93+
time_mins = torch.tensor([0.0, 10.0])
94+
95+
decay = model._saa_decay(time_mins, A=(2.0, 3.0), T=(15.0, 30.0))
96+
assert decay.shape == (2,)
97+
assert np.all(np.isclose(decay,
98+
model._decay(time_mins - torch.tensor([-100/60, 500/60]), 2.0, 15.0) +
99+
model._decay(time_mins - torch.tensor([-100/60, 500/60]), 3.0, 30.0)))
100+
assert torch.all(decay > 0)
101+
102+
def test_evaluate_rate(self, dummy_rate_input):
103+
"""Test the full aggregation method."""
104+
model = TotalDC4BackgroundRate(dummy_rate_input)
105+
106+
abs_times = torch.tensor([1000.0, 1060.0, 1120.0])
107+
rates = model.evaluate_rate(abs_times)
108+
109+
assert rates.shape == (3,)
110+
assert rates.dtype == torch.float32
111+
112+
class TestTotalBackgroundDensity:
113+
114+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.NNDensityInferenceWrapper')
115+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.build_c_arqs_flow')
116+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.build_cmlp_diaggaussian_base')
117+
def test_init_and_properties(self, mock_base_builder, mock_flow_builder, mock_wrapper, dummy_density_input):
118+
"""Test that the flow builds correctly from dict parameters and properties read out correctly."""
119+
model = TotalBackgroundDensityCMLPDGaussianCARQSFlow(
120+
density_input=dummy_density_input,
121+
worker_device="cpu",
122+
batch_size=128,
123+
compile_mode=None
124+
)
125+
126+
assert model.context_dim == 1
127+
assert model.source_dim == 4
128+
129+
assert model._menergy_cuts == (100.0, 10000.0)
130+
assert model._total_time == 10000.0
131+
132+
mock_base_builder.assert_called_once()
133+
mock_flow_builder.assert_called_once()
134+
mock_wrapper.assert_called_once()
135+
136+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.TotalBackgroundDensityCMLPDGaussianCARQSFlow._load_model', return_value=None)
137+
def test_inverse_transform_coordinates(self, mock_load, dummy_density_input):
138+
"""Test the physics to normalized-coordinate inverse mappings."""
139+
model = TotalBackgroundDensityCMLPDGaussianCARQSFlow(dummy_density_input, "cpu", 128)
140+
141+
nem = torch.tensor([0.0])
142+
nphi = torch.tensor([0.5])
143+
npsi = torch.tensor([0.25])
144+
nchi = torch.tensor([0.5])
145+
dummy = torch.tensor([0.0])
146+
147+
res = model._inverse_transform_coordinates(nem, nphi, npsi, nchi, dummy)
148+
149+
assert res.shape == (1, 4)
150+
np.testing.assert_allclose(res[0, 0].item(), 100.0)
151+
np.testing.assert_allclose(res[0, 1].item(), np.pi / 2)
152+
np.testing.assert_allclose(res[0, 2].item(), np.pi / 2)
153+
np.testing.assert_allclose(res[0, 3].item(), np.pi / 2)
154+
155+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.TotalBackgroundDensityCMLPDGaussianCARQSFlow._load_model', return_value=None)
156+
def test_transform_coordinates(self, mock_load, dummy_density_input):
157+
"""Test calculation of transformed context, source, and jacobian."""
158+
model = TotalBackgroundDensityCMLPDGaussianCARQSFlow(dummy_density_input, "cpu", 128)
159+
160+
time = torch.tensor([1000.0])
161+
em = torch.tensor([1000.0])
162+
phi = torch.tensor([np.pi])
163+
scatt_az = torch.tensor([np.pi])
164+
scatt_pol = torch.tensor([np.pi/2])
165+
166+
ctx, src, jac = model._transform_coordinates(time, em, phi, scatt_az, scatt_pol)
167+
168+
assert ctx.shape == (1, 5)
169+
assert src.shape == (1, 4)
170+
assert jac.shape == (1,)
171+
172+
np.testing.assert_allclose(src[0, 0].item(), 0.5)
173+
174+
@patch('cosipy.background_estimation.ml.NFBackgroundModels.TotalBackgroundDensityCMLPDGaussianCARQSFlow._load_model', return_value=None)
175+
def test_valid_samples(self, mock_load, dummy_density_input):
176+
"""Test the logical masking bounds for validation checks."""
177+
model = TotalBackgroundDensityCMLPDGaussianCARQSFlow(dummy_density_input, "cpu", 128)
178+
179+
nem = torch.tensor([0.0, -1.0, 0.5])
180+
nphi = torch.tensor([0.5, 1.5, 0.5])
181+
npsi = torch.tensor([0.5, 0.5, -0.1])
182+
nchi = torch.tensor([0.5, 0.5, 0.5])
183+
dummy = torch.tensor([0.0, 0.0, 0.0])
184+
185+
mask = model._valid_samples(nem, nphi, npsi, nchi, dummy)
186+
187+
assert mask.tolist() == [True, False, False]

0 commit comments

Comments
 (0)