diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index a1d5800c4..6f067f794 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -197,6 +197,8 @@ def get_model(model_config, local_data=None, backend='torch'): elif model_config.type.lower() in ['atc_model']: from federatedscope.nlp.hetero_tasks.model import ATCModel model = ATCModel(model_config) + elif model_config.type.lower() in ['nn']: + model = None else: raise ValueError('Model {} is not provided'.format(model_config.type)) diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index b32baf74e..24d538eae 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -176,6 +176,14 @@ def get_trainer(model=None, data=data, device=device, monitor=monitor) + elif config.trainer.type.lower() in ['nntrainer']: + from federatedscope.vertical_fl.nn_model.trainer.trainer \ + import nnTrainer + trainer = nnTrainer(config=config, + model=model, + data=data, + device=device, + monitor=monitor) else: # try to find user registered trainer trainer = None diff --git a/federatedscope/core/auxiliaries/worker_builder.py b/federatedscope/core/auxiliaries/worker_builder.py index 49fd30631..75a415032 100644 --- a/federatedscope/core/auxiliaries/worker_builder.py +++ b/federatedscope/core/auxiliaries/worker_builder.py @@ -65,6 +65,9 @@ def get_client_cls(cfg): from federatedscope.vertical_fl.tree_based_models.worker \ import TreeClient return TreeClient + elif cfg.vertical.algo == 'nn': + from federatedscope.vertical_fl.nn_model.worker import nnClient + return nnClient else: raise ValueError(f'No client class for {cfg.vertical.algo}') @@ -180,6 +183,9 @@ def get_server_cls(cfg): from federatedscope.vertical_fl.tree_based_models.worker \ import TreeServer return TreeServer + elif cfg.vertical.algo == 'nn': + from federatedscope.vertical_fl.nn_model.worker import nnServer + return nnServer else: raise ValueError(f'No server class for {cfg.vertical.algo}') diff --git a/federatedscope/core/configs/cfg_fl_setting.py b/federatedscope/core/configs/cfg_fl_setting.py index 7c3a62bf5..164791e00 100644 --- a/federatedscope/core/configs/cfg_fl_setting.py +++ b/federatedscope/core/configs/cfg_fl_setting.py @@ -102,6 +102,7 @@ def extend_fl_setting_cfg(cfg): cfg.vertical.data_size_for_debug = 0 # use a subset for debug in vfl, # 0 indicates using the entire dataset (disable debug mode) + cfg.vertical.output_layer = [2, 2] # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_fl_setting_cfg) diff --git a/federatedscope/vertical_fl/dataloader/dataloader.py b/federatedscope/vertical_fl/dataloader/dataloader.py index 09ad26187..9effe3632 100644 --- a/federatedscope/vertical_fl/dataloader/dataloader.py +++ b/federatedscope/vertical_fl/dataloader/dataloader.py @@ -52,15 +52,27 @@ def load_vertical_data(config=None, generate=False): theta = np.random.uniform(low=-1.0, high=1.0, size=(total_dims, 1)) x = np.random.choice([-1.0, 1.0, -2.0, 2.0, -3.0, 3.0], size=(INSTANCE_NUM, total_dims)) - y = np.asarray([ - 1.0 if x >= 0 else -1.0 - for x in np.reshape(np.matmul(x, theta), -1) - ]) + if config.vertical.algo == 'nn': + y = np.asarray([ + 1.0 if x >= 0 else 0 + for x in np.reshape(np.matmul(x, theta), -1) + ]) + else: + y = np.asarray([ + 1.0 if x >= 0 else -1.0 + for x in np.reshape(np.matmul(x, theta), -1) + ]) train_num = int(TRAIN_SPLIT * INSTANCE_NUM) test_data = {'theta': theta, 'x': x[train_num:], 'y': y[train_num:]} data = dict() + test_data_1 = test_data.copy() + test_data_1['x'] = x[train_num:, :config.vertical.dims[0]] + test_data_1['y'] = None + test_data_2 = test_data.copy() + test_data_2['x'] = x[train_num:, config.vertical.dims[0]:] + # For Server data[0] = dict() data[0]['train'] = None @@ -71,7 +83,7 @@ def load_vertical_data(config=None, generate=False): data[1] = dict() data[1]['train'] = {'x': x[:train_num, :config.vertical.dims[0]]} data[1]['val'] = None - data[1]['test'] = test_data + data[1]['test'] = test_data_1 # For Client #2 data[2] = dict() @@ -80,6 +92,6 @@ def load_vertical_data(config=None, generate=False): 'y': y[:train_num] } data[2]['val'] = None - data[2]['test'] = test_data + data[2]['test'] = test_data_2 return data, config diff --git a/federatedscope/vertical_fl/linear_model/worker/vertical_server.py b/federatedscope/vertical_fl/linear_model/worker/vertical_server.py index 2fd34faf7..36a178535 100644 --- a/federatedscope/vertical_fl/linear_model/worker/vertical_server.py +++ b/federatedscope/vertical_fl/linear_model/worker/vertical_server.py @@ -5,6 +5,7 @@ from federatedscope.core.message import Message from federatedscope.vertical_fl.Paillier import abstract_paillier from federatedscope.core.auxiliaries.model_builder import get_model +from sklearn import metrics logger = logging.getLogger(__name__) @@ -123,6 +124,13 @@ def evaluate(self): test_y = self.data['test']['y'] loss = np.mean( np.log(1 + np.exp(-test_y * np.matmul(test_x, self.theta)))) - acc = np.mean((test_y * np.matmul(test_x, self.theta)) > 0) - - return {'test_loss': loss, 'test_acc': acc, 'test_total': len(test_y)} + y_hat = np.matmul(test_x, self.theta) + auc = metrics.roc_auc_score(test_y, y_hat) + acc = np.mean((test_y * y_hat) > 0) + + return { + 'test_loss': loss, + 'test_acc': acc, + 'test_auc': auc, + 'test_total': len(test_y) + } diff --git a/federatedscope/vertical_fl/nn_model/baseline/vertical_nn.yaml b/federatedscope/vertical_fl/nn_model/baseline/vertical_nn.yaml new file mode 100644 index 000000000..8402ff6eb --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/baseline/vertical_nn.yaml @@ -0,0 +1,31 @@ +use_gpu: False +federate: + mode: standalone + client_num: 2 + total_round_num: 30 +model: + type: nn + use_bias: False +train: + optimizer: + lr: 0.7 +data: + root: data/ + type: synthetic_vfl_data +criterion: + type: CrossEntropyLoss +dataloader: + type: raw + batch_size: 50 +vertical: + use: True + dims: [5, 10] + output_layer: [2, 2] + algo: 'nn' + protect_method: 'dp' + protect_args: [{'algo': 'Bernoulli', 'para': 0.1}] # [{'algo': 'Laplace', 'para': 0.01}] +trainer: + type: nntrainer +eval: + freq: 5 + best_res_update_round_wise_key: test_loss diff --git a/federatedscope/vertical_fl/nn_model/trainer/__init__.py b/federatedscope/vertical_fl/nn_model/trainer/__init__.py new file mode 100644 index 000000000..d5b957a66 --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/trainer/__init__.py @@ -0,0 +1,3 @@ +from federatedscope.vertical_fl.nn_model.trainer.trainer import nnTrainer + +__all__ = ['nnTrainer'] diff --git a/federatedscope/vertical_fl/nn_model/trainer/trainer.py b/federatedscope/vertical_fl/nn_model/trainer/trainer.py new file mode 100644 index 000000000..9f9a8537c --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/trainer/trainer.py @@ -0,0 +1,145 @@ +import copy + +import torch +from torch import nn, optim +from torch.autograd import Variable +import numpy as np + +from federatedscope.vertical_fl.dataloader.utils import batch_iter + + +class nnTrainer(object): + def __init__(self, model, data, device, config, monitor): + self.model = model + self.data = data + self.device = device + self.cfg = config + self.monitor = monitor + + self.batch_x = None + self.batch_y = None + self.batch_y_hat = None + self.bottom_model = None + self.top_model = None + + self.criterion = nn.MSELoss() + + self.grad_partition = [0] + config.vertical.output_layer + + def sample_data(self, index=None): + if index is None: + return next(self.dataloader) + else: + return self.data['train']['x'][index] + + def _init_for_train(self, + bottom_input_layer, + bottom_output_layer, + top_input_layer=None): + self.lr = self.cfg.train.optimizer.lr + self.dataloader = batch_iter(self.data['train'], + self.cfg.dataloader.batch_size, + shuffled=True) + self._set_bottom_model(bottom_input_layer, bottom_output_layer) + self.bottom_model_opt = optim.SGD(self.bottom_model.parameters(), + lr=self.lr) + self.bottom_model_opt.zero_grad() + if top_input_layer: + self._set_top_model(top_input_layer) + self.top_model_opt = optim.SGD(self.top_model.parameters(), + lr=self.lr) + self.top_model_opt.zero_grad() + + def fetch_train_data(self, index=None): + # Fetch new data + self.bottom_model_opt.zero_grad() + if self.top_model: + self.top_model_opt.zero_grad() + if not index: + batch_index, self.batch_x, self.batch_y = self.sample_data(index) + # convert 'range' to 'list' + # to support gRPC protocols in distributed mode + batch_index = list(batch_index) + else: + self.batch_x = self.sample_data(index) + batch_index = 'None' + + self.batch_x = torch.Tensor(self.batch_x) + if self.batch_y is not None: + self.batch_y = np.vstack(self.batch_y).reshape(-1, 1) + self.batch_y = torch.Tensor(self.batch_y) + + return batch_index + + def train_bottom(self): + self.middle_result = self.bottom_model(self.batch_x) + middle_result = self.middle_result.data # detach().requires_grad_() + return middle_result + + def train_top(self, input_): + + train_loss, grad = self.protect_grad(input_) + self.top_model_opt.step() + + grad_list = [] + for i in range(len(self.grad_partition) - 1): + grad_list.append( + grad[:, self.grad_partition[i]:self.grad_partition[i] + + self.grad_partition[i + 1]]) + my_grad = grad_list[-1] + self.bottom_model_backward(my_grad) + + return train_loss, grad_list[:-1] + + def bottom_model_backward(self, grad=None): + self.middle_result.backward(grad) + self.bottom_model_opt.step() + + def _set_bottom_model(self, input_layer, out_put_layer): + self.bottom_input_layer = input_layer + self.bottom_output_layer = out_put_layer + self.bottom_model = nn.Sequential( + nn.Linear(input_layer, out_put_layer), nn.ReLU()) + + def _set_top_model(self, input_layer, out_put_layer=1): + self.top_input_layer = input_layer + self.top_model = nn.Sequential(nn.Linear(input_layer, out_put_layer), + nn.Sigmoid()) + + def protect_grad(self, input_): + + fake_grad = 0 + algo = None + para = None + # The following protect method is proposed in + # "Differentially Private Label Protection in Split Learning" + if self.cfg.vertical.protect_method == 'dp': + args = self.cfg.vertical.protect_args[0] if len( + self.cfg.vertical.protect_args) > 0 else {} + algo = args.get('algo', 'Laplace') + para = args.get('para', 1) + repeat_model = copy.deepcopy(self.top_model) + repeat_input = input_.detach().requires_grad_() + repeat_y_hat = repeat_model(repeat_input) + fake_y = torch.Tensor( + np.vstack(np.ones(len(self.batch_y))).reshape( + -1, 1)) - self.batch_y + fake_train_loss = self.criterion(repeat_y_hat, fake_y) + fake_train_loss.backward() + fake_grad = repeat_input.grad + + y_hat = self.top_model(input_) + train_loss = self.criterion(y_hat, self.batch_y) + train_loss.backward() + grad = input_.grad + + if algo == 'Laplace': + u = np.random.laplace(para) + elif algo == 'Bernoulli': + u = np.random.binomial(1, para) + else: + u = 0 + + grad = grad + u * (fake_grad - grad) + + return train_loss, grad diff --git a/federatedscope/vertical_fl/nn_model/worker/__init__.py b/federatedscope/vertical_fl/nn_model/worker/__init__.py new file mode 100644 index 000000000..7f77dfa42 --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/worker/__init__.py @@ -0,0 +1,6 @@ +from federatedscope.vertical_fl.nn_model.worker.nn_client import\ + nnClient +from federatedscope.vertical_fl.nn_model.worker.nn_server import\ + nnServer + +__all__ = ['nnServer', 'nnClient'] diff --git a/federatedscope/vertical_fl/nn_model/worker/nn_client.py b/federatedscope/vertical_fl/nn_model/worker/nn_client.py new file mode 100644 index 000000000..51c479dcf --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/worker/nn_client.py @@ -0,0 +1,233 @@ +import numpy as np +import logging + +import torch +from torch import nn +from torch.autograd import Variable +from sklearn import metrics + +from federatedscope.core.workers import Client +from federatedscope.core.message import Message + +logger = logging.getLogger(__name__) + + +class nnClient(Client): + def __init__(self, + ID=-1, + server_id=None, + state=-1, + config=None, + data=None, + model=None, + device='cpu', + strategy=None, + *args, + **kwargs): + + super(nnClient, + self).__init__(ID, server_id, state, config, data, model, device, + strategy, *args, **kwargs) + + self.middle_result_dict = dict() + self.data = data + self.client_num = config.federate.client_num + self.own_label = ('y' in data['train']) + self.dims = [0] + config.vertical.dims + self.bottom_input_layer = self.dims[self.ID] - self.dims[self.ID - 1] + self.bottom_output_layer = config.vertical.output_layer[self.ID - 1] + self.top_input_layer = None + if self.own_label: + self.top_input_layer = np.sum(config.vertical.output_layer) + self.eval_middle_result_dict = dict() + + self._init_data_related_var() + + self.register_handlers('model_para', + self.callback_funcs_for_model_para) + self.register_handlers('data_sample', + self.callback_func_for_data_sample) + self.register_handlers('middle_result', + self.callback_func_for_middle_result) + self.register_handlers('grad', self.callback_func_for_grad) + self.register_handlers('continue', self.callback_funcs_for_continue) + self.register_handlers('eval', self.callback_func_for_eval) + self.register_handlers('eval_middle_result', + self.callback_func_for_eval_middle_result) + + # def train(self): + # raise NotImplementedError + # + # def eval(self): + # raise NotImplementedError + + def _init_data_related_var(self): + self.trainer._init_for_train(self.bottom_input_layer, + self.bottom_output_layer, + self.top_input_layer) + self.test_x = torch.Tensor(self.data['test']['x']) + + if self.own_label: + self.test_y = np.vstack(self.data['test']['y']).reshape(-1, 1) + self.test_y = torch.Tensor(self.test_y) + + def callback_funcs_for_model_para(self, message: Message): + self.state = message.state + + if self.own_label: + self.start_a_new_training_round() + + def start_a_new_training_round(self): + logger.info(f'----------- Starting a new round (Round ' + f'#{self.state}) -------------') + batch_index = self.trainer.fetch_train_data() + + receiver = [ + each for each in list(self.comm_manager.neighbors.keys()) + if each != self.server_id + ] + send_message = Message(msg_type='data_sample', + sender=self.ID, + state=self.state, + receiver=receiver, + content=batch_index) + self.comm_manager.send(send_message) + self.train() + + def callback_func_for_data_sample(self, message: Message): + self.state = message.state + batch_index = message.content + _ = self.trainer.fetch_train_data(index=batch_index) + self.train() + + def train(self): + middle_result = self.trainer.train_bottom() + if self.own_label: + self.middle_result_dict[self.ID] = middle_result + else: + self.comm_manager.send( + Message(msg_type='middle_result', + sender=self.ID, + receiver=[self.client_num], + state=self.state, + content=middle_result)) + + def callback_func_for_middle_result(self, message: Message): + middle_result = message.content + self.middle_result_dict[message.sender] = middle_result + if len(self.middle_result_dict) == self.client_num: + client_ids = list(self.middle_result_dict.keys()) + client_ids = sorted(client_ids) + + middle_result = torch.cat( + [self.middle_result_dict[i] for i in client_ids], 1) + self.middle_result_dict = dict() + + middle_result = Variable(middle_result, requires_grad=True) + + train_loss, grad_list = self.trainer.train_top(middle_result) + # print(train_loss) + for i in range(self.client_num - 1): + self.comm_manager.send( + Message(msg_type='grad', + sender=self.ID, + receiver=[i + 1], + state=self.state, + content=grad_list[i])) + + def callback_func_for_grad(self, message: Message): + grad = message.content + self.trainer.bottom_model_backward(grad) + + self.comm_manager.send( + Message(msg_type='continue', + sender=self.ID, + receiver=[self.client_num], + state=self.state, + content='None')) + + def callback_funcs_for_continue(self, message: Message): + + if (self.state+1) % self._cfg.eval.freq == 0 and \ + (self.state+1) != self._cfg.federate.total_round_num: + self.eval_middle_result_dict[self.ID] = self.trainer.bottom_model( + self.test_x) + self.comm_manager.send( + Message(msg_type='eval', + sender=self.ID, + receiver=[ + each for each in self.comm_manager.neighbors + if each != self.server_id + ], + state=self.state, + content='None')) + + elif self.state + 1 < self._cfg.federate.total_round_num: + self.state += 1 + self.start_a_new_training_round() + else: + self.comm_manager.send( + Message(msg_type='eval', + sender=self.ID, + receiver=[ + each for each in self.comm_manager.neighbors + if each != self.server_id + ], + state=self.state, + content='None')) + self.eval_middle_result_dict[self.ID] = self.trainer.bottom_model( + self.test_x) + + def callback_func_for_eval(self, message: Message): + eval_middle_result = self.trainer.bottom_model(self.test_x) + self.comm_manager.send( + Message(msg_type='eval_middle_result', + sender=self.ID, + receiver=[message.sender], + state=self.state, + content=eval_middle_result)) + + def callback_func_for_eval_middle_result(self, message: Message): + eval_middle_result = message.content + self.eval_middle_result_dict[message.sender] = eval_middle_result + if len(self.eval_middle_result_dict) == self.client_num: + client_ids = list(self.eval_middle_result_dict.keys()) + client_ids = sorted(client_ids) + + eval_middle_result = torch.cat( + [self.eval_middle_result_dict[i] for i in client_ids], 1) + self.eval_middle_result_dict = dict() + + y_hat = self.trainer.top_model(eval_middle_result) + + test_loss = self.trainer.criterion(y_hat, self.test_y) + + auc = metrics.roc_auc_score( + self.test_y.reshape(-1).detach().numpy(), + y_hat.reshape(-1).detach().numpy()) + y_hat = (y_hat >= 0.5) + + acc = torch.sum(y_hat == self.test_y) / len(self.test_y) + + self.metrics = { + 'test_loss': test_loss.detach().numpy(), + "test_auc": auc, + "test_acc": acc.numpy(), + 'test_total': len(self.test_y) + } + + self._monitor.update_best_result(self.best_results, + self.metrics, + results_type='server_global_eval') + + formatted_logs = self._monitor.format_eval_res( + self.metrics, + rnd=self.state, + role='Server #', + forms=self._cfg.eval.report) + + logger.info(formatted_logs) + + if self.state + 1 < self._cfg.federate.total_round_num: + self.state += 1 + self.start_a_new_training_round() diff --git a/federatedscope/vertical_fl/nn_model/worker/nn_server.py b/federatedscope/vertical_fl/nn_model/worker/nn_server.py new file mode 100644 index 000000000..44d34f3d1 --- /dev/null +++ b/federatedscope/vertical_fl/nn_model/worker/nn_server.py @@ -0,0 +1,46 @@ +import numpy as np +import logging + +import torch + +from federatedscope.core.workers import Server +from federatedscope.core.message import Message +from federatedscope.vertical_fl.Paillier import abstract_paillier +from federatedscope.core.auxiliaries.model_builder import get_model + +logger = logging.getLogger(__name__) + + +class nnServer(Server): + def __init__(self, + ID=-1, + state=0, + config=None, + data=None, + model=None, + client_num=5, + total_round_num=10, + device='cpu', + strategy=None, + **kwargs): + super(nnServer, + self).__init__(ID, state, config, data, model, client_num, + total_round_num, device, strategy, **kwargs) + self.model_dict = dict() + cfg_key_size = config.vertical.key_size + self.public_key, self.private_key = \ + abstract_paillier.generate_paillier_keypair(n_length=cfg_key_size) + self.vertical_dims = config.vertical.dims + + def trigger_for_start(self): + if self.check_client_join_in(): + self.broadcast_client_address() + self.trigger_for_feat_engr(self.broadcast_model_para) + + def broadcast_model_para(self): + self.comm_manager.send( + Message(msg_type='model_para', + sender=self.ID, + receiver=[each for each in self.comm_manager.neighbors], + state=self.state, + content='None'))