Skip to content

Commit f0f44bd

Browse files
🔁 Merge pull request #120 from bfGraph/develop
🔁 Update official-docs with develop branch
2 parents 5fca3b7 + e2b3128 commit f0f44bd

35 files changed

+847
-64
lines changed

.github/workflows/pytest.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
name: Run Unit Test via Pytest
2+
3+
on: [push]
4+
5+
jobs:
6+
build:
7+
runs-on: ubuntu-latest
8+
strategy:
9+
matrix:
10+
python-version: ["3.8"]
11+
12+
steps:
13+
- uses: actions/checkout@v3
14+
- name: Set up Python ${{ matrix.python-version }}
15+
uses: actions/setup-python@v4
16+
with:
17+
python-version: ${{ matrix.python-version }}
18+
- name: Install dependencies
19+
run: |
20+
python -m pip install --upgrade pip
21+
pip install -e .[dev]
22+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
23+
- name: Test with pytest
24+
run: |
25+
coverage run -m pytest -v -s -p no:warnings
26+
- name: Generate Coverage Report
27+
run: |
28+
coverage report -m

.gitignore

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@ dist
3030
*.npy
3131
.coverage
3232
dev-stgraph/
33-
htmlconv/
33+
htmlconv/
34+
*.txt
35+
egl_kernel.cu
36+
egl_kernel.ptx

benchmarking/gcn/seastar/utils.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
11
import torch
22

3+
34
def accuracy(logits, labels):
45
_, indices = torch.max(logits, dim=1)
56
correct = torch.sum(indices == labels)
67
return correct.item() * 1.0 / len(labels)
78

9+
810
# GPU | CPU
911
def get_default_device():
10-
1112
if torch.cuda.is_available():
12-
return torch.device('cuda:0')
13+
return torch.device("cuda:0")
1314
else:
14-
return torch.device('cpu')
15+
return torch.device("cpu")
16+
1517

1618
def to_default_device(data):
17-
18-
if isinstance(data,(list,tuple)):
19-
return [to_default_device(x,get_default_device()) for x in data]
20-
21-
return data.to(get_default_device(),non_blocking = True)
19+
if isinstance(data, (list, tuple)):
20+
return [to_default_device(x, get_default_device()) for x in data]
21+
22+
return data.to(get_default_device(), non_blocking=True)
23+
24+
25+
def generate_train_mask(size: int, train_test_split: int) -> list:
26+
cutoff = size * train_test_split
27+
return [1 if i < cutoff else 0 for i in range(size)]
28+
29+
30+
def generate_test_mask(size: int, train_test_split: int) -> list:
31+
cutoff = size * train_test_split
32+
return [0 if i < cutoff else 1 for i in range(size)]

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ dependencies = [
4444
[project.optional-dependencies]
4545
dev = [
4646
"black",
47-
"pytest",
47+
"pytest >= 7.4.3",
48+
"pytest-cov >= 4.1.0",
4849
"tqdm >= 4.64.1",
4950
"build >= 0.10.0",
5051
"gdown >= 4.6.6",

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
cuda_python==12.1.0
2-
Jinja2==3.1.2
2+
Jinja2==3.1.3
33
networkx==3.1
44
numpy==1.23.4
55
pandas==1.3.5
@@ -10,7 +10,7 @@ snoop==0.4.2
1010
sphinx_rtd_theme==1.2.0
1111
sympy==1.11.1
1212
termcolor==2.3.0
13-
tqdm==4.64.1
13+
tqdm==4.66.3
1414
pybind11==2.10.4
1515
build==0.10.0
1616
pynvml==11.5.0

stgraph/benchmark_tools/table.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1+
from __future__ import annotations
2+
13
from rich.console import Console
24
from rich.table import Table
35

4-
console = Console()
56

67
class BenchmarkTable:
78
def __init__(self, title: str, col_name_list: list[str]):
8-
self.title = '\n' + title + '\n'
9+
self.title = "\n" + title + "\n"
910
self.col_name_list = col_name_list
1011
self._table = Table(title=self.title, show_edge=False, style="black bold")
1112
self._num_cols = len(col_name_list)
1213
self._num_rows = 0
13-
14+
1415
self._table_add_columns()
15-
16+
1617
def _table_add_columns(self):
1718
for col_name in self.col_name_list:
1819
self._table.add_column(col_name, justify="left")
19-
20+
2021
def add_row(self, values: list):
2122
values_str = tuple([str(val) for val in values])
2223
self._table.add_row(*values_str)
23-
24-
def display(self):
25-
console.print(self._table)
24+
25+
def display(self, output_file=None):
26+
if not output_file:
27+
console = Console()
28+
else:
29+
console = Console(file=output_file)
30+
console.print(self._table)

stgraph/dataset/dynamic/england_covid_dataloader.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ class EnglandCovidDataLoader(STGraphDynamicDataset):
5353
The name of the dataset.
5454
gdata : dict
5555
Graph meta data.
56+
5657
"""
5758

5859
def __init__(
@@ -65,6 +66,16 @@ def __init__(
6566
"""COVID-19 cases in England's NUTS3 regions."""
6667
super().__init__()
6768

69+
if not isinstance(lags, int):
70+
raise TypeError("lags must be of type int")
71+
if lags < 0:
72+
raise ValueError("lags must be a positive integer")
73+
74+
if cutoff_time is not None and not isinstance(cutoff_time, int):
75+
raise TypeError("cutoff_time must be of type int")
76+
if cutoff_time is not None and cutoff_time < 0:
77+
raise ValueError("cutoff_time must be a positive integer")
78+
6879
self.name = "England_COVID"
6980
self._url = "https://raw.githubusercontent.com/benedekrozemberczki/pytorch_geometric_temporal/master/dataset/england_covid.json"
7081
self._verbose = verbose

stgraph/dataset/static/cora_dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class CoraDataLoader(STGraphStaticDataset):
6161
The name of the dataset.
6262
gdata : dict
6363
Graph meta data.
64+
6465
"""
6566

6667
def __init__(

stgraph/dataset/stgraph_dataset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def __init__(self: STGraphDataset) -> None:
6767
6868
_load_dataset()
6969
Loads the dataset from cache
70+
7071
"""
7172
self.name = ""
7273
self.gdata = {}
@@ -106,6 +107,7 @@ def _has_dataset_cache(self: STGraphDataset) -> bool:
106107
# The dataset is cached, continue cached operations
107108
else:
108109
# The dataset is not cached, continue load and save operations
110+
109111
"""
110112
user_home_dir = os.path.expanduser("~")
111113
stgraph_dir = user_home_dir + "/.stgraph"
@@ -128,6 +130,7 @@ def _get_cache_file_path(self: STGraphDataset) -> str:
128130
-------
129131
str
130132
The absolute path of the cached dataset file
133+
131134
"""
132135
user_home_dir = os.path.expanduser("~")
133136
stgraph_dir = user_home_dir + "/.stgraph"

stgraph/dataset/temporal/hungarycp_dataloader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class HungaryCPDataLoader(STGraphTemporalDataset):
5858
The name of the dataset.
5959
gdata : dict
6060
Graph meta data.
61+
6162
"""
6263

6364
def __init__(

0 commit comments

Comments
 (0)