@@ -484,6 +484,16 @@ def start_splitlearning_session():
484484 rounds : int = data .get ("rounds" , "" )
485485 round_timeout : int = data .get ("round_timeout" , None )
486486 model_name_prefix : str = data .get ("model_name_prefix" , None )
487+ client_ids : str = data .get ("client_ids" , None )
488+
489+ if client_ids is not None and not isinstance (client_ids , str ):
490+ return jsonify ({"message" : "client_ids must be a comma separated string" }), 400
491+ if client_ids is not None :
492+ client_ids : list [str ] = client_ids .split ("," )
493+ if len (client_ids ) == 0 :
494+ return jsonify ({"message" : "client_ids must be a comma separated string" }), 400
495+ if any (not isinstance (client_id , str ) for client_id in client_ids ):
496+ return jsonify ({"message" : "client_ids must be a comma separated string" }), 400
487497
488498 if model_name_prefix is None or not isinstance (model_name_prefix , str ) or len (model_name_prefix ) == 0 :
489499 model_name_prefix = None
@@ -500,7 +510,7 @@ def start_splitlearning_session():
500510
501511 if not rounds or not isinstance (rounds , int ):
502512 rounds = session_config .rounds
503- nr_available_clients = _get_number_of_available_clients ()
513+ nr_available_clients = _get_number_of_available_clients (client_ids = client_ids )
504514
505515 if nr_available_clients < min_clients :
506516 return jsonify ({"message" : f"Number of available clients is lower than the required minimum of { min_clients } " }), 400
0 commit comments