Skip to content

Commit c8258de

Browse files
committed
Refactor test optimizers
1 parent 47bb88a commit c8258de

File tree

1 file changed

+13
-18
lines changed

1 file changed

+13
-18
lines changed

tests/test_optimizers.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,19 @@
22

33
from diffpy.snmf.optimizers import get_weights
44

5-
test_matrix = [
6-
# ([stretched_component_gram_matrix, linear_coefficient, lower_bound, upper_bound], expected)
7-
([[[1, 0], [0, 1]], [1, 1], 0, 0], [0, 0]),
8-
([[[1, 0], [0, 1]], [1, 1], -1, 1], [-1, -1]),
9-
([[[1.75, 0], [0, 1.5]], [1, 1.2], -1, 1], [-0.571428571428571, -0.8]),
10-
([[[0.75, 0.2], [0.2, 0.75]], [-0.1, -0.2], -1, 1], [0.066985645933014, 0.248803827751196]),
11-
([[[2, -1, 0], [-1, 2, -1], [0, -1, 2]], [1, 1, 1], -10, 12], [-1.5, -2, -1.5]),
12-
([[[2, -1, 0], [-1, 2, -1], [0, -1, 2]], [1, -1, -1], -10, 12], [0, 1, 1]),
13-
([[[4, 0, 0, 0], [0, 3, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], [-2, -3, -4, -1], 0, 1000], [0.5, 1, 2, 1]),
14-
]
155

16-
17-
@pytest.mark.parametrize("tm", test_matrix)
18-
def test_get_weights(tm):
19-
stretched_component_gram_matrix = tm[0][0]
20-
linear_coefficient = tm[0][1]
21-
lower_bound = tm[0][2]
22-
upper_bound = tm[0][3]
23-
expected = tm[1]
6+
@pytest.mark.parametrize(
7+
"stretched_component_gram_matrix, linear_coefficient, lower_bound, upper_bound, expected",
8+
[
9+
([[1, 0], [0, 1]], [1, 1], 0, 0, [0, 0]),
10+
([[1, 0], [0, 1]], [1, 1], -1, 1, [-1, -1]),
11+
([[1.75, 0], [0, 1.5]], [1, 1.2], -1, 1, [-0.571428571428571, -0.8]),
12+
([[0.75, 0.2], [0.2, 0.75]], [-0.1, -0.2], -1, 1, [0.066985645933014, 0.248803827751196]),
13+
([[2, -1, 0], [-1, 2, -1], [0, -1, 2]], [1, 1, 1], -10, 12, [-1.5, -2, -1.5]),
14+
([[2, -1, 0], [-1, 2, -1], [0, -1, 2]], [1, -1, -1], -10, 12, [0, 1, 1]),
15+
([[4, 0, 0, 0], [0, 3, 0, 0], [0, 0, 2, 0], [0, 0, 0, 1]], [-2, -3, -4, -1], 0, 1000, [0.5, 1, 2, 1]),
16+
],
17+
)
18+
def test_get_weights(stretched_component_gram_matrix, linear_coefficient, lower_bound, upper_bound, expected):
2419
actual = get_weights(stretched_component_gram_matrix, linear_coefficient, lower_bound, upper_bound)
2520
assert actual == pytest.approx(expected, rel=1e-4, abs=1e-6)

0 commit comments

Comments
 (0)