Skip to content

Commit 6bb02d6

Browse files
authored
Fix tensor permutation dimensions in validation and test
1 parent aefe7db commit 6bb02d6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/sample.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
input_val = normalize(val_data["images"].to(device), mean=mean, std=std)
2828
with th.no_grad():
2929
val_out = model(input_val)
30-
val_out = val_out.permute((0, 2, 3, 4, 1))
30+
val_out = val_out.permute((0, 3, 4, 2, 1))
3131
label_val = th.nn.functional.one_hot(
3232
val_data["annotation"].type(th.int64), num_classes=5
3333
).to(device)
@@ -63,7 +63,7 @@ def disp_result(
6363
input_test = normalize(test_data["images"].to(device), mean=mean, std=std)
6464
with th.no_grad():
6565
test_out = model(input_test)
66-
test_out = test_out.permute((0, 2, 3, 4, 1))
66+
test_out = test_out.permute((0, 3, 4, 2, 1))
6767
label_test = th.nn.functional.one_hot(
6868
test_data["annotation"].type(th.int64), num_classes=5
6969
).to(device)

0 commit comments

Comments
 (0)