Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 4 additions & 21 deletions stgraph/nn/pytorch/temporal/tgcn.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,16 @@
import torch
from stgraph.nn.pytorch.graph_conv import GraphConv


class TGCN(torch.nn.Module):

def __init__(
self,
in_channels: int,
out_channels: int,
):
def __init__(self, in_channels, out_channels):
super(TGCN, self).__init__()

self.in_channels = in_channels
self.out_channels = out_channels

# Update GCN Layer
self.conv_z = GraphConv(self.in_channels, self.out_channels, activation=None)

# Update linear layer
self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels)

# Reset GCN layer
self.conv_r = GraphConv(self.in_channels, self.out_channels, activation=None)

# Reset linear layer
self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels)

# Candidate (Current Memory Content) GCN layer
self.conv_h = GraphConv(self.in_channels, self.out_channels, activation=None)

# Candidate linear layer
self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels)

def _set_hidden_state(self, X, H):
Expand All @@ -39,20 +20,23 @@ def _set_hidden_state(self, X, H):

def _calculate_update_gate(self, g, X, edge_weight, H):
h = self.conv_z(g, X, edge_weight=edge_weight)
h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values
Z = torch.cat((h, H), dim=1)
Z = self.linear_z(Z)
Z = torch.sigmoid(Z)
return Z

def _calculate_reset_gate(self, g, X, edge_weight, H):
h = self.conv_r(g, X, edge_weight=edge_weight)
h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values
R = torch.cat((h, H), dim=1)
R = self.linear_r(R)
R = torch.sigmoid(R)
return R

def _calculate_candidate_state(self, g, X, edge_weight, H, R):
h = self.conv_h(g, X, edge_weight=edge_weight)
h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values
H_tilde = torch.cat((h, H * R), dim=1)
H_tilde = self.linear_h(H_tilde)
H_tilde = torch.tanh(H_tilde)
Expand All @@ -63,7 +47,6 @@ def _calculate_hidden_state(self, Z, H, H_tilde):
return H

def forward(self, g, X, edge_weight=None, H=None):

H = self._set_hidden_state(X, H)
Z = self._calculate_update_gate(g, X, edge_weight, H)
R = self._calculate_reset_gate(g, X, edge_weight, H)
Expand Down
77 changes: 41 additions & 36 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import argparse
import time
import traceback

import numpy as np
import pynvml
import snoop
import torch
import torch.nn as nn
import torch.nn.functional as F
import traceback

from .model import GCN
from .utils import accuracy, generate_test_mask, generate_train_mask, to_default_device
from rich.progress import Progress

from stgraph.benchmark_tools.table import BenchmarkTable
from stgraph.dataset import CoraDataLoader
from stgraph.graph.static.static_graph import StaticGraph
from stgraph.benchmark_tools.table import BenchmarkTable
from .model import GCN
from .utils import accuracy, generate_test_mask, generate_train_mask, \
to_default_device


def train(
Expand Down Expand Up @@ -90,35 +88,42 @@ def train(
)

try:
for epoch in range(num_epochs):
torch.cuda.reset_peak_memory_stats(0)
model.train()
if cuda:
torch.cuda.synchronize()
t0 = time.time()

# forward
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()

now_mem = torch.cuda.max_memory_allocated(0) + graph_mem
Used_memory = max(now_mem, Used_memory)

if cuda:
torch.cuda.synchronize()

run_time_this_epoch = time.time() - t0

if epoch >= 3:
dur.append(run_time_this_epoch)

train_acc = accuracy(logits[train_mask], labels[train_mask])
table.add_row(
[epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2))]
with Progress() as progress:
epoch_progress = progress.add_task(
f"{dataset}",
total=num_epochs
)
for epoch in range(num_epochs):
torch.cuda.reset_peak_memory_stats(0)
model.train()
if cuda:
torch.cuda.synchronize()
t0 = time.time()

# forward
logits = model(g, features)
loss = loss_fcn(logits[train_mask], labels[train_mask])
optimizer.zero_grad()
loss.backward()
optimizer.step()

now_mem = torch.cuda.max_memory_allocated(0) + graph_mem
Used_memory = max(now_mem, Used_memory)

if cuda:
torch.cuda.synchronize()

run_time_this_epoch = time.time() - t0

if epoch >= 3:
dur.append(run_time_this_epoch)

train_acc = accuracy(logits[train_mask], labels[train_mask])
table.add_row(
[epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2))]
)

progress.update(epoch_progress, advance=1)

table.display(output_file=f)
print("Average Time taken: {:6f}".format(np.mean(dur)), file=f)
Expand Down
2 changes: 0 additions & 2 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def main(args):

for dataset_name, execute_choice in gcn_datasets.items():
if execute_choice == "Y":
print(f"Started training {testpack_properties['Name']} on {dataset_name}")

output_file_path = output_folder_path + "/" + dataset_name + ".txt"
if os.path.exists(output_file_path):
Expand All @@ -50,7 +49,6 @@ def main(args):

dataset_results[dataset_name] = result

print(f"Finished training {testpack_properties['Name']} on {dataset_name}")

table = Table(title="GCN Results")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,6 @@ def main(args):

for dataset_name, execute_choice in temporal_datasets.items():
if execute_choice == "Y":
print(f"Started training TGCN on {dataset_name}")

# train_process = subprocess.run(
# ["bash", "train_tgcn.sh", dataset_name, "8", "16"]
# )

output_file_path = output_folder_path + "/" + dataset_name + ".txt"
if os.path.exists(output_file_path):
Expand All @@ -53,18 +48,14 @@ def main(args):
dataset=dataset_name,
num_hidden=16,
feat_size=8,
lr=0.01,
lr=1e-4,
backprop_every=0,
num_epochs=30,
num_epochs=15,
output_file_path=output_file_path,
)

# breakpoint()

dataset_results[dataset_name] = result

print(f"Finished training TGCN on {dataset_name}")

# printing the summary of the run
table = Table(title="Temporal-TGCN Results")

Expand Down
Loading