Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions federatedscope/core/auxiliaries/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,8 @@ def get_model(model_config, local_data=None, backend='torch'):
elif model_config.type.lower() in ['atc_model']:
from federatedscope.nlp.hetero_tasks.model import ATCModel
model = ATCModel(model_config)
elif model_config.type.lower() in ['nn']:
model = None
else:
raise ValueError('Model {} is not provided'.format(model_config.type))

Expand Down
8 changes: 8 additions & 0 deletions federatedscope/core/auxiliaries/trainer_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ def get_trainer(model=None,
data=data,
device=device,
monitor=monitor)
elif config.trainer.type.lower() in ['nntrainer']:
from federatedscope.vertical_fl.nn_model.trainer.trainer \
import nnTrainer
trainer = nnTrainer(config=config,
model=model,
data=data,
device=device,
monitor=monitor)
else:
# try to find user registered trainer
trainer = None
Expand Down
6 changes: 6 additions & 0 deletions federatedscope/core/auxiliaries/worker_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,9 @@ def get_client_cls(cfg):
from federatedscope.vertical_fl.tree_based_models.worker \
import TreeClient
return TreeClient
elif cfg.vertical.algo == 'nn':
from federatedscope.vertical_fl.nn_model.worker import nnClient
return nnClient
else:
raise ValueError(f'No client class for {cfg.vertical.algo}')

Expand Down Expand Up @@ -180,6 +183,9 @@ def get_server_cls(cfg):
from federatedscope.vertical_fl.tree_based_models.worker \
import TreeServer
return TreeServer
elif cfg.vertical.algo == 'nn':
from federatedscope.vertical_fl.nn_model.worker import nnServer
return nnServer
else:
raise ValueError(f'No server class for {cfg.vertical.algo}')

Expand Down
1 change: 1 addition & 0 deletions federatedscope/core/configs/cfg_fl_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def extend_fl_setting_cfg(cfg):
cfg.vertical.data_size_for_debug = 0 # use a subset for debug in vfl,
# 0 indicates using the entire dataset (disable debug mode)

cfg.vertical.output_layer = [2, 2]
# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_fl_setting_cfg)

Expand Down
24 changes: 18 additions & 6 deletions federatedscope/vertical_fl/dataloader/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,27 @@ def load_vertical_data(config=None, generate=False):
theta = np.random.uniform(low=-1.0, high=1.0, size=(total_dims, 1))
x = np.random.choice([-1.0, 1.0, -2.0, 2.0, -3.0, 3.0],
size=(INSTANCE_NUM, total_dims))
y = np.asarray([
1.0 if x >= 0 else -1.0
for x in np.reshape(np.matmul(x, theta), -1)
])
if config.vertical.algo == 'nn':
y = np.asarray([
1.0 if x >= 0 else 0
for x in np.reshape(np.matmul(x, theta), -1)
])
else:
y = np.asarray([
1.0 if x >= 0 else -1.0
for x in np.reshape(np.matmul(x, theta), -1)
])

train_num = int(TRAIN_SPLIT * INSTANCE_NUM)
test_data = {'theta': theta, 'x': x[train_num:], 'y': y[train_num:]}
data = dict()

test_data_1 = test_data.copy()
test_data_1['x'] = x[train_num:, :config.vertical.dims[0]]
test_data_1['y'] = None
test_data_2 = test_data.copy()
test_data_2['x'] = x[train_num:, config.vertical.dims[0]:]

# For Server
data[0] = dict()
data[0]['train'] = None
Expand All @@ -71,7 +83,7 @@ def load_vertical_data(config=None, generate=False):
data[1] = dict()
data[1]['train'] = {'x': x[:train_num, :config.vertical.dims[0]]}
data[1]['val'] = None
data[1]['test'] = test_data
data[1]['test'] = test_data_1

# For Client #2
data[2] = dict()
Expand All @@ -80,6 +92,6 @@ def load_vertical_data(config=None, generate=False):
'y': y[:train_num]
}
data[2]['val'] = None
data[2]['test'] = test_data
data[2]['test'] = test_data_2

return data, config
14 changes: 11 additions & 3 deletions federatedscope/vertical_fl/linear_model/worker/vertical_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from federatedscope.core.message import Message
from federatedscope.vertical_fl.Paillier import abstract_paillier
from federatedscope.core.auxiliaries.model_builder import get_model
from sklearn import metrics

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -123,6 +124,13 @@ def evaluate(self):
test_y = self.data['test']['y']
loss = np.mean(
np.log(1 + np.exp(-test_y * np.matmul(test_x, self.theta))))
acc = np.mean((test_y * np.matmul(test_x, self.theta)) > 0)

return {'test_loss': loss, 'test_acc': acc, 'test_total': len(test_y)}
y_hat = np.matmul(test_x, self.theta)
auc = metrics.roc_auc_score(test_y, y_hat)
acc = np.mean((test_y * y_hat) > 0)

return {
'test_loss': loss,
'test_acc': acc,
'test_auc': auc,
'test_total': len(test_y)
}
31 changes: 31 additions & 0 deletions federatedscope/vertical_fl/nn_model/baseline/vertical_nn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use_gpu: False
federate:
mode: standalone
client_num: 2
total_round_num: 30
model:
type: nn
use_bias: False
train:
optimizer:
lr: 0.7
data:
root: data/
type: synthetic_vfl_data
criterion:
type: CrossEntropyLoss
dataloader:
type: raw
batch_size: 50
vertical:
use: True
dims: [5, 10]
output_layer: [2, 2]
algo: 'nn'
protect_method: 'dp'
protect_args: [{'algo': 'Bernoulli', 'para': 0.1}] # [{'algo': 'Laplace', 'para': 0.01}]
trainer:
type: nntrainer
eval:
freq: 5
best_res_update_round_wise_key: test_loss
3 changes: 3 additions & 0 deletions federatedscope/vertical_fl/nn_model/trainer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from federatedscope.vertical_fl.nn_model.trainer.trainer import nnTrainer

__all__ = ['nnTrainer']
145 changes: 145 additions & 0 deletions federatedscope/vertical_fl/nn_model/trainer/trainer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import copy

import torch
from torch import nn, optim
from torch.autograd import Variable
import numpy as np

from federatedscope.vertical_fl.dataloader.utils import batch_iter


class nnTrainer(object):
def __init__(self, model, data, device, config, monitor):
self.model = model
self.data = data
self.device = device
self.cfg = config
self.monitor = monitor

self.batch_x = None
self.batch_y = None
self.batch_y_hat = None
self.bottom_model = None
self.top_model = None

self.criterion = nn.MSELoss()

self.grad_partition = [0] + config.vertical.output_layer

def sample_data(self, index=None):
if index is None:
return next(self.dataloader)
else:
return self.data['train']['x'][index]

def _init_for_train(self,
bottom_input_layer,
bottom_output_layer,
top_input_layer=None):
self.lr = self.cfg.train.optimizer.lr
self.dataloader = batch_iter(self.data['train'],
self.cfg.dataloader.batch_size,
shuffled=True)
self._set_bottom_model(bottom_input_layer, bottom_output_layer)
self.bottom_model_opt = optim.SGD(self.bottom_model.parameters(),
lr=self.lr)
self.bottom_model_opt.zero_grad()
if top_input_layer:
self._set_top_model(top_input_layer)
self.top_model_opt = optim.SGD(self.top_model.parameters(),
lr=self.lr)
self.top_model_opt.zero_grad()

def fetch_train_data(self, index=None):
# Fetch new data
self.bottom_model_opt.zero_grad()
if self.top_model:
self.top_model_opt.zero_grad()
if not index:
batch_index, self.batch_x, self.batch_y = self.sample_data(index)
# convert 'range' to 'list'
# to support gRPC protocols in distributed mode
batch_index = list(batch_index)
else:
self.batch_x = self.sample_data(index)
batch_index = 'None'

self.batch_x = torch.Tensor(self.batch_x)
if self.batch_y is not None:
self.batch_y = np.vstack(self.batch_y).reshape(-1, 1)
self.batch_y = torch.Tensor(self.batch_y)

return batch_index

def train_bottom(self):
self.middle_result = self.bottom_model(self.batch_x)
middle_result = self.middle_result.data # detach().requires_grad_()
return middle_result

def train_top(self, input_):

train_loss, grad = self.protect_grad(input_)
self.top_model_opt.step()

grad_list = []
for i in range(len(self.grad_partition) - 1):
grad_list.append(
grad[:, self.grad_partition[i]:self.grad_partition[i] +
self.grad_partition[i + 1]])
my_grad = grad_list[-1]
self.bottom_model_backward(my_grad)

return train_loss, grad_list[:-1]

def bottom_model_backward(self, grad=None):
self.middle_result.backward(grad)
self.bottom_model_opt.step()

def _set_bottom_model(self, input_layer, out_put_layer):
self.bottom_input_layer = input_layer
self.bottom_output_layer = out_put_layer
self.bottom_model = nn.Sequential(
nn.Linear(input_layer, out_put_layer), nn.ReLU())

def _set_top_model(self, input_layer, out_put_layer=1):
self.top_input_layer = input_layer
self.top_model = nn.Sequential(nn.Linear(input_layer, out_put_layer),
nn.Sigmoid())

def protect_grad(self, input_):

fake_grad = 0
algo = None
para = None
# The following protect method is proposed in
# "Differentially Private Label Protection in Split Learning"
if self.cfg.vertical.protect_method == 'dp':
args = self.cfg.vertical.protect_args[0] if len(
self.cfg.vertical.protect_args) > 0 else {}
algo = args.get('algo', 'Laplace')
para = args.get('para', 1)
repeat_model = copy.deepcopy(self.top_model)
repeat_input = input_.detach().requires_grad_()
repeat_y_hat = repeat_model(repeat_input)
fake_y = torch.Tensor(
np.vstack(np.ones(len(self.batch_y))).reshape(
-1, 1)) - self.batch_y
fake_train_loss = self.criterion(repeat_y_hat, fake_y)
fake_train_loss.backward()
fake_grad = repeat_input.grad

y_hat = self.top_model(input_)
train_loss = self.criterion(y_hat, self.batch_y)
train_loss.backward()
grad = input_.grad

if algo == 'Laplace':
u = np.random.laplace(para)
elif algo == 'Bernoulli':
u = np.random.binomial(1, para)
else:
u = 0

grad = grad + u * (fake_grad - grad)

return train_loss, grad
6 changes: 6 additions & 0 deletions federatedscope/vertical_fl/nn_model/worker/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from federatedscope.vertical_fl.nn_model.worker.nn_client import\
nnClient
from federatedscope.vertical_fl.nn_model.worker.nn_server import\
nnServer

__all__ = ['nnServer', 'nnClient']
Loading