|
418 | 418 | " # perform a single optimization step (parameter update)\n",
|
419 | 419 | " optimizer.step()\n",
|
420 | 420 | " # update running training loss\n",
|
421 |
| - " train_loss += loss.item()*data.size(0)\n", |
| 421 | + " train_loss += loss.item()\n", |
422 | 422 | " \n",
|
423 | 423 | " ###################### \n",
|
424 | 424 | " # validate the model #\n",
|
|
430 | 430 | " # calculate the loss\n",
|
431 | 431 | " loss = criterion(output, target)\n",
|
432 | 432 | " # update running validation loss \n",
|
433 |
| - " valid_loss += loss.item()*data.size(0)\n", |
| 433 | + " valid_loss += loss.item()\n", |
434 | 434 | " \n",
|
435 | 435 | " # print training/validation statistics \n",
|
436 | 436 | " # calculate average loss over an epoch\n",
|
437 |
| - " train_loss = train_loss/len(train_loader.dataset)\n", |
438 |
| - " valid_loss = valid_loss/len(valid_loader.dataset)\n", |
| 437 | + " train_loss = train_loss/len(train_loader)\n", |
| 438 | + " valid_loss = valid_loss/len(valid_loader)\n", |
439 | 439 | " \n",
|
440 | 440 | " print('Epoch: {} \\tTraining Loss: {:.6f} \\tValidation Loss: {:.6f}'.format(\n",
|
441 | 441 | " epoch+1, \n",
|
|
0 commit comments