Skip to content

Commit 83acd3d

Browse files
adding input parameter to _get_number_of_available_clients function
1 parent ec17a0a commit 83acd3d

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

fedn/network/api/v1/session_routes.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,16 @@ def start_splitlearning_session():
484484
rounds: int = data.get("rounds", "")
485485
round_timeout: int = data.get("round_timeout", None)
486486
model_name_prefix: str = data.get("model_name_prefix", None)
487+
client_ids: str = data.get("client_ids", None)
488+
489+
if client_ids is not None and not isinstance(client_ids, str):
490+
return jsonify({"message": "client_ids must be a comma separated string"}), 400
491+
if client_ids is not None:
492+
client_ids: list[str] = client_ids.split(",")
493+
if len(client_ids) == 0:
494+
return jsonify({"message": "client_ids must be a comma separated string"}), 400
495+
if any(not isinstance(client_id, str) for client_id in client_ids):
496+
return jsonify({"message": "client_ids must be a comma separated string"}), 400
487497

488498
if model_name_prefix is None or not isinstance(model_name_prefix, str) or len(model_name_prefix) == 0:
489499
model_name_prefix = None
@@ -500,7 +510,7 @@ def start_splitlearning_session():
500510

501511
if not rounds or not isinstance(rounds, int):
502512
rounds = session_config.rounds
503-
nr_available_clients = _get_number_of_available_clients()
513+
nr_available_clients = _get_number_of_available_clients(client_ids=client_ids)
504514

505515
if nr_available_clients < min_clients:
506516
return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400

0 commit comments

Comments
 (0)