From 86284016b852f43a8f0afd093158b21625e94f9c Mon Sep 17 00:00:00 2001 From: Yuanhao Geng <1801214626@qq.com> Date: Sun, 2 Jan 2022 08:29:05 +0000 Subject: [PATCH 1/3] implement bfs algo --- graphormer/data/algos.pyx | 168 +++++++++++++++++++- graphormer/data/algos_numba.py | 96 +++++++++++ graphormer/data/dgl_datasets/dgl_dataset.py | 20 ++- graphormer/data/wrapper.py | 21 ++- 4 files changed, 284 insertions(+), 21 deletions(-) create mode 100644 graphormer/data/algos_numba.py diff --git a/graphormer/data/algos.pyx b/graphormer/data/algos.pyx index ec90118..3dfbbc5 100644 --- a/graphormer/data/algos.pyx +++ b/graphormer/data/algos.pyx @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +# distutils: language = c++ import cython from cython.parallel cimport prange, parallel cimport numpy import numpy -def floyd_warshall(adjacency_matrix): +cdef floyd_warshall(adjacency_matrix): (nrows, ncols) = adjacency_matrix.shape assert nrows == ncols @@ -15,7 +16,7 @@ def floyd_warshall(adjacency_matrix): adj_mat_copy = adjacency_matrix.astype(long, order='C', casting='safe', copy=True) assert adj_mat_copy.flags['C_CONTIGUOUS'] cdef numpy.ndarray[long, ndim=2, mode='c'] M = adj_mat_copy - cdef numpy.ndarray[long, ndim=2, mode='c'] path = numpy.zeros([n, n], dtype=numpy.int64) + cdef numpy.ndarray[long, ndim=2, mode='c'] path = - numpy.ones([n, n], dtype=numpy.int64) cdef unsigned int i, j, k cdef long M_ij, M_ik, cost_ikkj @@ -54,15 +55,16 @@ def floyd_warshall(adjacency_matrix): return M, path -def get_all_edges(path, i, j): - cdef unsigned int k = path[i][j] - if k == 0: +cdef get_all_edges(path, i, j): + cdef int k = path[i][j] + if k == -1: return [] else: return get_all_edges(path, i, k) + [k] + get_all_edges(path, k, j) +def fw_spatial_pos_and_edge_input(adj, edge_feat, max_dist=5): -def gen_edge_input(max_dist, path, edge_feat): + shortest_path_result, path = floyd_warshall(adj) (nrows, ncols) = path.shape assert nrows == ncols @@ -84,8 +86,158 @@ def gen_edge_input(max_dist, path, edge_feat): if path_copy[i][j] == 510: continue path = [i] + get_all_edges(path_copy, i, j) + [j] - num_path = len(path) - 1 + num_path = min(len(path) - 1, max_dist_copy) for k in range(num_path): edge_fea_all[i, j, k, :] = edge_feat_copy[path[k], path[k+1], :] - return edge_fea_all + return shortest_path_result, edge_fea_all + +from libcpp.vector cimport vector +from libcpp.queue cimport queue +import numpy as np +cimport numpy as np + +cdef inline reverse_path(vector[int] path_): + cdef: + int i + int n = path_.size() + for i in range(n//2): + path_[i], path_[n-i-1] = path_[n-i-1], path_[i] + return path_ + +cdef bfs_shortest_path(vector[vector[int]] adj_list, int startVertex): + + cdef: + int n = adj_list.size() + unsigned size, vertex, adjVertex, idx + np.int64_t[:] path = np.full(n, -1, dtype="int64") + np.int64_t[:] dist = np.full(n, -1, dtype="int64") + queue[int] q = queue[int]() + vector[int] adjVertices + + dist[startVertex] = 0 + path[startVertex] = startVertex + q.push(startVertex) + + while not q.empty(): + size = q.size() + while size > 0: + size -= 1 + vertex = q.front() + q.pop() + adjVertices = adj_list[vertex] + for idx in range(adjVertices.size()): + if dist[adjVertices[idx]] == -1: + dist[adjVertices[idx]] = dist[vertex] + 1 + path[adjVertices[idx]] = vertex + q.push(adjVertices[idx]) + + return dist, path + +cdef get_full_path( + np.int64_t[:] path, + np.int64_t[:, :, :] edge_type, + int max_dist, + int cur_node +): + + cdef: + unsigned i, j, cur + int n = path.shape[0] + int size = edge_type.shape[2] + np.int64_t[:, :, :] edge_input = np.full( + shape=(n, max_dist, size), + fill_value=-1, + dtype="int64" + ) + vector[int] path_ + + for i in range(n): + if i == cur_node: + continue + path_ = vector[int]() + if path[i] == -1: + continue + path_.push_back(i) + cur = i + while path[cur] != cur_node: + path_.push_back(path[cur]) + cur = path[cur] + path_.push_back(cur_node) + path_ = reverse_path(path_) + for j in range(min(max_dist, path_.size() - 1)): + edge_input[i, j, :] = edge_type[path_[j], path_[j+1], :] + + return edge_input + +def bfs_spatial_pos_and_edge_input( + np.int64_t[:, :] adj_matrix, + np.int64_t[:, :, :] edge_type, + int max_dist=5 +): + + cdef: + int i, j + int n = adj_matrix.shape[0] + int edge_type_shape = edge_type.shape[2] + np.ndarray[np.int64_t, ndim=4, mode='c'] edge_input = np.full( + shape=(n, n, max_dist, edge_type_shape), + fill_value=-1, + dtype="int64" + ) + np.int64_t[:, :] spatial_pos = np.full((n ,n), 510, dtype="int64") + cdef vector[vector[int]] adj_list + + for i in range(n): + adj_list.push_back(vector[int]()) + for j in range(n): + if adj_matrix[i][j] == 1: + adj_list[i].push_back(j) + + for i in range(n): + dist, path = bfs_shortest_path(adj_list, i) + edge_input[i] = np.asarray( + get_full_path( + path, edge_type, max_dist, i + ) + ) + for j in range(n): + if dist[j] != -1: + spatial_pos[i, j] = dist[j] + + return np.asarray(spatial_pos), np.asarray(edge_input) + +def bfs_target_spatial_pos_and_edge_input( + np.int64_t[:, :] adj_matrix, + np.int64_t[:, :, :] edge_type, + int max_dist=5, +): + + cdef: + int i, j + int n = adj_matrix.shape[0] + int edge_type_shape = edge_type.shape[2] + np.ndarray[np.int64_t, ndim=4, mode='c'] edge_input = np.full( + shape=(n, n, max_dist, edge_type_shape), + fill_value=-1, + dtype="int64" + ) + np.int64_t[:, :] spatial_pos = np.full((n ,n), 510, dtype="int64") + cdef vector[vector[int]] adj_list + + for i in range(n): + adj_list.push_back(vector[int]()) + for j in range(i+1): + if adj_matrix[i][j] == 1: + adj_list[i].push_back(j) + for j in range(i): + if adj_matrix[j][i] == 1: + adj_list[j].push_back(i) + dist, path = bfs_shortest_path(adj_list, i) + edge_input[i, :i+1, :, :] = np.asarray(get_full_path( + path, edge_type[:i+1, :i+1, :], max_dist, i)) + for j in range(i+1): + if dist[j] != -1: + spatial_pos[i, j] = dist[j] + + return np.asarray(spatial_pos), np.asarray(edge_input) diff --git a/graphormer/data/algos_numba.py b/graphormer/data/algos_numba.py new file mode 100644 index 0000000..eda25d8 --- /dev/null +++ b/graphormer/data/algos_numba.py @@ -0,0 +1,96 @@ +from numba import njit, prange +import numpy as np + +@njit +def bfs_shortest_path(adj_list, startVertex): + + n = adj_list.shape[0] + path = np.full(n, -1) + dist = np.full(n, -1) + q = np.full(n, -1) + q_start, q_end = -1, 0 + + dist[startVertex] = 0 + path[startVertex] = startVertex + q[0] = startVertex + + while q_start != q_end: + q_start += 1 + vertex = q[q_start] + adjVertices = adj_list[vertex] + cur = 0 + while adjVertices[cur] != -1: + val = adjVertices[cur] + cur += 1 + if dist[val] == -1: + q_end += 1 + q[q_end] = val + dist[val] = dist[vertex] + 1 + path[val] = vertex + + return dist, path + +@njit +def get_full_path(path, edge_type, max_dist, cur_node): + + n = path.shape[0] + size = edge_type.shape[2] + edge_input = np.full( + shape=(n, max_dist, size), + fill_value=-1, + dtype="int64" + ) + path_ = np.full(n, -1) + + for i in range(n): + cur_path = 0 + if i == cur_node: + continue + if path[i] == -1: + continue + path_[cur_path] = i + cur_path += 1 + cur = i + while path[cur] != cur_node: + path_[cur_path] = path[cur] + cur_path += 1 + cur = path[cur] + path_[cur_path] = cur_node + for j in range(min(max_dist, cur_path)): + edge_input[i, j, :] = edge_type[ + path_[cur_path-j], path_[cur_path-j-1], :] + + return edge_input + +@njit(parallel=True) +def bfs_numba_spatial_pos_and_edge_input( + adj_matrix, + edge_type, + max_dist=5, +): + n = adj_matrix.shape[0] + edge_type_shape = edge_type.shape[2] + edge_input = np.full( + shape=(n, n, max_dist, edge_type_shape), + fill_value=-1, + ) + spatial_pos = np.full((n ,n), 510) + adj_list = np.full((n ,n), -1) + + for i in range(n): + cur = 0 + for j in range(n): + if adj_matrix[i, j] == 1: + adj_list[i, cur] = j + cur += 1 + + for i in prange(n): + dist, path = bfs_shortest_path(adj_list, i) + edge_input[i] = get_full_path( + path, edge_type, max_dist, i + ) + for j in range(n): + if dist[j] != -1: + spatial_pos[i, j] = dist[j] + + return spatial_pos, edge_input diff --git a/graphormer/data/dgl_datasets/dgl_dataset.py b/graphormer/data/dgl_datasets/dgl_dataset.py index 0bf0d25..91d3f7c 100644 --- a/graphormer/data/dgl_datasets/dgl_dataset.py +++ b/graphormer/data/dgl_datasets/dgl_dataset.py @@ -14,6 +14,7 @@ from ..wrapper import convert_to_single_emb from .. import algos +from ..algos_numba import bfs_numba_spatial_pos_and_edge_input from copy import copy @@ -104,7 +105,8 @@ def extract_tensor_from_dict(feature: torch.Tensor): ) def __preprocess_dgl_graph( - self, graph_data: DGLGraph, y: torch.Tensor, idx: int + self, graph_data: DGLGraph, y: torch.Tensor, idx: int, + algo_name="bfs_numba", max_dist=5 ) -> PYGGraph: if not graph_data.is_homogeneous: raise ValueError( @@ -112,6 +114,13 @@ def __preprocess_dgl_graph( ) N = graph_data.num_nodes() + if algo_name == "bfs_numba": + process_algo = bfs_numba_spatial_pos_and_edge_input + elif algo_name == "bfs_cython": + process_algo = algos.bfs_spatial_pos_and_edge_input + elif algo_name == "floyd": + process_algo = algos.fw_spatial_pos_and_edge_input + ( node_int_feature, node_float_feature, @@ -126,10 +135,9 @@ def __preprocess_dgl_graph( edge_index[0].long(), edge_index[1].long() ] = convert_to_single_emb(edge_int_feature) dense_adj = graph_data.adj().to_dense().type(torch.int) - shortest_path_result, path = algos.floyd_warshall(dense_adj.numpy()) - max_dist = np.amax(shortest_path_result) - edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy()) - spatial_pos = torch.from_numpy((shortest_path_result)).long() + + spatial_pos, edge_input = process_algo( + dense_adj.numpy().astype("long"), attn_edge_type, max_dist) attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token pyg_graph = PYGGraph() @@ -137,7 +145,7 @@ def __preprocess_dgl_graph( pyg_graph.adj = dense_adj pyg_graph.attn_bias = attn_bias pyg_graph.attn_edge_type = attn_edge_type - pyg_graph.spatial_pos = spatial_pos + pyg_graph.spatial_pos = torch.from_numpy(spatial_pos).long() pyg_graph.in_degree = dense_adj.long().sum(dim=1).view(-1) pyg_graph.out_degree = pyg_graph.in_degree pyg_graph.edge_input = torch.from_numpy(edge_input).long() diff --git a/graphormer/data/wrapper.py b/graphormer/data/wrapper.py index 587bf34..d223ca8 100644 --- a/graphormer/data/wrapper.py +++ b/graphormer/data/wrapper.py @@ -11,6 +11,7 @@ pyximport.install(setup_args={"include_dirs": np.get_include()}) from . import algos +from .algos_numba import bfs_numba_spatial_pos_and_edge_input @torch.jit.script @@ -21,7 +22,15 @@ def convert_to_single_emb(x, offset: int = 512): return x -def preprocess_item(item): +def preprocess_item(item, algo_name="bfs_numba", max_dist=5): + + if algo_name == "bfs_numba": + process_algo = bfs_numba_spatial_pos_and_edge_input + elif algo_name == "bfs_cython": + process_algo = algos.bfs_spatial_pos_and_edge_input + elif algo_name == "floyd": + process_algo = algos.fw_spatial_pos_and_edge_input + edge_attr, edge_index, x = item.edge_attr, item.edge_index, item.x N = x.size(0) x = convert_to_single_emb(x) @@ -29,6 +38,7 @@ def preprocess_item(item): # node adj matrix [N, N] bool adj = torch.zeros([N, N], dtype=torch.bool) adj[edge_index[0, :], edge_index[1, :]] = True + adj = adj.long() # edge feature here if len(edge_attr.size()) == 1: @@ -38,18 +48,15 @@ def preprocess_item(item): convert_to_single_emb(edge_attr) + 1 ) - shortest_path_result, path = algos.floyd_warshall(adj.numpy()) - max_dist = np.amax(shortest_path_result) - edge_input = algos.gen_edge_input(max_dist, path, attn_edge_type.numpy()) - spatial_pos = torch.from_numpy((shortest_path_result)).long() + spatial_pos, edge_input = process_algo(adj, attn_edge_type, max_dist) attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token # combine item.x = x item.attn_bias = attn_bias item.attn_edge_type = attn_edge_type - item.spatial_pos = spatial_pos - item.in_degree = adj.long().sum(dim=1).view(-1) + item.spatial_pos = torch.from_numpy(spatial_pos).long() + item.in_degree = adj.sum(dim=1).view(-1) item.out_degree = item.in_degree # for undirected graph item.edge_input = torch.from_numpy(edge_input).long() From c90865c1825691aafb8fc9878f6e7ac181461f67 Mon Sep 17 00:00:00 2001 From: Yuanhao Geng <1801214626@qq.com> Date: Sun, 2 Jan 2022 09:33:03 +0000 Subject: [PATCH 2/3] update dataset config --- graphormer/data/dataset.py | 21 +++++++++++++++++++-- graphormer/data/dgl_datasets/dgl_dataset.py | 12 ++++++++---- graphormer/data/pyg_datasets/pyg_dataset.py | 10 +++++++++- graphormer/tasks/graph_prediction.py | 12 +++++++++++- 4 files changed, 47 insertions(+), 8 deletions(-) diff --git a/graphormer/data/dataset.py b/graphormer/data/dataset.py index a939976..c80da17 100644 --- a/graphormer/data/dataset.py +++ b/graphormer/data/dataset.py @@ -69,13 +69,30 @@ def __init__( train_idx = None, valid_idx = None, test_idx = None, + max_dist = 5 + algo_name = "bfs_numba" ): super().__init__() if dataset is not None: if dataset_source == "dgl": - self.dataset = GraphormerDGLDataset(dataset, seed=seed, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx) + self.dataset = GraphormerDGLDataset( + dataset, + seed=seed, + train_idx=train_idx, + valid_idx=valid_idx, + test_idx=test_idx, + max_dist=max_dist, + algo_name=algo_name + ) elif dataset_source == "pyg": - self.dataset = GraphormerPYGDataset(dataset, train_idx=train_idx, valid_idx=valid_idx, test_idx=test_idx) + self.dataset = GraphormerPYGDataset( + dataset, + train_idx=train_idx, + valid_idx=valid_idx, + test_idx=test_idx, + max_dist=max_dist + algo_name=algo_name + ) else: raise ValueError("customized dataset can only have source pyg or dgl") elif dataset_source == "dgl": diff --git a/graphormer/data/dgl_datasets/dgl_dataset.py b/graphormer/data/dgl_datasets/dgl_dataset.py index 91d3f7c..de3c759 100644 --- a/graphormer/data/dgl_datasets/dgl_dataset.py +++ b/graphormer/data/dgl_datasets/dgl_dataset.py @@ -25,8 +25,12 @@ def __init__(self, train_idx=None, valid_idx=None, test_idx=None, + max_dist=5, + algo_name="bfs_numba" ): self.dataset = dataset + self.max_dist = max_dist + num_data = len(self.dataset) self.seed = seed if train_idx is None: @@ -114,11 +118,11 @@ def __preprocess_dgl_graph( ) N = graph_data.num_nodes() - if algo_name == "bfs_numba": + if self.algo_name == "bfs_numba": process_algo = bfs_numba_spatial_pos_and_edge_input - elif algo_name == "bfs_cython": + elif self.algo_name == "bfs_cython": process_algo = algos.bfs_spatial_pos_and_edge_input - elif algo_name == "floyd": + elif self.algo_name == "floyd": process_algo = algos.fw_spatial_pos_and_edge_input ( @@ -137,7 +141,7 @@ def __preprocess_dgl_graph( dense_adj = graph_data.adj().to_dense().type(torch.int) spatial_pos, edge_input = process_algo( - dense_adj.numpy().astype("long"), attn_edge_type, max_dist) + dense_adj.numpy().astype("long"), attn_edge_type, self.max_dist) attn_bias = torch.zeros([N + 1, N + 1], dtype=torch.float) # with graph token pyg_graph = PYGGraph() diff --git a/graphormer/data/pyg_datasets/pyg_dataset.py b/graphormer/data/pyg_datasets/pyg_dataset.py index 9091878..ef33f8d 100644 --- a/graphormer/data/pyg_datasets/pyg_dataset.py +++ b/graphormer/data/pyg_datasets/pyg_dataset.py @@ -25,8 +25,12 @@ def __init__( train_set=None, valid_set=None, test_set=None, + max_dist=5, + algo_name="bfs_numba" ): self.dataset = dataset + self.algo_name = algo_name + self.max_dist = max_dist if self.dataset is not None: self.num_data = len(self.dataset) self.seed = seed @@ -98,7 +102,11 @@ def __getitem__(self, idx): item = self.dataset[idx] item.idx = idx item.y = item.y.reshape(-1) - return preprocess_item(item) + return preprocess_item( + item, + algo_name=self.algo_name, + max_dist=self.max_dist + ) else: raise TypeError("index to a GraphormerPYGDataset can only be an integer.") diff --git a/graphormer/tasks/graph_prediction.py b/graphormer/tasks/graph_prediction.py index cc0334a..fd2d556 100644 --- a/graphormer/tasks/graph_prediction.py +++ b/graphormer/tasks/graph_prediction.py @@ -107,6 +107,11 @@ class GraphPredictionConfig(FairseqDataclass): metadata={"help": "edge type in the graph"}, ) + algo_name: str = field( + default="bfs_numba", + metadata={"help": "algo for getting path and edge input"}, + ) + seed: int = II("common.seed") pretrained_model_name: str = field( @@ -148,7 +153,10 @@ def __init__(self, cfg): train_idx=dataset_dict["train_idx"], valid_idx=dataset_dict["valid_idx"], test_idx=dataset_dict["test_idx"], - seed=cfg.seed) + seed=cfg.seed, + algo_name=cfg.algo_name, + max_dist=cfg.multi_hop_max_dist + ) else: raise ValueError(f"dataset {cfg.dataset_name} is not found in customized dataset module {cfg.user_data_dir}") else: @@ -156,6 +164,8 @@ def __init__(self, cfg): dataset_spec=cfg.dataset_name, dataset_source=cfg.dataset_source, seed=cfg.seed, + algo_name=cfg.algo_name, + max_dist=cfg.multi_hop_max_dist ) def __import_user_defined_datasets(self, dataset_dir): From 82beb08a78a9ce61e969799f7e7e4d249a94ac1c Mon Sep 17 00:00:00 2001 From: Yuanhao Geng <1801214626@qq.com> Date: Mon, 3 Jan 2022 05:56:37 +0000 Subject: [PATCH 3/3] dgl_dateset.py fix --- graphormer/data/algos.pyx | 35 --------------------- graphormer/data/dgl_datasets/dgl_dataset.py | 2 +- 2 files changed, 1 insertion(+), 36 deletions(-) diff --git a/graphormer/data/algos.pyx b/graphormer/data/algos.pyx index 3dfbbc5..e16a23f 100644 --- a/graphormer/data/algos.pyx +++ b/graphormer/data/algos.pyx @@ -206,38 +206,3 @@ def bfs_spatial_pos_and_edge_input( spatial_pos[i, j] = dist[j] return np.asarray(spatial_pos), np.asarray(edge_input) - -def bfs_target_spatial_pos_and_edge_input( - np.int64_t[:, :] adj_matrix, - np.int64_t[:, :, :] edge_type, - int max_dist=5, -): - - cdef: - int i, j - int n = adj_matrix.shape[0] - int edge_type_shape = edge_type.shape[2] - np.ndarray[np.int64_t, ndim=4, mode='c'] edge_input = np.full( - shape=(n, n, max_dist, edge_type_shape), - fill_value=-1, - dtype="int64" - ) - np.int64_t[:, :] spatial_pos = np.full((n ,n), 510, dtype="int64") - cdef vector[vector[int]] adj_list - - for i in range(n): - adj_list.push_back(vector[int]()) - for j in range(i+1): - if adj_matrix[i][j] == 1: - adj_list[i].push_back(j) - for j in range(i): - if adj_matrix[j][i] == 1: - adj_list[j].push_back(i) - dist, path = bfs_shortest_path(adj_list, i) - edge_input[i, :i+1, :, :] = np.asarray(get_full_path( - path, edge_type[:i+1, :i+1, :], max_dist, i)) - for j in range(i+1): - if dist[j] != -1: - spatial_pos[i, j] = dist[j] - - return np.asarray(spatial_pos), np.asarray(edge_input) diff --git a/graphormer/data/dgl_datasets/dgl_dataset.py b/graphormer/data/dgl_datasets/dgl_dataset.py index de3c759..53dcce7 100644 --- a/graphormer/data/dgl_datasets/dgl_dataset.py +++ b/graphormer/data/dgl_datasets/dgl_dataset.py @@ -30,7 +30,7 @@ def __init__(self, ): self.dataset = dataset self.max_dist = max_dist - + self.algo_name = algo_name num_data = len(self.dataset) self.seed = seed if train_idx is None: