Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 125 additions & 8 deletions graphormer/data/algos.pyx
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# distutils: language = c++

import cython
from cython.parallel cimport prange, parallel
cimport numpy
import numpy

def floyd_warshall(adjacency_matrix):
cdef floyd_warshall(adjacency_matrix):

(nrows, ncols) = adjacency_matrix.shape
assert nrows == ncols
Expand All @@ -15,7 +16,7 @@ def floyd_warshall(adjacency_matrix):
adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True)
assert adj_mat_copy.flags['C_CONTIGUOUS']
cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy
cdef numpy.ndarray[long, ndim=2, mode='c'] path = numpy.zeros([n, n], dtype=numpy.int64)
cdef numpy.ndarray[long, ndim=2, mode='c'] path = - numpy.ones([n, n], dtype=numpy.int64)

cdef unsigned int i, j, k
cdef long M_ij, M_ik, cost_ikkj
Expand Down Expand Up @@ -54,15 +55,16 @@ def floyd_warshall(adjacency_matrix):
return M, path


def get_all_edges(path, i, j):
cdef unsigned int k = path[i][j]
if k == 0:
cdef get_all_edges(path, i, j):
cdef int k = path[i][j]
if k == -1:
return []
else:
return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j)

def fw_spatial_pos_and_edge_input(adj, edge_feat, max_dist=5):

def gen_edge_input(max_dist, path, edge_feat):
shortest_path_result, path = floyd_warshall(adj)

(nrows, ncols) = path.shape
assert nrows == ncols
Expand All @@ -84,8 +86,123 @@ def gen_edge_input(max_dist, path, edge_feat):
if path_copy[i][j] == 510:
continue
path = [i] + get_all_edges(path_copy, i, j) + [j]
num_path = len(path) - 1
num_path = min(len(path) - 1, max_dist_copy)
for k in range(num_path):
edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :]

return edge_fea_all
return shortest_path_result, edge_fea_all

from libcpp.vector cimport vector
from libcpp.queue cimport queue
import numpy as np
cimport numpy as np

cdef inline reverse_path(vector[int] path_):
cdef:
int i
int n = path_.size()
for i in range(n//2):
path_[i], path_[n-i-1] = path_[n-i-1], path_[i]
return path_

cdef bfs_shortest_path(vector[vector[int]] adj_list, int startVertex):

cdef:
int n = adj_list.size()
unsigned size, vertex, adjVertex, idx
np.int64_t[:] path = np.full(n, -1, dtype="int64")
np.int64_t[:] dist = np.full(n, -1, dtype="int64")
queue[int] q = queue[int]()
vector[int] adjVertices

dist[startVertex] = 0
path[startVertex] = startVertex
q.push(startVertex)

while not q.empty():
size = q.size()
while size > 0:
size -= 1
vertex = q.front()
q.pop()
adjVertices = adj_list[vertex]
for idx in range(adjVertices.size()):
if dist[adjVertices[idx]] == -1:
dist[adjVertices[idx]] = dist[vertex] + 1
path[adjVertices[idx]] = vertex
q.push(adjVertices[idx])

return dist, path

cdef get_full_path(
np.int64_t[:] path,
np.int64_t[:, :, :] edge_type,
int max_dist,
int cur_node
):

cdef:
unsigned i, j, cur
int n = path.shape[0]
int size = edge_type.shape[2]
np.int64_t[:, :, :] edge_input = np.full(
shape=(n, max_dist, size),
fill_value=-1,
dtype="int64"
)
vector[int] path_

for i in range(n):
if i == cur_node:
continue
path_ = vector[int]()
if path[i] == -1:
continue
path_.push_back(i)
cur = i
while path[cur] != cur_node:
path_.push_back(path[cur])
cur = path[cur]
path_.push_back(cur_node)
path_ = reverse_path(path_)
for j in range(min(max_dist, path_.size() - 1)):
edge_input[i, j, :] = edge_type[path_[j], path_[j+1], :]

return edge_input

def bfs_spatial_pos_and_edge_input(
np.int64_t[:, :] adj_matrix,
np.int64_t[:, :, :] edge_type,
int max_dist=5
):

cdef:
int i, j
int n = adj_matrix.shape[0]
int edge_type_shape = edge_type.shape[2]
np.ndarray[np.int64_t, ndim=4, mode='c'] edge_input = np.full(
shape=(n, n, max_dist, edge_type_shape),
fill_value=-1,
dtype="int64"
)
np.int64_t[:, :] spatial_pos = np.full((n ,n), 510, dtype="int64")
cdef vector[vector[int]] adj_list

for i in range(n):
adj_list.push_back(vector[int]())
for j in range(n):
if adj_matrix[i][j] == 1:
adj_list[i].push_back(j)

for i in range(n):
dist, path = bfs_shortest_path(adj_list, i)
edge_input[i] = np.asarray(
get_full_path(
path, edge_type, max_dist, i
)
)
for j in range(n):
if dist[j] != -1:
spatial_pos[i, j] = dist[j]

return np.asarray(spatial_pos), np.asarray(edge_input)
96 changes: 96 additions & 0 deletions graphormer/data/algos_numba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from numba import njit, prange
import numpy as np

@njit
def bfs_shortest_path(adj_list, startVertex):

n = adj_list.shape[0]
path = np.full(n, -1)
dist = np.full(n, -1)
q = np.full(n, -1)
q_start, q_end = -1, 0

dist[startVertex] = 0
path[startVertex] = startVertex
q[0] = startVertex

while q_start != q_end:
q_start += 1
vertex = q[q_start]
adjVertices = adj_list[vertex]
cur = 0
while adjVertices[cur] != -1:
val = adjVertices[cur]
cur += 1
if dist[val] == -1:
q_end += 1
q[q_end] = val
dist[val] = dist[vertex] + 1
path[val] = vertex

return dist, path

@njit
def get_full_path(path, edge_type, max_dist, cur_node):

n = path.shape[0]
size = edge_type.shape[2]
edge_input = np.full(
shape=(n, max_dist, size),
fill_value=-1,
dtype="int64"
)
path_ = np.full(n, -1)

for i in range(n):
cur_path = 0
if i == cur_node:
continue
if path[i] == -1:
continue
path_[cur_path] = i
cur_path += 1
cur = i
while path[cur] != cur_node:
path_[cur_path] = path[cur]
cur_path += 1
cur = path[cur]
path_[cur_path] = cur_node
for j in range(min(max_dist, cur_path)):
edge_input[i, j, :] = edge_type[
path_[cur_path-j], path_[cur_path-j-1], :]

return edge_input

@njit(parallel=True)
def bfs_numba_spatial_pos_and_edge_input(
adj_matrix,
edge_type,
max_dist=5,
):
n = adj_matrix.shape[0]
edge_type_shape = edge_type.shape[2]
edge_input = np.full(
shape=(n, n, max_dist, edge_type_shape),
fill_value=-1,
)
spatial_pos = np.full((n ,n), 510)
adj_list = np.full((n ,n), -1)

for i in range(n):
cur = 0
for j in range(n):
if adj_matrix[i, j] == 1:
adj_list[i, cur] = j
cur += 1

for i in prange(n):
dist, path = bfs_shortest_path(adj_list, i)
edge_input[i] = get_full_path(
path, edge_type, max_dist, i
)
for j in range(n):
if dist[j] != -1:
spatial_pos[i, j] = dist[j]

return spatial_pos, edge_input
21 changes: 19 additions & 2 deletions graphormer/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,13 +69,30 @@ def __init__(
train_idx = None,
valid_idx = None,
test_idx = None,
max_dist = 5
algo_name = "bfs_numba"
):
super().__init__()
if dataset is not None:
if dataset_source == "dgl":
self.dataset = GraphormerDGLDataset(dataset, seed=seed, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
self.dataset = GraphormerDGLDataset(
dataset,
seed=seed,
train_idx=train_idx,
valid_idx=valid_idx,
test_idx=test_idx,
max_dist=max_dist,
algo_name=algo_name
)
elif dataset_source == "pyg":
self.dataset = GraphormerPYGDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx)
self.dataset = GraphormerPYGDataset(
dataset,
train_idx=train_idx,
valid_idx=valid_idx,
test_idx=test_idx,
max_dist=max_dist
algo_name=algo_name
)
else:
raise ValueError("customized dataset can only have source pyg or dgl")
elif dataset_source == "dgl":
Expand Down
24 changes: 18 additions & 6 deletions graphormer/data/dgl_datasets/dgl_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from ..wrapper import convert_to_single_emb
from .. import algos
from ..algos_numba import bfs_numba_spatial_pos_and_edge_input
from copy import copy


Expand All @@ -24,8 +25,12 @@ def __init__(self,
train_idx=None,
valid_idx=None,
test_idx=None,
max_dist=5,
algo_name="bfs_numba"
):
self.dataset = dataset
self.max_dist = max_dist
self.algo_name = algo_name
num_data = len(self.dataset)
self.seed = seed
if train_idx is None:
Expand Down Expand Up @@ -104,14 +109,22 @@ def extract_tensor_from_dict(feature: torch.Tensor):
)

def __preprocess_dgl_graph(
self, graph_data: DGLGraph, y: torch.Tensor, idx: int
self, graph_data: DGLGraph, y: torch.Tensor, idx: int,
algo_name="bfs_numba", max_dist=5
) -> PYGGraph:
if not graph_data.is_homogeneous:
raise ValueError(
"Heterogeneous DGLGraph is found. Only homogeneous graph is supported."
)
N = graph_data.num_nodes()

if self.algo_name == "bfs_numba":
process_algo = bfs_numba_spatial_pos_and_edge_input
elif self.algo_name == "bfs_cython":
process_algo = algos.bfs_spatial_pos_and_edge_input
elif self.algo_name == "floyd":
process_algo = algos.fw_spatial_pos_and_edge_input

(
node_int_feature,
node_float_feature,
Expand All @@ -126,18 +139,17 @@ def __preprocess_dgl_graph(
edge_index[0].long(), edge_index[1].long()
] = convert_to_single_emb(edge_int_feature)
dense_adj = graph_data.adj().to_dense().type(torch.int)
shortest_path_result, path = algos.floyd_warshall(dense_adj.numpy())
max_dist = np.amax(shortest_path_result)
edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy())
spatial_pos = torch.from_numpy((shortest_path_result)).long()

spatial_pos, edge_input = process_algo(
dense_adj.numpy().astype("long"), attn_edge_type, self.max_dist)
attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token

pyg_graph = PYGGraph()
pyg_graph.x = convert_to_single_emb(node_int_feature)
pyg_graph.adj = dense_adj
pyg_graph.attn_bias = attn_bias
pyg_graph.attn_edge_type = attn_edge_type
pyg_graph.spatial_pos = spatial_pos
pyg_graph.spatial_pos = torch.from_numpy(spatial_pos).long()
pyg_graph.in_degree = dense_adj.long().sum(dim=1).view(-1)
pyg_graph.out_degree = pyg_graph.in_degree
pyg_graph.edge_input = torch.from_numpy(edge_input).long()
Expand Down
10 changes: 9 additions & 1 deletion graphormer/data/pyg_datasets/pyg_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,12 @@ def __init__(
train_set=None,
valid_set=None,
test_set=None,
max_dist=5,
algo_name="bfs_numba"
):
self.dataset = dataset
self.algo_name = algo_name
self.max_dist = max_dist
if self.dataset is not None:
self.num_data = len(self.dataset)
self.seed = seed
Expand Down Expand Up @@ -98,7 +102,11 @@ def __getitem__(self, idx):
item = self.dataset[idx]
item.idx = idx
item.y = item.y.reshape(-1)
return preprocess_item(item)
return preprocess_item(
item,
algo_name=self.algo_name,
max_dist=self.max_dist
)
else:
raise TypeError("index to a GraphormerPYGDataset can only be an integer.")

Expand Down
Loading