diff --git a/.gitignore b/.gitignore index d5835c5c..99855ff3 100644 --- a/.gitignore +++ b/.gitignore @@ -25,4 +25,6 @@ dist *.svg *.json *.npy -dev-stgraph/ \ No newline at end of file +.coverage +dev-stgraph/ +htmlconv/ \ No newline at end of file diff --git a/stgraph/dataset/CoraDataLoader.py b/stgraph/dataset/CoraDataLoader.py deleted file mode 100644 index b9dbe00f..00000000 --- a/stgraph/dataset/CoraDataLoader.py +++ /dev/null @@ -1,122 +0,0 @@ -import os -import json -import urllib.request -import time -import random - -import numpy as np - -from rich import inspect -from rich.pretty import pprint -from rich.progress import track -from rich.console import Console - -console = Console() - -class CoraDataLoader: - def __init__(self, verbose:bool = False, split=0.75) -> None: - self.name = "Cora" - self.num_nodes = 0 - self.num_edges = 0 - - self._train_split = split - self._test_split = 1-split - - self._local_file_path = f"../../dataset/cora/cora.json" - self._url_path = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/cora.json" - self._verbose = verbose - - self._load_dataset() - self._get_edge_info() - self._get_targets_and_features() - self._get_graph_attributes() - - self._train_mask = [0] * self.num_nodes - self._test_mask = [0] * self.num_nodes - - self._get_mask_info() - - def _load_dataset(self) -> None: - - if self._is_local_exists(): - # loading the dataset from the local folder - if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset locally') - with open(self._local_file_path) as dataset_json: - self._dataset = json.load(dataset_json) - else: - # loading the dataset by downloading them online - if self._verbose: - console.log(f'Downloading [cyan]{self.name}[/cyan] dataset') - self._dataset = json.loads(urllib.request.urlopen(self._url_path).read()) - - def _get_edge_info(self): - edges = np.array(self._dataset["edges"]) - edge_list = [] - for i in range(len(edges)): - edge = edges[i] - edge_list.append((edge[0], edge[1])) - - self._edge_list = edge_list - - def _get_targets_and_features(self): - self._all_features = np.array(self._dataset["features"]) - self._all_targets = np.array(self._dataset["labels"]).T - - def _get_mask_info(self): - train_len = int(self.num_nodes * self._train_split) - - for i in range(0, train_len): - self._train_mask[i] = 1 - - random.shuffle(self._train_mask) - - for i in range(len(self._train_mask)): - if self._train_mask[i] == 0: - self._test_mask[i] = 1 - - self._train_mask = np.array(self._train_mask) - self._test_mask = np.array(self._test_mask) - - def get_edges(self) -> np.ndarray: - return self._edge_list - - def get_all_features(self) -> np.ndarray: - return self._all_features - - def get_all_targets(self) -> np.ndarray: - return self._all_targets - - def get_train_mask(self): - return self._train_mask - - def get_test_mask(self): - return self._test_mask - - def get_train_features(self) -> np.ndarray: - train_range = int(len(self._all_features) * self.split) - return self._all_features[:train_range] - - def get_train_targets(self) -> np.ndarray: - train_range = int(len(self._all_targets) * self.split) - return self._all_targets[:train_range] - - def get_test_features(self) -> np.ndarray: - test_range = int(len(self._all_features) * self.split) - return self._all_features[test_range:] - - def get_test_targets(self) -> np.ndarray: - test_range = int(len(self._all_targets) * self.split) - return self._all_targets[test_range:] - - def _get_graph_attributes(self): - node_set = set() - for edge in self._edge_list: - node_set.add(edge[0]) - node_set.add(edge[1]) - - self.num_nodes = len(node_set) - self.num_edges = len(self._edge_list) - - def _is_local_exists(self) -> bool: - return os.path.exists(self._local_file_path) \ No newline at end of file diff --git a/stgraph/dataset/HungaryCPDataLoader.py b/stgraph/dataset/HungaryCPDataLoader.py index e57da723..63824eb6 100644 --- a/stgraph/dataset/HungaryCPDataLoader.py +++ b/stgraph/dataset/HungaryCPDataLoader.py @@ -2,35 +2,47 @@ import json from rich.console import Console import numpy as np + console = Console() + class HungaryCPDataLoader: - def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph = False) -> None: + def __init__( + self, + folder_name, + dataset_name, + lags, + cutoff_time, + verbose: bool = False, + for_stgraph=False, + ) -> None: self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' + self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" self._verbose = verbose self.for_stgraph = for_stgraph self.lags = lags - + self._load_dataset() self.total_timestamps = min(len(self._dataset["FX"]), cutoff_time) - + self._get_num_nodes() self._get_num_edges() self._get_edges() self._get_edge_weights() self._get_targets_and_features() - + def _load_dataset(self): if os.path.exists(self._local_path): dataset_file = open(self._local_path) self._dataset = json.load(dataset_file) if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json') + console.log( + f"Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json" + ) else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') + console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") quit() - + def _get_num_nodes(self): node_set = set() max_node_id = 0 @@ -38,29 +50,28 @@ def _get_num_nodes(self): node_set.add(edge[0]) node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" self.num_nodes = len(node_set) - + def _get_num_edges(self): self.num_edges = len(self._dataset["edges"]) - + def _get_edges(self): if self.for_stgraph: self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] else: self._edge_list = np.array(self._dataset["edges"]).T - + def _get_edge_weights(self): self._edge_weights = np.ones(self.num_edges) - + def _get_targets_and_features(self): stacked_target = np.array(self._dataset["FX"]) - self._all_targets = np.array([ - stacked_target[i, :].T - for i in range(stacked_target.shape[0]) - ]) - + self._all_targets = np.array( + [stacked_target[i, :].T for i in range(stacked_target.shape[0])] + ) + def get_edges(self): return self._edge_list @@ -68,4 +79,4 @@ def get_edge_weights(self): return self._edge_weights def get_all_targets(self): - return self._all_targets \ No newline at end of file + return self._all_targets diff --git a/stgraph/dataset/METRLADataLoader.py b/stgraph/dataset/METRLADataLoader.py index 7f88a007..390b39c0 100644 --- a/stgraph/dataset/METRLADataLoader.py +++ b/stgraph/dataset/METRLADataLoader.py @@ -2,21 +2,32 @@ import json from rich.console import Console import numpy as np + console = Console() import torch from rich import inspect + class METRLADataLoader: - def __init__(self , folder_name, dataset_name, num_timesteps_in, num_timesteps_out, cutoff_time, verbose: bool = False, for_stgraph: bool = False): + def __init__( + self, + folder_name, + dataset_name, + num_timesteps_in, + num_timesteps_out, + cutoff_time, + verbose: bool = False, + for_stgraph: bool = False, + ): self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' + self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" self._verbose = verbose self.for_stgraph = for_stgraph - + self.num_timesteps_in = num_timesteps_in self.num_timesteps_out = num_timesteps_out - + self._load_dataset() self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) @@ -25,18 +36,18 @@ def __init__(self , folder_name, dataset_name, num_timesteps_in, num_timesteps_o self._get_edges() self._get_edge_weights() self._get_targets_and_features() - + def _load_dataset(self): # loading the dataset locally if os.path.exists(self._local_path): dataset_file = open(self._local_path) self._dataset = json.load(dataset_file) if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/') + console.log(f"Loading [cyan]{self.name}[/cyan] dataset from dataset/") else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') - quit() - + console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") + quit() + def _get_num_nodes(self): node_set = set() max_node_id = 0 @@ -44,42 +55,42 @@ def _get_num_nodes(self): node_set.add(edge[0]) node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" self.num_nodes = len(node_set) - + def _get_num_edges(self): self.num_edges = len(self._dataset["edges"]) - + def _get_edges(self): if self.for_stgraph: self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] else: - self._edge_list = np.array(self._dataset["edges"]).T - + self._edge_list = np.array(self._dataset["edges"]).T + # TODO: We are sorting the edge weights accordingly, but are we doing # the same for edges in the edge list def _get_edge_weights(self): if self.for_stgraph: edges = self._dataset["edges"] edge_weights = self._dataset["weights"] - comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) else: - self._edge_weights = np.array(self._dataset["weights"]) - + self._edge_weights = np.array(self._dataset["weights"]) + def _get_targets_and_features(self): X = [] - - for timestamp in range(self._dataset['time_periods']): + + for timestamp in range(self._dataset["time_periods"]): if timestamp < self.total_timestamps: X.append(self._dataset[str(timestamp)]) - + X = np.array(X) - X = X.transpose( - (1, 2, 0) - ) + X = X.transpose((1, 2, 0)) X = X.astype(np.float32) # Normalise as in DCRNN paper (via Z-Score Method) @@ -89,12 +100,14 @@ def _get_targets_and_features(self): X = X / stds.reshape(1, -1, 1) X = torch.from_numpy(X) - + inspect(X) - + indices = [ (i, i + (self.num_timesteps_in + self.num_timesteps_out)) - for i in range(X.shape[2] - (self.num_timesteps_in + self.num_timesteps_out) + 1) + for i in range( + X.shape[2] - (self.num_timesteps_in + self.num_timesteps_out) + 1 + ) ] # Generate observations @@ -110,15 +123,15 @@ def _get_targets_and_features(self): self._all_features = np.array(features) self._all_targets = np.array(target) - + def get_edges(self): return self._edge_list - + def get_edge_weights(self): return self._edge_weights - + def get_all_targets(self): return self._all_targets - + def get_all_features(self): - return self._all_features \ No newline at end of file + return self._all_features diff --git a/stgraph/dataset/PedalMeDataLoader.py b/stgraph/dataset/PedalMeDataLoader.py index f9fec7d5..75bef120 100644 --- a/stgraph/dataset/PedalMeDataLoader.py +++ b/stgraph/dataset/PedalMeDataLoader.py @@ -2,38 +2,50 @@ import json from rich.console import Console import numpy as np + console = Console() from rich import inspect + class PedalMeDataLoader: - def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph: bool = False): + def __init__( + self, + folder_name, + dataset_name, + lags, + cutoff_time, + verbose: bool = False, + for_stgraph: bool = False, + ): self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' + self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" self._verbose = verbose self.for_stgraph = for_stgraph self.lags = lags - + self._load_dataset() self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - + self._get_num_nodes() self._get_num_edges() self._get_edges() self._get_edge_weights() self._get_targets_and_features() - + def _load_dataset(self): if os.path.exists(self._local_path): dataset_file = open(self._local_path) self._dataset = json.load(dataset_file) - + if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json') + console.log( + f"Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json" + ) else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') + console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") quit() - + def _get_num_nodes(self): node_set = set() max_node_id = 0 @@ -41,43 +53,42 @@ def _get_num_nodes(self): node_set.add(edge[0]) node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" self.num_nodes = len(node_set) - + def _get_num_edges(self): self.num_edges = len(self._dataset["edges"]) - + def _get_edges(self): if self.for_stgraph: self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] else: self._edge_list = np.array(self._dataset["edges"]).T - + def _get_edge_weights(self): if self.for_stgraph: edges = self._dataset["edges"] edge_weights = self._dataset["weights"] - comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) else: self._edge_weights = np.array(self._dataset["weights"]) - + def _get_targets_and_features(self): targets = [] for time in range(self.total_timestamps): targets.append(np.array(self._dataset[str(time)])) - + stacked_target = np.stack(targets) self._all_targets = np.array( - [ - stacked_target[i, :].T - for i in range(stacked_target.shape[0]) - ] + [stacked_target[i, :].T for i in range(stacked_target.shape[0])] ) - + def get_edges(self): return self._edge_list @@ -85,4 +96,4 @@ def get_edge_weights(self): return self._edge_weights def get_all_targets(self): - return self._all_targets \ No newline at end of file + return self._all_targets diff --git a/stgraph/dataset/STGraphDataset.py b/stgraph/dataset/STGraphDataset.py new file mode 100644 index 00000000..319a464e --- /dev/null +++ b/stgraph/dataset/STGraphDataset.py @@ -0,0 +1,215 @@ +"""Base class for all STGraph dataset loaders""" + +import os +import json +import ssl +import urllib.request + +from abc import ABC, abstractmethod +from rich.console import Console + +console = Console() + + +class STGraphDataset(ABC): + r"""Abstract base class for graph dataset loaders + + The dataset handling is done as follows + + 1. Checks whether the dataset is present in cache. + 2. If not present in the cache, it downloads it from the URL. + 3. It then saves the downloaded file inside the cache. + 4. Incase it is present inside the cache, it directly loads it from there + 5. Dataset specific graph processing is then done + + Attributes + ---------- + + name : str + The name of the dataset + gdata : dict + Meta data associated with the dataset + + _dataset : dict + The loaded graph dataset + _url : str + The URL from where the dataset is downloaded online + _verbose : bool + Flag to control whether to display verbose info + _cache_folder : str + Folder inside ~/.stgraph where the dataset cache is stored + _cache_file_type : str + The file type used for storing the cached dataset + + Methods + ------- + + _has_dataset_cache() + Checks if the dataset is stored in cache + + _get_cache_file_path() + Returns the absolute path of the cached dataset file + + _init_graph_data() + Initialises the ``gdata`` attribute with all necessary meta data + + _process_dataset() + Processes the dataset to be used by STGraph + + _download_dataset() + Downloads the dataset using the URL + + _save_dataset() + Saves the dataset to cache + + _load_dataset() + Loads the dataset from cache + """ + + def __init__(self) -> None: + self.name = "" + self.gdata = {} + + self._dataset = {} + self._url = "" + self._verbose = False + self._cache_folder = "/dataset_cache/" + self._cache_file_type = "json" + + def _has_dataset_cache(self) -> bool: + r"""Checks if the dataset is stored in cache + + This private method checks whether the graph dataset cache file exists + in the dataset cache folder. The cache .json file is found in the following + directory ``~/.stgraph/dataset_cache/. + + Returns + ------- + bool + ``True`` if the cache file exists, else ``False`` + + Notes + ----- + The cache file is usually stored as a json file named as ``dataset_name.json`` and is stored + inside the ``~/.stgraph/dataset_cache/``. Incase the directory does not exists, it + is created by this method. + + This private method is intended for internal use within the class and should not be + called directly from outside the class. + + Example + ------- + + .. code-block:: python + + if self._has_dataset_cache(): + # The dataset is cached, continue cached operations + else: + # The dataset is not cached, continue load and save operations + """ + + user_home_dir = os.path.expanduser("~") + stgraph_dir = user_home_dir + "/.stgraph" + cache_dir = stgraph_dir + self._cache_folder + + if os.path.exists(stgraph_dir) == False: + os.system("mkdir " + stgraph_dir) + + if os.path.exists(cache_dir) == False: + os.system("mkdir " + cache_dir) + + cache_file_name = self.name + "." + self._cache_file_type + + return os.path.exists(cache_dir + cache_file_name) + + def _get_cache_file_path(self) -> str: + r"""Returns the absolute path of the cached dataset file + + Returns + ------- + str + The absolute path of the cached dataset file + """ + + user_home_dir = os.path.expanduser("~") + stgraph_dir = user_home_dir + "/.stgraph" + cache_dir = stgraph_dir + self._cache_folder + cache_file_name = self.name + "." + self._cache_file_type + + return cache_dir + cache_file_name + + def _delete_cached_dataset(self) -> None: + r"""Deletes the cached dataset file""" + + os.remove(self._get_cache_file_path()) + + @abstractmethod + def _init_graph_data(self) -> None: + r"""Initialises the ``gdata`` attribute with all necessary meta data + + This is an abstract method that is implemented by ``STGraphStaticDataset``. + The meta data is initialised based on the type of the graph dataset. The values + are calculated as key-value pairs by the respective dataloaders when they + are initialised. + """ + pass + + @abstractmethod + def _process_dataset(self) -> None: + r"""Processes the dataset to be used by STGraph + + This is an abstract method that is to be implemented by each dataset loader. + The implementation in specific to the nature of the dataset itself. The dataset is + processed in such a way that it can be smoothly used within STGraph. + """ + pass + + def _download_dataset(self) -> None: + r"""Downloads the dataset using the URL + + Downloads the dataset files from the URL set by default for each data loader or + by one provided by the user. If verbose mode is enabled, it displays download status. + """ + if self._verbose: + console.log( + f"[cyan bold]{self.name}[/cyan bold] not present in cache. Downloading right now." + ) + + context = ssl._create_unverified_context() + self._dataset = json.loads( + urllib.request.urlopen(self._url, context=context).read() + ) + + if self._verbose: + console.log(f"[cyan bold]{self.name}[/cyan bold] download complete.") + + def _save_dataset(self) -> None: + r"""Saves the dataset to cache + + Saves the downloaded dataset file to the cache folder. If verbose mode is enabled, + it displays the save information. + """ + with open(self._get_cache_file_path(), "w") as cache_file: + json.dump(self._dataset, cache_file) + + if self._verbose: + console.log( + f"[cyan bold]{self.name}[/cyan bold] dataset saved to cache" + ) + + def _load_dataset(self) -> None: + r"""Loads the dataset from cache + + Loads the caches dataset json file as a python dictionary. If verbose mode is enabled, + it displays the loading status. + """ + if self._verbose: + console.log(f"Loading [cyan bold]{self.name}[/cyan bold] from cache") + + with open(self._get_cache_file_path()) as cache_file: + self._dataset = json.load(cache_file) + + if self._verbose: + console.log( + f"Successfully loaded [cyan bold]{self.name}[/cyan bold] from cache" + ) diff --git a/stgraph/dataset/__init__.py b/stgraph/dataset/__init__.py index 5d49c59a..9ae14678 100644 --- a/stgraph/dataset/__init__.py +++ b/stgraph/dataset/__init__.py @@ -1 +1,13 @@ -'''Dataset loader provided by STGraph''' \ No newline at end of file +"""Dataset loader provided by STGraph""" + +# TODO: Change this accordingly if needed +from stgraph.dataset.static.CoraDataLoader import CoraDataLoader + +from stgraph.dataset.temporal.HungaryCPDataLoader import HungaryCPDataLoader +from stgraph.dataset.temporal.METRLADataLoader import METRLADataLoader +from stgraph.dataset.temporal.MontevideoBusDataLoader import MontevideoBusDataLoader +from stgraph.dataset.temporal.PedalMeDataLoader import PedalMeDataLoader +from stgraph.dataset.temporal.WikiMathDataLoader import WikiMathDataLoader +from stgraph.dataset.temporal.WindmillOutputDataLoader import WindmillOutputDataLoader + +from stgraph.dataset.dynamic.EnglandCovidDataLoader import EnglandCovidDataLoader diff --git a/stgraph/dataset/dynamic/EnglandCovidDataLoader.py b/stgraph/dataset/dynamic/EnglandCovidDataLoader.py new file mode 100644 index 00000000..85130ffb --- /dev/null +++ b/stgraph/dataset/dynamic/EnglandCovidDataLoader.py @@ -0,0 +1,210 @@ +import numpy as np + +from stgraph.dataset.dynamic.STGraphDynamicDataset import STGraphDynamicDataset + + +class EnglandCovidDataLoader(STGraphDynamicDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 8, + cutoff_time: int = None, + redownload: bool = False, + ) -> None: + r"""Dynamic dataset tracking COVID-19 cases in England's NUTS3 regions + + This dataset captures the interplay between COVID-19 cases and mobility + in England's NUTS3 regions from March 3rd to May 12th. It is a directed + and weighted graph that offers daily case count and movement of people + between each region through node and edge features respectively. + + This class provides functionality for loading, processing, and accessing the England + Covid dataset for use in deep learning tasks such as predicting the COVID cases + in a region. + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import EnglandCovidDataLoader + + eng_covid = EnglandCovidDataLoader(verbose=True) + num_nodes_dict = eng_covid.gdata["num_nodes"] + num_edges_dict = eng_covid.gdata["num_edges"] + total_timestamps = eng_covid.gdata["total_timestamps"] + + edge_list = eng_covid.get_edges() + edge_weights = eng_covid.get_edge_weights() + feats = eng_covid.get_all_features() + targets = eng_covid.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 8) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset for each timestamp + _edge_weights : list + List of edge weights for each timestamp + _all_features : list + Node features for each timestamp minus lags + _all_targets : list + Node target value for each timestamp minus lags + """ + super().__init__() + + self.name = "England_COVID" + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_targets_and_features() + self._set_edge_info() + self._presort_edge_weights() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + self._dataset["time_periods"], self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = self._dataset["time_periods"] + + def _set_targets_and_features(self): + r"""Calculates and sets the target and feature attributes""" + stacked_target = np.array(self._dataset["y"]) + standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( + np.std(stacked_target, axis=0) + 10**-10 + ) + + self._all_features = [ + standardized_target[i : i + self._lags, :].T + for i in range(self.gdata["total_timestamps"] - self._lags) + ] + self._all_targets = [ + standardized_target[i + self._lags, :].T + for i in range(self.gdata["total_timestamps"] - self._lags) + ] + + def _set_edge_info(self): + r"""Sets edge info such as edge list and edge weights""" + self._edge_list = [] + self._edge_weights = [] + + for time in range(self.gdata["total_timestamps"]): + time_edge_list = [] + time_edge_weights = [] + + for edge in self._dataset["edge_mapping"]["edge_index"][str(time)]: + time_edge_list.append((edge[0], edge[1])) + + for weight in self._dataset["edge_mapping"]["edge_weight"][str(time)]: + time_edge_weights.append(weight) + + self._edge_list.append(time_edge_list) + self._edge_weights.append(time_edge_weights) + self.gdata["num_edges"][str(time)] = len(time_edge_list) + self.gdata["num_nodes"][str(time)] = len( + {node for edge in time_edge_list for node in edge} + ) + + def _presort_edge_weights(self): + r""" + Presorting edges according to (dest,src) since that is how eids are formed + allowing forward and backward kernel to access edge weights + """ + final_edges_lst = [] + final_edge_weights_lst = [] + + for i in range(len(self._edge_list)): + src_list = [edge[0] for edge in self._edge_list[i]] + dst_list = [edge[1] for edge in self._edge_list[i]] + weights = self._edge_weights[i] + + edge_info_list = [] + sorted_edge_weights_lst = [] + + for j in range(len(weights)): + edge_info = (src_list[j], dst_list[j], weights[j]) + edge_info_list.append(edge_info) + + # since it has to be sorted according to the reverse order + sorted_edge_info_list = sorted( + edge_info_list, key=lambda element: (element[1], element[0]) + ) + + time_edge = [] + + for edge in sorted_edge_info_list: + time_edge.append((edge[0], edge[1])) + sorted_edge_weights_lst.append(edge[2]) + + final_edges_lst.append(time_edge) + final_edge_weights_lst.append(np.array(sorted_edge_weights_lst)) + + self._edge_list = final_edges_lst + self._edge_weights = final_edge_weights_lst + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_features(self): + r"""Returns the features for each timestamp""" + return self._all_features + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets diff --git a/stgraph/dataset/dynamic/STGraphDynamicDataset.py b/stgraph/dataset/dynamic/STGraphDynamicDataset.py new file mode 100644 index 00000000..de1dc9f4 --- /dev/null +++ b/stgraph/dataset/dynamic/STGraphDynamicDataset.py @@ -0,0 +1,27 @@ +"""Base class for all STGraph dynamic graph datasets""" + +from stgraph.dataset.STGraphDataset import STGraphDataset + + +class STGraphDynamicDataset(STGraphDataset): + r"""Base class for dynamic graph datasets + + This class is a subclass of ``STGraphDataset`` and provides the base structure for + handling dynamic graph datasets.""" + + def __init__(self) -> None: + super().__init__() + + self._init_graph_data() + + def _init_graph_data(self) -> dict: + r"""Initialize graph meta data for a dynamic dataset. + + The ``num_nodes``, ``num_edges``, ``total_timestamps`` keys are set to value 0 + """ + self.gdata["num_nodes"] = {} + self.gdata["num_edges"] = {} + self.gdata["total_timestamps"] = 0 + + self._lags = 0 + self._cutoff_time = None diff --git a/stgraph/dataset/dynamic/__init__.py b/stgraph/dataset/dynamic/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stgraph/dataset/static/CoraDataLoader.py b/stgraph/dataset/static/CoraDataLoader.py new file mode 100644 index 00000000..0e5b310f --- /dev/null +++ b/stgraph/dataset/static/CoraDataLoader.py @@ -0,0 +1,139 @@ +import random + +import numpy as np +from rich.console import Console + +from stgraph.dataset.static.STGraphStaticDataset import STGraphStaticDataset + + +console = Console() + + +class CoraDataLoader(STGraphStaticDataset): + def __init__(self, verbose=False, url=None, redownload=False) -> None: + r"""Citation network consisting of scientific publications + + The Cora dataset consists of 2708 scientific publications classified into one of seven classes. + The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued + word vector indicating the absence/presence of the corresponding word from the dictionary. + The dictionary consists of 1433 unique words. + + This class provides functionality for loading, processing, and accessing the Cora dataset + for use in deep learning tasks such as graph-based node classification. + + .. list-table:: gdata + :widths: 25 25 25 25 + :header-rows: 1 + + * - num_nodes + - num_edges + - num_feats + - num_classes + * - 2708 + - 10556 + - 1433 + - 7 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import CoraDataLoader + + cora = CoraDataLoader() + num_nodes = cora.gdata["num_nodes"] + edge_list = cora.get_edges() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _edge_list : np.ndarray + The edge list of the graph dataset + _all_features : np.ndarray + Numpy array of the node features + _all_targets : np.ndarray + Numpy array of the node target features + """ + super().__init__() + + self.name = "Cora" + self._verbose = verbose + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/cora.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + r"""Process the Cora dataset. + + Calls private methods to extract edge list, node features, target classes + and the train/test binary mask array. + """ + self._set_edge_info() + self._set_targets_and_features() + self._set_graph_attributes() + + def _set_edge_info(self) -> None: + r"""Extract edge information from the dataset""" + edges = np.array(self._dataset["edges"]) + edge_list = [] + for i in range(len(edges)): + edge = edges[i] + edge_list.append((edge[0], edge[1])) + + self._edge_list = edge_list + + def _set_targets_and_features(self): + r"""Extract targets and features from the dataset.""" + self._all_features = np.array(self._dataset["features"]) + self._all_targets = np.array(self._dataset["labels"]).T + + def _set_graph_attributes(self): + r"""Calculates and stores graph meta data inside ``gdata``""" + node_set = set() + for edge in self._edge_list: + node_set.add(edge[0]) + node_set.add(edge[1]) + + self.gdata["num_nodes"] = len(node_set) + self.gdata["num_edges"] = len(self._edge_list) + self.gdata["num_feats"] = len(self._all_features[0]) + self.gdata["num_classes"] = len(set(self._all_targets)) + + def get_edges(self) -> list: + r"""Get the edge list.""" + return self._edge_list + + def get_all_features(self) -> np.ndarray: + r"""Get all features.""" + return self._all_features + + def get_all_targets(self) -> np.ndarray: + r"""Get all targets.""" + return self._all_targets diff --git a/stgraph/dataset/static/STGraphStaticDataset.py b/stgraph/dataset/static/STGraphStaticDataset.py new file mode 100644 index 00000000..94edbd65 --- /dev/null +++ b/stgraph/dataset/static/STGraphStaticDataset.py @@ -0,0 +1,29 @@ +"""Base class for all STGraph static graph datasets""" + +from rich.console import Console + +from stgraph.dataset.STGraphDataset import STGraphDataset + + +console = Console() + + +class STGraphStaticDataset(STGraphDataset): + r"""Base class for static graph datasets + + This class is a subclass of ``STGraphDataset`` and provides the base structure for + handling static graph datasets. + """ + + def __init__(self) -> None: + super().__init__() + + self._init_graph_data() + + def _init_graph_data(self) -> dict: + r"""Initialize graph meta data for a static dataset. + + The ``num_nodes`` and ``num_edges`` keys are set to value 0 + """ + self.gdata["num_nodes"] = 0 + self.gdata["num_edges"] = 0 diff --git a/stgraph/dataset/static/__init__.py b/stgraph/dataset/static/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/stgraph/dataset/temporal/HungaryCPDataLoader.py b/stgraph/dataset/temporal/HungaryCPDataLoader.py new file mode 100644 index 00000000..3df40635 --- /dev/null +++ b/stgraph/dataset/temporal/HungaryCPDataLoader.py @@ -0,0 +1,178 @@ +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class HungaryCPDataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 4, + cutoff_time: int = None, + redownload: bool = False, + ) -> None: + r"""County level chicken pox cases in Hungary + + This dataset comprises information on weekly occurrences of chickenpox + in Hungary from 2005 to 2015. The graph structure is static with nodes + representing the counties and edges are neighbourhoods between them. + Vertex features are lagged weekly counts of the chickenpox cases. + + This class provides functionality for loading, processing, and accessing the Hungary + Chickenpox dataset for use in deep learning tasks such as County level case count prediction. + + .. list-table:: gdata + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 20 + - 102 + - 521 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import HungaryCPDataLoader + + hungary = HungaryCPDataLoader(verbose=True) + num_nodes = hungary.gdata["num_nodes"] + edge_list = hungary.get_edges() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 4) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_targets : numpy.ndarray + Numpy array of the node target value + """ + + super().__init__() + + if type(lags) != int: + raise TypeError("lags must be of type int") + if lags < 0: + raise ValueError("lags must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + + self.name = "Hungary_Chickenpox" + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/HungaryCP.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_targets_and_features() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + len(self._dataset["FX"]), self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = len(self._dataset["FX"]) + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + self._edge_weights = np.ones(self.gdata["num_edges"]) + + def _set_targets_and_features(self): + r"""Calculates and sets the target and feature attributes""" + stacked_target = np.array(self._dataset["FX"]) + + self._all_targets = [ + stacked_target[i + self._lags, :].T + for i in range(self.gdata["total_timestamps"] - self._lags) + ] + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets diff --git a/stgraph/dataset/temporal/METRLADataLoader.py b/stgraph/dataset/temporal/METRLADataLoader.py new file mode 100644 index 00000000..734202f1 --- /dev/null +++ b/stgraph/dataset/temporal/METRLADataLoader.py @@ -0,0 +1,234 @@ +import torch +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class METRLADataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = True, + url: str = None, + num_timesteps_in: int = 12, + num_timesteps_out: int = 12, + cutoff_time: int = None, + redownload: bool = False, + ): + r"""A traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions. + + A dataset for predicting traffic patterns in the Los Angeles Metropolitan area, + comprising traffic data obtained from 207 loop detectors on highways in Los Angeles County. + The dataset includes aggregated 5-minute interval readings spanning a four-month + period from March 2012 to June 2012. + + This class provides functionality for loading, processing, and accessing the METRLA + dataset for use in deep learning tasks such as traffic forecasting. + + .. list-table:: gdata + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 207 + - 1722 + - 100 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import METRLADataLoader + + metrla = METRLADataLoader(verbose=True) + num_nodes = metrla.gdata["num_nodes"] + num_edges = metrla.gdata["num_edges"] + total_timestamps = metrla.gdata["total_timestamps"] + + edge_list = metrla.get_edges() + edge_weights = metrla.get_edge_weights() + feats = metrla.get_all_features() + targets = metrla.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + num_timesteps_in : int, optional + The number of timesteps the sequence model sees (default is 12) + num_timesteps_out : int, optional + The number of timesteps the sequence model has to predict (default is 12) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _num_timesteps_in : int + The number of timesteps the sequence model sees + _num_timesteps_out : int + The number of timesteps the sequence model has to predict + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_features : numpy.ndarray + Numpy array of the node feature value + _all_targets : numpy.ndarray + Numpy array of the node target value + """ + + super().__init__() + + if type(num_timesteps_in) != int: + raise TypeError("num_timesteps_in must be of type int") + if num_timesteps_in < 0: + raise ValueError("num_timesteps_in must be a positive integer") + + if type(num_timesteps_out) != int: + raise TypeError("num_timesteps_out must be of type int") + if num_timesteps_out < 0: + raise ValueError("num_timesteps_out must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + + self.name = "METRLA" + self._verbose = verbose + self._num_timesteps_in = num_timesteps_in + self._num_timesteps_out = num_timesteps_out + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/METRLA.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_targets_and_features() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + self._dataset["time_periods"], self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = self._dataset["time_periods"] + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + edges = self._dataset["edges"] + edge_weights = self._dataset["weights"] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] + comb_edge_list.sort(key=lambda x: (x[1], x[0])) + self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) + + def _set_targets_and_features(self): + r"""Calculates and sets the target and feature attributes""" + X = [] + + for timestamp in range(self.gdata["total_timestamps"]): + X.append(self._dataset[str(timestamp)]) + + X = np.array(X) + X = X.transpose((1, 2, 0)) + X = X.astype(np.float32) + + # Normalise as in DCRNN paper (via Z-Score Method) + means = np.mean(X, axis=(0, 2)) + X = X - means.reshape(1, -1, 1) + stds = np.std(X, axis=(0, 2)) + X = X / stds.reshape(1, -1, 1) + + X = torch.from_numpy(X) + + indices = [ + (i, i + (self._num_timesteps_in + self._num_timesteps_out)) + for i in range( + X.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1 + ) + ] + + # Generate observations + features, target = [], [] + for i, j in indices: + features.append((X[:, :, i : i + self._num_timesteps_in]).numpy()) + target.append((X[:, 0, i + self._num_timesteps_in : j]).numpy()) + + self._all_features = np.array(features) + self._all_targets = np.array(target) + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets + + def get_all_features(self): + r"""Returns the features for each timestamp""" + return self._all_features diff --git a/stgraph/dataset/temporal/MontevideoBusDataLoader.py b/stgraph/dataset/temporal/MontevideoBusDataLoader.py new file mode 100644 index 00000000..dd6c2e9f --- /dev/null +++ b/stgraph/dataset/temporal/MontevideoBusDataLoader.py @@ -0,0 +1,234 @@ +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class MontevideoBusDataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 4, + cutoff_time: int = None, + redownload: bool = False, + ) -> None: + r"""A dataset of inflow passenger at bus stop level from Montevideo city. + + This dataset compiles hourly passenger inflow data for 11 key bus lines + in Montevideo, Uruguay, during October 2020. Focused on routes to the city + center, it encompasses bus stop vertices, interlinked by edges representing + connections with weights indicating road distances. The target variable + is passenger inflow, sourced from diverse data outlets within Montevideo's + Metropolitan Transportation System (STM). + + This class provides functionality for loading, processing, and accessing the + Montevideo Bus dataset for use in deep learning tasks such as passenger inflow prediction. + + .. list-table:: gdata + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 675 + - 690 + - 744 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import MontevideoBusDataLoader + + monte = MontevideoBusDataLoader(verbose=True) + num_nodes = monte.gdata["num_nodes"] + num_edges = monte.gdata["num_edges"] + total_timestamps = monte.gdata["total_timestamps"] + + edge_list = monte.get_edges() + edge_weights = monte.get_edge_weights() + feats = monte.get_all_features() + targets = monte.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 4) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_targets : numpy.ndarray + Numpy array of the node target value + _all_features : numpy.ndarray + Numpy array of the node feature value + """ + + super().__init__() + + if type(lags) != int: + raise TypeError("lags must be of type int") + if lags < 0: + raise ValueError("lags must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + if cutoff_time != None and cutoff_time <= lags: + raise ValueError("cutoff_time must be greater than lags") + + self.name = "Montevideo_Bus" + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/montevideobus.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_features() + self._set_targets() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + len(self._dataset["nodes"][0]["y"]), self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = len(self._dataset["nodes"][0]["y"]) + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + edges = self._dataset["edges"] + edge_weights = self._dataset["weights"] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] + comb_edge_list.sort(key=lambda x: (x[1], x[0])) + self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) + + def _set_features(self): + r"""Calculates and sets the feature attributes""" + features = [] + + for node in self._dataset["nodes"]: + X = node.get("X") + for feature_var in ["y"]: + features.append( + np.array(X.get(feature_var)[: self.gdata["total_timestamps"]]) + ) + + stacked_features = np.stack(features).T + standardized_features = ( + stacked_features - np.mean(stacked_features, axis=0) + ) / np.std(stacked_features, axis=0) + + self._all_features = np.array( + [ + standardized_features[i : i + self._lags, :].T + for i in range(len(standardized_features) - self._lags) + ] + ) + + def _set_targets(self): + r"""Calculates and sets the target attributes""" + targets = [] + for node in self._dataset["nodes"]: + y = node.get("y")[: self.gdata["total_timestamps"]] + targets.append(np.array(y)) + + stacked_targets = np.stack(targets).T + standardized_targets = ( + stacked_targets - np.mean(stacked_targets, axis=0) + ) / np.std(stacked_targets, axis=0) + + self._all_targets = np.array( + [ + standardized_targets[i + self._lags, :].T + for i in range(len(standardized_targets) - self._lags) + ] + ) + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets + + def get_all_features(self): + r"""Returns the features for each timestamp""" + return self._all_features diff --git a/stgraph/dataset/temporal/PedalMeDataLoader.py b/stgraph/dataset/temporal/PedalMeDataLoader.py new file mode 100644 index 00000000..c149a244 --- /dev/null +++ b/stgraph/dataset/temporal/PedalMeDataLoader.py @@ -0,0 +1,197 @@ +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class PedalMeDataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 4, + cutoff_time: int = None, + redownload: bool = False, + ) -> None: + r"""A dataset of PedalMe Bicycle deliver orders in London. + + This class provides functionality for loading, processing, and accessing the PedalMe + dataset for use in deep learning tasks such as node classification. + + .. list-table:: gdata + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 15 + - 225 + - 36 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import PedalMeDataLoader + + pedal = PedalMeDataLoader(verbose=True) + num_nodes = pedal.gdata["num_nodes"] + num_edges = pedal.gdata["num_edges"] + total_timestamps = pedal.gdata["total_timestamps"] + + edge_list = pedal.get_edges() + edge_weights = pedal.get_edge_weights() + targets = pedal.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 4) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_targets : numpy.ndarray + Numpy array of the node target value + """ + + super().__init__() + + if type(lags) != int: + raise TypeError("lags must be of type int") + if lags < 0: + raise ValueError("lags must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + if cutoff_time != None and cutoff_time <= lags: + raise ValueError("cutoff_time must be greater than lags") + + self.name = "PedalMe" + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/pedalme.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_targets() + self._set_features() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + self._dataset["time_periods"], self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = self._dataset["time_periods"] + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + edges = self._dataset["edges"] + edge_weights = self._dataset["weights"] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] + comb_edge_list.sort(key=lambda x: (x[1], x[0])) + self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) + + def _set_targets(self): + r"""Calculates and sets the target attributes""" + targets = [] + for time in range(self.gdata["total_timestamps"]): + targets.append(np.array(self._dataset[str(time)])) + + stacked_target = np.stack(targets) + + self._all_targets = np.array( + [ + stacked_target[i + self._lags, :].T + for i in range(stacked_target.shape[0] - self._lags) + ] + ) + + def _set_features(self): + # TODO: + pass + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets diff --git a/stgraph/dataset/temporal/STGraphTemporalDataset.py b/stgraph/dataset/temporal/STGraphTemporalDataset.py new file mode 100644 index 00000000..287e4659 --- /dev/null +++ b/stgraph/dataset/temporal/STGraphTemporalDataset.py @@ -0,0 +1,31 @@ +"""Base class for all STGraph temporal graph datasets""" + +from rich.console import Console +from stgraph.dataset.STGraphDataset import STGraphDataset + +console = Console() + + +class STGraphTemporalDataset(STGraphDataset): + r"""Base class for temporal graph datasets + + This class is a subclass of ``STGraphDataset`` and provides the base structure for + handling temporal graph datasets. + """ + + def __init__(self) -> None: + super().__init__() + + self._init_graph_data() + + def _init_graph_data(self) -> dict: + r"""Initialize graph meta data for a temporal dataset. + + The ``num_nodes``, ``num_edges``, ``total_timestamps`` keys are set to value 0 + """ + self.gdata["num_nodes"] = 0 + self.gdata["num_edges"] = 0 + self.gdata["total_timestamps"] = 0 + + self._lags = 0 + self._cutoff_time = None diff --git a/stgraph/dataset/temporal/WikiMathDataLoader.py b/stgraph/dataset/temporal/WikiMathDataLoader.py new file mode 100644 index 00000000..ea972146 --- /dev/null +++ b/stgraph/dataset/temporal/WikiMathDataLoader.py @@ -0,0 +1,200 @@ +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class WikiMathDataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 8, + cutoff_time: int = None, + redownload: bool = False, + ) -> None: + r"""A dataset of vital mathematical articles sourced from Wikipedia. + + The graph dataset is static, with vertices representing Wikipedia pages and + edges representing links. The graph is both directed and weighted, where the weights + indicate the number of links originating from the source page connecting + to the target page. The target is the daily user visits to the Wikipedia pages + between March 16th 2019 and March 15th 2021 which results in 731 periods. + + This class provides functionality for loading, processing, and accessing the Hungary + Chickenpox dataset for use in deep learning tasks such as County level case count prediction. + + .. list-table:: gdata + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 1068 + - 27079 + - 731 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import WikiMathDataLoader + + wiki = WikiMathDataLoader(verbose=True) + num_nodes = wiki.gdata["num_nodes"] + num_edges = wiki.gdata["num_edges"] + total_timestamps = wiki.gdata["total_timestamps"] + + edge_list = wiki.get_edges() + edge_weights = wiki.get_edge_weights() + targets = wiki.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 8) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_targets : numpy.ndarray + Numpy array of the node target value + """ + super().__init__() + + if type(lags) != int: + raise TypeError("lags must be of type int") + if lags < 0: + raise ValueError("lags must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + + # TODO: Add check for cutoff_time <= lags + + self.name = "WikiMath" + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + + if not url: + self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/wikivital_mathematics.json" + else: + self._url = url + + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_targets() + self._set_features() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + self._dataset["time_periods"], self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = self._dataset["time_periods"] + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + edges = self._dataset["edges"] + edge_weights = self._dataset["weights"] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] + comb_edge_list.sort(key=lambda x: (x[1], x[0])) + self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) + + def _set_targets(self): + r"""Calculates and sets the target attributes""" + targets = [] + for time in range(self.gdata["total_timestamps"]): + targets.append(np.array(self._dataset[str(time)]["y"])) + stacked_target = np.stack(targets) + standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( + np.std(stacked_target, axis=0) + 10**-10 + ) + self._all_targets = np.array( + [standardized_target[i, :].T for i in range(len(targets))] + ) + + def _set_features(self): + # TODO: + pass + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets diff --git a/stgraph/dataset/temporal/WindmillOutputDataLoader.py b/stgraph/dataset/temporal/WindmillOutputDataLoader.py new file mode 100644 index 00000000..7c8efb32 --- /dev/null +++ b/stgraph/dataset/temporal/WindmillOutputDataLoader.py @@ -0,0 +1,227 @@ +import numpy as np + +from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset + + +class WindmillOutputDataLoader(STGraphTemporalDataset): + def __init__( + self, + verbose: bool = False, + url: str = None, + lags: int = 8, + cutoff_time: int = None, + size: str = "large", + redownload: bool = False, + ) -> None: + r"""Hourly energy output of windmills from a European country for more than 2 years. + + This class provides functionality for loading, processing, and accessing the Windmill + output dataset for use in deep learning such as regression tasks. + + .. list-table:: gdata for Windmill Output Small + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 11 + - 121 + - 17472 + + .. list-table:: gdata for Windmill Output Medium + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 26 + - 676 + - 17472 + + .. list-table:: gdata for Windmill Output Large + :widths: 33 33 33 + :header-rows: 1 + + * - num_nodes + - num_edges + - total_timestamps + * - 319 + - 101761 + - 17472 + + Example + ------- + + .. code-block:: python + + from stgraph.dataset import WindmillOutputDataLoader + + wind_small = WindmillOutputDataLoader(verbose=True, size="small") + num_nodes = wind_small.gdata["num_nodes"] + num_edges = wind_small.gdata["num_edges"] + total_timestamps = wind_small.gdata["total_timestamps"] + + edge_list = wind_small.get_edges() + edge_weights = wind_small.get_edge_weights() + targets = wind_small.get_all_targets() + + Parameters + ---------- + + verbose : bool, optional + Flag to control whether to display verbose info (default is False) + url : str, optional + The URL from where the dataset is downloaded online (default is None) + lags : int, optional + The number of time lags (default is 8) + cutoff_time : int, optional + The cutoff timestamp for the temporal dataset (default is None) + size : str, optional + The dataset size among large, medium and small (default is large) + redownload : bool, optional (default is False) + Redownload the dataset online and save to cache + + Attributes + ---------- + name : str + The name of the dataset. + _verbose : bool + Flag to control whether to display verbose info. + _lags : int + The number of time lags + _cutoff_time : int + The cutoff timestamp for the temporal dataset + _edge_list : list + The edge list of the graph dataset + _edge_weights : numpy.ndarray + Numpy array of the edge weights + _all_targets : numpy.ndarray + Numpy array of the node target value + """ + super().__init__() + + if type(lags) != int: + raise TypeError("lags must be of type int") + if lags < 0: + raise ValueError("lags must be a positive integer") + + if cutoff_time != None and type(cutoff_time) != int: + raise TypeError("cutoff_time must be of type int") + if cutoff_time != None and cutoff_time < 0: + raise ValueError("cutoff_time must be a positive integer") + + # TODO: Added check for cutoff <= lags + + if type(size) != str: + raise TypeError("size must be of type string") + if size not in ["large", "medium", "small"]: + raise ValueError( + "size must take either of the following values : large, medium or small" + ) + + self.name = "WindMill_" + size + self._verbose = verbose + self._lags = lags + self._cutoff_time = cutoff_time + self._size = size + + if not url: + if size == "large": + self._url = ( + "https://graphmining.ai/temporal_datasets/windmill_output.json" + ) + elif size == "medium": + self._url = "https://graphmining.ai/temporal_datasets/windmill_output_medium.json" + elif size == "small": + self._url = "https://graphmining.ai/temporal_datasets/windmill_output_small.json" + else: + self._url = url + + if self._has_dataset_cache(): + self._load_dataset() + else: + self._download_dataset() + self._save_dataset() + + self._process_dataset() + + def _process_dataset(self) -> None: + self._set_total_timestamps() + self._set_num_nodes() + self._set_num_edges() + self._set_edges() + self._set_edge_weights() + self._set_targets() + + def _set_total_timestamps(self) -> None: + r"""Sets the total timestamps present in the dataset + + It sets the total timestamps present in the dataset into the + gdata attribute dictionary. It is the minimum of the cutoff time + choosen by the user and the total time periods present in the + original dataset. + """ + if self._cutoff_time != None: + self.gdata["total_timestamps"] = min( + self._dataset["time_periods"], self._cutoff_time + ) + else: + self.gdata["total_timestamps"] = self._dataset["time_periods"] + + def _set_num_nodes(self): + r"""Sets the total number of nodes present in the dataset""" + node_set = set() + max_node_id = 0 + for edge in self._dataset["edges"]: + node_set.add(edge[0]) + node_set.add(edge[1]) + max_node_id = max(max_node_id, edge[0], edge[1]) + + assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + self.gdata["num_nodes"] = len(node_set) + + def _set_num_edges(self): + r"""Sets the total number of edges present in the dataset""" + self.gdata["num_edges"] = len(self._dataset["edges"]) + + def _set_edges(self): + r"""Sets the edge list of the dataset""" + self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] + + def _set_edge_weights(self): + r"""Sets the edge weights of the dataset""" + edges = self._dataset["edges"] + edge_weights = self._dataset["weights"] + comb_edge_list = [ + (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) + ] + comb_edge_list.sort(key=lambda x: (x[1], x[0])) + self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) + + def _set_targets(self): + r"""Calculates and sets the target attributes""" + stacked_target = np.stack(self._dataset["block"]) + standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( + np.std(stacked_target, axis=0) + 10**-10 + ) + self._all_targets = [ + standardized_target[i, :].T for i in range(self.gdata["total_timestamps"]) + ] + + def _set_features(self): + # TODO: + pass + + def get_edges(self): + r"""Returns the edge list""" + return self._edge_list + + def get_edge_weights(self): + r"""Returns the edge weights""" + return self._edge_weights + + def get_all_targets(self): + r"""Returns the targets for each timestamp""" + return self._all_targets diff --git a/stgraph/dataset/temporal/__init__.py b/stgraph/dataset/temporal/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dataset/dynamic/test_EnglandCovidDataLoader.py b/tests/dataset/dynamic/test_EnglandCovidDataLoader.py new file mode 100644 index 00000000..b683edc6 --- /dev/null +++ b/tests/dataset/dynamic/test_EnglandCovidDataLoader.py @@ -0,0 +1,64 @@ +import numpy as np +import urllib.request + +from stgraph.dataset import EnglandCovidDataLoader + + +def EnglandCovidDataCheck(eng_covid: EnglandCovidDataLoader): + assert eng_covid.gdata["total_timestamps"] == ( + 61 if not eng_covid._cutoff_time else eng_covid._cutoff_time + ) + assert ( + len(list(eng_covid.gdata["num_nodes"].values())) + == eng_covid.gdata["total_timestamps"] + ) + + for time, num_node in eng_covid.gdata["num_nodes"].items(): + assert num_node == 129 + + assert ( + len(list(eng_covid.gdata["num_edges"].values())) + == eng_covid.gdata["total_timestamps"] + ) + + edge_list = eng_covid.get_edges() + + assert len(edge_list) == eng_covid.gdata["total_timestamps"] + assert len(edge_list[0][0]) == 2 + + edge_weights = eng_covid.get_edge_weights() + + assert len(edge_weights) == eng_covid.gdata["total_timestamps"] + + for i in range(len(edge_list)): + assert len(edge_list[i]) == len(edge_weights[i]) + + all_features = eng_covid.get_all_features() + all_targets = eng_covid.get_all_targets() + + assert len(all_features) == eng_covid.gdata["total_timestamps"] - eng_covid._lags + + assert all_features[0].shape == ( + eng_covid.gdata["num_nodes"]["0"], + eng_covid._lags, + ) + + assert len(all_targets) == eng_covid.gdata["total_timestamps"] - eng_covid._lags + + assert all_targets[0].shape == (eng_covid.gdata["num_nodes"]["0"],) + + +def test_EnglandCovidDataLoader(): + eng_covid = EnglandCovidDataLoader(verbose=True) + eng_covid_1 = EnglandCovidDataLoader(cutoff_time=30) + eng_covid_2 = EnglandCovidDataLoader( + url="https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json" + ) + eng_covid_3 = EnglandCovidDataLoader(lags=12) + # eng_covid_4 = EnglandCovidDataLoader(redownload=True) + + EnglandCovidDataCheck(eng_covid) + EnglandCovidDataCheck(eng_covid_1) + # EnglandCovidDataCheck(eng_covid_2) + EnglandCovidDataCheck(eng_covid_3) + # EnglandCovidDataCheck(eng_covid_4) diff --git a/tests/dataset/static/test_CoraDataLoader.py b/tests/dataset/static/test_CoraDataLoader.py new file mode 100644 index 00000000..ae4a27c5 --- /dev/null +++ b/tests/dataset/static/test_CoraDataLoader.py @@ -0,0 +1,29 @@ +from stgraph.dataset import CoraDataLoader + + +def CoraDataCheck(cora: CoraDataLoader): + assert len(cora._edge_list) == 10556 + assert cora._all_features.shape == (2708, 1433) + assert cora._all_targets.shape == (2708,) + + assert cora.gdata["num_nodes"] == 2708 + assert cora.gdata["num_edges"] == 10556 + assert cora.gdata["num_feats"] == 1433 + assert cora.gdata["num_classes"] == 7 + + edge_list = cora.get_edges() + + assert len(edge_list) == 10556 and len(edge_list[0]) == 2 + assert cora.get_all_features().shape == (2708, 1433) + assert cora.get_all_targets().shape == (2708,) + + +def test_CoraDataLoader(): + cora = CoraDataLoader() + + cora_1 = CoraDataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/cora.json", + ) + + CoraDataCheck(cora) + CoraDataCheck(cora_1) diff --git a/tests/dataset/temporal/test_HungaryCPDataLoader.py b/tests/dataset/temporal/test_HungaryCPDataLoader.py new file mode 100644 index 00000000..5fc11ad8 --- /dev/null +++ b/tests/dataset/temporal/test_HungaryCPDataLoader.py @@ -0,0 +1,54 @@ +import pytest + +from stgraph.dataset import HungaryCPDataLoader + + +def HungaryCPDataChecker(hungary: HungaryCPDataLoader): + assert hungary.gdata["total_timestamps"] == ( + 521 if not hungary._cutoff_time else hungary._cutoff_time + ) + + assert hungary.gdata["num_nodes"] == 20 + assert hungary.gdata["num_edges"] == 102 + + edges = hungary.get_edges() + edge_weights = hungary.get_edge_weights() + all_targets = hungary.get_all_targets() + + assert len(edges) == 102 + assert len(edges[0]) == 2 + + assert len(edge_weights) == 102 + + assert len(all_targets) == hungary.gdata["total_timestamps"] - hungary._lags + assert all_targets[0].shape == (hungary.gdata["num_nodes"],) + + +def test_HungaryCPDataLoader(): + hungary_1 = HungaryCPDataLoader(verbose=True) + hungary_2 = HungaryCPDataLoader(lags=6) + hungary_3 = HungaryCPDataLoader(cutoff_time=100) + hungary_4 = HungaryCPDataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/HungaryCP.json" + ) + + HungaryCPDataChecker(hungary_1) + HungaryCPDataChecker(hungary_2) + HungaryCPDataChecker(hungary_3) + # HungaryCPDataChecker(hungary_4) + + with pytest.raises(TypeError) as exec: + HungaryCPDataLoader(lags="lags") + assert str(exec.value) == "lags must be of type int" + + with pytest.raises(ValueError) as exec: + HungaryCPDataLoader(lags=-1) + assert str(exec.value) == "lags must be a positive integer" + + with pytest.raises(TypeError) as exec: + HungaryCPDataLoader(cutoff_time="time") + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + HungaryCPDataLoader(cutoff_time=-1) + assert str(exec.value) == "cutoff_time must be a positive integer" diff --git a/tests/dataset/temporal/test_METRLADataLoader.py b/tests/dataset/temporal/test_METRLADataLoader.py new file mode 100644 index 00000000..de07ad3f --- /dev/null +++ b/tests/dataset/temporal/test_METRLADataLoader.py @@ -0,0 +1,62 @@ +import pytest +from stgraph.dataset import METRLADataLoader + + +def METRLADataCheck(metrla: METRLADataLoader): + assert metrla.gdata["total_timestamps"] == ( + 100 if not metrla._cutoff_time else metrla._cutoff_time + ) + + assert metrla.gdata["num_nodes"] == 207 + assert metrla.gdata["num_edges"] == 1722 + + edges = metrla.get_edges() + edge_weights = metrla.get_edge_weights() + # all_targets = metrla.get_all_targets() + # all_features = metrla.get_all_features() + + assert len(edges) == 1722 + assert len(edges[0]) == 2 + + assert len(edge_weights) == 1722 + + # TODO: Check targets and features list + + +def test_METRLADataLoader(): + metrla_1 = METRLADataLoader(verbose=True) + metrla_2 = METRLADataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/METRLA.json" + ) + metrla_3 = METRLADataLoader(num_timesteps_in=8, num_timesteps_out=8) + metrla_4 = METRLADataLoader(cutoff_time=50) + # metrla_5 = METRLADataLoader(redownload=True) + + METRLADataCheck(metrla_1) + METRLADataCheck(metrla_2) + METRLADataCheck(metrla_3) + METRLADataCheck(metrla_4) + + with pytest.raises(TypeError) as exec: + METRLADataLoader(num_timesteps_in="num_timesteps_in") + assert str(exec.value) == "num_timesteps_in must be of type int" + + with pytest.raises(ValueError) as exec: + METRLADataLoader(num_timesteps_in=-1) + assert str(exec.value) == "num_timesteps_in must be a positive integer" + + with pytest.raises(TypeError) as exec: + METRLADataLoader(num_timesteps_out="num_timesteps_out") + assert str(exec.value) == "num_timesteps_out must be of type int" + + with pytest.raises(ValueError) as exec: + METRLADataLoader(num_timesteps_out=-1) + assert str(exec.value) == "num_timesteps_out must be a positive integer" + + with pytest.raises(TypeError) as exec: + METRLADataLoader(cutoff_time="time") + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + METRLADataLoader(cutoff_time=-1) + assert str(exec.value) == "cutoff_time must be a positive integer" diff --git a/tests/dataset/temporal/test_MontevideoBusDataLoader.py b/tests/dataset/temporal/test_MontevideoBusDataLoader.py new file mode 100644 index 00000000..052b94f0 --- /dev/null +++ b/tests/dataset/temporal/test_MontevideoBusDataLoader.py @@ -0,0 +1,69 @@ +import pytest +from stgraph.dataset import MontevideoBusDataLoader + + +def MontevideoBusDataCheck(monte: MontevideoBusDataLoader): + assert monte.gdata["total_timestamps"] == ( + 744 if not monte._cutoff_time else monte._cutoff_time + ) + + assert monte.gdata["num_nodes"] == 675 + assert monte.gdata["num_edges"] == 690 + + edges = monte.get_edges() + edge_weights = monte.get_edge_weights() + all_targets = monte.get_all_targets() + all_features = monte.get_all_features() + + assert len(edges) == 690 + + for edge in edges: + assert len(edge) == 2 + + assert len(edge_weights) == 690 + + assert all_features.shape == ( + monte.gdata["total_timestamps"] - monte._lags, + monte.gdata["num_nodes"], + monte._lags, + ) + + assert all_targets.shape == ( + monte.gdata["total_timestamps"] - monte._lags, + monte.gdata["num_nodes"], + ) + + +def test_MontevideoBusDataLoader(): + monte_1 = MontevideoBusDataLoader(verbose=True) + monte_2 = MontevideoBusDataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/montevideobus.json" + ) + monte_3 = MontevideoBusDataLoader(lags=6) + monte_4 = MontevideoBusDataLoader(cutoff_time=50) + # monte_5 = MontevideoBusDataLoader(redownload=True) + + MontevideoBusDataCheck(monte_1) + MontevideoBusDataCheck(monte_2) + MontevideoBusDataCheck(monte_3) + MontevideoBusDataCheck(monte_4) + + with pytest.raises(TypeError) as exec: + MontevideoBusDataLoader(lags="lags") + assert str(exec.value) == "lags must be of type int" + + with pytest.raises(ValueError) as exec: + MontevideoBusDataLoader(lags=-1) + assert str(exec.value) == "lags must be a positive integer" + + with pytest.raises(TypeError) as exec: + MontevideoBusDataLoader(cutoff_time="time") + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + MontevideoBusDataLoader(cutoff_time=-1) + assert str(exec.value) == "cutoff_time must be a positive integer" + + with pytest.raises(ValueError) as exec: + MontevideoBusDataLoader(cutoff_time=4) + assert str(exec.value) == "cutoff_time must be greater than lags" diff --git a/tests/dataset/temporal/test_PedalMeDataLoader.py b/tests/dataset/temporal/test_PedalMeDataLoader.py new file mode 100644 index 00000000..dee56646 --- /dev/null +++ b/tests/dataset/temporal/test_PedalMeDataLoader.py @@ -0,0 +1,63 @@ +import pytest +from stgraph.dataset import PedalMeDataLoader + + +def PedalMeDataCheck(pedal: PedalMeDataLoader): + assert pedal.gdata["total_timestamps"] == ( + 36 if not pedal._cutoff_time else pedal._cutoff_time + ) + + assert pedal.gdata["num_nodes"] == 15 + assert pedal.gdata["num_edges"] == 225 + + edges = pedal.get_edges() + edge_weights = pedal.get_edge_weights() + all_targets = pedal.get_all_targets() + + assert len(edges) == 225 + + for edge in edges: + assert len(edge) == 2 + + assert len(edge_weights) == 225 + + assert all_targets.shape == ( + pedal.gdata["total_timestamps"] - pedal._lags, + pedal.gdata["num_nodes"], + ) + + +def test_PedalMeDataLoader(): + pedal_1 = PedalMeDataLoader(verbose=True) + pedal_2 = PedalMeDataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/pedalme.json" + ) + pedal_3 = PedalMeDataLoader(lags=6) + pedal_4 = PedalMeDataLoader(cutoff_time=20) + # pedal_5 = PedalMeDataLoader(redownload=True) + + PedalMeDataCheck(pedal_1) + PedalMeDataCheck(pedal_2) + PedalMeDataCheck(pedal_3) + PedalMeDataCheck(pedal_4) + # PedalMeDataCheck(pedal_5) + + with pytest.raises(TypeError) as exec: + PedalMeDataLoader(lags="lags") + assert str(exec.value) == "lags must be of type int" + + with pytest.raises(ValueError) as exec: + PedalMeDataLoader(lags=-1) + assert str(exec.value) == "lags must be a positive integer" + + with pytest.raises(TypeError) as exec: + PedalMeDataLoader(cutoff_time="time") + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + PedalMeDataLoader(cutoff_time=-1) + assert str(exec.value) == "cutoff_time must be a positive integer" + + with pytest.raises(ValueError) as exec: + PedalMeDataLoader(cutoff_time=4) + assert str(exec.value) == "cutoff_time must be greater than lags" diff --git a/tests/dataset/temporal/test_WikiMathDataLoader.py b/tests/dataset/temporal/test_WikiMathDataLoader.py new file mode 100644 index 00000000..4e1f5f30 --- /dev/null +++ b/tests/dataset/temporal/test_WikiMathDataLoader.py @@ -0,0 +1,56 @@ +import pytest +from stgraph.dataset import WikiMathDataLoader + + +def WikiMathDataCheck(wiki: WikiMathDataLoader): + assert wiki.gdata["total_timestamps"] == ( + 731 if not wiki._cutoff_time else wiki._cutoff_time + ) + + assert wiki.gdata["num_nodes"] == 1068 + assert wiki.gdata["num_edges"] == 27079 + + edges = wiki.get_edges() + edge_weights = wiki.get_edge_weights() + all_targets = wiki.get_all_targets() + + assert len(edges) == 27079 + + for edge in edges: + assert len(edge) == 2 + + assert len(edge_weights) == 27079 + + # TODO: Add tests for features and targets arrays + + +def test_WikiMathDataLoader(): + wiki_1 = WikiMathDataLoader(verbose=True) + wiki_2 = WikiMathDataLoader( + url="https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/wikivital_mathematics.json" + ) + wiki_3 = WikiMathDataLoader(lags=4) + wiki_4 = WikiMathDataLoader(cutoff_time=500) + # wiki_5 = WikiMathDataLoader(redownload=True) + + WikiMathDataCheck(wiki_1) + WikiMathDataCheck(wiki_2) + WikiMathDataCheck(wiki_3) + WikiMathDataCheck(wiki_4) + # WikiMathDataCheck(wiki_5) + + with pytest.raises(TypeError) as exec: + WikiMathDataLoader(lags="lags") + assert str(exec.value) == "lags must be of type int" + + with pytest.raises(ValueError) as exec: + WikiMathDataLoader(lags=-1) + assert str(exec.value) == "lags must be a positive integer" + + with pytest.raises(TypeError) as exec: + WikiMathDataLoader(cutoff_time="time") + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + WikiMathDataLoader(cutoff_time=-1) + assert str(exec.value) == "cutoff_time must be a positive integer" diff --git a/tests/dataset/temporal/test_WindmillOutputDataLoader.py b/tests/dataset/temporal/test_WindmillOutputDataLoader.py new file mode 100644 index 00000000..693486e5 --- /dev/null +++ b/tests/dataset/temporal/test_WindmillOutputDataLoader.py @@ -0,0 +1,74 @@ +import pytest +from stgraph.dataset import WindmillOutputDataLoader + + +def WindmillOutputDataCheck(wind: WindmillOutputDataLoader): + assert wind.gdata["total_timestamps"] == ( + 17472 if not wind._cutoff_time else wind._cutoff_time + ) + + if wind._size == "large": + assert wind.gdata["num_nodes"] == 319 + assert wind.gdata["num_edges"] == 101761 + elif wind._size == "medium": + assert wind.gdata["num_nodes"] == 26 + assert wind.gdata["num_edges"] == 676 + elif wind._size == "small": + assert wind.gdata["num_nodes"] == 11 + assert wind.gdata["num_edges"] == 121 + + edges = wind.get_edges() + edge_weights = wind.get_edge_weights() + all_targets = wind.get_all_targets() + + if wind._size == "large": + assert len(edges) == 101761 + assert len(edge_weights) == 101761 + elif wind._size == "medium": + assert len(edges) == 676 + assert len(edge_weights) == 676 + elif wind._size == "small": + assert len(edges) == 121 + assert len(edge_weights) == 121 + + for edge in edges: + len(edge) == 2 + + # TODO: Test for targets and features + + +def test_WindmillOutputDataLoader(): + urls = { + "large": "https://graphmining.ai/temporal_datasets/windmill_output.json", + "medium": "https://graphmining.ai/temporal_datasets/windmill_output_medium.json", + "small": "https://graphmining.ai/temporal_datasets/windmill_output_small.json", + } + + for size in ["large", "medium", "small"]: + wind_1 = WindmillOutputDataLoader(verbose=True, size=size) + wind_2 = WindmillOutputDataLoader(url=urls[size], size=size) + wind_3 = WindmillOutputDataLoader(lags=4, size=size) + wind_4 = WindmillOutputDataLoader(cutoff_time=100, size=size) + # wind_5 = WindmillOutputDataLoader(redownload=True, size=size) + + WindmillOutputDataCheck(wind_1) + WindmillOutputDataCheck(wind_2) + WindmillOutputDataCheck(wind_3) + WindmillOutputDataCheck(wind_4) + # WindmillOutputDataCheck(wind_5) + + with pytest.raises(TypeError) as exec: + WindmillOutputDataLoader(lags="lags", size=size) + assert str(exec.value) == "lags must be of type int" + + with pytest.raises(ValueError) as exec: + WindmillOutputDataLoader(lags=-1, size=size) + assert str(exec.value) == "lags must be a positive integer" + + with pytest.raises(TypeError) as exec: + WindmillOutputDataLoader(cutoff_time="time", size=size) + assert str(exec.value) == "cutoff_time must be of type int" + + with pytest.raises(ValueError) as exec: + WindmillOutputDataLoader(cutoff_time=-1, size=size) + assert str(exec.value) == "cutoff_time must be a positive integer" diff --git a/tests/graph/dynamic/test_DynamicGraphPCSR.py b/tests/graph/dynamic/test_DynamicGraphPCSR.py deleted file mode 100644 index e0a824d5..00000000 --- a/tests/graph/dynamic/test_DynamicGraphPCSR.py +++ /dev/null @@ -1,278 +0,0 @@ -import json -from rich import inspect - -from stgraph.graph.dynamic.pcsr.PCSRGraph import PCSRGraph -from stgraph.graph.static.csr import get_dev_array - -class TestDynamicGraphPCSR: - # Stores the edge-list for each timestamp - # edge_index[t] has the edge list for timestamp = t - edge_index = [ - [[0, 1], [1, 2], [2, 3]], - [[1, 2], [2, 3], [3, 1]], - [[1, 2], [2, 3], [3, 1], [3, 0]], - [[2, 3], [3, 1], [3, 0], [0, 1]], - ] - - # sorted based on second-element, then first-element - sorted_edge_index = [ - [[0, 1], [1, 2], [2, 3]], - [[3, 1], [1, 2], [2, 3]], - [[3, 0], [3, 1], [1, 2], [2, 3]], - [[3, 0], [0, 1], [3, 1], [2, 3]], - ] - - # stores the edge-weights for each timestamp - # edge_weight[t] has the edge weight for timestamp = t - # corresponding to the edge in edge_index - edge_weight = [[4, 7, 9], [2, 11, 13], [3, 8, 1, 5], [15, 6, 10, 12]] - - # total timestamps for this dynamic graph dataset - time_periods = 4 - - # node features for each timestamp - y = [[2, 11, 15, 8], [9, 5, 7, 10], [1, 3, 17, 19], [4, 6, 12, 13]] - - def test_get_graph_attr(self): - pcsr_graph = PCSRGraph(edge_list=self.sorted_edge_index, max_num_nodes=4) - graph_attr = pcsr_graph._get_graph_attr(edge_list=self.sorted_edge_index) - - # checking if the size of graph_attr = 4, which is the - # total number of timestamps present - assert len(graph_attr) == 4 - - # checking the (num_nodes, num_edges) pair for each timestamp - assert graph_attr["0"] == (4, 3) - assert graph_attr["1"] == (4, 3) - assert graph_attr["2"] == (4, 4) - assert graph_attr["3"] == (4, 4) - - def test_preprocess_graph_structure(self): - pcsr_graph = PCSRGraph(edge_list=self.sorted_edge_index, max_num_nodes=4) - - # checking graph_updates for t = 0 - graph_updates_0 = pcsr_graph.graph_updates["0"] - assert len(graph_updates_0["add"]) == 3 - for edge in [(0, 1), (1, 2), (2, 3)]: - assert edge in graph_updates_0["add"] - - assert len(graph_updates_0["delete"]) == 0 - - assert graph_updates_0["num_nodes"] == 4 - assert graph_updates_0["num_edges"] == 3 - - # checking graph_updates for t = 1 - graph_updates_1 = pcsr_graph.graph_updates["1"] - assert len(graph_updates_1["add"]) == 1 - assert (3, 1) in graph_updates_1["add"] - - assert len(graph_updates_1["delete"]) == 1 - assert (0, 1) in graph_updates_1["delete"] - - assert graph_updates_1["num_nodes"] == 4 - assert graph_updates_1["num_edges"] == 3 - - # checking graph_updates for t = 2 - graph_updates_2 = pcsr_graph.graph_updates["2"] - - assert len(graph_updates_2["add"]) == 1 - assert (3, 0) in graph_updates_2["add"] - - assert len(graph_updates_2["delete"]) == 0 - - assert graph_updates_2["num_nodes"] == 4 - assert graph_updates_2["num_edges"] == 4 - - # checking graph_updates for t = 3 - graph_updates_3 = pcsr_graph.graph_updates["3"] - - assert len(graph_updates_3["add"]) == 1 - assert (0, 1) in graph_updates_3["add"] - - assert len(graph_updates_3["delete"]) == 1 - assert (1, 2) in graph_updates_3["delete"] - - assert graph_updates_3["num_nodes"] == 4 - assert graph_updates_3["num_edges"] == 4 - - def test_get_graph(self): - pcsr_graph = PCSRGraph(edge_list=self.sorted_edge_index, max_num_nodes=4) - - # for time = 0 - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.fwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.fwd_column_indices_ptr, 3), - get_dev_array(pcsr_graph.fwd_eids_ptr, 3), - ) - - assert row_offset == [0, 0, 1, 2, 3] - assert column_indices == [0, 1, 2] - assert eids == [0, 1, 2] - - # for time = 1 - pcsr_graph.get_graph(1) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.fwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.fwd_column_indices_ptr, 3), - get_dev_array(pcsr_graph.fwd_eids_ptr, 3), - ) - - assert row_offset == [0, 0, 1, 2, 3] - assert column_indices == [3, 1, 2] - assert eids == [0, 1, 2] - - # for time = 2 - pcsr_graph.get_graph(2) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.fwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.fwd_column_indices_ptr, 4), - get_dev_array(pcsr_graph.fwd_eids_ptr, 4), - ) - - assert row_offset == [0, 1, 2, 3, 4] - assert column_indices == [3, 3, 1, 2] - assert eids == [0, 1, 2, 3] - - # for time = 3 - pcsr_graph.get_graph(3) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.fwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.fwd_column_indices_ptr, 4), - get_dev_array(pcsr_graph.fwd_eids_ptr, 4), - ) - - assert row_offset == [0, 1, 3, 3, 4] - assert column_indices == [3, 0, 3, 2] - assert eids == [0, 1, 2, 3] - - def test_get_backward_graph(self): - - pcsr_graph = PCSRGraph(edge_list=self.sorted_edge_index, max_num_nodes=4) - pcsr_graph.get_graph(3) - - # for time = 3 - pcsr_graph.get_backward_graph(3) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.bwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.bwd_column_indices_ptr, 4), - get_dev_array(pcsr_graph.bwd_eids_ptr, 4), - ) - - assert row_offset == [0,1,1,2,4] - assert column_indices == [1,3,0,1] - assert eids == [1,3,0,2] - - # for time = 2 - pcsr_graph.get_backward_graph(2) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.bwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.bwd_column_indices_ptr, 4), - get_dev_array(pcsr_graph.bwd_eids_ptr, 4), - ) - - assert row_offset == [0,0,1,2,4] - assert column_indices == [2,3,0,1] - assert eids == [2,3,0,1] - - # for time = 1 - pcsr_graph.get_backward_graph(1) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.bwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.bwd_column_indices_ptr, 3), - get_dev_array(pcsr_graph.bwd_eids_ptr, 3), - ) - - assert row_offset == [0,0,1,2,3] - assert column_indices == [2,3,1] - assert eids == [1,2,0] - - # for time = 0 - pcsr_graph.get_backward_graph(0) - - row_offset, column_indices, eids = ( - get_dev_array(pcsr_graph.bwd_row_offset_ptr, 5), - get_dev_array(pcsr_graph.bwd_column_indices_ptr, 3), - get_dev_array(pcsr_graph.bwd_eids_ptr, 3), - ) - - assert row_offset == [0,1,2,3,3] - assert column_indices == [1,2,3] - assert eids == [0,1,2] - - def test_get_num_nodes(self): - """ - Assert the number of nodes in the graph, then repeat - this assert step after calling DynamicGraph.get_graph() - and DynamicGraph.get_backward_graph() in sequential order - """ - - # base graph: t = 0 - pcsr_graph = PCSRGraph(edge_list=self.edge_index, max_num_nodes=4) - assert pcsr_graph.get_num_nodes() == 4 - - # graph: t = 1 - pcsr_graph.get_graph(1) - assert pcsr_graph.get_num_nodes() == 4 - - # graph: t = 2 - pcsr_graph.get_graph(2) - assert pcsr_graph.get_num_nodes() == 4 - - # graph: t = 3 - pcsr_graph.get_graph(3) - assert pcsr_graph.get_num_nodes() == 4 - - # Now moving the graph in the backward direction - # graph: t = 2 - pcsr_graph.get_backward_graph(2) - assert pcsr_graph.get_num_nodes() == 4 - - # graph: t = 1 - pcsr_graph.get_backward_graph(1) - assert pcsr_graph.get_num_nodes() == 4 - - # graph: t = 1 - pcsr_graph.get_backward_graph(0) - assert pcsr_graph.get_num_nodes() == 4 - - def test_get_num_edges(self): - """ - Assert the number of edges in the graph, then repeat - this assert step after calling DynamicGraph.get_graph() - and DynamicGraph.get_backward_graph() in sequential order - """ - - # base graph: t = 0 - pcsr_graph = PCSRGraph(edge_list=self.edge_index, max_num_nodes=4) - assert pcsr_graph.get_num_edges() == 3 - - # graph: t = 1 - pcsr_graph.get_graph(1) - assert pcsr_graph.get_num_edges() == 3 - - # graph: t = 2 - pcsr_graph.get_graph(2) - assert pcsr_graph.get_num_edges() == 4 - - # graph: t = 3 - pcsr_graph.get_graph(3) - assert pcsr_graph.get_num_edges() == 4 - - # Now moving the graph in the backward direction - # graph: t = 2 - pcsr_graph.get_backward_graph(2) - assert pcsr_graph.get_num_edges() == 4 - - # graph: t = 1 - pcsr_graph.get_backward_graph(1) - assert pcsr_graph.get_num_edges() == 3 - - # graph: t = 1 - pcsr_graph.get_backward_graph(0) - assert pcsr_graph.get_num_edges() == 3 \ No newline at end of file