We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 86e7ac9 commit 70aafabCopy full SHA for 70aafab
wavetorch/utils.py
@@ -34,23 +34,3 @@ def accuracy_onehot(y_pred, y_label):
34
35
def normalize_power(X):
36
return X / torch.sum(X, dim=1, keepdim=True)
37
-
38
39
-def calc_cm(model, dataloader, verbose=True):
40
- """Calculate the confusion matrix
41
- """
42
- with torch.no_grad():
43
- list_yb_pred = []
44
- list_yb = []
45
- i = 1
46
- for xb, yb in dataloader:
47
- yb_pred = model(xb)
48
- list_yb_pred.append(yb_pred)
49
- list_yb.append(yb)
50
- if verbose: print("cm: processing batch %d" % i)
51
- i += 1
52
53
- y_pred = torch.cat(list_yb_pred, dim=0)
54
- y_truth = torch.cat(list_yb, dim=0)
55
56
- return confusion_matrix(y_truth.argmax(dim=1).numpy(), y_pred.argmax(dim=1).numpy())
0 commit comments