Skip to content

Commit a6444b0

Browse files
committed
multiple collaborators per certificate
1 parent e6f3f5f commit a6444b0

File tree

5 files changed

+16
-4
lines changed

5 files changed

+16
-4
lines changed

openfl/component/aggregator/aggregator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def __init__(self,
4040
aggregator_uuid,
4141
federation_uuid,
4242
authorized_cols,
43+
cn_mapping,
4344

4445
init_state_path,
4546
best_state_path,
@@ -75,6 +76,7 @@ def __init__(self,
7576

7677
# if the collaborator requests a delta, this value is set to true
7778
self.authorized_cols = authorized_cols
79+
self.cn_mapping = cn_mapping
7880
self.uuid = aggregator_uuid
7981
self.federation_uuid = federation_uuid
8082
self.assigner = assigner
@@ -225,7 +227,7 @@ def valid_collaborator_cn_and_id(self, cert_common_name,
225227
# FIXME: '' instead of None is just for protobuf compatibility.
226228
# Cleaner solution?
227229
if self.single_col_cert_common_name == '':
228-
return (cert_common_name == collaborator_common_name
230+
return (cert_common_name == self.cn_mapping[collaborator_common_name]
229231
and collaborator_common_name in self.authorized_cols)
230232
# otherwise, common_name must be in whitelist and
231233
# collaborator_common_name must be in authorized_cols

openfl/federated/plan/plan.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,15 @@ def parse(plan_config_path: Path, cols_config_path: Path = None,
135135
gandlf_config['output_dir'] = gandlf_config.get('output_dir', '.')
136136
plan.config['task_runner']['settings']['gandlf_config'] = gandlf_config
137137

138-
plan.authorized_cols = Plan.load(cols_config_path).get(
138+
cols_info = Plan.load(cols_config_path).get(
139139
'collaborators', []
140140
)
141+
if isinstance(cols_info, list):
142+
plan.cn_mapping = {col: col for col in cols_info}
143+
plan.authorized_cols = cols_info
144+
else:
145+
plan.cn_mapping = cols_info
146+
plan.authorized_cols = list(cols_info.keys())
141147

142148
# TODO: Does this need to be a YAML file? Probably want to use key
143149
# value as the plan hash
@@ -223,6 +229,7 @@ def __init__(self):
223229
"""Initialize."""
224230
self.config = {} # dictionary containing patched plan definition
225231
self.authorized_cols = [] # authorized collaborator list
232+
self.cn_mapping = {} # expected cert common name for each collaborator
226233
self.cols_data_paths = {} # collaborator data paths dict
227234

228235
self.collaborator_ = None # collaborator object
@@ -338,6 +345,7 @@ def get_aggregator(self, tensor_dict=None):
338345
defaults[SETTINGS]['aggregator_uuid'] = self.aggregator_uuid
339346
defaults[SETTINGS]['federation_uuid'] = self.federation_uuid
340347
defaults[SETTINGS]['authorized_cols'] = self.authorized_cols
348+
defaults[SETTINGS]['cn_mapping'] = self.cn_mapping
341349
defaults[SETTINGS]['assigner'] = self.get_assigner()
342350
defaults[SETTINGS]['compression_pipeline'] = self.get_tensor_pipe()
343351
defaults[SETTINGS]['straggler_handling_policy'] = self.get_straggler_handling_policy()

openfl/interface/collaborator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def register_data_path(collaborator_name, data_path=None, silent=False):
137137
def generate_cert_request_(collaborator_name,
138138
silent, skip_package):
139139
"""Generate certificate request for the collaborator."""
140+
# TODO: this should take an extra argument: common_name
140141
generate_cert_request(collaborator_name, silent, skip_package)
141142

142143

@@ -304,7 +305,7 @@ def certify(collaborator_name, silent, request_pkg=None, import_=False):
304305
from openfl.utilities.utils import rmtree
305306

306307
common_name = f'{collaborator_name}'.lower()
307-
308+
# TODO: read and parse CSR and use the actual CN
308309
if not import_:
309310
if request_pkg:
310311
Path(f'{CERT_DIR}/client').mkdir(parents=True, exist_ok=True)

openfl/interface/interactive_api/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def _prepare_plan(self, model_provider, data_loader,
377377
self.plan.authorized_cols = [
378378
name for name, info in shard_registry.items() if info['is_online']
379379
]
380+
self.plan.cn_mapping = {name:name for name in self.plan.authorized_cols}
380381
# Network part of the plan
381382
# We keep in mind that an aggregator FQND will be the same as the directors FQDN
382383
# We just choose a port randomly from plan hash

openfl/transport/grpc/aggregator_server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def validate_collaborator(self, request, context):
8484
common_name, collaborator_common_name):
8585
# Random delay in authentication failures
8686
sleep(5 * random()) # nosec
87-
context.abort(
87+
context.abort( # TODO? add the expected CN in the error msg
8888
StatusCode.UNAUTHENTICATED,
8989
f'Invalid collaborator. CN: |{common_name}| '
9090
f'collaborator_common_name: |{collaborator_common_name}|')

0 commit comments

Comments
 (0)