Skip to content

Commit aefe7db

Browse files
authored
Update data path and permute dimensions in train.py
1 parent c849223 commit aefe7db

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def train():
110110
data_set = Loader(
111111
input_shape=input_shape,
112112
val_keys=val_keys,
113-
data_path="/home/lveerama/Courses/Foundations_of_Machine_Learning/day_13_exercise_segmentation_solution/data/",
113+
data_path="./data/",
114114
)
115115
epochs = 125
116116
batch_size = 2
@@ -164,7 +164,7 @@ def train():
164164

165165
opt.zero_grad()
166166
preds = model(input_x)
167-
preds = preds.permute((0, 2, 3, 4, 1))
167+
preds = preds.permute((0, 3, 4, 2, 1))
168168
loss = th.mean(
169169
softmax_focal_loss(
170170
preds, labels_y, th.ones((preds.shape[-1])).to(device)
@@ -186,7 +186,7 @@ def train():
186186
).to(device)
187187
with th.no_grad():
188188
val_out = model(input_val)
189-
val_out = val_out.permute((0, 2, 3, 4, 1))
189+
val_out = val_out.permute((0, 3, 4, 2, 1))
190190
val_loss = th.mean(
191191
softmax_focal_loss(
192192
val_out, label_val, th.ones((val_out.shape[-1])).to(device)

0 commit comments

Comments
 (0)