Skip to content

Commit 70aafab

Browse files
committed
Remove unused calc_cm function
1 parent 86e7ac9 commit 70aafab

File tree

1 file changed

+0
-20
lines changed

1 file changed

+0
-20
lines changed

wavetorch/utils.py

-20
Original file line numberDiff line numberDiff line change
@@ -34,23 +34,3 @@ def accuracy_onehot(y_pred, y_label):
3434

3535
def normalize_power(X):
3636
return X / torch.sum(X, dim=1, keepdim=True)
37-
38-
39-
def calc_cm(model, dataloader, verbose=True):
40-
"""Calculate the confusion matrix
41-
"""
42-
with torch.no_grad():
43-
list_yb_pred = []
44-
list_yb = []
45-
i = 1
46-
for xb, yb in dataloader:
47-
yb_pred = model(xb)
48-
list_yb_pred.append(yb_pred)
49-
list_yb.append(yb)
50-
if verbose: print("cm: processing batch %d" % i)
51-
i += 1
52-
53-
y_pred = torch.cat(list_yb_pred, dim=0)
54-
y_truth = torch.cat(list_yb, dim=0)
55-
56-
return confusion_matrix(y_truth.argmax(dim=1).numpy(), y_pred.argmax(dim=1).numpy())

0 commit comments

Comments
 (0)