Skip to content

Commit 2cdd81a

Browse files
author
pjanowsk
committed
Add the next two tests.
1 parent 287e065 commit 2cdd81a

2 files changed

Lines changed: 263 additions & 0 deletions

File tree

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import cosipy
2+
import pytest
3+
4+
if not cosipy.with_ml:
5+
pytest.skip(reason="Optional [ml] dependencies not installed", allow_module_level=True)
6+
7+
from cosipy.response.ml import NFWorkerState
8+
9+
def test_nf_worker_state_initialization():
10+
"""
11+
Sanity check to ensure NFWorkerState is importable and
12+
variables are initialized to None as expected for coverage.
13+
"""
14+
assert NFWorkerState.worker_device is None
15+
assert NFWorkerState.density_module is None
16+
assert NFWorkerState.area_module is None
17+
assert NFWorkerState.progress_queue is None
18+
19+
def test_nf_worker_state_settable():
20+
"""
21+
Verify that the variables can be updated.
22+
"""
23+
NFWorkerState.worker_device = "cpu"
24+
assert NFWorkerState.worker_device == "cpu"
25+
26+
NFWorkerState.worker_device = None
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import cosipy
2+
import pytest
3+
4+
from unittest.mock import MagicMock, patch
5+
6+
from cosipy.interfaces import UnbinnedThreeMLSourceResponseInterface
7+
from typing import Iterable, Type
8+
from astromodels.sources import Source
9+
from cosipy.interfaces.event import TimeTagEmCDSEventInSCFrameInterface, EmCDSEventInSCFrameInterface
10+
from cosipy.interfaces import EventInterface
11+
from astromodels import PointSource, ExtendedSource
12+
from cosipy import test_data
13+
import shutil
14+
import numpy as np
15+
16+
from cosipy.threeml.unbinned_model_folding import (
17+
UnbinnedThreeMLModelFolding,
18+
CachedUnbinnedThreeMLModelFolding
19+
)
20+
21+
data_path = test_data.path
22+
23+
class MockResponse(UnbinnedThreeMLSourceResponseInterface):
24+
"""Simulates a source response."""
25+
def __init__(self, counts=10.0, density=None, event_type=TimeTagEmCDSEventInSCFrameInterface):
26+
self._counts = counts
27+
self._density = density if density is not None else [1.0, 1.0, 1.0]
28+
self._event_type = event_type
29+
self.source_set = None
30+
31+
def set_source(self, source):
32+
self.source_set = source
33+
34+
def copy(self):
35+
return MockResponse(self._counts, self._density, self._event_type)
36+
37+
def expected_counts(self) -> float:
38+
return self._counts
39+
40+
def expectation_density(self) -> Iterable[float]:
41+
return self._density
42+
43+
@property
44+
def event_type(self) -> Type[EventInterface]:
45+
return self._event_type
46+
47+
class MockCachedResponse(MockResponse):
48+
"""Simulates a response that supports caching to disk."""
49+
def __init__(self, **kwargs):
50+
super().__init__(**kwargs)
51+
self.init_called = False
52+
self.saved_path = None
53+
self.loaded_path = None
54+
55+
def init_cache(self):
56+
self.init_called = True
57+
58+
def cache_to_file(self, path):
59+
self.saved_path = path
60+
61+
def cache_from_file(self, path):
62+
self.loaded_path = path
63+
64+
def test_folding_init_event_type_mismatch():
65+
"""Verify that inconsistent event types raise a RuntimeError."""
66+
psr = MockResponse(counts = 5.0, density = [1.0, 2.0], event_type=TimeTagEmCDSEventInSCFrameInterface)
67+
esr = MockResponse(counts = 5.0, density = [1.0, 2.0], event_type=EmCDSEventInSCFrameInterface)
68+
69+
with pytest.raises(RuntimeError):
70+
UnbinnedThreeMLModelFolding(point_source_response=psr, extended_source_response=esr)
71+
72+
def test_cache_source_responses_no_model():
73+
"""Ensure RuntimeError if expected_counts is called before set_model."""
74+
folding = UnbinnedThreeMLModelFolding(point_source_response=MockResponse())
75+
with pytest.raises(RuntimeError):
76+
folding.expected_counts()
77+
78+
def test_cache_source_responses_logic():
79+
"""Test the full lifecycle of the Mixin: mapping sources to responses."""
80+
mock_model = MagicMock()
81+
mock_source = MagicMock(spec=PointSource)
82+
mock_model.sources = {"src1": mock_source}
83+
mock_model.to_dict.return_value = {"src1": "params_v1"}
84+
85+
psr = MockResponse(counts=10.0)
86+
folding = UnbinnedThreeMLModelFolding(point_source_response=psr)
87+
folding.set_model(mock_model)
88+
89+
assert folding.expected_counts() == 10.0
90+
assert "src1" in folding._source_responses
91+
assert folding._source_responses["src1"].source_set == mock_source
92+
93+
assert folding._cache_source_responses() is False
94+
95+
mock_model.to_dict.return_value = {"src1": "params_v2"}
96+
assert folding._cache_source_responses() is True
97+
98+
def test_mixin_missing_response_errors():
99+
"""Verify errors when model has a source type but the folding lacks the response."""
100+
mock_model = MagicMock()
101+
mock_model.sources = {"ext": MagicMock(spec=ExtendedSource)}
102+
mock_model.to_dict.return_value = {"ext": "data"}
103+
104+
folding = UnbinnedThreeMLModelFolding(point_source_response=MockResponse())
105+
folding.set_model(mock_model)
106+
107+
with pytest.raises(RuntimeError):
108+
folding.expected_counts()
109+
110+
def test_expectation_density_with_batching():
111+
"""Test the batching generator path in UnbinnedThreeMLModelFolding."""
112+
def gen_density():
113+
yield from [1.0, 2.0, 3.0, 4.0]
114+
115+
mock_model = MagicMock()
116+
mock_model.sources = {"s1": MagicMock(spec=PointSource)}
117+
mock_model.to_dict.return_value = {"s1": "v1"}
118+
119+
psr = MockResponse(density=gen_density())
120+
folding = UnbinnedThreeMLModelFolding(point_source_response=psr, batch_size=2)
121+
folding.set_model(mock_model)
122+
123+
result = list(folding.expectation_density())
124+
assert result == [1.0, 2.0, 3.0, 4.0]
125+
assert folding.event_type == TimeTagEmCDSEventInSCFrameInterface
126+
127+
def test_expectation_density_empty_model():
128+
"""Verify that a model with no sources returns an empty iterable."""
129+
folding = UnbinnedThreeMLModelFolding(point_source_response=MockResponse())
130+
131+
mock_model = MagicMock()
132+
mock_model.sources = {}
133+
mock_model.to_dict.return_value = {}
134+
folding.set_model(mock_model)
135+
136+
result = folding.expectation_density()
137+
138+
assert list(result) == []
139+
140+
def test_expectation_density_fast_track_multi_source():
141+
"""Test the 'fast path' where we sum multiple sources that have __len__."""
142+
s1_dens = np.array([1.0, 2.0, 3.0])
143+
s2_dens = np.array([0.5, 0.5, 0.5])
144+
145+
mock_model = MagicMock()
146+
mock_model.sources = {
147+
"src1": MagicMock(spec=PointSource),
148+
"src2": MagicMock(spec=PointSource)
149+
}
150+
mock_model.to_dict.return_value = {"src1": 1, "src2": 2}
151+
152+
psr = MockResponse(density=s1_dens)
153+
folding = UnbinnedThreeMLModelFolding(point_source_response=psr)
154+
folding.set_model(mock_model)
155+
156+
folding._cache_source_responses()
157+
folding._source_responses["src1"]._density = s1_dens
158+
folding._source_responses["src2"]._density = s2_dens
159+
160+
result = folding.expectation_density()
161+
162+
expected = np.array([1.5, 2.5, 3.5])
163+
np.testing.assert_allclose(result, expected)
164+
165+
def test_cached_folding_init_cache():
166+
"""Verify init_cache propagates to underlying responses."""
167+
res_a = MockCachedResponse()
168+
folding = CachedUnbinnedThreeMLModelFolding(point_source_response=res_a)
169+
170+
folding._source_responses = {"src_a": res_a}
171+
172+
with patch.object(folding, '_cache_source_responses'):
173+
folding.init_cache()
174+
assert res_a.init_called is True
175+
176+
def test_cached_folding_save_and_load_with_cleanup():
177+
"""
178+
Verify saving/loading logic using the library's test_data path.
179+
Ensures files are created, verified, and strictly cleaned up.
180+
"""
181+
output_dir = data_path / "temp_cache_test"
182+
183+
res_a = MockCachedResponse()
184+
res_b = MockCachedResponse()
185+
186+
folding = CachedUnbinnedThreeMLModelFolding(point_source_response=res_a)
187+
folding._source_responses = {"src_a": res_a, "src_b": res_b}
188+
189+
try:
190+
with patch.object(folding, '_cache_source_responses'):
191+
folding.save_caches(output_dir, cache_only=["src_a"])
192+
193+
expected_file_a = output_dir / "src_a_source_response_cache.h5"
194+
assert res_a.saved_path == expected_file_a
195+
assert res_b.saved_path is None
196+
197+
expected_file_a.touch()
198+
199+
folding.load_caches(output_dir, load_only=["src_a"])
200+
assert res_a.loaded_path == expected_file_a
201+
202+
folding.load_caches(output_dir, load_only=["src_b"])
203+
assert res_b.loaded_path is None
204+
205+
finally:
206+
if output_dir.exists():
207+
shutil.rmtree(output_dir)
208+
209+
def test_cached_folding_isinstance_branches():
210+
"""
211+
Targets the 'False' branch of isinstance(...) checks in
212+
init_cache, save_caches, and load_caches.
213+
"""
214+
output_dir = data_path / "branch_coverage_temp"
215+
216+
res_std = MockResponse()
217+
218+
folding = CachedUnbinnedThreeMLModelFolding(point_source_response=res_std)
219+
folding._source_responses = {"src_std": res_std}
220+
221+
try:
222+
with patch.object(folding, '_cache_source_responses'):
223+
folding.init_cache()
224+
225+
folding.save_caches(output_dir)
226+
dummy_file = output_dir / "src_std_source_response_cache.h5"
227+
assert not (dummy_file).exists()
228+
229+
output_dir.mkdir(parents=True, exist_ok=True)
230+
dummy_file.touch()
231+
232+
folding.load_caches(output_dir)
233+
234+
assert True
235+
finally:
236+
if output_dir.exists():
237+
shutil.rmtree(output_dir)

0 commit comments

Comments
 (0)