|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
| 5 | +import pytest |
5 | 6 | import torch |
6 | 7 |
|
7 | 8 | import torchio as tio |
@@ -194,3 +195,104 @@ def test_batch_copy_preserves_original(self) -> None: |
194 | 195 | tio.Noise(std=1.0)(batch) |
195 | 196 | # Original should be unchanged (copy=True default) |
196 | 197 | torch.testing.assert_close(batch.data, original) |
| 198 | + |
| 199 | + |
| 200 | +# ── Coverage gap tests ─────────────────────────────────────────────── |
| 201 | + |
| 202 | + |
| 203 | +class TestImagesBatchValidation: |
| 204 | + def test_non_5d_raises(self) -> None: |
| 205 | + from torchio.data.batch import ImagesBatch |
| 206 | + |
| 207 | + with pytest.raises(ValueError, match="5"): |
| 208 | + ImagesBatch( |
| 209 | + data=torch.rand(1, 10, 10), |
| 210 | + affines=[tio.AffineMatrix()], |
| 211 | + image_class=tio.ScalarImage, |
| 212 | + ) |
| 213 | + |
| 214 | + def test_affine_count_mismatch_raises(self) -> None: |
| 215 | + from torchio.data.batch import ImagesBatch |
| 216 | + |
| 217 | + with pytest.raises(ValueError, match="affines"): |
| 218 | + ImagesBatch( |
| 219 | + data=torch.rand(2, 1, 5, 5, 5), |
| 220 | + affines=[tio.AffineMatrix()], # only 1 for batch of 2 |
| 221 | + image_class=tio.ScalarImage, |
| 222 | + ) |
| 223 | + |
| 224 | + def test_from_images_empty_raises(self) -> None: |
| 225 | + from torchio.data.batch import ImagesBatch |
| 226 | + |
| 227 | + with pytest.raises(ValueError, match="empty"): |
| 228 | + ImagesBatch.from_images([]) |
| 229 | + |
| 230 | + def test_data_setter_non_5d_raises(self) -> None: |
| 231 | + from torchio.data.batch import ImagesBatch |
| 232 | + |
| 233 | + batch = ImagesBatch( |
| 234 | + data=torch.rand(1, 1, 5, 5, 5), |
| 235 | + affines=[tio.AffineMatrix()], |
| 236 | + image_class=tio.ScalarImage, |
| 237 | + ) |
| 238 | + with pytest.raises(ValueError, match="5"): |
| 239 | + batch.data = torch.rand(1, 5, 5) |
| 240 | + |
| 241 | + def test_device_property(self) -> None: |
| 242 | + from torchio.data.batch import ImagesBatch |
| 243 | + |
| 244 | + batch = ImagesBatch( |
| 245 | + data=torch.rand(1, 1, 5, 5, 5), |
| 246 | + affines=[tio.AffineMatrix()], |
| 247 | + image_class=tio.ScalarImage, |
| 248 | + ) |
| 249 | + assert batch.device.type == "cpu" |
| 250 | + |
| 251 | + def test_len(self) -> None: |
| 252 | + from torchio.data.batch import ImagesBatch |
| 253 | + |
| 254 | + batch = ImagesBatch( |
| 255 | + data=torch.rand(3, 1, 5, 5, 5), |
| 256 | + affines=[tio.AffineMatrix() for _ in range(3)], |
| 257 | + image_class=tio.ScalarImage, |
| 258 | + ) |
| 259 | + assert len(batch) == 3 |
| 260 | + |
| 261 | + |
| 262 | +class TestSubjectsBatchEdgeCases: |
| 263 | + def test_from_subjects_empty_raises(self) -> None: |
| 264 | + from torchio.data.batch import SubjectsBatch |
| 265 | + |
| 266 | + with pytest.raises(ValueError, match="empty"): |
| 267 | + SubjectsBatch.from_subjects([]) |
| 268 | + |
| 269 | + def test_device_property(self) -> None: |
| 270 | + subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 5, 5, 5))) |
| 271 | + from torchio.data.batch import SubjectsBatch |
| 272 | + |
| 273 | + batch = SubjectsBatch.from_subjects([subject]) |
| 274 | + assert batch.device.type == "cpu" |
| 275 | + |
| 276 | + def test_getattr_invalid_raises(self) -> None: |
| 277 | + subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 5, 5, 5))) |
| 278 | + from torchio.data.batch import SubjectsBatch |
| 279 | + |
| 280 | + batch = SubjectsBatch.from_subjects([subject]) |
| 281 | + with pytest.raises(AttributeError): |
| 282 | + _ = batch.nonexistent_image |
| 283 | + |
| 284 | + def test_len(self) -> None: |
| 285 | + subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 5, 5, 5))) |
| 286 | + from torchio.data.batch import SubjectsBatch |
| 287 | + |
| 288 | + batch = SubjectsBatch.from_subjects([subject]) |
| 289 | + assert len(batch) == 1 |
| 290 | + |
| 291 | + def test_repr(self) -> None: |
| 292 | + subject = tio.Subject(t1=tio.ScalarImage(torch.rand(1, 5, 5, 5))) |
| 293 | + from torchio.data.batch import SubjectsBatch |
| 294 | + |
| 295 | + batch = SubjectsBatch.from_subjects([subject]) |
| 296 | + r = repr(batch) |
| 297 | + assert "SubjectsBatch" in r |
| 298 | + assert "t1" in r |
0 commit comments