diff --git a/GNNs/PyG/gcn_link_prediction.ipynb b/GNNs/PyG/gcn_link_prediction.ipynb index 8efda3f..3d6617a 100644 --- a/GNNs/PyG/gcn_link_prediction.ipynb +++ b/GNNs/PyG/gcn_link_prediction.ipynb @@ -672,7 +672,7 @@ "train_edge_neighbor_loader = conn.gds.edgeNeighborLoader(\n", " v_in_feats=[\"x\"],\n", " v_out_labels=[\"y\"],\n", - " num_batches=5,\n", + " batch_size=1000,\n", " e_extra_feats=[\"is_train\",\"is_val\"],\n", " output_format=\"PyG\",\n", " num_neighbors=10,\n", @@ -692,7 +692,7 @@ "val_edge_neighbor_loader = conn.gds.edgeNeighborLoader(\n", " v_in_feats=[\"x\"],\n", " v_out_labels=[\"y\"],\n", - " num_batches=5,\n", + " batch_size=500,\n", " e_extra_feats=[\"is_train\",\"is_val\"],\n", " output_format=\"PyG\",\n", " num_neighbors=10,\n", diff --git a/applications/nodepiece/nodepiece.ipynb b/applications/nodepiece/nodepiece.ipynb index 13f58b5..bac92e8 100644 --- a/applications/nodepiece/nodepiece.ipynb +++ b/applications/nodepiece/nodepiece.ipynb @@ -368,12 +368,12 @@ "source": [ "import time\n", "import numpy as np\n", - "from pyTigerGraph.gds.metrics import Accuracy\n", + "from pyTigerGraph.gds.metrics import Accuracy, Accumulator\n", "\n", "\n", "for i in range(10):\n", " acc = Accuracy()\n", - " epoch_loss = 0\n", + " epoch_loss = Accumulator()\n", " start = time.time()\n", " for batch in np_loader:\n", " labels = batch[\"y\"]\n", @@ -383,21 +383,21 @@ " optimizer.zero_grad()\n", " loss_val.backward()\n", " optimizer.step()\n", - " epoch_loss += loss_val.item()\n", + " epoch_loss.update(loss_val.item())\n", " end = time.time()\n", " val_acc = Accuracy()\n", - " val_epoch_loss = 0\n", + " val_epoch_loss = Accumulator()\n", " for val_batch in valid_loader:\n", " with torch.no_grad():\n", " labels = val_batch[\"y\"]\n", " out = model(val_batch)\n", " loss_val = loss(out, labels)\n", " val_acc.update(out.argmax(dim=1), labels)\n", - " val_epoch_loss += loss_val.item()\n", - " print(\"EPOCH {}: {}\".format(i, epoch_loss/np_loader.num_batches), \n", + " val_epoch_loss.update(loss_val.item())\n", + " print(\"EPOCH {}: {}\".format(i, epoch_loss.mean), \n", " \"Train Accuracy:\", acc.value, \n", " \"Time:\", end-start, \n", - " \"Valid Loss: {}\".format(val_epoch_loss/valid_loader.num_batches), \n", + " \"Valid Loss: {}\".format(val_epoch_loss.mean), \n", " \"Valid Accuracy:\", val_acc.value)" ] }, @@ -448,7 +448,7 @@ "source": [ "acc = Accuracy()\n", "\n", - "epoch_loss = 0\n", + "epoch_loss = Accumulator()\n", "start = time.time()\n", "model.eval()\n", "for batch in test_loader:\n", @@ -456,9 +456,9 @@ " out = model(batch)\n", " loss_val = loss(out, labels)\n", " acc.update(out.argmax(dim=1), labels)\n", - " epoch_loss += loss_val.item()\n", + " epoch_loss.update(loss_val.item())\n", "end = time.time()\n", - "print(\"Loss: {}, Accuracy: {}\".format(epoch_loss/test_loader.num_batches, acc.value), \"Time:\", end-start)" + "print(\"Loss: {}, Accuracy: {}\".format(epoch_loss.mean, acc.value), \"Time:\", end-start)" ] }, {