Skip to content

Commit cb8c7fe

Browse files
authored
Merge pull request #1 from Machine-Learning-Foundations/suggestions
fix: correct type annotations in get_indices function
2 parents d1f11bd + 034c65d commit cb8c7fe

File tree

2 files changed

+4
-12
lines changed

2 files changed

+4
-12
lines changed

src/custom_conv.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ def get_indices(image: torch.Tensor, kernel: torch.Tensor) -> tuple:
88
"""Get the indices to set up pixel vectors for convolution by matrix-multiplication.
99
1010
Args:
11-
image (jnp.ndarray): The input image of shape [height, width.]
12-
kernel (jnp.ndarray): A 2d-convolution kernel.
11+
image (torch.Tensor): The input image of shape [height, width].
12+
kernel (torch.Tensor): A 2d-convolution kernel.
1313
1414
Returns:
1515
tuple: An integer array with the indices, the number of rows in the result,

src/mnist.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,7 @@ def zero_grad(model: Net) -> Net:
157157

158158
# TODO: Train the model.
159159
# Use `loss.backward()`, `sgd_step` and `zero_grad`.
160-
preds = model(imgs)
161-
loss_val = cross_entropy(
162-
label=th.nn.functional.one_hot(labels, num_classes=10), out=preds
163-
)
164-
loss_val.backward()
165-
166-
model = sgd_step(model, learning_rate=args.lr)
167-
model = zero_grad(model)
168-
epoch_loss.append(loss_val.item())
160+
169161
print(f"Loss: {sum(epoch_loss)/len(epoch_loss):2.4f}")
170162

171163
train_acc = get_acc(model=model, dataloader=train_loader)
@@ -176,7 +168,7 @@ def zero_grad(model: Net) -> Net:
176168
test_acc = get_acc(model=model, dataloader=test_loader)
177169
train_accs.append(per_epoch_train_acc)
178170
val_accs.append(per_epoch_val_acc)
179-
test_accs.append(train_acc)
171+
test_accs.append(test_acc)
180172
train_accs_np = np.stack(train_accs, axis=0)
181173
val_accs_np = np.stack(val_accs, axis=0)
182174
test_accs_np = np.stack(test_accs)

0 commit comments

Comments
 (0)