diff --git a/stgraph/nn/pytorch/temporal/tgcn.py b/stgraph/nn/pytorch/temporal/tgcn.py index 94685ead..0e2d50ee 100644 --- a/stgraph/nn/pytorch/temporal/tgcn.py +++ b/stgraph/nn/pytorch/temporal/tgcn.py @@ -1,35 +1,16 @@ import torch from stgraph.nn.pytorch.graph_conv import GraphConv - class TGCN(torch.nn.Module): - - def __init__( - self, - in_channels: int, - out_channels: int, - ): + def __init__(self, in_channels, out_channels): super(TGCN, self).__init__() - self.in_channels = in_channels self.out_channels = out_channels - - # Update GCN Layer self.conv_z = GraphConv(self.in_channels, self.out_channels, activation=None) - - # Update linear layer self.linear_z = torch.nn.Linear(2 * self.out_channels, self.out_channels) - - # Reset GCN layer self.conv_r = GraphConv(self.in_channels, self.out_channels, activation=None) - - # Reset linear layer self.linear_r = torch.nn.Linear(2 * self.out_channels, self.out_channels) - - # Candidate (Current Memory Content) GCN layer self.conv_h = GraphConv(self.in_channels, self.out_channels, activation=None) - - # Candidate linear layer self.linear_h = torch.nn.Linear(2 * self.out_channels, self.out_channels) def _set_hidden_state(self, X, H): @@ -39,6 +20,7 @@ def _set_hidden_state(self, X, H): def _calculate_update_gate(self, g, X, edge_weight, H): h = self.conv_z(g, X, edge_weight=edge_weight) + h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values Z = torch.cat((h, H), dim=1) Z = self.linear_z(Z) Z = torch.sigmoid(Z) @@ -46,6 +28,7 @@ def _calculate_update_gate(self, g, X, edge_weight, H): def _calculate_reset_gate(self, g, X, edge_weight, H): h = self.conv_r(g, X, edge_weight=edge_weight) + h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values R = torch.cat((h, H), dim=1) R = self.linear_r(R) R = torch.sigmoid(R) @@ -53,6 +36,7 @@ def _calculate_reset_gate(self, g, X, edge_weight, H): def _calculate_candidate_state(self, g, X, edge_weight, H, R): h = self.conv_h(g, X, edge_weight=edge_weight) + h = torch.clamp(h, min=-1e6, max=1e6) # Clamp to avoid extreme values H_tilde = torch.cat((h, H * R), dim=1) H_tilde = self.linear_h(H_tilde) H_tilde = torch.tanh(H_tilde) @@ -63,7 +47,6 @@ def _calculate_hidden_state(self, Z, H, H_tilde): return H def forward(self, g, X, edge_weight=None, H=None): - H = self._set_hidden_state(X, H) Z = self._calculate_update_gate(g, X, edge_weight, H) R = self._calculate_reset_gate(g, X, edge_weight, H) diff --git a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py index ee129dbc..37f126e6 100644 --- a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py +++ b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/train.py @@ -1,20 +1,18 @@ -import argparse import time +import traceback import numpy as np import pynvml -import snoop import torch -import torch.nn as nn import torch.nn.functional as F -import traceback - -from .model import GCN -from .utils import accuracy, generate_test_mask, generate_train_mask, to_default_device +from rich.progress import Progress +from stgraph.benchmark_tools.table import BenchmarkTable from stgraph.dataset import CoraDataLoader from stgraph.graph.static.static_graph import StaticGraph -from stgraph.benchmark_tools.table import BenchmarkTable +from .model import GCN +from .utils import accuracy, generate_test_mask, generate_train_mask, \ + to_default_device def train( @@ -90,35 +88,42 @@ def train( ) try: - for epoch in range(num_epochs): - torch.cuda.reset_peak_memory_stats(0) - model.train() - if cuda: - torch.cuda.synchronize() - t0 = time.time() - - # forward - logits = model(g, features) - loss = loss_fcn(logits[train_mask], labels[train_mask]) - optimizer.zero_grad() - loss.backward() - optimizer.step() - - now_mem = torch.cuda.max_memory_allocated(0) + graph_mem - Used_memory = max(now_mem, Used_memory) - - if cuda: - torch.cuda.synchronize() - - run_time_this_epoch = time.time() - t0 - - if epoch >= 3: - dur.append(run_time_this_epoch) - - train_acc = accuracy(logits[train_mask], labels[train_mask]) - table.add_row( - [epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2))] + with Progress() as progress: + epoch_progress = progress.add_task( + f"{dataset}", + total=num_epochs ) + for epoch in range(num_epochs): + torch.cuda.reset_peak_memory_stats(0) + model.train() + if cuda: + torch.cuda.synchronize() + t0 = time.time() + + # forward + logits = model(g, features) + loss = loss_fcn(logits[train_mask], labels[train_mask]) + optimizer.zero_grad() + loss.backward() + optimizer.step() + + now_mem = torch.cuda.max_memory_allocated(0) + graph_mem + Used_memory = max(now_mem, Used_memory) + + if cuda: + torch.cuda.synchronize() + + run_time_this_epoch = time.time() - t0 + + if epoch >= 3: + dur.append(run_time_this_epoch) + + train_acc = accuracy(logits[train_mask], labels[train_mask]) + table.add_row( + [epoch, run_time_this_epoch, train_acc, (now_mem * 1.0 / (1024**2))] + ) + + progress.update(epoch_progress, advance=1) table.display(output_file=f) print("Average Time taken: {:6f}".format(np.mean(dur)), file=f) diff --git a/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py b/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py index e15737ca..7e9acba1 100644 --- a/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py +++ b/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py @@ -31,7 +31,6 @@ def main(args): for dataset_name, execute_choice in gcn_datasets.items(): if execute_choice == "Y": - print(f"Started training {testpack_properties['Name']} on {dataset_name}") output_file_path = output_folder_path + "/" + dataset_name + ".txt" if os.path.exists(output_file_path): @@ -50,7 +49,6 @@ def main(args): dataset_results[dataset_name] = result - print(f"Finished training {testpack_properties['Name']} on {dataset_name}") table = Table(title="GCN Results") diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py index 536253b6..47961b8e 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py @@ -39,11 +39,6 @@ def main(args): for dataset_name, execute_choice in temporal_datasets.items(): if execute_choice == "Y": - print(f"Started training TGCN on {dataset_name}") - - # train_process = subprocess.run( - # ["bash", "train_tgcn.sh", dataset_name, "8", "16"] - # ) output_file_path = output_folder_path + "/" + dataset_name + ".txt" if os.path.exists(output_file_path): @@ -53,18 +48,14 @@ def main(args): dataset=dataset_name, num_hidden=16, feat_size=8, - lr=0.01, + lr=1e-4, backprop_every=0, - num_epochs=30, + num_epochs=15, output_file_path=output_file_path, ) - # breakpoint() - dataset_results[dataset_name] = result - print(f"Finished training TGCN on {dataset_name}") - # printing the summary of the run table = Table(title="Temporal-TGCN Results") diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py index 494cbe82..7737c95c 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py @@ -1,38 +1,31 @@ -import argparse import time -import numpy as np -import pandas as pd -import torch -import snoop -import pynvml -import sys -import os import traceback -from .model import STGraphTGCN -from stgraph.graph.static.static_graph import StaticGraph +import numpy as np +import pynvml +import torch +from rich.progress import Progress -from stgraph.dataset import WindmillOutputDataLoader -from stgraph.dataset import WikiMathDataLoader +from stgraph.benchmark_tools.table import BenchmarkTable from stgraph.dataset import HungaryCPDataLoader -from stgraph.dataset import PedalMeDataLoader from stgraph.dataset import METRLADataLoader from stgraph.dataset import MontevideoBusDataLoader - -from stgraph.benchmark_tools.table import BenchmarkTable -from .utils import to_default_device, get_default_device - -from rich import inspect +from stgraph.dataset import PedalMeDataLoader +from stgraph.dataset import WikiMathDataLoader +from stgraph.dataset import WindmillOutputDataLoader +from stgraph.graph.static.static_graph import StaticGraph +from .model import STGraphTGCN +from .utils import to_default_device, get_default_device, init_weights def train( - dataset: str, - num_hidden: int, - feat_size: int, - lr: float, - backprop_every: int, - num_epochs: int, - output_file_path: str, + dataset: str, + num_hidden: int, + feat_size: int, + lr: float, + backprop_every: int, + num_epochs: int, + output_file_path: str, ) -> int: with open(output_file_path, "w") as f: if torch.cuda.is_available(): @@ -62,21 +55,36 @@ def train( edge_list = dataloader.get_edges() edge_weight_list = dataloader.get_edge_weights() targets = dataloader.get_all_targets() + assert not np.isnan(targets).any(), "Targets contain NaN values" pynvml.nvmlInit() handle = pynvml.nvmlDeviceGetHandleByIndex(0) initial_used_gpu_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - G = StaticGraph(edge_list, edge_weight_list, dataloader.gdata["num_nodes"]) - graph_mem = pynvml.nvmlDeviceGetMemoryInfo(handle).used - initial_used_gpu_mem + G = StaticGraph( + edge_list, + edge_weight_list, + dataloader.gdata["num_nodes"] + ) + graph_mem = pynvml.nvmlDeviceGetMemoryInfo( + handle).used - initial_used_gpu_mem edge_weight = to_default_device( torch.unsqueeze(torch.FloatTensor(edge_weight_list), 1) ) + + # Clamp edge weights to have min. value of 1e-6 + edge_weight = torch.clamp(edge_weight, min=1e-6) + targets = to_default_device(torch.FloatTensor(np.array(targets))) num_hidden_units = num_hidden num_outputs = 1 - model = to_default_device(STGraphTGCN(feat_size, num_hidden_units, num_outputs)) + model = to_default_device( + STGraphTGCN(feat_size, num_hidden_units, num_outputs)) + + # Apply custom weight initialization + model.apply(init_weights) + optimizer = torch.optim.Adam(model.parameters(), lr=lr) # Logging Output @@ -105,70 +113,89 @@ def train( norm = to_default_device(norm) G.set_ndata("norm", norm.unsqueeze(1)) - # train - print("Training...\n", file=f) + print("Starting training...\n", file=f) try: - for epoch in range(num_epochs): - torch.cuda.synchronize() - torch.cuda.reset_peak_memory_stats(0) - model.train() - - t0 = time.time() - gpu_mem_arr = [] - cost_arr = [] - - for index in range(num_iter): - optimizer.zero_grad() - cost = 0 - hidden_state = None - y_hat = torch.randn( - (dataloader.gdata["num_nodes"], feat_size), - device=get_default_device(), - ) - for k in range(backprop_every): - t = index * backprop_every + k - - if t >= total_timestamps - dataloader._lags: - break - - if dataset == "METRLA" and t >= total_timestamps - (dataloader._num_timesteps_out + dataloader._num_timesteps_in): - break - - y_out, y_hat, hidden_state = model( - G, y_hat, edge_weight, hidden_state - ) - # breakpoint() - cost = cost + torch.mean((y_out - targets[t]) ** 2) - - if cost == 0: - break - - cost = cost / (backprop_every + 1) - cost.backward() - optimizer.step() - torch.cuda.synchronize() - cost_arr.append(cost.item()) - - used_gpu_mem = torch.cuda.max_memory_allocated(0) + graph_mem - gpu_mem_arr.append(used_gpu_mem) - - run_time_this_epoch = time.time() - t0 - - if epoch >= 3: - dur.append(run_time_this_epoch) - max_gpu.append(max(gpu_mem_arr)) - - table.add_row( - [ - epoch, - "{:.5f}".format(run_time_this_epoch), - "{:.4f}".format(sum(cost_arr) / len(cost_arr)), - "{:.4f}".format((max(gpu_mem_arr) * 1.0 / (1024**2))), - ] + with Progress() as progress: + epoch_progress = progress.add_task( + f"{dataset}", + total=num_epochs ) + while not progress.finished: + for epoch in range(num_epochs): + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats(0) + model.train() + + t0 = time.time() + gpu_mem_arr = [] + cost_arr = [] + + for index in range(num_iter): + optimizer.zero_grad() + cost = 0 + hidden_state = None + y_hat = torch.randn( + (dataloader.gdata["num_nodes"], feat_size), + device=get_default_device(), + ) + for k in range(backprop_every): + t = index * backprop_every + k + + if t >= total_timestamps - dataloader._lags: + break + + if dataset == "METRLA" and t >= total_timestamps - ( + dataloader._num_timesteps_out + dataloader._num_timesteps_in): + break + + y_out, y_hat, hidden_state = model( + G, y_hat, edge_weight, hidden_state + ) + + cost = cost + torch.mean( + (y_out - targets[t]) ** 2) + + if cost == 0: + break + + cost = cost / (backprop_every + 1) + cost.backward() + + torch.nn.utils.clip_grad_norm_( + model.parameters(), + 1.0 + ) + + optimizer.step() + torch.cuda.synchronize() + cost_arr.append(cost.item()) + + used_gpu_mem = torch.cuda.max_memory_allocated( + 0) + graph_mem + gpu_mem_arr.append(used_gpu_mem) + + run_time_this_epoch = time.time() - t0 + + if epoch >= 3: + dur.append(run_time_this_epoch) + max_gpu.append(max(gpu_mem_arr)) + + table.add_row( + [ + epoch, + "{:.5f}".format(run_time_this_epoch), + "{:.4f}".format(sum(cost_arr) / len(cost_arr)), + "{:.4f}".format( + (max(gpu_mem_arr) * 1.0 / (1024 ** 2))), + ] + ) + + progress.update(epoch_progress, advance=1) + table.display(output_file=f) print("Average Time taken: {:6f}".format(np.mean(dur)), file=f) + return 0 except Exception as e: diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/utils.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/utils.py index b2f4091a..1c4a703d 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/utils.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/utils.py @@ -12,3 +12,9 @@ def to_default_device(data): if isinstance(data, (list, tuple)): return [to_default_device(x) for x in data] return data.to(get_default_device(), non_blocking=True) + + +def init_weights(layer): + if isinstance(layer, torch.nn.Linear): + torch.nn.init.xavier_uniform_(layer.weight) + layer.bias.data.fill_(0.01) \ No newline at end of file