Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ classifiers = [
keywords = ["GNN", "TGNN", "GPU Programming"]
requires-python = ">=3.8"
dependencies = [
'Jinja2 >= 2',
'Jinja2 >= 3.1.3',
'pynvrtc >= 9.2',
'pydot',
'networkx >= 3.1',
Expand All @@ -46,7 +46,7 @@ dev = [
"black",
"pytest >= 7.4.3",
"pytest-cov >= 4.1.0",
"tqdm >= 4.64.1",
"tqdm >= 4.66.3",
"build >= 0.10.0",
"gdown >= 4.6.6",
"pynvml >= 11.5.0",
Expand Down
4 changes: 2 additions & 2 deletions stgraph/dataset/temporal/metrla_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class METRLADataLoader(STGraphTemporalDataset):
r"""Traffic forecasting dataset based on the Los Angeles city..
r"""Traffic forecasting dataset based on the Los Angeles city.

A dataset for predicting traffic patterns in the Los Angeles Metropolitan area,
comprising traffic data obtained from 207 loop detectors on highways in Los
Expand Down Expand Up @@ -72,7 +72,7 @@ class METRLADataLoader(STGraphTemporalDataset):

def __init__(
self: METRLADataLoader,
verbose: bool = True,
verbose: bool = False,
num_timesteps_in: int = 12,
num_timesteps_out: int = 12,
cutoff_time: int | None = None,
Expand Down
8 changes: 8 additions & 0 deletions tests/scripts/stgraph_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ def main(args):
for testpack in testpack_names:
script_path = "v" + version_number + "/" + testpack + "/" + testpack + ".py"
output_folder_path = "v" + version_number + "/" + testpack + "/outputs"

# create the script outputs folder if it doesn't exist
if not os.path.exists(output_folder_path):
try:
os.makedirs(output_folder_path)
except OSError as e:
print(f"Failed to create the script outputs folder at {output_folder_path}: {e}")

if os.path.exists(script_path):
subprocess.run(["python3", script_path, "-o", output_folder_path])
else:
Expand Down
4 changes: 2 additions & 2 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def to_default_device(data):
return data.to(get_default_device(), non_blocking=True)


def generate_train_mask(size: int, train_test_split: int) -> list:
def generate_train_mask(size: int, train_test_split: float) -> list:
cutoff = size * train_test_split
return [1 if i < cutoff else 0 for i in range(size)]


def generate_test_mask(size: int, train_test_split: int) -> list:
def generate_test_mask(size: int, train_test_split: float) -> list:
cutoff = size * train_test_split
return [0 if i < cutoff else 1 for i in range(size)]
6 changes: 3 additions & 3 deletions tests/scripts/v1_1_0/gcn_dataloaders/gcn_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main(args):
f"[bold yellow]{testpack_properties['Name']}: {testpack_properties['Description']}"
)

# if the value if set to "Y", then the tests are executed for the given
# if the value is set to "Y", then the tests are executed for the given
# dataset. Else if set to "N", then it is ignored.
gcn_datasets = {
"Cora": "Y",
Expand All @@ -31,7 +31,7 @@ def main(args):

for dataset_name, execute_choice in gcn_datasets.items():
if execute_choice == "Y":
print(f"Started training TGCN on {dataset_name}")
print(f"Started training {testpack_properties['Name']} on {dataset_name}")

output_file_path = output_folder_path + "/" + dataset_name + ".txt"
if os.path.exists(output_file_path):
Expand All @@ -50,7 +50,7 @@ def main(args):

dataset_results[dataset_name] = result

print(f"Finished training TGCN on {dataset_name}")
print(f"Finished training {testpack_properties['Name']} on {dataset_name}")

table = Table(title="GCN Results")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ def main(args):
# for prop_name, prop_value in testpack_properties.items():
# console.print(f"[cyan bold]{prop_name}[/cyan bold] : {prop_value}")

# if the value if set to "Y", then the tests are executed for the given
# if the value is set to "Y", then the tests are executed for the given
# dataset. Else if set to "N", then it is ignored.
temporal_datasets = {
"Hungary_Chickenpox": "Y",
"METRLA": "N",
"METRLA": "Y",
"Montevideo_Bus": "Y",
"PedalMe": "Y",
"WikiMath": "Y",
Expand Down
3 changes: 3 additions & 0 deletions tests/scripts/v1_1_0/temporal_tgcn_dataloaders/tgcn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,9 @@ def train(
if t >= total_timestamps - dataloader._lags:
break

if dataset == "METRLA" and t >= total_timestamps - (dataloader._num_timesteps_out + dataloader._num_timesteps_in):
break

y_out, y_hat, hidden_state = model(
G, y_hat, edge_weight, hidden_state
)
Expand Down