diff --git a/examples/pna/README.md b/examples/pna/README.md new file mode 100644 index 00000000..5b7d4190 --- /dev/null +++ b/examples/pna/README.md @@ -0,0 +1,37 @@ +# Principal Neighbourhood Aggregation for Graph Nets (PNA) + +[Principal Neighbourhood Aggregation for Graph Nets \(PNA\)](https://arxiv.org/abs/2004.05718) is a graph learning model combining multiple aggregators with degree-scalers. + + +### Datasets + +We perform graph classification experiment to reproduce paper results on [OGB](https://ogb.stanford.edu/). + +### Dependencies + +- paddlepaddle >= 2.2.0 +- pgl >= 2.2.4 + +### How to run + + +``` +python main.py --config config.yaml # train on ogbg-molhiv +python main.py --config config_pcba.yaml # train on ogbg-molpcba +``` + + +### Important Hyperparameters + +- aggregators: a list of aggregators name. ("mean", "sum", "max", "min", "var", "std") +- scalers: a list of scalers name. ("identity", "amplification", "attenuation", "linear", "inverse_linear") +- tower: The number of towers. +- divide_input: hether the input features should be split between towers or not. +- pre_layers: the number of MLP layers behind aggregators. +- post_layers: MLP layers after aggregator. + +### Experiment results (ROC-AUC) +| | GIN | PNA(paper result) | PNA(ours)| +|-------------|----------|------------|-----------------| +|HIV | 0.7778 | 0.7905 | 0.7929 | +|PCBA | 0.2266 | 0.2838 | 0.2801 | diff --git a/examples/pna/config.yaml b/examples/pna/config.yaml new file mode 100644 index 00000000..39d357f5 --- /dev/null +++ b/examples/pna/config.yaml @@ -0,0 +1,43 @@ +task_name: train.hiv +dataset_name: ogbg-molhiv +metrics: rocauc + +hidden_size: 128 +out_size: 128 +dropout: 0.3 +num_layers: 4 +batch_norm: True +residual: True +aggregators: ["mean","max","min", "std"] +scalers: ["identity", "amplification", "attenuation"] +in_feat_dropout: 0 +post_layers: 1 +pre_layers: 1 +towers: 1 +edge_feat: True +optim: momentum + +seed: 41 +# data config +num_class: 1 + +# runconfig +epochs: 200 +batch_size: 128 +lr: 0.01 +lr_reduce_factor: 0.5 +lr_schedule_patience: 20 +min_lr: 0.0001 +weight_decay: 0.000003 +num_workers: 4 +shuffle: True +max_time: 48 +log_step: 100 + +# logger +stdout: True +log_dir: ./logs +log_filename: log.txt +save_dir: ./checkpoints +output_dir: ./outputs +files2saved: ["*.yaml", "*.py", "./utils"] diff --git a/examples/pna/config_pcba.yaml b/examples/pna/config_pcba.yaml new file mode 100644 index 00000000..32eb41a3 --- /dev/null +++ b/examples/pna/config_pcba.yaml @@ -0,0 +1,42 @@ +task_name: train.pcba +dataset_name: ogbg-molpcba +metrics: ap + +hidden_size: 510 +out_size: 510 +dropout: 0.2 +num_layers: 4 +batch_norm: True +residual: True +aggregators: ["mean", "sum", "max"] +scalers: ["identity"] +in_feat_dropout: 0.0 +post_layers: 1 +pre_layers: 1 +towers: 5 +edge_feat: True +seed: 41 +optim: adam +# data config +num_class: 128 + +# runconfig +epochs: 100 +batch_size: 512 +lr: 0.0005 +lr_reduce_factor: 0.8 +lr_schedule_patience: 4 +min_lr: 0.00002 +weight_decay: 0.000003 +num_workers: 4 +shuffle: True +max_time: 48 +log_step: 100 + +# logger +stdout: True +log_dir: ./logs +log_filename: log.txt +save_dir: ./checkpoints +output_dir: ./outputs +files2saved: ["*.yaml", "*.py", "./utils"] diff --git a/examples/pna/dataset.py b/examples/pna/dataset.py new file mode 100644 index 00000000..39fabffa --- /dev/null +++ b/examples/pna/dataset.py @@ -0,0 +1,134 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch + +import os +import sys +import json +import numpy as np +import glob +import copy +import time +import argparse +from collections import OrderedDict, namedtuple +from scipy.sparse import csr_matrix +import pgl +import paddle +from pgl.utils.data.dataset import Dataset, StreamDataset, HadoopDataset +from pgl.utils.data import Dataloader +from pgl.utils.logger import log + +from utils.config import prepare_config, make_dir +from ogb.graphproppred import GraphPropPredDataset +from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims + + +class Subset(Dataset): + r""" + Subset of a dataset at specified indices. + Arguments: + dataset (Dataset): The whole Dataset + indices (sequence): Indices in the whole set selected for subset + """ + + def __init__(self, dataset, indices, mode='train'): + self.dataset = dataset + if paddle.distributed.get_world_size() == 1 or mode != "train": + self.indices = indices + else: + self.indices = indices[int(paddle.distributed.get_rank())::int( + paddle.distributed.get_world_size())] + + self.mode = mode + + def __getitem__(self, idx): + return self.dataset[self.indices[idx]] + + def __len__(self): + return len(self.indices) + + +class ShardedDataset(Dataset): + """ + SharderDataset + """ + + def __init__(self, data, mode="train"): + if paddle.distributed.get_world_size() == 1 or mode != "train": + self.data = data + else: + self.data = data[int(paddle.distributed.get_rank())::int( + paddle.distributed.get_world_size())] + + def __getitem__(self, idx): + return self.data[idx] + + def __len__(self): + return len(self.data) + + +class MolDataset(Dataset): + """ + Transfer raw ogb dataset to pgl dataset + """ + + def __init__(self, config, raw_dataset, mode='train', transform=None): + self.config = config + self.raw_dataset = raw_dataset + self.mode = mode + + log.info("preprocess graph data in %s" % self.__class__.__name__) + self.graph_list = [] + self.label = [] + for i in range(len(self.raw_dataset)): + # num_nodes, edge_index, node_feat, edge_feat, label + graph, label = self.raw_dataset[i] + num_nodes = graph['num_nodes'] + node_feat = graph['node_feat'].copy() + edges = list(zip(graph["edge_index"][0], graph["edge_index"][1])) + edge_feat = graph['edge_feat'].copy() + main_graph = pgl.Graph( + num_nodes=num_nodes, + edges=edges, + node_feat={'feat': node_feat}, + edge_feat={'feat': edge_feat}) + self.graph_list.append(main_graph) + self.label.append(label) + + def __getitem__(self, idx): + return self.graph_list[idx], self.label[idx] + + def __len__(self): + return len(self.graph_list) + + +class CollateFn(object): + def __init__(self): + pass + + def __call__(self, batch_data): + graph_list = [] + labels = [] + for g, label in batch_data: + if g is None: + continue + graph_list.append(g) + labels.append(label) + + labels = np.array(labels) + batch_valid = (labels == labels).astype("bool") + labels = np.nan_to_num(labels).astype("float32") + + g = pgl.Graph.batch(graph_list) + return g, labels, batch_valid diff --git a/examples/pna/main.py b/examples/pna/main.py new file mode 100644 index 00000000..54ed9f6a --- /dev/null +++ b/examples/pna/main.py @@ -0,0 +1,264 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import ssl +ssl._create_default_https_context = ssl._create_unverified_context +import os +import sys +import time +import argparse +import numpy as np +from datetime import datetime +import paddle +import paddle.nn as nn +from paddle.optimizer import Adam, Momentum +import paddle.distributed as dist +from tqdm import tqdm +import pgl +from pgl.utils.logger import log +from pgl.utils.data import Dataloader +from dataset import Subset, MolDataset, CollateFn +from utils.config import prepare_config, make_dir +from utils.logger import prepare_logger, log_to_file +import random +from ogb.graphproppred import GraphPropPredDataset, Evaluator +from tensorboardX import SummaryWriter +from pna_model import PNAModel + + +def main(config): + """ + main function + """ + # setting seeds + random.seed(config.seed) + np.random.seed(config.seed) + paddle.seed(config.seed) + if dist.get_world_size() > 1: + dist.init_parallel_env() + + if dist.get_rank() == 0: + timestamp = datetime.now().strftime("%Hh%Mm%Ss") + log_path = os.path.join(config.log_dir, + "tensorboard_log_%s" % timestamp) + writer = SummaryWriter(log_path) + log.info("loading data") + raw_dataset = GraphPropPredDataset(name=config.dataset_name) + config.num_class = raw_dataset.num_tasks + config.eval_metric = raw_dataset.eval_metric + config.task_type = raw_dataset.task_type + mol_dataset = MolDataset(config, raw_dataset) + splitted_index = raw_dataset.get_idx_split() + train_ds = Subset(mol_dataset, splitted_index['train'], mode='train') + valid_ds = Subset(mol_dataset, splitted_index['valid'], mode="valid") + test_ds = Subset(mol_dataset, splitted_index['test'], mode="test") + + log.info("Train Examples: %s" % len(train_ds)) + log.info("Val Examples: %s" % len(valid_ds)) + log.info("Test Examples: %s" % len(test_ds)) + fn = CollateFn() + train_loader = Dataloader( + train_ds, + batch_size=config.batch_size, + shuffle=True, + num_workers=config.num_workers, + collate_fn=fn) + + valid_loader = Dataloader( + valid_ds, + batch_size=config.batch_size, + num_workers=config.num_workers, + collate_fn=fn) + + test_loader = Dataloader( + test_ds, + batch_size=config.batch_size, + num_workers=config.num_workers, + collate_fn=fn) + + deg_hog = paddle.to_tensor(get_degree_histogram(train_loader)) + + model = PNAModel( + config.hidden_size, + config.out_size, + config.aggregators, + config.scalers, + deg_hog, + pre_layers=config.pre_layers, + post_layers=config.post_layers, + towers=config.towers, + residual=config.residual, + batch_norm=config.batch_norm, + L=config.num_layers, + dropout=config.dropout, + in_feat_dropout=config.in_feat_dropout, + edge_feat=config.edge_feat, + num_class=config.num_class) + model = paddle.DataParallel(model) + + criterion = nn.loss.BCEWithLogitsLoss() + scheduler = paddle.optimizer.lr.ReduceOnPlateau( + learning_rate=config.lr, + factor=config.lr_reduce_factor, + patience=config.lr_schedule_patience, + mode="min", + min_lr=config.min_lr, + verbose=True) + if config.optim == "momentum": + optim = Momentum( + learning_rate=scheduler, + parameters=model.parameters(), + weight_decay=config.weight_decay) + else: + optim = Adam( + learning_rate=scheduler, + parameters=model.parameters(), + weight_decay=config.weight_decay) + evaluator = Evaluator(config.dataset_name) + best_valid = 0 + global_step = 0 + for epoch in range(1, config.epochs + 1): + model.train() + for idx, batch_data in enumerate(tqdm(train_loader, desc="Iteration")): + g, labels, unmask = batch_data + g = g.tensor() + labels = paddle.to_tensor(labels) + unmask = paddle.to_tensor(unmask) + + pred = model(g) + pred = paddle.masked_select(pred, unmask) + labels = paddle.masked_select(labels, unmask) + train_loss = criterion(pred, labels) + train_loss.backward() + optim.step() + optim.clear_grad() + if global_step % config.log_step == 0: + message = "train: epoch %d | step %d | " % (epoch, global_step) + message += "loss %.6f" % (train_loss.numpy()) + # log.info(message) + if dist.get_rank() == 0: + writer.add_scalar("loss", train_loss.numpy(), global_step) + global_step += 1 + + valid_result = evaluate(model, valid_loader, criterion, evaluator) + message = "valid: epoch %d | step %d | " % (epoch, global_step) + for key, value in valid_result.items(): + message += " | %s %.6f" % (key, value) + if dist.get_rank() == 0: + writer.add_scalar("valid_%s" % key, value, global_step) + log.info(message) + + test_result = evaluate(model, test_loader, criterion, evaluator) + message = "test: epoch %d | step %d | " % (epoch, global_step) + for key, value in test_result.items(): + message += " | %s %.6f" % (key, value) + if dist.get_rank() == 0: + writer.add_scalar("test_%s" % key, value, global_step) + log.info(message) + scheduler.step(-test_result[config.metrics]) + if best_valid < valid_result[config.metrics]: + best_valid = valid_result[config.metrics] + best_valid_result = valid_result + best_test_result = test_result + + message = "best result: epoch %d | " % (epoch) + message += "valid %s: %.6f | " % (config.metrics, + best_valid_result[config.metrics]) + message += "test %s: %.6f | " % (config.metrics, + best_test_result[config.metrics]) + log.info(message) + + message = "final eval best result:%.6f" % best_valid_result[config.metrics] + log.info(message) + message = "final test best result:%.6f" % best_test_result[config.metrics] + log.info(message) + + +@paddle.no_grad() +def evaluate(model, loader, criterion, evaluator): + """ + eval function + """ + model.eval() + total_loss = [] + y_true = [] + y_pred = [] + is_valid = [] + + for idx, batch_data in enumerate(tqdm(loader, desc="Iteration")): + g, labels, unmask = batch_data + g = g.tensor() + labels = paddle.to_tensor(labels) + unmask = paddle.to_tensor(unmask) + + pred = model(g) + eval_loss = criterion( + paddle.masked_select(pred, unmask), + paddle.masked_select(labels, unmask)) + total_loss.append(eval_loss.numpy()) + + y_pred.append(pred.numpy()) + y_true.append(labels.numpy()) + is_valid.append(unmask.numpy()) + + y_pred = np.concatenate(y_pred) + y_true = np.concatenate(y_true) + is_valid = np.concatenate(is_valid) + is_valid = is_valid.astype("bool") + y_true = y_true.astype("float32") + y_true[~is_valid] = np.nan + input_dict = {'y_true': y_true, 'y_pred': y_pred} + result = evaluator.eval(input_dict) + + total_loss = np.mean(total_loss) + model.train() + print(result) + return {"loss": total_loss, config.metrics: result[config.metrics]} + + +def get_degree_histogram(loader): + """ + get the degree histogram of dataloader + """ + max_degree = 0 + for data in loader: + g, _, _ = data + d = g.indegree() + max_degree = max(max_degree, int(d.max().item())) + deg_hog = np.zeros(max_degree + 1, dtype="long") + for data in loader: + g, _, _ = data + d = g.indegree() + deg_hog += np.bincount(d, minlength=deg_hog.shape[0]) + return deg_hog + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description='gnn') + parser.add_argument("--config", type=str, default="./config.yaml") + parser.add_argument("--task_name", type=str, default="ogbg-hiv-pna") + parser.add_argument("--mode", type=str, default="train") + parser.add_argument("--log_id", type=str, default=None) + args = parser.parse_args() + + if dist.get_rank() == 0: + config = prepare_config(args.config, isCreate=True, isSave=True) + if args.log_id is not None: + config.log_filename = "%s_%s" % (args.log_id, config.log_filename) + log_to_file(log, config.log_dir, config.log_filename) + else: + config = prepare_config(args.config, isCreate=False, isSave=False) + + config.log_id = args.log_id + main(config) diff --git a/examples/pna/mol_encoder.py b/examples/pna/mol_encoder.py new file mode 100644 index 00000000..0eba4824 --- /dev/null +++ b/examples/pna/mol_encoder.py @@ -0,0 +1,68 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims + +full_atom_feature_dims = get_atom_feature_dims() +full_bond_feature_dims = get_bond_feature_dims() + + +class AtomEncoder(nn.Layer): + def __init__(self, emb_dim): + super(AtomEncoder, self).__init__() + + self.atom_embedding_list = nn.LayerList() + + for i, dim in enumerate(full_atom_feature_dims): + weight_attr = nn.initializer.XavierUniform() + emb = paddle.nn.Embedding(dim, emb_dim, weight_attr=weight_attr) + self.atom_embedding_list.append(emb) + + def forward(self, x): + x_embedding = 0 + for i in range(x.shape[1]): + x_embedding += self.atom_embedding_list[i](x[:, i]) + + return x_embedding + + +class BondEncoder(nn.Layer): + def __init__(self, emb_dim): + super(BondEncoder, self).__init__() + + self.bond_embedding_list = nn.LayerList() + + for i, dim in enumerate(full_bond_feature_dims): + weight_attr = nn.initializer.XavierUniform() + emb = paddle.nn.Embedding(dim, emb_dim, weight_attr=weight_attr) + self.bond_embedding_list.append(emb) + + def forward(self, edge_attr): + bond_embedding = 0 + for i in range(edge_attr.shape[1]): + bond_embedding += self.bond_embedding_list[i](edge_attr[:, i]) + + return bond_embedding + + +# if __name__ == '__main__': +# from ogb.graphproppred import GraphPropPredDataset +# dataset = GraphPropPredDataset(name='ogbg-molpcba') +# atom_enc = AtomEncoder(100) +# bond_enc = BondEncoder(100) + +# print(atom_enc(dataset[0].x)) +# print(bond_enc(dataset[0].edge_attr)) diff --git a/examples/pna/pna_model.py b/examples/pna/pna_model.py new file mode 100644 index 00000000..505d233b --- /dev/null +++ b/examples/pna/pna_model.py @@ -0,0 +1,153 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F +import pgl +import pgl.nn as gnn +from pgl.nn import functional as GF +from mol_encoder import AtomEncoder, BondEncoder + + +class PNAModel(nn.Layer): + """ + Implementation of PNA Model + """ + + def __init__( + self, + hidden_size, + out_size, + aggregators, + scalers, + deg_hog, + pre_layers=1, + post_layers=1, + towers=1, + residual=True, + batch_norm=True, + L=3, + dropout=0.3, + in_feat_dropout=0.0, + edge_feat=False, + num_class=1, ): + super(PNAModel, self).__init__() + self.out_size = out_size + self.hidden_size = hidden_size + self.aggregators = aggregators + self.scalers = scalers + self.pre_layers = pre_layers + self.post_layers = post_layers + self.towers = towers + self.residual = residual + self.batch_norm = batch_norm + self.L = L + self.dropout = dropout + self.in_feat_dropout = in_feat_dropout + self.edge_feat = edge_feat + self.embedding_h = AtomEncoder(emb_dim=hidden_size) + if self.edge_feat: + self.embedding_e = BondEncoder(emb_dim=hidden_size) + self.layers = nn.LayerList() + self.in_feat_dropout = nn.Dropout(in_feat_dropout) + self.bns = nn.LayerList() + for i in range(self.L - 1): + self.layers.append( + gnn.PNAConv( + self.hidden_size, + self.hidden_size, + self.aggregators, + self.scalers, + deg_hog, + towers=self.towers, + pre_layers=self.pre_layers, + post_layers=self.post_layers, + divide_input=False, + use_edge=self.edge_feat)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1D(hidden_size)) + self.layers.append( + gnn.PNAConv( + self.hidden_size, + self.out_size, + self.aggregators, + self.scalers, + deg_hog, + towers=self.towers, + pre_layers=self.pre_layers, + post_layers=post_layers, + divide_input=False, + use_edge=self.edge_feat)) + if self.batch_norm: + self.bns.append(nn.BatchNorm1D(out_size)) + self.MLP_layer = MLPReadout(out_size, num_class) + self.pool = gnn.GraphPool("mean") + + def forward(self, graph): + """ + forward of PNAModel + """ + h = graph.node_feat['feat'] + h = self.embedding_h(h) + h = self.in_feat_dropout(h) + e = None + + if self.edge_feat: + e = self.embedding_e(graph.edge_feat['feat']) + e = self.in_feat_dropout(e) + for i, conv in enumerate(self.layers): + x = h + deg = graph.indegree() + h = conv(graph, h, deg, e) + if self.batch_norm: + h = self.bns[i](h) + h = F.relu(h) # + + if self.residual: + h = h + x + h = F.dropout(h, self.dropout, training=self.training) + + hg = self.pool(graph, h) + return self.MLP_layer(hg) + + +class MLPReadout(nn.Layer): + """ + An Implementation of MLP layer + """ + + def __init__(self, input_dim, output_dim, L=2): # L=nb_hidden_layers + super().__init__() + list_FC_layers = [ + nn.Linear( + input_dim // 2**l, input_dim // 2**(l + 1), bias_attr=True) + for l in range(L) + ] + list_FC_layers.append( + nn.Linear( + input_dim // 2**L, output_dim, bias_attr=True)) + self.FC_layers = nn.LayerList(list_FC_layers) + self.L = L + + def forward(self, x): + """ + forward function of MLPReadout + """ + y = x + for l in range(self.L): + y = self.FC_layers[l](y) + y = F.relu(y) + y = self.FC_layers[self.L](y) + return y diff --git a/examples/pna/utils/config.py b/examples/pna/utils/config.py new file mode 100644 index 00000000..a4f09aa8 --- /dev/null +++ b/examples/pna/utils/config.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""doc +""" + +import sys +import datetime +import os +import yaml +import random +import shutil +import six +import warnings +import glob +from utils.util import get_last_dir + + +class AttrDict(dict): + def __init__(self, d={}, **kwargs): + if kwargs: + d.update(**kwargs) + + for k, v in d.items(): + setattr(self, k, v) + + # Class attributes + # for k in self.__class__.__dict__.keys(): + # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): + # setattr(self, k, getattr(self, k)) + + def __setattr__(self, name, value): + if isinstance(value, (list, tuple)): + value = [ + self.__class__(x) if isinstance(x, dict) else x for x in value + ] + elif isinstance(value, dict) and not isinstance(value, self.__class__): + value = self.__class__(value) + super(AttrDict, self).__setattr__(name, value) + super(AttrDict, self).__setitem__(name, value) + + __setitem__ = __setattr__ + + def __getattr__(self, attr): + try: + value = super(AttrDict, self).__getitem__(attr) + except KeyError: + # log.warn("%s attribute is not existed, return None" % attr) + warnings.warn("%s attribute is not existed, return None" % attr) + value = None + return value + + def update(self, e=None, **f): + d = e or dict() + d.update(f) + for k in d: + setattr(self, k, d[k]) + + def pop(self, k, d=None): + delattr(self, k) + return super(EasyDict, self).pop(k, d) + + +def make_dir(path): + """Build directory""" + if not os.path.exists(path): + os.makedirs(path) + + +def load_config(config_file): + """Load config file""" + with open(config_file) as f: + if hasattr(yaml, 'FullLoader'): + config = yaml.load(f, Loader=yaml.FullLoader) + else: + config = yaml.load(f) + return config + + +def create_necessary_dirs(config): + """Create some necessary directories to save some important files. + """ + + config.log_dir = os.path.join(config.log_dir, config.task_name) + config.save_dir = os.path.join(config.save_dir, config.task_name) + config.output_dir = os.path.join(config.output_dir, config.task_name) + + make_dir(config.log_dir) + make_dir(config.save_dir) + make_dir(config.output_dir) + + +def save_files(config): + """Save config file so that we can know the config when we look back + """ + filelist = config.files2saved + targetpath = config.log_dir + + if filelist is not None: + for file_or_dir in filelist: + if os.path.isdir(file_or_dir): + last_name = get_last_dir(file_or_dir) + dst = os.path.join(targetpath, last_name) + try: + shutil.copytree(file_or_dir, dst) + except Exception as e: + print(e) + print("backup %s to %s" % (file_or_dir, targetpath)) + else: + for filename in files(files=file_or_dir): + if os.path.isfile(filename): + print("backup %s to %s" % (filename, targetpath)) + shutil.copy2(filename, targetpath) + else: + print("%s is not existed." % filename) + + +def files(curr_dir='./', files='*.py'): + for i in glob.glob(os.path.join(curr_dir, files)): + yield i + + +def prepare_config(config_file, isCreate=False, isSave=False): + if os.path.isfile(config_file): + config = load_config(config_file) + config = AttrDict(config) + else: + print("%s is not a yaml file" % config_file) + raise + + if isCreate: + create_necessary_dirs(config) + + if isSave: + save_files(config) + + config.model_dir = config.save_dir + + return config diff --git a/examples/pna/utils/logger.py b/examples/pna/utils/logger.py new file mode 100644 index 00000000..89e9d0dc --- /dev/null +++ b/examples/pna/utils/logger.py @@ -0,0 +1,46 @@ +#-*- coding: utf-8 -*- +import sys +import os +import logging + + +def prepare_logger(log_dir=None, log_filename=None, stdout=False): + logger = logging.getLogger("logger") + logger.setLevel(logging.DEBUG) + + formatter = logging.Formatter( + fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s' + ) + + if stdout or log_dir is None: + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + # handler.setLevel(logging.INFO) + logger.addHandler(handler) + + if log_dir is not None: + if os.path.isdir(log_dir): + handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + else: + handler = logging.FileHandler(log_dir) + handler.setFormatter(formatter) + # handler.setLevel(logging.INFO) + logger.addHandler(handler) + + logger.propagate = False + + return logger + + +def log_to_file(logger, log_dir, log_filename="log.txt"): + if os.path.isdir(log_dir): + handler = logging.FileHandler(os.path.join(log_dir, log_filename)) + else: + handler = logging.FileHandler(log_dir) + + formatter = logging.Formatter( + fmt='[%(levelname)s] %(asctime)s [%(filename)12s:%(lineno)5d]:\t%(message)s' + ) + handler.setFormatter(formatter) + + logger.addHandler(handler) diff --git a/examples/pna/utils/util.py b/examples/pna/utils/util.py new file mode 100644 index 00000000..c88952c2 --- /dev/null +++ b/examples/pna/utils/util.py @@ -0,0 +1,45 @@ +#-*- coding: utf-8 -*- +import os +import sys +import warnings +import numpy as np + + +def strarr2int8arr(str, sep='\n'): + bytes = sep.join(str).encode("utf-8") + arr = np.frombuffer(bytes, dtype="int8") + return arr + + +def int82strarr(arr, sep='\n'): + string = arr.tobytes().decode("utf-8").split(sep) + return string + + +def get_last_dir(path): + """Get the last directory of a path. + """ + if os.path.isfile(path): + # e.g: "../checkpoints/task_name/epoch0_step300/predict.txt" + # return "epoch0_step300" + last_dir = path.split("/")[-2] + + elif os.path.isdir(path): + if path[-1] == '/': + # e.g: "../checkpoints/task_name/epoch0_step300/" + last_dir = path.split('/')[-2] + else: + # e.g: "../checkpoints/task_name/epoch0_step300" + last_dir = path.split('/')[-1] + else: + # path or file is not existed + warnings.warn('%s is not a existed file or path' % path) + last_dir = "" + + return last_dir + + +def make_dir(path): + """Build directory""" + if not os.path.exists(path): + os.makedirs(path) diff --git a/examples/sag_pool/README.md b/examples/sag_pool/README.md index 92059996..dab2cf7c 100644 --- a/examples/sag_pool/README.md +++ b/examples/sag_pool/README.md @@ -26,7 +26,7 @@ python main.py --use_cuda --dataset_name PROTEINS --lr 0.005 --batch_size 128 -- - data\_path: the root path of your dataset - dataset\_name: the name of the dataset. ("MUTAG", "IMDBBINARY", "IMDBMULTI", "COLLAB", "PROTEINS", "NCI1", "PTC", "REDDITBINARY", "REDDITMULTI5K") -- fold\_idx: The $fold\_idx^{th}$ fold of dataset splited. Here we use 10 fold cross-validation +- fold\_idx: The $fold\_{idx}^{th}$ fold of dataset splited. Here we use 10 fold cross-validation - min\_score: parameter for SAGPool which indicates minimal node score. (When min\_score is not None, pool\_ratio is ignored) - pool\_ratio: parameter for SAGPool which decides how many nodes will be removed. diff --git a/pgl/nn/__init__.py b/pgl/nn/__init__.py index bdf614db..61b657b1 100644 --- a/pgl/nn/__init__.py +++ b/pgl/nn/__init__.py @@ -17,11 +17,13 @@ from pgl.nn import conv from pgl.nn import pool from pgl.nn import gmt_pool +from pgl.nn import pna_conv from pgl.nn.pool import * from pgl.nn.conv import * from pgl.nn.gmt_pool import * - +from pgl.nn.pna_conv import * __all__ = [] __all__ += conv.__all__ __all__ += pool.__all__ __all__ += gmt_pool.__all__ +__all__ += pna_conv.__all__ diff --git a/pgl/nn/pna_conv.py b/pgl/nn/pna_conv.py new file mode 100644 index 00000000..19d389e9 --- /dev/null +++ b/pgl/nn/pna_conv.py @@ -0,0 +1,257 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import numpy as np +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from pgl.utils.logger import log +import pgl +__all__ = ['PNAConv'] + + +class PNAConv(nn.Layer): + """Implementation of Principal Neighbourhood Aggregation graph convolution operator + + This is an implementation of the paper Principal Neighbourhood + Aggregation for Graph Nets (https://arxiv.org/pdf/2004.05718). + + Args: + input_size (int):the size of input. + hidden_size (int): the size of output. + aggregators (list): List of aggregation function keyword, + choices in ["mean", "sum", "max", "min", "var", "std"] + scalers: (list): List of scaler function keyword, + choices in ["identity", "amplification", + "attenuation", "linear", "inverse_linear"] + deg (Tensor): Histogram of in-degrees of nodes in the training set for computed avg_deg for scalers + towers (int, optional): Number of towers. Default: 1 + pre_layers (int, optional): Number of transformation layers before + aggregation. Default: 1 + post_layers (int, optional): Number of transformation layers after + aggregation. Default: 1 + divide_input (bool, optional): Whether the input features should + be split between towers or not. Default: False + use_edge (bool, optional): Whether to use edge feature. Default: False + """ + + def __init__(self, + input_size, + hidden_size, + aggregators, + scalers, + deg, + towers=1, + pre_layers=1, + post_layers=1, + divide_input=False, + use_edge=False): + super(PNAConv, self).__init__() + if divide_input: + assert input_size % towers == 0 + assert hidden_size % towers == 0 + self.input_size = input_size + self.hidden_size = hidden_size + self.aggregators = [AGGREGATOR[aggr] for aggr in aggregators] + self.scalers = [SCALERS[scaler] for scaler in scalers] + self.use_edge = use_edge + self.towers = towers + self.divide_input = divide_input + self.pre_layers = pre_layers + self.post_layers = post_layers + self.F_in = input_size // towers if divide_input else input_size + self.F_out = self.hidden_size // towers + + deg = deg.astype("float32") + total_no_vertices = deg.sum() + bin_degrees = paddle.arange(len(deg), dtype="float32") + self.avg_deg = { + 'lin': ((bin_degrees * deg).sum() / total_no_vertices).item(), + 'log': + (((bin_degrees + 1).log() * deg).sum() / total_no_vertices).item(), + 'exp': ( + (bin_degrees.exp() * deg).sum() / total_no_vertices).item(), + } + if use_edge: + self.edge_mlp = paddle.nn.Linear(input_size, self.F_in) + self.pre_nns = nn.LayerList() + self.post_nns = nn.LayerList() + for _ in range(self.towers): + modules = [ + nn.Linear((3 if self.use_edge else 2) * self.F_in, self.F_in) + ] + for _ in range(self.pre_layers - 1): + modules += [nn.ReLU()] + modules += [nn.Linear(self.F_in, self.F_in)] + self.pre_nns.append(nn.Sequential(*modules)) + input_size = (len(aggregators) * len(scalers) + 1) * self.F_in + modules = [nn.Linear(input_size, self.F_out)] + for _ in range(post_layers - 1): + modules += [nn.ReLU()] + modules += [nn.Linear(self.F_out, self.F_out)] + self.post_nns.append(nn.Sequential(*modules)) + self.lin = nn.Linear(self.hidden_size, self.hidden_size) + + def _send_attention(self, src_feat, dst_feat, edge_feat): + if "edge_feat" in edge_feat: + edge_feat = self.edge_mlp(edge_feat['edge_feat']) + edge_feat = edge_feat.reshape([-1, 1, self.F_in]) + edge_feat = edge_feat.tile([1, self.towers, 1]) + h = paddle.concat( + [src_feat['h'], dst_feat['h'], edge_feat], axis=-1) + else: + h = paddle.concat([src_feat['h'], dst_feat['h']], axis=-1) + hs = [nn(h[:, i]) for i, nn in enumerate(self.pre_nns)] + return {"h": paddle.stack(hs, axis=1)} + + def _reduce_attention(self, msg): + outs = [aggr(msg['h'], msg._segment_ids) for aggr in self.aggregators] + out = paddle.concat(outs, axis=-1) + return out.reshape([out.shape[0], -1]) + + def forward(self, graph, feature, deg, edge_feat=None): + """ + forward function of PNAConv + Args: + graph: pgl.graph instance. + feature: A tensor with shape (num_nodes, input_size) + deg: the in-degree of input nodes + edge_feat(optional): input edge features + """ + if self.divide_input: + feature = feature.reshape([-1, self.towers, self.F_in]) + else: + feature = feature.reshape([-1, 1, self.F_in]).tile( + [1, self.towers, 1]) + msg = graph.send( + self._send_attention, + src_feat={"h": feature}, + dst_feat={"h": feature}, + edge_feat={"edge_feat": edge_feat} if self.use_edge else {}) + + out = graph.recv(reduce_func=self._reduce_attention, msg=msg) + out = out.reshape([out.shape[0], self.towers, -1]) + deg = deg.astype("float32").reshape([-1, 1, 1]) + outs = [scaler(out, deg, self.avg_deg) for scaler in self.scalers] + out = paddle.concat(outs, axis=-1) + out = paddle.concat([feature, out], axis=-1) + outs = [nn(out[:, i]) for i, nn in enumerate(self.post_nns)] + out = paddle.concat(outs, axis=1) + return self.lin(out) + + +def scale_identity(src, deg, avg_deg): + """ + Implementation of identity scaler + """ + return src + + +def scale_amplification(src, deg, avg_deg): + """ + Implementation of amplification scaler + """ + return src * (paddle.log(deg + 1) / avg_deg['log']) + + +def scale_attenuation(src, deg, avg_deg): + """ + Implementation of attenuation scaler + """ + scale = avg_deg['log'] / paddle.log(deg + 1) + scale = paddle.where(deg == 0, paddle.ones(scale.shape), scale) + return src * scale + + +def scale_linear(src, deg, avg_deg): + """ + Implementation of linear scaler + """ + return src * (deg / avg_deg['lin']) + + +def scale_inverse_linear(src, deg, avg_deg): + """ + Implementation of inverse_linear scaler + """ + + scale = avg_deg['lin'] / deg + scale = paddle.where(deg == 0, paddle.ones(scale.shape), scale) + return src * scale + + +SCALERS = { + 'identity': scale_identity, + 'amplification': scale_amplification, + 'attenuation': scale_attenuation, + 'linear': scale_linear, + 'inverse_linear': scale_inverse_linear +} + + +def aggregate_sum(src, segment_id): + """ + Implementation of sum aggregator + """ + return pgl.math.segment_sum(src, segment_id) + + +def aggregate_mean(src, segment_id): + """ + Implementation of mean aggregator + """ + return pgl.math.segment_mean(src, segment_id) + + +def aggregate_max(src, segment_id): + """ + Implementation of max aggregator + """ + return pgl.math.segment_max(src, segment_id) + + +def aggregate_min(src, segment_id): + """ + Implementation of min aggregator + """ + return pgl.math.segment_min(src, segment_id) + + +def aggregate_var(src, segment_id): + """ + Implementation of var aggregator + """ + mean = aggregate_mean(src, segment_id) + mean_squares = aggregate_mean(src * src, segment_id) + return mean_squares - mean * mean + + +def aggregate_std(src, segment_id): + """ + Implementation of std aggregator + """ + return paddle.sqrt( + paddle.nn.functional.relu(aggregate_var(src, segment_id)) + 1e-5) + + +AGGREGATOR = { + "sum": aggregate_sum, + "mean": aggregate_mean, + "max": aggregate_max, + "min": aggregate_min, + "var": aggregate_var, + "std": aggregate_std +} diff --git a/tests/test_conv.py b/tests/test_conv.py index 8330ea45..3afda6c9 100644 --- a/tests/test_conv.py +++ b/tests/test_conv.py @@ -22,6 +22,9 @@ def get_conv_list(): + """ + get_conv_list + """ return [ pgl.nn.GCNConv( input_size=global_feat_dim, output_size=global_feat_dim), @@ -41,7 +44,14 @@ def get_conv_list(): class ConvTest(unittest.TestCase): + """ + Test Conv + """ + def run_graph_conv(self, dtype="float32"): + """ + run_graph_conv + """ num_nodes = 5 edges = [(0, 1), (1, 2), (3, 4)] @@ -63,13 +73,52 @@ def run_graph_conv(self, dtype="float32"): self.assertTrue(isinstance(out, paddle.Tensor)) def test_graph_conv_float32(self): + """ + test_graph_conv_float32 + """ paddle.set_default_dtype("float32") self.run_graph_conv("float32") def test_graph_conv_float64(self): + """ + test_graph_conv_float32 + """ paddle.set_default_dtype("float64") self.run_graph_conv("float64") + def test_pna_conv(self): + """ + test pna conv + """ + num_nodes = 5 + edges = [(0, 1), (1, 2), (3, 4)] + nfeat = np.random.randn(num_nodes, global_feat_dim).astype("float32") + efeat = np.random.randn(len(edges), global_feat_dim).astype("float32") + + g = pgl.Graph( + edges=edges, + num_nodes=num_nodes, + node_feat={'nfeat': nfeat}, + edge_feat={'efeat': efeat}).tensor() + paddle.set_default_dtype("float32") + pna_conv = pgl.nn.PNAConv( + input_size=global_feat_dim, + hidden_size=global_feat_dim * 2, + aggregators=["mean", "max", "min", "sum", "var", "std"], + scalers=[ + "identity", "amplification", "attenuation", "linear", + "inverse_linear" + ], + deg=paddle.to_tensor([0, 1, 1, 1, 2]), + towers=2, + pre_layers=1, + post_layers=2, + divide_input=False, + use_edge=True) + out = pna_conv(g, g.node_feat['nfeat'], + g.indegree(), g.edge_feat['efeat']) + assert out.shape == [num_nodes, global_feat_dim * 2] + if __name__ == "__main__": unittest.main()