@@ -45,30 +45,30 @@ def test__prepare_output():
45
45
metric = MeanAveragePrecision ()
46
46
47
47
metric ._type = "binary"
48
- scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
48
+ scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
49
49
assert scores .shape == y .shape == (1 , 120 )
50
50
51
51
metric ._type = "multiclass"
52
52
scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 4 , (5 , 3 , 2 ))))
53
53
assert scores .shape == (4 , 30 ) and y .shape == (30 ,)
54
54
55
55
metric ._type = "multilabel"
56
- scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 )). bool () ))
56
+ scores , y = metric ._prepare_output ((torch .rand ((5 , 4 , 3 , 2 )), torch .randint (0 , 2 , (5 , 4 , 3 , 2 ))))
57
57
assert scores .shape == y .shape == (4 , 30 )
58
58
59
59
60
60
def test_update ():
61
61
metric = MeanAveragePrecision ()
62
62
assert len (metric ._y_pred ) == len (metric ._y_true ) == 0
63
- metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 )). bool () ))
63
+ metric .update ((torch .rand ((5 , 4 )), torch .randint (0 , 2 , (5 , 4 ))))
64
64
assert len (metric ._y_pred ) == len (metric ._y_true ) == 1
65
65
66
66
67
67
def test__compute_recall_and_precision ():
68
68
m = MeanAveragePrecision ()
69
69
70
70
scores = torch .rand ((50 ,))
71
- y_true = torch .randint (0 , 2 , (50 ,)). bool ()
71
+ y_true = torch .randint (0 , 2 , (50 ,))
72
72
precision , recall , _ = precision_recall_curve (y_true .numpy (), scores .numpy ())
73
73
P = y_true .sum (dim = - 1 )
74
74
ignite_recall , ignite_precision = m ._compute_recall_and_precision (y_true , scores , P )
@@ -77,7 +77,7 @@ def test__compute_recall_and_precision():
77
77
78
78
# When there's no actual positive. Numpy expectedly raises warning.
79
79
scores = torch .rand ((50 ,))
80
- y_true = torch .zeros ((50 ,)). bool ()
80
+ y_true = torch .zeros ((50 ,))
81
81
precision , recall , _ = precision_recall_curve (y_true .numpy (), scores .numpy ())
82
82
P = torch .tensor (0 )
83
83
ignite_recall , ignite_precision = m ._compute_recall_and_precision (y_true , scores , P )
@@ -147,7 +147,7 @@ def test_compute_nonbinary_data(class_mean):
147
147
148
148
# Multilabel
149
149
m = MeanAveragePrecision (is_multilabel = True , class_mean = class_mean )
150
- y_true = torch .randint (0 , 2 , (130 , 5 , 2 , 2 )). bool ()
150
+ y_true = torch .randint (0 , 2 , (130 , 5 , 2 , 2 ))
151
151
m .update ((scores [:50 ], y_true [:50 ]))
152
152
m .update ((scores [50 :], y_true [50 :]))
153
153
ignite_map = m .compute ().numpy ()
0 commit comments