diff --git a/.github/workflows/ruff.yaml b/.github/workflows/ruff.yaml new file mode 100644 index 00000000..94f35cdc --- /dev/null +++ b/.github/workflows/ruff.yaml @@ -0,0 +1,25 @@ +name: Ruff Linting + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.8"] + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff + - name: Analysing the code with ruff + run: | + cd stgraph/dataset/ + ruff check . + cd ../../ \ No newline at end of file diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 00000000..f2f3fdbd --- /dev/null +++ b/.pylintrc @@ -0,0 +1,15 @@ +[FORMAT] + +# Set the maximum line length to a value that suits your project +max-line-length = 120 + +[MESSAGES CONTROL] + +# Disable specific messages by adding them to the "disable" option +disable= + line-too-long, + too-many-instance-attributes, + too-many-arguments, + import-error, + too-few-public-methods, + # Add more disabled messages here if needed diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 00000000..80e82b09 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,27 @@ +[lint] +select = ["ALL"] + +ignore = [ + "FBT002", + "FBT001", + "PLR0913", + "TRY003", + "EM101", + "ERA001", + "S607", + "S602", + "PTH111", + "PTH110", + "PTH107", + "PTH123", + "S605", + "S310", + "FIX002", + "D203", + "D211", + "D212", + "D213", +] + +[lint.per-file-ignores] +"__init__.py" = ["F401"] diff --git a/stgraph/dataset/EnglandCovidDataLoader.py b/stgraph/dataset/EnglandCovidDataLoader.py deleted file mode 100644 index 5586b895..00000000 --- a/stgraph/dataset/EnglandCovidDataLoader.py +++ /dev/null @@ -1,152 +0,0 @@ -import os -import json -import urllib.request -import time - -import numpy as np - -from rich import inspect -from rich.pretty import pprint -from rich.progress import track -from rich.console import Console - -console = Console() - -from rich.traceback import install - -install(show_locals=True) - - -class EnglandCovidDataLoader: - def __init__( - self, verbose: bool = False, lags: int = 8, split=0.75, for_stgraph=False - ) -> None: - self.name = "EnglandCOVID" - self.lags = lags - self.split = split - - self._graph_attr = {} - self._graph_updates = {} - self._max_num_nodes = 0 - - self._local_path = "england_covid.json" - self._url_path = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json" - self._verbose = verbose - - self._load_dataset() - self.total_timestamps = self._dataset["time_periods"] - self._get_targets_and_features() - - if for_stgraph: - self._get_edge_info_stgraph() - self._presort_edge_weights() - else: - self._get_edge_info_pygt() - - def get_graph_data(self): - return self._graph_updates, self._max_num_nodes - - def _load_dataset(self) -> None: - # loading the dataset by downloading them online - if self._verbose: - console.log(f"Downloading [cyan]{self.name}[/cyan] dataset") - self._dataset = json.loads(urllib.request.urlopen(self._url_path).read()) - - # THIS NEEDS TO BE EDITED - # with open('../../dataset/eng_covid/eng_covid.json', 'w') as f: - # json.dump(self._dataset,f) - - def _get_edge_info_stgraph(self): - # getting the edge_list and edge_weights - self._edge_list = [] - self._edge_weights = [] - - for time in range(self._dataset["time_periods"]): - time_edge_list = [] - time_edge_weights = [] - - for edge in self._dataset["edge_mapping"]["edge_index"][str(time)]: - time_edge_list.append((edge[0], edge[1])) - - for weight in self._dataset["edge_mapping"]["edge_weight"][str(time)]: - time_edge_weights.append(weight) - - self._edge_list.append(time_edge_list) - self._edge_weights.append(time_edge_weights) - - def _get_edge_info_pygt(self): - self._edge_list = [] - self._edge_weights = [] - - for time in range(self._dataset["time_periods"]): - self._edge_list.append( - np.array(self._dataset["edge_mapping"]["edge_index"][str(time)]).T - ) - self._edge_weights.append( - np.array(self._dataset["edge_mapping"]["edge_weight"][str(time)]) - ) - - def _get_targets_and_features(self): - stacked_target = np.array(self._dataset["y"]) - standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( - np.std(stacked_target, axis=0) + 10**-10 - ) - - self._all_features = [ - standardized_target[i : i + self.lags, :].T - for i in range(self._dataset["time_periods"] - self.lags) - ] - self._all_targets = [ - standardized_target[i + self.lags, :].T - for i in range(self._dataset["time_periods"] - self.lags) - ] - - def _presort_edge_weights(self): - """ - Presorting edges according to (dest,src) since that is how eids are formed - allowing forward and backward kernel to access edge weights - """ - final_edges_lst = [] - final_edge_weights_lst = [] - - for i in range(len(self._edge_list)): - src_list = [edge[0] for edge in self._edge_list[i]] - dst_list = [edge[1] for edge in self._edge_list[i]] - weights = self._edge_weights[i] - - edge_info_list = [] - sorted_edges_lst = [] - sorted_edge_weights_lst = [] - - for j in range(len(weights)): - edge_info = (src_list[j], dst_list[j], weights[j]) - edge_info_list.append(edge_info) - - # since it has to be sorted according to the reverse order - sorted_edge_info_list = sorted( - edge_info_list, key=lambda element: (element[1], element[0]) - ) - - time_edge = [] - - for edge in sorted_edge_info_list: - time_edge.append((edge[0], edge[1])) - sorted_edge_weights_lst.append(edge[2]) - - final_edges_lst.append(time_edge) - final_edge_weights_lst.append(np.array(sorted_edge_weights_lst)) - - self._edge_list = final_edges_lst - self._edge_weights = final_edge_weights_lst - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_features(self): - return self._all_features - - def get_all_targets(self): - return self._all_targets diff --git a/stgraph/dataset/HungaryCPDataLoader.py b/stgraph/dataset/HungaryCPDataLoader.py deleted file mode 100644 index 63824eb6..00000000 --- a/stgraph/dataset/HungaryCPDataLoader.py +++ /dev/null @@ -1,82 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np - -console = Console() - - -class HungaryCPDataLoader: - def __init__( - self, - folder_name, - dataset_name, - lags, - cutoff_time, - verbose: bool = False, - for_stgraph=False, - ) -> None: - self.name = dataset_name - self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" - self._verbose = verbose - self.for_stgraph = for_stgraph - self.lags = lags - - self._load_dataset() - self.total_timestamps = min(len(self._dataset["FX"]), cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets_and_features() - - def _load_dataset(self): - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log( - f"Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json" - ) - else: - console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - def _get_edge_weights(self): - self._edge_weights = np.ones(self.num_edges) - - def _get_targets_and_features(self): - stacked_target = np.array(self._dataset["FX"]) - self._all_targets = np.array( - [stacked_target[i, :].T for i in range(stacked_target.shape[0])] - ) - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets diff --git a/stgraph/dataset/LinkPredDataLoader.py b/stgraph/dataset/LinkPredDataLoader.py deleted file mode 100644 index 0961713b..00000000 --- a/stgraph/dataset/LinkPredDataLoader.py +++ /dev/null @@ -1,79 +0,0 @@ -import os -import json -import numpy as np -from rich.console import Console - -console = Console() - -class LinkPredDataLoader: - def __init__(self, folder_name, dataset_name, cutoff_time, verbose: bool = False, for_stgraph= False): - self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' - self._verbose = verbose - self.for_stgraph = for_stgraph - - self._load_dataset() - self._get_max_num_nodes() - self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - self._get_edge_info() - self._preprocess_pos_neg_edges() - - def _load_dataset(self) -> None: - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/') - else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') - quit() - - def _get_max_num_nodes(self): - node_set = set() - max_node_id = 0 - - for i in range(len(self._dataset["edge_mapping"]["edge_index"])): - for edge in self._dataset["edge_mapping"]["edge_index"][str(i)]["add"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.max_num_nodes = len(node_set) - - def _get_edge_info(self): - # getting the edge_list and edge_weights - edge_list = [] - updates = self._dataset["edge_mapping"]["edge_index"] - - working_set = set([(edge[0], edge[1]) for edge in updates["0"]["add"]]) - edge_list.append(list(working_set)) - for time in range(1, self.total_timestamps): - working_set = working_set.union(set([(edge[0], edge[1]) for edge in updates[str(time)]["add"]])).difference(set([(edge[0], edge[1]) for edge in updates[str(time)]["delete"]])) - edge_list.append(list(working_set)) - - if self.for_stgraph: - self._edge_list = edge_list - else: - self._edge_list = [np.array(edge_lst_t).T for edge_lst_t in edge_list] - - def _preprocess_pos_neg_edges(self): - updates = self._dataset["edge_mapping"]["edge_index"] - - pos_neg_edge_list = [] - pos_neg_edge_label_list = [] - - for i in range(1, self.total_timestamps): - pos_edges_tup = list(updates[str(i)]["add"]) - neg_edges_tup = list(updates[str(i)]["neg"]) - pos_neg_edge_list.append(pos_edges_tup + neg_edges_tup) - pos_neg_edge_label_list.append([(edge[0], edge[1], 1) for edge in pos_edges_tup] + [(edge[0], edge[1], 0) for edge in neg_edges_tup]) - - self._pos_neg_edge_list = [np.array(edge_list).T for edge_list in pos_neg_edge_list] - self._pos_neg_edge_label_list = [np.array([edge[2] for edge in edge_list]) for edge_list in pos_neg_edge_label_list] - - def get_edges(self): - return self._edge_list - - def get_pos_neg_edges(self): - return self._pos_neg_edge_list, self._pos_neg_edge_label_list \ No newline at end of file diff --git a/stgraph/dataset/METRLADataLoader.py b/stgraph/dataset/METRLADataLoader.py deleted file mode 100644 index 390b39c0..00000000 --- a/stgraph/dataset/METRLADataLoader.py +++ /dev/null @@ -1,137 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np - -console = Console() -import torch - -from rich import inspect - - -class METRLADataLoader: - def __init__( - self, - folder_name, - dataset_name, - num_timesteps_in, - num_timesteps_out, - cutoff_time, - verbose: bool = False, - for_stgraph: bool = False, - ): - self.name = dataset_name - self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" - self._verbose = verbose - self.for_stgraph = for_stgraph - - self.num_timesteps_in = num_timesteps_in - self.num_timesteps_out = num_timesteps_out - - self._load_dataset() - self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets_and_features() - - def _load_dataset(self): - # loading the dataset locally - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log(f"Loading [cyan]{self.name}[/cyan] dataset from dataset/") - else: - console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - # TODO: We are sorting the edge weights accordingly, but are we doing - # the same for edges in the edge list - def _get_edge_weights(self): - if self.for_stgraph: - edges = self._dataset["edges"] - edge_weights = self._dataset["weights"] - comb_edge_list = [ - (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) - ] - comb_edge_list.sort(key=lambda x: (x[1], x[0])) - self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - else: - self._edge_weights = np.array(self._dataset["weights"]) - - def _get_targets_and_features(self): - X = [] - - for timestamp in range(self._dataset["time_periods"]): - if timestamp < self.total_timestamps: - X.append(self._dataset[str(timestamp)]) - - X = np.array(X) - X = X.transpose((1, 2, 0)) - X = X.astype(np.float32) - - # Normalise as in DCRNN paper (via Z-Score Method) - means = np.mean(X, axis=(0, 2)) - X = X - means.reshape(1, -1, 1) - stds = np.std(X, axis=(0, 2)) - X = X / stds.reshape(1, -1, 1) - - X = torch.from_numpy(X) - - inspect(X) - - indices = [ - (i, i + (self.num_timesteps_in + self.num_timesteps_out)) - for i in range( - X.shape[2] - (self.num_timesteps_in + self.num_timesteps_out) + 1 - ) - ] - - # Generate observations - features, target = [], [] - for i, j in indices: - features.append((X[:, :, i : i + self.num_timesteps_in]).numpy()) - target.append((X[:, 0, i + self.num_timesteps_in : j]).numpy()) - - # inspect(indices) - # inspect(features) - # inspect(target) - # quit() - - self._all_features = np.array(features) - self._all_targets = np.array(target) - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets - - def get_all_features(self): - return self._all_features diff --git a/stgraph/dataset/MontevideoBusDataLoader.py b/stgraph/dataset/MontevideoBusDataLoader.py deleted file mode 100644 index 68282db4..00000000 --- a/stgraph/dataset/MontevideoBusDataLoader.py +++ /dev/null @@ -1,85 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np -console = Console() - -class MontevideoBusDataLoader: - def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph = False) -> None: - self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' - self._verbose = verbose - self.for_stgraph = for_stgraph - self.lags = lags - - self._load_dataset() - self.total_timestamps = min(len(self._dataset["nodes"][0]['y']), cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets() - - def _load_dataset(self): - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json') - else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - def _get_edge_weights(self): - if self.for_stgraph: - edges = self._dataset["edges"] - edge_weights = self._dataset["weights"] - comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))] - comb_edge_list.sort(key=lambda x: (x[1], x[0])) - self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - else: - self._edge_weights = np.array(self._dataset["weights"]) - - def _get_targets(self, target_var: str = "y"): - targets = [] - for node in self._dataset["nodes"]: - y = node.get(target_var) - targets.append(np.array(y)) - stacked_targets = np.stack(targets).T - standardized_targets = ( - stacked_targets - np.mean(stacked_targets, axis=0) - ) / np.std(stacked_targets, axis=0) - self._all_targets = np.array([ - standardized_targets[i, :].T - for i in range(len(standardized_targets)) - ]) - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets \ No newline at end of file diff --git a/stgraph/dataset/PedalMeDataLoader.py b/stgraph/dataset/PedalMeDataLoader.py deleted file mode 100644 index 75bef120..00000000 --- a/stgraph/dataset/PedalMeDataLoader.py +++ /dev/null @@ -1,99 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np - -console = Console() - -from rich import inspect - - -class PedalMeDataLoader: - def __init__( - self, - folder_name, - dataset_name, - lags, - cutoff_time, - verbose: bool = False, - for_stgraph: bool = False, - ): - self.name = dataset_name - self._local_path = f"../../dataset/{folder_name}/{dataset_name}.json" - self._verbose = verbose - self.for_stgraph = for_stgraph - self.lags = lags - - self._load_dataset() - self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets_and_features() - - def _load_dataset(self): - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - - if self._verbose: - console.log( - f"Loading [cyan]{self.name}[/cyan] dataset from dataset/{self.name}.json" - ) - else: - console.log(f"Failed to find [cyan]{self.name}[/cyan] dataset from dataset") - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - def _get_edge_weights(self): - if self.for_stgraph: - edges = self._dataset["edges"] - edge_weights = self._dataset["weights"] - comb_edge_list = [ - (edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges)) - ] - comb_edge_list.sort(key=lambda x: (x[1], x[0])) - self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - else: - self._edge_weights = np.array(self._dataset["weights"]) - - def _get_targets_and_features(self): - targets = [] - for time in range(self.total_timestamps): - targets.append(np.array(self._dataset[str(time)])) - - stacked_target = np.stack(targets) - - self._all_targets = np.array( - [stacked_target[i, :].T for i in range(stacked_target.shape[0])] - ) - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets diff --git a/stgraph/dataset/STGraphDataset.py b/stgraph/dataset/STGraphDataset.py deleted file mode 100644 index 319a464e..00000000 --- a/stgraph/dataset/STGraphDataset.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Base class for all STGraph dataset loaders""" - -import os -import json -import ssl -import urllib.request - -from abc import ABC, abstractmethod -from rich.console import Console - -console = Console() - - -class STGraphDataset(ABC): - r"""Abstract base class for graph dataset loaders - - The dataset handling is done as follows - - 1. Checks whether the dataset is present in cache. - 2. If not present in the cache, it downloads it from the URL. - 3. It then saves the downloaded file inside the cache. - 4. Incase it is present inside the cache, it directly loads it from there - 5. Dataset specific graph processing is then done - - Attributes - ---------- - - name : str - The name of the dataset - gdata : dict - Meta data associated with the dataset - - _dataset : dict - The loaded graph dataset - _url : str - The URL from where the dataset is downloaded online - _verbose : bool - Flag to control whether to display verbose info - _cache_folder : str - Folder inside ~/.stgraph where the dataset cache is stored - _cache_file_type : str - The file type used for storing the cached dataset - - Methods - ------- - - _has_dataset_cache() - Checks if the dataset is stored in cache - - _get_cache_file_path() - Returns the absolute path of the cached dataset file - - _init_graph_data() - Initialises the ``gdata`` attribute with all necessary meta data - - _process_dataset() - Processes the dataset to be used by STGraph - - _download_dataset() - Downloads the dataset using the URL - - _save_dataset() - Saves the dataset to cache - - _load_dataset() - Loads the dataset from cache - """ - - def __init__(self) -> None: - self.name = "" - self.gdata = {} - - self._dataset = {} - self._url = "" - self._verbose = False - self._cache_folder = "/dataset_cache/" - self._cache_file_type = "json" - - def _has_dataset_cache(self) -> bool: - r"""Checks if the dataset is stored in cache - - This private method checks whether the graph dataset cache file exists - in the dataset cache folder. The cache .json file is found in the following - directory ``~/.stgraph/dataset_cache/. - - Returns - ------- - bool - ``True`` if the cache file exists, else ``False`` - - Notes - ----- - The cache file is usually stored as a json file named as ``dataset_name.json`` and is stored - inside the ``~/.stgraph/dataset_cache/``. Incase the directory does not exists, it - is created by this method. - - This private method is intended for internal use within the class and should not be - called directly from outside the class. - - Example - ------- - - .. code-block:: python - - if self._has_dataset_cache(): - # The dataset is cached, continue cached operations - else: - # The dataset is not cached, continue load and save operations - """ - - user_home_dir = os.path.expanduser("~") - stgraph_dir = user_home_dir + "/.stgraph" - cache_dir = stgraph_dir + self._cache_folder - - if os.path.exists(stgraph_dir) == False: - os.system("mkdir " + stgraph_dir) - - if os.path.exists(cache_dir) == False: - os.system("mkdir " + cache_dir) - - cache_file_name = self.name + "." + self._cache_file_type - - return os.path.exists(cache_dir + cache_file_name) - - def _get_cache_file_path(self) -> str: - r"""Returns the absolute path of the cached dataset file - - Returns - ------- - str - The absolute path of the cached dataset file - """ - - user_home_dir = os.path.expanduser("~") - stgraph_dir = user_home_dir + "/.stgraph" - cache_dir = stgraph_dir + self._cache_folder - cache_file_name = self.name + "." + self._cache_file_type - - return cache_dir + cache_file_name - - def _delete_cached_dataset(self) -> None: - r"""Deletes the cached dataset file""" - - os.remove(self._get_cache_file_path()) - - @abstractmethod - def _init_graph_data(self) -> None: - r"""Initialises the ``gdata`` attribute with all necessary meta data - - This is an abstract method that is implemented by ``STGraphStaticDataset``. - The meta data is initialised based on the type of the graph dataset. The values - are calculated as key-value pairs by the respective dataloaders when they - are initialised. - """ - pass - - @abstractmethod - def _process_dataset(self) -> None: - r"""Processes the dataset to be used by STGraph - - This is an abstract method that is to be implemented by each dataset loader. - The implementation in specific to the nature of the dataset itself. The dataset is - processed in such a way that it can be smoothly used within STGraph. - """ - pass - - def _download_dataset(self) -> None: - r"""Downloads the dataset using the URL - - Downloads the dataset files from the URL set by default for each data loader or - by one provided by the user. If verbose mode is enabled, it displays download status. - """ - if self._verbose: - console.log( - f"[cyan bold]{self.name}[/cyan bold] not present in cache. Downloading right now." - ) - - context = ssl._create_unverified_context() - self._dataset = json.loads( - urllib.request.urlopen(self._url, context=context).read() - ) - - if self._verbose: - console.log(f"[cyan bold]{self.name}[/cyan bold] download complete.") - - def _save_dataset(self) -> None: - r"""Saves the dataset to cache - - Saves the downloaded dataset file to the cache folder. If verbose mode is enabled, - it displays the save information. - """ - with open(self._get_cache_file_path(), "w") as cache_file: - json.dump(self._dataset, cache_file) - - if self._verbose: - console.log( - f"[cyan bold]{self.name}[/cyan bold] dataset saved to cache" - ) - - def _load_dataset(self) -> None: - r"""Loads the dataset from cache - - Loads the caches dataset json file as a python dictionary. If verbose mode is enabled, - it displays the loading status. - """ - if self._verbose: - console.log(f"Loading [cyan bold]{self.name}[/cyan bold] from cache") - - with open(self._get_cache_file_path()) as cache_file: - self._dataset = json.load(cache_file) - - if self._verbose: - console.log( - f"Successfully loaded [cyan bold]{self.name}[/cyan bold] from cache" - ) diff --git a/stgraph/dataset/WikiMathDataLoader.py b/stgraph/dataset/WikiMathDataLoader.py deleted file mode 100644 index c71d2d2a..00000000 --- a/stgraph/dataset/WikiMathDataLoader.py +++ /dev/null @@ -1,86 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np -console = Console() - -class WikiMathDataLoader: - def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph= False) -> None: - self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' - self._verbose = verbose - self.for_stgraph = for_stgraph - self.lags = lags - - self._load_dataset() - self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets_and_features() - - def _load_dataset(self): - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/') - else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - def _get_edge_weights(self): - if self.for_stgraph: - edges = self._dataset["edges"] - edge_weights = self._dataset["weights"] - comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))] - comb_edge_list.sort(key=lambda x: (x[1], x[0])) - self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - else: - self._edge_weights = np.array(self._dataset["weights"]) - - def _get_targets_and_features(self): - targets = [] - for time in range(self.total_timestamps): - targets.append(np.array(self._dataset[str(time)]["y"])) - stacked_target = np.stack(targets) - standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( - np.std(stacked_target, axis=0) + 10 ** -10 - ) - self._all_targets = np.array( - [ - standardized_target[i, :].T - for i in range(len(targets)) - ] - ) - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets diff --git a/stgraph/dataset/WindmillOutputDataLoader.py b/stgraph/dataset/WindmillOutputDataLoader.py deleted file mode 100644 index 58674c40..00000000 --- a/stgraph/dataset/WindmillOutputDataLoader.py +++ /dev/null @@ -1,81 +0,0 @@ -import os -import json -from rich.console import Console -import numpy as np -console = Console() - -class WindmillOutputDataLoader: - def __init__(self, folder_name, dataset_name, lags, cutoff_time, verbose: bool = False, for_stgraph= False) -> None: - self.name = dataset_name - self._local_path = f'../../dataset/{folder_name}/{dataset_name}.json' - self._verbose = verbose - self.for_stgraph = for_stgraph - self.lags = lags - - self._load_dataset() - self.total_timestamps = min(self._dataset["time_periods"], cutoff_time) - - self._get_num_nodes() - self._get_num_edges() - self._get_edges() - self._get_edge_weights() - self._get_targets_and_features() - - def _load_dataset(self): - if os.path.exists(self._local_path): - dataset_file = open(self._local_path) - self._dataset = json.load(dataset_file) - if self._verbose: - console.log(f'Loading [cyan]{self.name}[/cyan] dataset from dataset/') - else: - console.log(f'Failed to find [cyan]{self.name}[/cyan] dataset from dataset') - quit() - - def _get_num_nodes(self): - node_set = set() - max_node_id = 0 - for edge in self._dataset["edges"]: - node_set.add(edge[0]) - node_set.add(edge[1]) - max_node_id = max(max_node_id, edge[0], edge[1]) - - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" - self.num_nodes = len(node_set) - - def _get_num_edges(self): - self.num_edges = len(self._dataset["edges"]) - - def _get_edges(self): - if self.for_stgraph: - self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - else: - self._edge_list = np.array(self._dataset["edges"]).T - - def _get_edge_weights(self): - if self.for_stgraph: - edges = self._dataset["edges"] - edge_weights = self._dataset["weights"] - comb_edge_list = [(edges[i][0], edges[i][1], edge_weights[i]) for i in range(len(edges))] - comb_edge_list.sort(key=lambda x: (x[1], x[0])) - self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - else: - self._edge_weights = np.array(self._dataset["weights"]) - - def _get_targets_and_features(self): - stacked_target = np.stack(self._dataset["block"]) - standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( - np.std(stacked_target, axis=0) + 10 ** -10 - ) - self._all_targets = [ - standardized_target[i, :].T - for i in range(self.total_timestamps) - ] - - def get_edges(self): - return self._edge_list - - def get_edge_weights(self): - return self._edge_weights - - def get_all_targets(self): - return self._all_targets diff --git a/stgraph/dataset/__init__.py b/stgraph/dataset/__init__.py index 9ae14678..120aa777 100644 --- a/stgraph/dataset/__init__.py +++ b/stgraph/dataset/__init__.py @@ -1,13 +1 @@ """Dataset loader provided by STGraph""" - -# TODO: Change this accordingly if needed -from stgraph.dataset.static.CoraDataLoader import CoraDataLoader - -from stgraph.dataset.temporal.HungaryCPDataLoader import HungaryCPDataLoader -from stgraph.dataset.temporal.METRLADataLoader import METRLADataLoader -from stgraph.dataset.temporal.MontevideoBusDataLoader import MontevideoBusDataLoader -from stgraph.dataset.temporal.PedalMeDataLoader import PedalMeDataLoader -from stgraph.dataset.temporal.WikiMathDataLoader import WikiMathDataLoader -from stgraph.dataset.temporal.WindmillOutputDataLoader import WindmillOutputDataLoader - -from stgraph.dataset.dynamic.EnglandCovidDataLoader import EnglandCovidDataLoader diff --git a/stgraph/dataset/dynamic/__init__.py b/stgraph/dataset/dynamic/__init__.py index e69de29b..4c0ddc36 100644 --- a/stgraph/dataset/dynamic/__init__.py +++ b/stgraph/dataset/dynamic/__init__.py @@ -0,0 +1 @@ +"""Dataset loaders for real-world dynamic datasets.""" diff --git a/stgraph/dataset/dynamic/EnglandCovidDataLoader.py b/stgraph/dataset/dynamic/england_covid_dataloader.py similarity index 71% rename from stgraph/dataset/dynamic/EnglandCovidDataLoader.py rename to stgraph/dataset/dynamic/england_covid_dataloader.py index 85130ffb..850e85ba 100644 --- a/stgraph/dataset/dynamic/EnglandCovidDataLoader.py +++ b/stgraph/dataset/dynamic/england_covid_dataloader.py @@ -1,27 +1,33 @@ +"""Dynamic dataset tracking COVID-19 cases in England's NUTS3 regions.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.dynamic.STGraphDynamicDataset import STGraphDynamicDataset +from stgraph.dataset.dynamic.stgraph_dynamic_dataset import STGraphDynamicDataset class EnglandCovidDataLoader(STGraphDynamicDataset): + """Dynamic dataset tracking COVID-19 cases in England's NUTS3 regions.""" + def __init__( - self, + self: EnglandCovidDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 8, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, ) -> None: - r"""Dynamic dataset tracking COVID-19 cases in England's NUTS3 regions + r"""Dynamic dataset tracking COVID-19 cases in England's NUTS3 regions. This dataset captures the interplay between COVID-19 cases and mobility in England's NUTS3 regions from March 3rd to May 12th. It is a directed and weighted graph that offers daily case count and movement of people between each region through node and edge features respectively. - This class provides functionality for loading, processing, and accessing the England - Covid dataset for use in deep learning tasks such as predicting the COVID cases - in a region. + This class provides functionality for loading, processing, and accessing + the England Covid dataset for use in deep learning tasks such as predicting + the COVID cases in a region. Example ------- @@ -42,7 +48,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -56,7 +61,6 @@ def __init__( Attributes ---------- - name : str The name of the dataset. _verbose : bool @@ -80,6 +84,10 @@ def __init__( self._verbose = verbose self._lags = lags self._cutoff_time = cutoff_time + self._all_features = None + self._all_targets = None + self._edge_list = None + self._edge_weights = None if not url: self._url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json" @@ -97,29 +105,30 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: EnglandCovidDataLoader) -> None: self._set_total_timestamps() self._set_targets_and_features() self._set_edge_info() self._presort_edge_weights() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: EnglandCovidDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time: self.gdata["total_timestamps"] = min( - self._dataset["time_periods"], self._cutoff_time + self._dataset["time_periods"], + self._cutoff_time, ) else: self.gdata["total_timestamps"] = self._dataset["time_periods"] - def _set_targets_and_features(self): - r"""Calculates and sets the target and feature attributes""" + def _set_targets_and_features(self: EnglandCovidDataLoader) -> None: + r"""Calculate and set the target and feature attributes.""" stacked_target = np.array(self._dataset["y"]) standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( np.std(stacked_target, axis=0) + 10**-10 @@ -134,32 +143,35 @@ def _set_targets_and_features(self): for i in range(self.gdata["total_timestamps"] - self._lags) ] - def _set_edge_info(self): - r"""Sets edge info such as edge list and edge weights""" + def _set_edge_info(self: EnglandCovidDataLoader) -> None: + r"""Set edge info such as edge list and edge weights.""" self._edge_list = [] self._edge_weights = [] for time in range(self.gdata["total_timestamps"]): - time_edge_list = [] time_edge_weights = [] - for edge in self._dataset["edge_mapping"]["edge_index"][str(time)]: - time_edge_list.append((edge[0], edge[1])) + time_edge_list = [ + (edge[0], edge[1]) + for edge in self._dataset["edge_mapping"]["edge_index"][str(time)] + ] - for weight in self._dataset["edge_mapping"]["edge_weight"][str(time)]: - time_edge_weights.append(weight) + time_edge_weights = list( + self._dataset["edge_mapping"]["edge_weight"][str(time)], + ) self._edge_list.append(time_edge_list) self._edge_weights.append(time_edge_weights) self.gdata["num_edges"][str(time)] = len(time_edge_list) self.gdata["num_nodes"][str(time)] = len( - {node for edge in time_edge_list for node in edge} + {node for edge in time_edge_list for node in edge}, ) - def _presort_edge_weights(self): - r""" + def _presort_edge_weights(self: EnglandCovidDataLoader) -> None: + r"""Presorting the edges. + Presorting edges according to (dest,src) since that is how eids are formed - allowing forward and backward kernel to access edge weights + allowing forward and backward kernel to access edge weights. """ final_edges_lst = [] final_edge_weights_lst = [] @@ -172,13 +184,13 @@ def _presort_edge_weights(self): edge_info_list = [] sorted_edge_weights_lst = [] - for j in range(len(weights)): - edge_info = (src_list[j], dst_list[j], weights[j]) - edge_info_list.append(edge_info) + for src, dst, weight in zip(src_list, dst_list, weights): + edge_info_list.append((src, dst, weight)) # since it has to be sorted according to the reverse order sorted_edge_info_list = sorted( - edge_info_list, key=lambda element: (element[1], element[0]) + edge_info_list, + key=lambda element: (element[1], element[0]), ) time_edge = [] @@ -193,18 +205,18 @@ def _presort_edge_weights(self): self._edge_list = final_edges_lst self._edge_weights = final_edge_weights_lst - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: EnglandCovidDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: EnglandCovidDataLoader) -> list: + r"""Return the edge weights.""" return self._edge_weights - def get_all_features(self): - r"""Returns the features for each timestamp""" + def get_all_features(self: EnglandCovidDataLoader) -> list: + r"""Return the features for each timestamp.""" return self._all_features - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: EnglandCovidDataLoader) -> list: + r"""Return the targets for each timestamp.""" return self._all_targets diff --git a/stgraph/dataset/dynamic/STGraphDynamicDataset.py b/stgraph/dataset/dynamic/stgraph_dynamic_dataset.py similarity index 52% rename from stgraph/dataset/dynamic/STGraphDynamicDataset.py rename to stgraph/dataset/dynamic/stgraph_dynamic_dataset.py index de1dc9f4..a50442e4 100644 --- a/stgraph/dataset/dynamic/STGraphDynamicDataset.py +++ b/stgraph/dataset/dynamic/stgraph_dynamic_dataset.py @@ -1,20 +1,20 @@ -"""Base class for all STGraph dynamic graph datasets""" +"""Base class for all STGraph dynamic graph datasets.""" -from stgraph.dataset.STGraphDataset import STGraphDataset +from __future__ import annotations +from stgraph.dataset.stgraph_dataset import STGraphDataset -class STGraphDynamicDataset(STGraphDataset): - r"""Base class for dynamic graph datasets - This class is a subclass of ``STGraphDataset`` and provides the base structure for - handling dynamic graph datasets.""" +class STGraphDynamicDataset(STGraphDataset): + r"""Base class for dynamic graph datasets.""" - def __init__(self) -> None: + def __init__(self: STGraphDynamicDataset) -> None: + r"""Provide the base structure for handling dynamic graph datasets.""" super().__init__() self._init_graph_data() - def _init_graph_data(self) -> dict: + def _init_graph_data(self: STGraphDynamicDataset) -> dict: r"""Initialize graph meta data for a dynamic dataset. The ``num_nodes``, ``num_edges``, ``total_timestamps`` keys are set to value 0 diff --git a/stgraph/dataset/static/STGraphStaticDataset.py b/stgraph/dataset/static/STGraphStaticDataset.py deleted file mode 100644 index 94edbd65..00000000 --- a/stgraph/dataset/static/STGraphStaticDataset.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Base class for all STGraph static graph datasets""" - -from rich.console import Console - -from stgraph.dataset.STGraphDataset import STGraphDataset - - -console = Console() - - -class STGraphStaticDataset(STGraphDataset): - r"""Base class for static graph datasets - - This class is a subclass of ``STGraphDataset`` and provides the base structure for - handling static graph datasets. - """ - - def __init__(self) -> None: - super().__init__() - - self._init_graph_data() - - def _init_graph_data(self) -> dict: - r"""Initialize graph meta data for a static dataset. - - The ``num_nodes`` and ``num_edges`` keys are set to value 0 - """ - self.gdata["num_nodes"] = 0 - self.gdata["num_edges"] = 0 diff --git a/stgraph/dataset/static/__init__.py b/stgraph/dataset/static/__init__.py index e69de29b..ff774c01 100644 --- a/stgraph/dataset/static/__init__.py +++ b/stgraph/dataset/static/__init__.py @@ -0,0 +1 @@ +"""Collection of dataset loaders for Static real-world datasets.""" diff --git a/stgraph/dataset/static/CoraDataLoader.py b/stgraph/dataset/static/cora_dataloader.py similarity index 69% rename from stgraph/dataset/static/CoraDataLoader.py rename to stgraph/dataset/static/cora_dataloader.py index 0e5b310f..60a39fad 100644 --- a/stgraph/dataset/static/CoraDataLoader.py +++ b/stgraph/dataset/static/cora_dataloader.py @@ -1,25 +1,36 @@ -import random +"""Citation network consisting of scientific publications.""" + + +from __future__ import annotations import numpy as np from rich.console import Console -from stgraph.dataset.static.STGraphStaticDataset import STGraphStaticDataset - +from stgraph.dataset.static.stgraph_static_dataset import STGraphStaticDataset console = Console() class CoraDataLoader(STGraphStaticDataset): - def __init__(self, verbose=False, url=None, redownload=False) -> None: - r"""Citation network consisting of scientific publications - - The Cora dataset consists of 2708 scientific publications classified into one of seven classes. - The citation network consists of 5429 links. Each publication in the dataset is described by a 0/1-valued - word vector indicating the absence/presence of the corresponding word from the dictionary. + """Citation network consisting of scientific publications.""" + + def __init__( + self: CoraDataLoader, + verbose: bool = False, + url: str | None = None, + redownload: bool = False, + ) -> None: + r"""Citation network consisting of scientific publications. + + The Cora dataset consists of 2708 scientific publications classified into + one of seven classes. The citation network consists of 5429 links. Each + publication in the dataset is described by a 0/1-valued word vector + indicating the absence/presence of the corresponding word from the dictionary. The dictionary consists of 1433 unique words. - This class provides functionality for loading, processing, and accessing the Cora dataset - for use in deep learning tasks such as graph-based node classification. + This class provides functionality for loading, processing, and accessing the + Cora dataset for use in deep learning tasks such as graph-based + node classification. .. list-table:: gdata :widths: 25 25 25 25 @@ -47,7 +58,6 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None: Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -72,6 +82,9 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None: self.name = "Cora" self._verbose = verbose + self._edge_list = None + self._all_features = None + self._all_targets = None if not url: self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/cora.json" @@ -89,7 +102,7 @@ def __init__(self, verbose=False, url=None, redownload=False) -> None: self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: CoraDataLoader) -> None: r"""Process the Cora dataset. Calls private methods to extract edge list, node features, target classes @@ -99,23 +112,23 @@ def _process_dataset(self) -> None: self._set_targets_and_features() self._set_graph_attributes() - def _set_edge_info(self) -> None: - r"""Extract edge information from the dataset""" + def _set_edge_info(self: CoraDataLoader) -> None: + r"""Extract edge information from the dataset.""" edges = np.array(self._dataset["edges"]) edge_list = [] - for i in range(len(edges)): - edge = edges[i] - edge_list.append((edge[0], edge[1])) + + for src, dst in edges: + edge_list.append((src, dst)) self._edge_list = edge_list - def _set_targets_and_features(self): + def _set_targets_and_features(self: CoraDataLoader) -> None: r"""Extract targets and features from the dataset.""" self._all_features = np.array(self._dataset["features"]) self._all_targets = np.array(self._dataset["labels"]).T - def _set_graph_attributes(self): - r"""Calculates and stores graph meta data inside ``gdata``""" + def _set_graph_attributes(self: CoraDataLoader) -> None: + r"""Calculate and stores graph meta data inside ``gdata``.""" node_set = set() for edge in self._edge_list: node_set.add(edge[0]) @@ -126,14 +139,14 @@ def _set_graph_attributes(self): self.gdata["num_feats"] = len(self._all_features[0]) self.gdata["num_classes"] = len(set(self._all_targets)) - def get_edges(self) -> list: + def get_edges(self: CoraDataLoader) -> list: r"""Get the edge list.""" return self._edge_list - def get_all_features(self) -> np.ndarray: + def get_all_features(self: CoraDataLoader) -> np.ndarray: r"""Get all features.""" return self._all_features - def get_all_targets(self) -> np.ndarray: + def get_all_targets(self: CoraDataLoader) -> np.ndarray: r"""Get all targets.""" return self._all_targets diff --git a/stgraph/dataset/static/stgraph_static_dataset.py b/stgraph/dataset/static/stgraph_static_dataset.py new file mode 100644 index 00000000..d4225700 --- /dev/null +++ b/stgraph/dataset/static/stgraph_static_dataset.py @@ -0,0 +1,23 @@ +"""Base class for all STGraph static graph datasets.""" + +from __future__ import annotations + +from stgraph.dataset.stgraph_dataset import STGraphDataset + + +class STGraphStaticDataset(STGraphDataset): + r"""Base class for static graph datasets.""" + + def __init__(self: STGraphStaticDataset) -> None: + r"""Provide the base structure for handling static graph datasets.""" + super().__init__() + + self._init_graph_data() + + def _init_graph_data(self: STGraphStaticDataset) -> dict: + r"""Initialize graph meta data for a static dataset. + + The ``num_nodes`` and ``num_edges`` keys are set to value 0 + """ + self.gdata["num_nodes"] = 0 + self.gdata["num_edges"] = 0 diff --git a/stgraph/dataset/stgraph_dataset.py b/stgraph/dataset/stgraph_dataset.py new file mode 100644 index 00000000..8e958e46 --- /dev/null +++ b/stgraph/dataset/stgraph_dataset.py @@ -0,0 +1,219 @@ +"""Base class for all STGraph dataset loaders.""" + +from __future__ import annotations + +import json +import os +import ssl +import urllib.request +from abc import ABC, abstractmethod + +from rich.console import Console + +console = Console() + + +class STGraphDataset(ABC): + r"""Abstract base class for graph dataset loaders.""" + + def __init__(self: STGraphDataset) -> None: + r"""Abstract base class for graph dataset loaders. + + The dataset handling is done as follows + + 1. Checks whether the dataset is present in cache. + 2. If not present in the cache, it downloads it from the URL. + 3. It then saves the downloaded file inside the cache. + 4. Incase it is present inside the cache, it directly loads it from there + 5. Dataset specific graph processing is then done + + Attributes + ---------- + name : str + The name of the dataset + gdata : dict + Meta data associated with the dataset + + _dataset : dict + The loaded graph dataset + _url : str + The URL from where the dataset is downloaded online + _verbose : bool + Flag to control whether to display verbose info + _cache_folder : str + Folder inside ~/.stgraph where the dataset cache is stored + _cache_file_type : str + The file type used for storing the cached dataset + + Methods + ------- + _has_dataset_cache() + Checks if the dataset is stored in cache + + _get_cache_file_path() + Returns the absolute path of the cached dataset file + + _init_graph_data() + Initialises the ``gdata`` attribute with all necessary meta data + + _process_dataset() + Processes the dataset to be used by STGraph + + _download_dataset() + Downloads the dataset using the URL + + _save_dataset() + Saves the dataset to cache + + _load_dataset() + Loads the dataset from cache + """ + self.name = "" + self.gdata = {} + + self._dataset = {} + self._url = "" + self._verbose = False + self._cache_folder = "/dataset_cache/" + self._cache_file_type = "json" + + def _has_dataset_cache(self: STGraphDataset) -> bool: + r"""Check if the dataset is stored in cache. + + This private method checks whether the graph dataset cache file exists + in the dataset cache folder. The cache .json file is found in the following + directory ``~/.stgraph/dataset_cache/. + + Returns: + ------- + bool + ``True`` if the cache file exists, else ``False`` + + Notes: + ----- + The cache file is usually stored as a json file named as + ``dataset_name.json`` and is stored inside the ``~/.stgraph/dataset_cache/``. + Incase the directory does not exists, it is created by this method. + + This private method is intended for internal use within the class and + should not be called directly from outside the class. + + Example: + ------- + .. code-block:: python + + if self._has_dataset_cache(): + # The dataset is cached, continue cached operations + else: + # The dataset is not cached, continue load and save operations + """ + user_home_dir = os.path.expanduser("~") + stgraph_dir = user_home_dir + "/.stgraph" + cache_dir = stgraph_dir + self._cache_folder + + if os.path.exists(stgraph_dir) is False: + os.system("mkdir " + stgraph_dir) + + if os.path.exists(cache_dir) is False: + os.system("mkdir " + cache_dir) + + cache_file_name = self.name + "." + self._cache_file_type + + return os.path.exists(cache_dir + cache_file_name) + + def _get_cache_file_path(self: STGraphDataset) -> str: + r"""Return the absolute path of the cached dataset file. + + Returns + ------- + str + The absolute path of the cached dataset file + """ + user_home_dir = os.path.expanduser("~") + stgraph_dir = user_home_dir + "/.stgraph" + cache_dir = stgraph_dir + self._cache_folder + cache_file_name = self.name + "." + self._cache_file_type + + return cache_dir + cache_file_name + + def _delete_cached_dataset(self: STGraphDataset) -> None: + r"""Delete the cached dataset file.""" + os.remove(self._get_cache_file_path()) + + @abstractmethod + def _init_graph_data(self: STGraphDataset) -> None: + r"""Initialise the ``gdata`` attribute with all necessary meta data. + + This is an abstract method that is implemented by ``STGraphStaticDataset``. + The meta data is initialised based on the type of the graph dataset. The + values are calculated as key-value pairs by the respective dataloaders + when they are initialised. + """ + + @abstractmethod + def _process_dataset(self: STGraphDataset) -> None: + r"""Process the dataset to be used by STGraph. + + This is an abstract method that is to be implemented by each dataset loader. + The implementation in specific to the nature of the dataset itself. + The dataset is processed in such a way that it can be smoothly used + within STGraph. + """ + + def _download_dataset(self: STGraphDataset) -> None: + r"""Download the dataset using the URL. + + Downloads the dataset files from the URL set by default for each data loader + or by one provided by the user. If verbose mode is enabled, it displays + download status. + """ + if self._verbose: + console.log( + f"[cyan bold]{self.name}[/cyan bold] not present in cache." + "Downloading right now.", + ) + + if not self._url.startswith(("http:", "https:")): + raise ValueError("URL must start with 'http:' or 'https:'") + + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + self._dataset = json.loads( + urllib.request.urlopen(self._url, context=context).read(), + ) + + if self._verbose: + console.log(f"[cyan bold]{self.name}[/cyan bold] download complete.") + + def _save_dataset(self: STGraphDataset) -> None: + r"""Save the dataset to cache. + + Saves the downloaded dataset file to the cache folder. If verbose mode + is enabled, it displays the save information. + """ + with open(self._get_cache_file_path(), "w") as cache_file: + json.dump(self._dataset, cache_file) + + if self._verbose: + console.log( + f"[cyan bold]{self.name}[/cyan bold] dataset saved to cache", + ) + + def _load_dataset(self: STGraphDataset) -> None: + r"""Load the dataset from cache. + + Loads the caches dataset json file as a python dictionary. If verbose mode + is enabled, it displays the loading status. + """ + if self._verbose: + console.log(f"Loading [cyan bold]{self.name}[/cyan bold] from cache") + + with open(self._get_cache_file_path()) as cache_file: + self._dataset = json.load(cache_file) + + if self._verbose: + console.log( + f"Successfully loaded [cyan bold]{self.name}[/cyan bold] from cache", + ) diff --git a/stgraph/dataset/temporal/__init__.py b/stgraph/dataset/temporal/__init__.py index e69de29b..2d8ab191 100644 --- a/stgraph/dataset/temporal/__init__.py +++ b/stgraph/dataset/temporal/__init__.py @@ -0,0 +1 @@ +"""Collection of dataset loaders for Temporal real-world datasets.""" diff --git a/stgraph/dataset/temporal/HungaryCPDataLoader.py b/stgraph/dataset/temporal/hungarycp_dataloader.py similarity index 67% rename from stgraph/dataset/temporal/HungaryCPDataLoader.py rename to stgraph/dataset/temporal/hungarycp_dataloader.py index 3df40635..688d3198 100644 --- a/stgraph/dataset/temporal/HungaryCPDataLoader.py +++ b/stgraph/dataset/temporal/hungarycp_dataloader.py @@ -1,26 +1,33 @@ +"""Temporal dataset for County level chicken pox cases in Hungary.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class HungaryCPDataLoader(STGraphTemporalDataset): + """Temporal dataset provided for County level chicken pox cases in Hungary.""" + def __init__( - self, + self: HungaryCPDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 4, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, ) -> None: - r"""County level chicken pox cases in Hungary + r"""County level chicken pox cases in Hungary. This dataset comprises information on weekly occurrences of chickenpox in Hungary from 2005 to 2015. The graph structure is static with nodes representing the counties and edges are neighbourhoods between them. Vertex features are lagged weekly counts of the chickenpox cases. - This class provides functionality for loading, processing, and accessing the Hungary - Chickenpox dataset for use in deep learning tasks such as County level case count prediction. + This class provides functionality for loading, processing, and accessing + the Hungary Chickenpox dataset for use in deep learning tasks such as + County level case count prediction. .. list-table:: gdata :widths: 33 33 33 @@ -46,7 +53,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -75,23 +81,25 @@ def __init__( _all_targets : numpy.ndarray Numpy array of the node target value """ - super().__init__() - if type(lags) != int: + if not isinstance(lags, int): raise TypeError("lags must be of type int") if lags < 0: raise ValueError("lags must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") self.name = "Hungary_Chickenpox" self._verbose = verbose self._lags = lags self._cutoff_time = cutoff_time + self._edge_list = None + self._edge_weights = None + self._all_targets = None if not url: self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/HungaryCP.json" @@ -109,7 +117,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: HungaryCPDataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -117,23 +125,24 @@ def _process_dataset(self) -> None: self._set_edge_weights() self._set_targets_and_features() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: HungaryCPDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - len(self._dataset["FX"]), self._cutoff_time + len(self._dataset["FX"]), + self._cutoff_time, ) else: self.gdata["total_timestamps"] = len(self._dataset["FX"]) - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: HungaryCPDataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -141,23 +150,25 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise RuntimeError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: HungaryCPDataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: HungaryCPDataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: HungaryCPDataLoader) -> None: + r"""Set the edge weights of the dataset.""" self._edge_weights = np.ones(self.gdata["num_edges"]) - def _set_targets_and_features(self): - r"""Calculates and sets the target and feature attributes""" + def _set_targets_and_features(self: HungaryCPDataLoader) -> None: + r"""Calculate and set the target and feature attributes.""" stacked_target = np.array(self._dataset["FX"]) self._all_targets = [ @@ -165,14 +176,14 @@ def _set_targets_and_features(self): for i in range(self.gdata["total_timestamps"] - self._lags) ] - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: HungaryCPDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: HungaryCPDataLoader) -> np.ndarray: + r"""Return the edge weights.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: HungaryCPDataLoader) -> np.ndarray: + r"""Return the targets for each timestamp.""" return self._all_targets diff --git a/stgraph/dataset/temporal/METRLADataLoader.py b/stgraph/dataset/temporal/metrla_dataloader.py similarity index 65% rename from stgraph/dataset/temporal/METRLADataLoader.py rename to stgraph/dataset/temporal/metrla_dataloader.py index 734202f1..c810d4dd 100644 --- a/stgraph/dataset/temporal/METRLADataLoader.py +++ b/stgraph/dataset/temporal/metrla_dataloader.py @@ -1,28 +1,35 @@ -import torch +"""Temporal dataset for traffic forecasting based on Los Angeles city.""" + +from __future__ import annotations + import numpy as np +import torch -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class METRLADataLoader(STGraphTemporalDataset): + """Temporal dataset for traffic forecasting based on the Los Angeles city.""" + def __init__( - self, + self: METRLADataLoader, verbose: bool = True, - url: str = None, + url: str | None = None, num_timesteps_in: int = 12, num_timesteps_out: int = 12, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, - ): - r"""A traffic forecasting dataset based on Los Angeles Metropolitan traffic conditions. + ) -> None: + r"""Traffic forecasting dataset based on the Los Angeles city.. A dataset for predicting traffic patterns in the Los Angeles Metropolitan area, - comprising traffic data obtained from 207 loop detectors on highways in Los Angeles County. - The dataset includes aggregated 5-minute interval readings spanning a four-month - period from March 2012 to June 2012. + comprising traffic data obtained from 207 loop detectors on highways in Los + Angeles County. The dataset includes aggregated 5-minute interval readings + spanning a four-month period from March 2012 to June 2012. - This class provides functionality for loading, processing, and accessing the METRLA - dataset for use in deep learning tasks such as traffic forecasting. + This class provides functionality for loading, processing, and accessing + the METRLA dataset for use in deep learning tasks such as traffic + forecasting. .. list-table:: gdata :widths: 33 33 33 @@ -54,7 +61,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -89,22 +95,21 @@ def __init__( _all_targets : numpy.ndarray Numpy array of the node target value """ - super().__init__() - if type(num_timesteps_in) != int: + if not isinstance(num_timesteps_in, int): raise TypeError("num_timesteps_in must be of type int") if num_timesteps_in < 0: raise ValueError("num_timesteps_in must be a positive integer") - if type(num_timesteps_out) != int: + if not isinstance(num_timesteps_out, int): raise TypeError("num_timesteps_out must be of type int") if num_timesteps_out < 0: raise ValueError("num_timesteps_out must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") self.name = "METRLA" @@ -112,6 +117,10 @@ def __init__( self._num_timesteps_in = num_timesteps_in self._num_timesteps_out = num_timesteps_out self._cutoff_time = cutoff_time + self._edge_list = None + self._edge_weights = None + self._all_features = None + self._all_targets = None if not url: self._url = "https://raw.githubusercontent.com/bfGraph/STGraph-Datasets/main/METRLA.json" @@ -129,7 +138,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: METRLADataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -137,23 +146,24 @@ def _process_dataset(self) -> None: self._set_edge_weights() self._set_targets_and_features() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: METRLADataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - self._dataset["time_periods"], self._cutoff_time + self._dataset["time_periods"], + self._cutoff_time, ) else: self.gdata["total_timestamps"] = self._dataset["time_periods"] - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: METRLADataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -161,19 +171,21 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise RuntimeError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: METRLADataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: METRLADataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: METRLADataLoader) -> None: + r"""Set the edge weights of the dataset.""" edges = self._dataset["edges"] edge_weights = self._dataset["weights"] comb_edge_list = [ @@ -182,53 +194,53 @@ def _set_edge_weights(self): comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - def _set_targets_and_features(self): - r"""Calculates and sets the target and feature attributes""" - X = [] - - for timestamp in range(self.gdata["total_timestamps"]): - X.append(self._dataset[str(timestamp)]) + def _set_targets_and_features(self: METRLADataLoader) -> None: + r"""Calculate and set the target and feature attributes.""" + x = [ + self._dataset[str(timestamp)] + for timestamp in range(self.gdata["total_timestamps"]) + ] - X = np.array(X) - X = X.transpose((1, 2, 0)) - X = X.astype(np.float32) + x = np.array(x).transpose(1, 2, 0).astype(np.float32) + # x = x.transpose((1, 2, 0)) + # x = x.astype(np.float32) # Normalise as in DCRNN paper (via Z-Score Method) - means = np.mean(X, axis=(0, 2)) - X = X - means.reshape(1, -1, 1) - stds = np.std(X, axis=(0, 2)) - X = X / stds.reshape(1, -1, 1) + means = np.mean(x, axis=(0, 2)) + x = x - means.reshape(1, -1, 1) + stds = np.std(x, axis=(0, 2)) + x = x / stds.reshape(1, -1, 1) - X = torch.from_numpy(X) + x = torch.from_numpy(x) indices = [ (i, i + (self._num_timesteps_in + self._num_timesteps_out)) for i in range( - X.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1 + x.shape[2] - (self._num_timesteps_in + self._num_timesteps_out) + 1, ) ] # Generate observations features, target = [], [] for i, j in indices: - features.append((X[:, :, i : i + self._num_timesteps_in]).numpy()) - target.append((X[:, 0, i + self._num_timesteps_in : j]).numpy()) + features.append((x[:, :, i : i + self._num_timesteps_in]).numpy()) + target.append((x[:, 0, i + self._num_timesteps_in : j]).numpy()) self._all_features = np.array(features) self._all_targets = np.array(target) - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: METRLADataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: METRLADataLoader) -> np.ndarray: + r"""Return the edge weights.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: METRLADataLoader) -> np.ndarray: + r"""Return the targets for each timestamp.""" return self._all_targets - def get_all_features(self): - r"""Returns the features for each timestamp""" + def get_all_features(self: METRLADataLoader) -> np.ndarray: + r"""Return the features for each timestamp.""" return self._all_features diff --git a/stgraph/dataset/temporal/MontevideoBusDataLoader.py b/stgraph/dataset/temporal/montevideobus_dataloader.py similarity index 70% rename from stgraph/dataset/temporal/MontevideoBusDataLoader.py rename to stgraph/dataset/temporal/montevideobus_dataloader.py index dd6c2e9f..fee3b54c 100644 --- a/stgraph/dataset/temporal/MontevideoBusDataLoader.py +++ b/stgraph/dataset/temporal/montevideobus_dataloader.py @@ -1,18 +1,24 @@ +"""Temporal dataset of inflow passenger at bus stop level from Montevideo city.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class MontevideoBusDataLoader(STGraphTemporalDataset): + r"""Temporal dataset of inflow passenger at bus stop level from Montevideo city.""" + def __init__( - self, + self: MontevideoBusDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 4, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, ) -> None: - r"""A dataset of inflow passenger at bus stop level from Montevideo city. + r"""Temporal dataset of inflow passenger at bus stop level from Montevideo city. This dataset compiles hourly passenger inflow data for 11 key bus lines in Montevideo, Uruguay, during October 2020. Focused on routes to the city @@ -22,7 +28,8 @@ def __init__( Metropolitan Transportation System (STM). This class provides functionality for loading, processing, and accessing the - Montevideo Bus dataset for use in deep learning tasks such as passenger inflow prediction. + Montevideo Bus dataset for use in deep learning tasks such as + passenger inflow prediction. .. list-table:: gdata :widths: 33 33 33 @@ -54,7 +61,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -85,19 +91,18 @@ def __init__( _all_features : numpy.ndarray Numpy array of the node feature value """ - super().__init__() - if type(lags) != int: + if not isinstance(lags, int): raise TypeError("lags must be of type int") if lags < 0: raise ValueError("lags must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") - if cutoff_time != None and cutoff_time <= lags: + if cutoff_time is not None and cutoff_time <= lags: raise ValueError("cutoff_time must be greater than lags") self.name = "Montevideo_Bus" @@ -121,7 +126,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: MontevideoBusDataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -130,23 +135,24 @@ def _process_dataset(self) -> None: self._set_features() self._set_targets() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: MontevideoBusDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - len(self._dataset["nodes"][0]["y"]), self._cutoff_time + len(self._dataset["nodes"][0]["y"]), + self._cutoff_time, ) else: self.gdata["total_timestamps"] = len(self._dataset["nodes"][0]["y"]) - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: MontevideoBusDataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -154,19 +160,21 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise ValueError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: MontevideoBusDataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: MontevideoBusDataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: MontevideoBusDataLoader) -> None: + r"""Set the edge weights of the dataset.""" edges = self._dataset["edges"] edge_weights = self._dataset["weights"] comb_edge_list = [ @@ -175,16 +183,13 @@ def _set_edge_weights(self): comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - def _set_features(self): - r"""Calculates and sets the feature attributes""" - features = [] - - for node in self._dataset["nodes"]: - X = node.get("X") - for feature_var in ["y"]: - features.append( - np.array(X.get(feature_var)[: self.gdata["total_timestamps"]]) - ) + def _set_features(self: MontevideoBusDataLoader) -> None: + r"""Calculate and set the feature attributes.""" + features = [ + np.array(node.get("X").get(feature_var)[: self.gdata["total_timestamps"]]) + for node in self._dataset["nodes"] + for feature_var in ["y"] + ] stacked_features = np.stack(features).T standardized_features = ( @@ -195,11 +200,11 @@ def _set_features(self): [ standardized_features[i : i + self._lags, :].T for i in range(len(standardized_features) - self._lags) - ] + ], ) - def _set_targets(self): - r"""Calculates and sets the target attributes""" + def _set_targets(self: MontevideoBusDataLoader) -> None: + r"""Calculate and set the target attributes.""" targets = [] for node in self._dataset["nodes"]: y = node.get("y")[: self.gdata["total_timestamps"]] @@ -214,21 +219,21 @@ def _set_targets(self): [ standardized_targets[i + self._lags, :].T for i in range(len(standardized_targets) - self._lags) - ] + ], ) - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: MontevideoBusDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: MontevideoBusDataLoader) -> np.ndarray: + r"""Return the edge weights.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: MontevideoBusDataLoader) -> np.ndarray: + r"""Return the targets for each timestamp.""" return self._all_targets - def get_all_features(self): - r"""Returns the features for each timestamp""" + def get_all_features(self: MontevideoBusDataLoader) -> np.ndarray: + r"""Return the features for each timestamp.""" return self._all_features diff --git a/stgraph/dataset/temporal/PedalMeDataLoader.py b/stgraph/dataset/temporal/pedalme_dataloader.py similarity index 67% rename from stgraph/dataset/temporal/PedalMeDataLoader.py rename to stgraph/dataset/temporal/pedalme_dataloader.py index c149a244..56a70fdb 100644 --- a/stgraph/dataset/temporal/PedalMeDataLoader.py +++ b/stgraph/dataset/temporal/pedalme_dataloader.py @@ -1,21 +1,28 @@ +"""Temporal dataset of PedalMe Bicycle deliver orders in London.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class PedalMeDataLoader(STGraphTemporalDataset): + """Temporal dataset of PedalMe Bicycle deliver orders in London.""" + def __init__( - self, + self: PedalMeDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 4, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, ) -> None: - r"""A dataset of PedalMe Bicycle deliver orders in London. + r"""Temporal dataset of PedalMe Bicycle deliver orders in London. - This class provides functionality for loading, processing, and accessing the PedalMe - dataset for use in deep learning tasks such as node classification. + This class provides functionality for loading, processing, and + accessing the PedalMe dataset for use in deep learning tasks + such as node classification. .. list-table:: gdata :widths: 33 33 33 @@ -46,7 +53,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -75,19 +81,18 @@ def __init__( _all_targets : numpy.ndarray Numpy array of the node target value """ - super().__init__() - if type(lags) != int: + if not isinstance(lags, int): raise TypeError("lags must be of type int") if lags < 0: raise ValueError("lags must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") - if cutoff_time != None and cutoff_time <= lags: + if cutoff_time is not None and cutoff_time <= lags: raise ValueError("cutoff_time must be greater than lags") self.name = "PedalMe" @@ -111,7 +116,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: PedalMeDataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -120,23 +125,24 @@ def _process_dataset(self) -> None: self._set_targets() self._set_features() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: PedalMeDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - self._dataset["time_periods"], self._cutoff_time + self._dataset["time_periods"], + self._cutoff_time, ) else: self.gdata["total_timestamps"] = self._dataset["time_periods"] - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: PedalMeDataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -144,19 +150,21 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise ValueError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: PedalMeDataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: PedalMeDataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: PedalMeDataLoader) -> None: + r"""Set the edge weights of the dataset.""" edges = self._dataset["edges"] edge_weights = self._dataset["weights"] comb_edge_list = [ @@ -165,11 +173,12 @@ def _set_edge_weights(self): comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - def _set_targets(self): - r"""Calculates and sets the target attributes""" - targets = [] - for time in range(self.gdata["total_timestamps"]): - targets.append(np.array(self._dataset[str(time)])) + def _set_targets(self: PedalMeDataLoader) -> None: + r"""Calculate and set the target attributes.""" + targets = [ + np.array(self._dataset[str(time)]) + for time in range(self.gdata["total_timestamps"]) + ] stacked_target = np.stack(targets) @@ -177,21 +186,20 @@ def _set_targets(self): [ stacked_target[i + self._lags, :].T for i in range(stacked_target.shape[0] - self._lags) - ] + ], ) - def _set_features(self): - # TODO: + def _set_features(self: PedalMeDataLoader) -> None: pass - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: PedalMeDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: PedalMeDataLoader) -> np.ndarray: + r"""Return the edge weights.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: PedalMeDataLoader) -> np.ndarray: + r"""Return the targets for each timestamp.""" return self._all_targets diff --git a/stgraph/dataset/temporal/STGraphTemporalDataset.py b/stgraph/dataset/temporal/stgraph_temporal_dataset.py similarity index 61% rename from stgraph/dataset/temporal/STGraphTemporalDataset.py rename to stgraph/dataset/temporal/stgraph_temporal_dataset.py index 287e4659..e7712461 100644 --- a/stgraph/dataset/temporal/STGraphTemporalDataset.py +++ b/stgraph/dataset/temporal/stgraph_temporal_dataset.py @@ -1,24 +1,28 @@ -"""Base class for all STGraph temporal graph datasets""" +"""Base class for all STGraph temporal graph datasets.""" + +from __future__ import annotations from rich.console import Console -from stgraph.dataset.STGraphDataset import STGraphDataset + +from stgraph.dataset.stgraph_dataset import STGraphDataset console = Console() class STGraphTemporalDataset(STGraphDataset): - r"""Base class for temporal graph datasets + r"""Base class for temporal graph datasets. This class is a subclass of ``STGraphDataset`` and provides the base structure for handling temporal graph datasets. """ - def __init__(self) -> None: + def __init__(self: STGraphTemporalDataset) -> None: + r"""Provide the base structure for handling temporal graph datasets.""" super().__init__() self._init_graph_data() - def _init_graph_data(self) -> dict: + def _init_graph_data(self: STGraphTemporalDataset) -> dict: r"""Initialize graph meta data for a temporal dataset. The ``num_nodes``, ``num_edges``, ``total_timestamps`` keys are set to value 0 diff --git a/stgraph/dataset/temporal/WikiMathDataLoader.py b/stgraph/dataset/temporal/wikimath_dataloader.py similarity index 65% rename from stgraph/dataset/temporal/WikiMathDataLoader.py rename to stgraph/dataset/temporal/wikimath_dataloader.py index ea972146..d1a2034e 100644 --- a/stgraph/dataset/temporal/WikiMathDataLoader.py +++ b/stgraph/dataset/temporal/wikimath_dataloader.py @@ -1,27 +1,35 @@ +r"""Temporal dataset of vital mathematical articles sourced from Wikipedia.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class WikiMathDataLoader(STGraphTemporalDataset): + r"""Temporal dataset of vital mathematical articles sourced from Wikipedia.""" + def __init__( - self, + self: WikiMathDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 8, - cutoff_time: int = None, + cutoff_time: int | None = None, redownload: bool = False, ) -> None: - r"""A dataset of vital mathematical articles sourced from Wikipedia. + r"""Temporal dataset of vital mathematical articles sourced from Wikipedia. The graph dataset is static, with vertices representing Wikipedia pages and - edges representing links. The graph is both directed and weighted, where the weights - indicate the number of links originating from the source page connecting - to the target page. The target is the daily user visits to the Wikipedia pages - between March 16th 2019 and March 15th 2021 which results in 731 periods. + edges representing links. The graph is both directed and weighted, where + the weights indicate the number of links originating from the source page + connecting to the target page. The target is the daily user visits to the + Wikipedia pages between March 16th 2019 and March 15th 2021 which results + in 731 periods. - This class provides functionality for loading, processing, and accessing the Hungary - Chickenpox dataset for use in deep learning tasks such as County level case count prediction. + This class provides functionality for loading, processing, and accessing + the Hungary Chickenpox dataset for use in deep learning tasks such as County + level case count prediction. .. list-table:: gdata :widths: 33 33 33 @@ -52,7 +60,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -83,18 +90,16 @@ def __init__( """ super().__init__() - if type(lags) != int: + if not isinstance(lags, int): raise TypeError("lags must be of type int") if lags < 0: raise ValueError("lags must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") - # TODO: Add check for cutoff_time <= lags - self.name = "WikiMath" self._verbose = verbose self._lags = lags @@ -116,7 +121,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: WikiMathDataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -125,23 +130,24 @@ def _process_dataset(self) -> None: self._set_targets() self._set_features() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: WikiMathDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - self._dataset["time_periods"], self._cutoff_time + self._dataset["time_periods"], + self._cutoff_time, ) else: self.gdata["total_timestamps"] = self._dataset["time_periods"] - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: WikiMathDataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -149,19 +155,21 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise ValueError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: WikiMathDataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: WikiMathDataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: WikiMathDataLoader) -> None: + r"""Set the edge weights of the dataset.""" edges = self._dataset["edges"] edge_weights = self._dataset["weights"] comb_edge_list = [ @@ -170,31 +178,32 @@ def _set_edge_weights(self): comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - def _set_targets(self): - r"""Calculates and sets the target attributes""" - targets = [] - for time in range(self.gdata["total_timestamps"]): - targets.append(np.array(self._dataset[str(time)]["y"])) + def _set_targets(self: WikiMathDataLoader) -> None: + r"""Calculate and set the target attributes.""" + targets = [ + np.array(self._dataset[str(time)]["y"]) + for time in range(self.gdata["total_timestamps"]) + ] + stacked_target = np.stack(targets) standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( np.std(stacked_target, axis=0) + 10**-10 ) self._all_targets = np.array( - [standardized_target[i, :].T for i in range(len(targets))] + [standardized_target[i, :].T for i in range(len(targets))], ) - def _set_features(self): - # TODO: + def _set_features(self: WikiMathDataLoader) -> None: pass - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: WikiMathDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: WikiMathDataLoader) -> np.ndarray: + r"""Return the edge weights.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: WikiMathDataLoader) -> np.ndarray: + r"""Return the targets for each timestamp.""" return self._all_targets diff --git a/stgraph/dataset/temporal/WindmillOutputDataLoader.py b/stgraph/dataset/temporal/windmilloutput_dataloader.py similarity index 67% rename from stgraph/dataset/temporal/WindmillOutputDataLoader.py rename to stgraph/dataset/temporal/windmilloutput_dataloader.py index 7c8efb32..326649ff 100644 --- a/stgraph/dataset/temporal/WindmillOutputDataLoader.py +++ b/stgraph/dataset/temporal/windmilloutput_dataloader.py @@ -1,22 +1,29 @@ +r"""Temporal dataset of hourly energy output of windmills.""" + +from __future__ import annotations + import numpy as np -from stgraph.dataset.temporal.STGraphTemporalDataset import STGraphTemporalDataset +from stgraph.dataset.temporal.stgraph_temporal_dataset import STGraphTemporalDataset class WindmillOutputDataLoader(STGraphTemporalDataset): + r"""Temporal dataset of hourly energy output of windmills.""" + def __init__( - self, + self: WindmillOutputDataLoader, verbose: bool = False, - url: str = None, + url: str | None = None, lags: int = 8, - cutoff_time: int = None, + cutoff_time: int | None = None, size: str = "large", redownload: bool = False, ) -> None: - r"""Hourly energy output of windmills from a European country for more than 2 years. + r"""Temporal dataset of hourly energy output of windmills. - This class provides functionality for loading, processing, and accessing the Windmill - output dataset for use in deep learning such as regression tasks. + This class provides functionality for loading, processing, and accessing + the Windmill output dataset for use in deep learning such as + regression tasks. .. list-table:: gdata for Windmill Output Small :widths: 33 33 33 @@ -69,7 +76,6 @@ def __init__( Parameters ---------- - verbose : bool, optional Flag to control whether to display verbose info (default is False) url : str, optional @@ -102,23 +108,22 @@ def __init__( """ super().__init__() - if type(lags) != int: + if not isinstance(lags, int): raise TypeError("lags must be of type int") if lags < 0: raise ValueError("lags must be a positive integer") - if cutoff_time != None and type(cutoff_time) != int: + if cutoff_time is not None and not isinstance(cutoff_time, int): raise TypeError("cutoff_time must be of type int") - if cutoff_time != None and cutoff_time < 0: + if cutoff_time is not None and cutoff_time < 0: raise ValueError("cutoff_time must be a positive integer") - # TODO: Added check for cutoff <= lags - - if type(size) != str: + if not isinstance(size, str): raise TypeError("size must be of type string") if size not in ["large", "medium", "small"]: raise ValueError( - "size must take either of the following values : large, medium or small" + "size must take either of the following values : " + "large, medium or small", ) self.name = "WindMill_" + size @@ -127,18 +132,20 @@ def __init__( self._cutoff_time = cutoff_time self._size = size + size_urls = { + "large": "https://graphmining.ai/temporal_datasets/windmill_output.json", + "medium": "https://graphmining.ai/temporal_datasets/windmill_output_medium.json", + "small": "https://graphmining.ai/temporal_datasets/windmill_output_small.json", + } + if not url: - if size == "large": - self._url = ( - "https://graphmining.ai/temporal_datasets/windmill_output.json" - ) - elif size == "medium": - self._url = "https://graphmining.ai/temporal_datasets/windmill_output_medium.json" - elif size == "small": - self._url = "https://graphmining.ai/temporal_datasets/windmill_output_small.json" + self._url = size_urls[self._size] else: self._url = url + if redownload and self._has_dataset_cache(): + self._delete_cached_dataset() + if self._has_dataset_cache(): self._load_dataset() else: @@ -147,7 +154,7 @@ def __init__( self._process_dataset() - def _process_dataset(self) -> None: + def _process_dataset(self: WindmillOutputDataLoader) -> None: self._set_total_timestamps() self._set_num_nodes() self._set_num_edges() @@ -155,23 +162,24 @@ def _process_dataset(self) -> None: self._set_edge_weights() self._set_targets() - def _set_total_timestamps(self) -> None: - r"""Sets the total timestamps present in the dataset + def _set_total_timestamps(self: WindmillOutputDataLoader) -> None: + r"""Set the total timestamps present in the dataset. It sets the total timestamps present in the dataset into the gdata attribute dictionary. It is the minimum of the cutoff time choosen by the user and the total time periods present in the original dataset. """ - if self._cutoff_time != None: + if self._cutoff_time is not None: self.gdata["total_timestamps"] = min( - self._dataset["time_periods"], self._cutoff_time + self._dataset["time_periods"], + self._cutoff_time, ) else: self.gdata["total_timestamps"] = self._dataset["time_periods"] - def _set_num_nodes(self): - r"""Sets the total number of nodes present in the dataset""" + def _set_num_nodes(self: WindmillOutputDataLoader) -> None: + r"""Set the total number of nodes present in the dataset.""" node_set = set() max_node_id = 0 for edge in self._dataset["edges"]: @@ -179,19 +187,21 @@ def _set_num_nodes(self): node_set.add(edge[1]) max_node_id = max(max_node_id, edge[0], edge[1]) - assert max_node_id == len(node_set) - 1, "Node ID labelling is not continuous" + if max_node_id != len(node_set) - 1: + raise ValueError("Node ID labelling is not continuous") + self.gdata["num_nodes"] = len(node_set) - def _set_num_edges(self): - r"""Sets the total number of edges present in the dataset""" + def _set_num_edges(self: WindmillOutputDataLoader) -> None: + r"""Set the total number of edges present in the dataset.""" self.gdata["num_edges"] = len(self._dataset["edges"]) - def _set_edges(self): - r"""Sets the edge list of the dataset""" + def _set_edges(self: WindmillOutputDataLoader) -> None: + r"""Set the edge list of the dataset.""" self._edge_list = [(edge[0], edge[1]) for edge in self._dataset["edges"]] - def _set_edge_weights(self): - r"""Sets the edge weights of the dataset""" + def _set_edge_weights(self: WindmillOutputDataLoader) -> None: + r"""Set the edge weights of the dataset.""" edges = self._dataset["edges"] edge_weights = self._dataset["weights"] comb_edge_list = [ @@ -200,8 +210,8 @@ def _set_edge_weights(self): comb_edge_list.sort(key=lambda x: (x[1], x[0])) self._edge_weights = np.array([edge_det[2] for edge_det in comb_edge_list]) - def _set_targets(self): - r"""Calculates and sets the target attributes""" + def _set_targets(self: WindmillOutputDataLoader) -> None: + r"""Calculate and sets the target attributes.""" stacked_target = np.stack(self._dataset["block"]) standardized_target = (stacked_target - np.mean(stacked_target, axis=0)) / ( np.std(stacked_target, axis=0) + 10**-10 @@ -210,18 +220,17 @@ def _set_targets(self): standardized_target[i, :].T for i in range(self.gdata["total_timestamps"]) ] - def _set_features(self): - # TODO: + def _set_features(self: WindmillOutputDataLoader) -> None: pass - def get_edges(self): - r"""Returns the edge list""" + def get_edges(self: WindmillOutputDataLoader) -> list: + r"""Return the edge list.""" return self._edge_list - def get_edge_weights(self): - r"""Returns the edge weights""" + def get_edge_weights(self: WindmillOutputDataLoader) -> np.ndarray: + r"""Return the edge weight.""" return self._edge_weights - def get_all_targets(self): - r"""Returns the targets for each timestamp""" + def get_all_targets(self: WindmillOutputDataLoader) -> list: + r"""Return the targets for each timestamp.""" return self._all_targets