Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
731b95f
🎉 Initial Commit for dataset-abstraction
nithinmanoj10 Oct 19, 2023
c15bda2
🚧 Started STGraphDataset implementation
nithinmanoj10 Oct 19, 2023
48e0dea
➕ Introduce STGraphStaticDataset class
nithinmanoj10 Oct 20, 2023
7a339fb
➕ Initialised the abstract methods for handling datasets
nithinmanoj10 Oct 20, 2023
6e3e5ee
➕ Added cached graph data handling methods
nithinmanoj10 Oct 20, 2023
0b7e7d6
✅ Completed implementing CoraDataLoader
nithinmanoj10 Oct 21, 2023
80bbaa2
📝 Documented STGraphDataset class
nithinmanoj10 Oct 21, 2023
151b607
✅ Completed CoraDatatsetLoader documentation
nithinmanoj10 Oct 21, 2023
0ebda52
📝 Added adhoc docs for STGraphStaticDataset
nithinmanoj10 Oct 21, 2023
c741c0c
🧪 Started writing tests for CoraDataLoader
nithinmanoj10 Oct 22, 2023
a45caa4
➕ Added more test checks for CoraDataLoader
nithinmanoj10 Oct 22, 2023
37e1635
🎉 Started working on HungaryCPDataLoader
nithinmanoj10 Oct 22, 2023
9189c25
➕ Added __init__ for HungaryCPDataLoader
nithinmanoj10 Nov 4, 2023
ba19168
✨ Added HungaryCPDataLoader
nithinmanoj10 Nov 4, 2023
9fd6c9c
📝 Added docstrings for HungaryCPDataLoader
nithinmanoj10 Nov 5, 2023
a60224a
➕ Created STGraphDynamicDataset base class
nithinmanoj10 Nov 6, 2023
2e8a51d
🚧 EnglandCovidDataLoader in progress
nithinmanoj10 Nov 6, 2023
5e01c1a
➕ Added processing methods for EnglandCOVID
nithinmanoj10 Nov 7, 2023
a284736
✅ Done adding EnglandCovidDataLoader
nithinmanoj10 Nov 13, 2023
04d6e46
Merge branch 'main' of https://github.com/bfGraph/STGraph into datase…
nithinmanoj10 Nov 23, 2023
b6fb8d0
🚧 METRLADataLoader in progress
nithinmanoj10 Nov 23, 2023
59ef3f6
Merge branch 'main' of https://github.com/bfGraph/STGraph into datase…
nithinmanoj10 Nov 23, 2023
8198485
✅ Added METRLADataLoader
nithinmanoj10 Nov 23, 2023
1454d83
🚧 MontevideoBusDataLoader in progress
nithinmanoj10 Nov 25, 2023
79927df
✅ MontevideoBusDataLoader added
nithinmanoj10 Nov 25, 2023
8fae359
✅ PedalMeDataLoader added
nithinmanoj10 Nov 26, 2023
20b9429
➕ WikiMathDataLoader added
nithinmanoj10 Dec 26, 2023
0b54da7
✅ WindmillOutputDataLoader added
nithinmanoj10 Dec 28, 2023
bcd1b12
🧪 Test for CoraDataLoader Added
nithinmanoj10 Jan 1, 2024
734da01
🧪 EnglandCOVID and HundaryCP tests added
nithinmanoj10 Jan 3, 2024
3bcab9e
🧪 METRLA Tests Added
nithinmanoj10 Jan 4, 2024
fc748c1
🧪 MonteVideo Tests Added
nithinmanoj10 Jan 5, 2024
3425578
🧪 PedalMe tests added
nithinmanoj10 Jan 6, 2024
ceea4b3
🧪 WikiMath and Windmill Tests Added
nithinmanoj10 Jan 6, 2024
6c1fcee
Merge pull request #92 from bfGraph/v1.1.0
nithinmanoj10 Jan 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ dist
*.svg
*.json
*.npy
dev-stgraph/
.coverage
dev-stgraph/
htmlconv/
122 changes: 0 additions & 122 deletions stgraph/dataset/CoraDataLoader.py

This file was deleted.

49 changes: 30 additions & 19 deletions stgraph/dataset/HungaryCPDataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,70 +2,81 @@
import json
from rich.console import Console
import numpy as np

console = Console()


class HungaryCPDataLoader:
def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph = False) -> None:
def __init__(
self,
folder_name,
dataset_name,
lags,
cutoff_time,
verbose: bool = False,
for_stgraph=False,
) -> None:
self.name = dataset_name
self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json'
self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json"
self._verbose = verbose
self.for_stgraph = for_stgraph
self.lags = lags

self._load_dataset()
self.total_timestamps = min(len(self._dataset["FX"]), cutoff_time)

self._get_num_nodes()
self._get_num_edges()
self._get_edges()
self._get_edge_weights()
self._get_targets_and_features()

def _load_dataset(self):
if os.path.exists(self._local_path):
dataset_file = open(self._local_path)
self._dataset = json.load(dataset_file)
if self._verbose:
console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json')
console.log(
f"Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json"
)
else:
console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset')
console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset")
quit()

def _get_num_nodes(self):
node_set = set()
max_node_id = 0
for edge in self._dataset["edges"]:
node_set.add(edge[0])
node_set.add(edge[1])
max_node_id = max(max_node_id, edge[0], edge[1])

assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous"
self.num_nodes = len(node_set)

def _get_num_edges(self):
self.num_edges = len(self._dataset["edges"])

def _get_edges(self):
if self.for_stgraph:
self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]]
else:
self._edge_list = np.array(self._dataset["edges"]).T

def _get_edge_weights(self):
self._edge_weights = np.ones(self.num_edges)

def _get_targets_and_features(self):
stacked_target = np.array(self._dataset["FX"])
self._all_targets = np.array([
stacked_target[i, :].T
for i in range(stacked_target.shape[0])
])

self._all_targets = np.array(
[stacked_target[i, :].T for i in range(stacked_target.shape[0])]
)

def get_edges(self):
return self._edge_list

def get_edge_weights(self):
return self._edge_weights

def get_all_targets(self):
return self._all_targets
return self._all_targets
75 changes: 44 additions & 31 deletions stgraph/dataset/METRLADataLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,32 @@
import json
from rich.console import Console
import numpy as np

console = Console()
import torch

from rich import inspect


class METRLADataLoader:
def __init__(self , folder_name, dataset_name, num_timesteps_in, num_timesteps_out, cutoff_time, verbose: bool = False, for_stgraph: bool = False):
def __init__(
self,
folder_name,
dataset_name,
num_timesteps_in,
num_timesteps_out,
cutoff_time,
verbose: bool = False,
for_stgraph: bool = False,
):
self.name = dataset_name
self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json'
self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json"
self._verbose = verbose
self.for_stgraph = for_stgraph

self.num_timesteps_in = num_timesteps_in
self.num_timesteps_out = num_timesteps_out

self._load_dataset()
self.total_timestamps = min(self._dataset["time_periods"], cutoff_time)

Expand All @@ -25,61 +36,61 @@ def __init__(self , folder_name, dataset_name, num_timesteps_in, num_timesteps_o
self._get_edges()
self._get_edge_weights()
self._get_targets_and_features()

def _load_dataset(self):
# loading the dataset locally
if os.path.exists(self._local_path):
dataset_file = open(self._local_path)
self._dataset = json.load(dataset_file)
if self._verbose:
console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/')
console.log(f"Loading [cyan]{self.name}[/cyan] dataset from dataset/")
else:
console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset')
quit()
console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset")
quit()

def _get_num_nodes(self):
node_set = set()
max_node_id = 0
for edge in self._dataset["edges"]:
node_set.add(edge[0])
node_set.add(edge[1])
max_node_id = max(max_node_id, edge[0], edge[1])

assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous"
self.num_nodes = len(node_set)

def _get_num_edges(self):
self.num_edges = len(self._dataset["edges"])

def _get_edges(self):
if self.for_stgraph:
self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]]
else:
self._edge_list = np.array(self._dataset["edges"]).T
self._edge_list = np.array(self._dataset["edges"]).T

# TODO: We are sorting the edge weights accordingly, but are we doing
# the same for edges in the edge list
def _get_edge_weights(self):
if self.for_stgraph:
edges = self._dataset["edges"]
edge_weights = self._dataset["weights"]
comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))]
comb_edge_list = [
(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))
]
comb_edge_list.sort(key=lambda x: (x[1], x[0]))
self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list])
else:
self._edge_weights = np.array(self._dataset["weights"])
self._edge_weights = np.array(self._dataset["weights"])

def _get_targets_and_features(self):
X = []
for timestamp in range(self._dataset['time_periods']):

for timestamp in range(self._dataset["time_periods"]):
if timestamp < self.total_timestamps:
X.append(self._dataset[str(timestamp)])

X = np.array(X)
X = X.transpose(
(1, 2, 0)
)
X = X.transpose((1, 2, 0))
X = X.astype(np.float32)

# Normalise as in DCRNN paper (via Z-Score Method)
Expand All @@ -89,12 +100,14 @@ def _get_targets_and_features(self):
X = X / stds.reshape(1, -1, 1)

X = torch.from_numpy(X)

inspect(X)

indices = [
(i, i + (self.num_timesteps_in + self.num_timesteps_out))
for i in range(X.shape[2] - (self.num_timesteps_in + self.num_timesteps_out) + 1)
for i in range(
X.shape[2] - (self.num_timesteps_in + self.num_timesteps_out) + 1
)
]

# Generate observations
Expand All @@ -110,15 +123,15 @@ def _get_targets_and_features(self):

self._all_features = np.array(features)
self._all_targets = np.array(target)

def get_edges(self):
return self._edge_list

def get_edge_weights(self):
return self._edge_weights

def get_all_targets(self):
return self._all_targets

def get_all_features(self):
return self._all_features
return self._all_features
Loading