diff --git a/pyproject.toml b/pyproject.toml index eb772f9..efbb11c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ classifiers = [ keywords = ["GNN", "TGNN", "GPU Programming"] requires-python = ">=3.8" dependencies = [ - 'Jinja2 >= 2', + 'Jinja2 >= 3.1.3', 'pynvrtc >= 9.2', 'pydot', 'networkx >= 3.1', @@ -46,7 +46,7 @@ dev = [ "black", "pytest >= 7.4.3", "pytest-cov >= 4.1.0", - "tqdm >= 4.64.1", + "tqdm >= 4.66.3", "build >= 0.10.0", "gdown >= 4.6.6", "pynvml >= 11.5.0", diff --git a/stgraph/dataset/temporal/metrla_dataloader.py b/stgraph/dataset/temporal/metrla_dataloader.py index fef0bc1..d72fef5 100644 --- a/stgraph/dataset/temporal/metrla_dataloader.py +++ b/stgraph/dataset/temporal/metrla_dataloader.py @@ -9,7 +9,7 @@ class METRLADataLoader(STGraphTemporalDataset): - r"""Traffic forecasting dataset based on the Los Angeles city.. + r"""Traffic forecasting dataset based on the Los Angeles city. A dataset for predicting traffic patterns in the Los Angeles Metropolitan area, comprising traffic data obtained from 207 loop detectors on highways in Los @@ -72,7 +72,7 @@ class METRLADataLoader(STGraphTemporalDataset): def __init__( self: METRLADataLoader, - verbose: bool = True, + verbose: bool = False, num_timesteps_in: int = 12, num_timesteps_out: int = 12, cutoff_time: int | None = None, diff --git a/tests/scripts/stgraph_script.py b/tests/scripts/stgraph_script.py index 862231a..b1236fa 100644 --- a/tests/scripts/stgraph_script.py +++ b/tests/scripts/stgraph_script.py @@ -10,6 +10,14 @@ def main(args): for testpack in testpack_names: script_path = "v" + version_number + "/" + testpack + "/" + testpack + ".py" output_folder_path = "v" + version_number + "/" + testpack + "/outputs" + + # create the script outputs folder if it doesn't exist + if not os.path.exists(output_folder_path): + try: + os.makedirs(output_folder_path) + except OSError as e: + print(f"Failed to create the script outputs folder at {output_folder_path}: {e}") + if os.path.exists(script_path): subprocess.run(["python3", script_path, "-o", output_folder_path]) else: diff --git a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py index 0f2d52c..61d7107 100644 --- a/tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py +++ b/tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py @@ -22,11 +22,11 @@ def to_default_device(data): return data.to(get_default_device(), non_blocking=True) -def generate_train_mask(size: int, train_test_split: int) -> list: +def generate_train_mask(size: int, train_test_split: float) -> list: cutoff = size * train_test_split return [1 if i < cutoff else 0 for i in range(size)] -def generate_test_mask(size: int, train_test_split: int) -> list: +def generate_test_mask(size: int, train_test_split: float) -> list: cutoff = size * train_test_split return [0 if i < cutoff else 1 for i in range(size)] diff --git a/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py b/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py index 7c64612..e15737c 100644 --- a/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py +++ b/tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py @@ -21,7 +21,7 @@ def main(args): f"[bold yellow]{testpack_properties['Name']}: {testpack_properties['Description']}" ) - # if the value if set to "Y", then the tests are executed for the given + # if the value is set to "Y", then the tests are executed for the given # dataset. Else if set to "N", then it is ignored. gcn_datasets = { "Cora": "Y", @@ -31,7 +31,7 @@ def main(args): for dataset_name, execute_choice in gcn_datasets.items(): if execute_choice == "Y": - print(f"Started training TGCN on {dataset_name}") + print(f"Started training {testpack_properties['Name']} on {dataset_name}") output_file_path = output_folder_path + "/" + dataset_name + ".txt" if os.path.exists(output_file_path): @@ -50,7 +50,7 @@ def main(args): dataset_results[dataset_name] = result - print(f"Finished training TGCN on {dataset_name}") + print(f"Finished training {testpack_properties['Name']} on {dataset_name}") table = Table(title="GCN Results") diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py index 421e045..536253b 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/temporal_tgcn_dataloaders.py @@ -24,11 +24,11 @@ def main(args): # for prop_name, prop_value in testpack_properties.items(): # console.print(f"[cyan bold]{prop_name}[/cyan bold] : {prop_value}") - # if the value if set to "Y", then the tests are executed for the given + # if the value is set to "Y", then the tests are executed for the given # dataset. Else if set to "N", then it is ignored. temporal_datasets = { "Hungary_Chickenpox": "Y", - "METRLA": "N", + "METRLA": "Y", "Montevideo_Bus": "Y", "PedalMe": "Y", "WikiMath": "Y", diff --git a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py index 2ad4832..494cbe8 100644 --- a/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py +++ b/tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py @@ -131,6 +131,9 @@ def train( if t >= total_timestamps - dataloader._lags: break + if dataset == "METRLA" and t >= total_timestamps - (dataloader._num_timesteps_out + dataloader._num_timesteps_in): + break + y_out, y_hat, hidden_state = model( G, y_hat, edge_weight, hidden_state )