-
Notifications
You must be signed in to change notification settings - Fork 1.2k
/
Copy pathtest_gibbs_noise.py
74 lines (59 loc) · 2.65 KB
/
test_gibbs_noise.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import unittest
from copy import deepcopy
import numpy as np
from parameterized import parameterized
from monai.data.synthetic import create_test_image_2d, create_test_image_3d
from monai.transforms import GibbsNoise
from monai.utils.misc import set_determinism
from monai.utils.module import optional_import
from tests.test_utils import TEST_NDARRAYS, assert_allclose, dict_product
_, has_torch_fft = optional_import("torch.fft", name="fftshift")
shapes = ((128, 64), (64, 48, 80))
input_types = TEST_NDARRAYS if has_torch_fft else [np.array]
TEST_CASES = [[p_dict["shape"], p_dict["input_type"]] for p_dict in dict_product(shape=shapes, input_type=input_types)]
class TestGibbsNoise(unittest.TestCase):
def setUp(self):
set_determinism(0)
super().setUp()
def tearDown(self):
set_determinism(None)
@staticmethod
def get_data(im_shape, input_type):
create_test_image = create_test_image_2d if len(im_shape) == 2 else create_test_image_3d
im = create_test_image(*im_shape, num_objs=4, rad_max=20, noise_max=0.0, num_seg_classes=5)[0][None]
return input_type(im)
@parameterized.expand(TEST_CASES)
def test_same_result(self, im_shape, input_type):
im = self.get_data(im_shape, input_type)
alpha = 0.8
t = GibbsNoise(alpha)
out1 = t(deepcopy(im))
out2 = t(deepcopy(im))
assert_allclose(out1, out2, rtol=1e-7, atol=0, type_test="tensor")
@parameterized.expand(TEST_CASES)
def test_identity(self, im_shape, input_type):
im = self.get_data(im_shape, input_type)
alpha = 0.0
t = GibbsNoise(alpha)
out = t(deepcopy(im))
assert_allclose(out, im, atol=1e-2, rtol=1e-7, type_test="tensor")
@parameterized.expand(TEST_CASES)
def test_alpha_1(self, im_shape, input_type):
im = self.get_data(im_shape, input_type)
alpha = 1.0
t = GibbsNoise(alpha)
out = t(deepcopy(im))
assert_allclose(out, 0 * im, rtol=1e-7, atol=0, type_test="tensor")
if __name__ == "__main__":
unittest.main()