diff --git a/.ci/tests/examples/print_logs.sh b/.ci/tests/examples/print_logs.sh index e2b079026..1c6383c99 100755 --- a/.ci/tests/examples/print_logs.sh +++ b/.ci/tests/examples/print_logs.sh @@ -22,11 +22,23 @@ if [ "$service" == "api-server" ]; then fi if [ "$service" == "combiner" ]; then - echo "Reducer logs" + echo "Combiner logs" docker logs "$(basename $PWD)-combiner-1" exit 0 fi +if [ "$service" == "controller" ]; then + echo "Controller logs" + docker logs "$(basename $PWD)-controller-1" + exit 0 +fi + +if [ "$service" == "hooks" ]; then + echo "Hooks logs" + docker logs "hook" + exit 0 +fi + if [ "$service" == "client" ]; then echo "Client 0 logs" if [ "$example" == "mnist-keras" ]; then diff --git a/.ci/tests/examples/run.sh b/.ci/tests/examples/run.sh index f77a80778..bfa0a1262 100755 --- a/.ci/tests/examples/run.sh +++ b/.ci/tests/examples/run.sh @@ -28,7 +28,7 @@ else docker compose \ -f ../../docker-compose.yaml \ -f docker-compose.override.yaml \ - up -d --build combiner api-server mongo minio client1 + up -d --build combiner controller api-server hooks mongo minio client1 fi # add server functions to python path to import server functions code @@ -40,6 +40,9 @@ python ../../.ci/tests/examples/wait_for.py reducer >&2 echo "Wait for combiners to connect" python ../../.ci/tests/examples/wait_for.py combiners +>&2 echo "Wait for controller to connect" +python ../../.ci/tests/examples/wait_for.py controller + >&2 echo "Upload compute package" python ../../.ci/tests/examples/api_test.py set_package --path package.tgz --helper "$helper" --name test @@ -61,20 +64,6 @@ fi >&2 echo "Checking rounds success" python ../../.ci/tests/examples/wait_for.py rounds ->&2 echo "Test client connection with dowloaded settings" -# Get config -python ../../.ci/tests/examples/api_test.py get_client_config --output ../../client.yaml - -# Redeploy clients with config -docker compose \ - -f ../../docker-compose.yaml \ - -f docker-compose.override.yaml \ - -f ../../.ci/tests/examples/compose-client-settings.override.yaml \ - up -d - ->&2 echo "Wait for clients to reconnect" -python ../../.ci/tests/examples/wait_for.py clients - >&2 echo "Test API GET requests" python ../../.ci/tests/examples/api_test.py test_api_get_methods diff --git a/.ci/tests/examples/wait_for.py b/.ci/tests/examples/wait_for.py index 3c4fa9383..e61b12b3e 100644 --- a/.ci/tests/examples/wait_for.py +++ b/.ci/tests/examples/wait_for.py @@ -63,6 +63,20 @@ def _test_nodes(n_nodes, node_type, reducer_host='localhost', reducer_port='8092 _eprint(f'Request exception enconuntered: {e}.') return False +def _test_controller(reducer_host='localhost', reducer_port='8092'): + try: + response = requests.get( + f'http://{reducer_host}:{reducer_port}/get_controller_status', verify=False) + + if response.status_code == 200: + data = json.loads(response.content) + _eprint(f'Controller is running: {data}') + return True + + except Exception as e: + _eprint(f'Request exception encountered: {e}.') + return False + def rounds(n_rounds=3): assert (_retry(_test_rounds, n_rounds=n_rounds)) @@ -79,6 +93,9 @@ def combiners(n_combiners=1): def reducer(): assert (_retry(_test_nodes, n_nodes=1, node_type='reducer')) +def controller(): + assert (_retry(_test_controller)) + if __name__ == '__main__': fire.Fire() diff --git a/.github/workflows/build-containers.yaml b/.github/workflows/build-containers.yaml index 31705cb6e..872b425e9 100644 --- a/.github/workflows/build-containers.yaml +++ b/.github/workflows/build-containers.yaml @@ -1,22 +1,26 @@ -name: "build containers" +name: build containers on: workflow_dispatch: + inputs: + custom_tag: + description: "Extra image tag to add (e.g. torchimage)" + required: false + default: "" + extra_pip_packages: + description: "Space-separated pip packages to add (e.g. numpy==1.26.4 uvicorn)" + required: false + default: "" push: - branches: - - master - - develop + branches: [ master, develop ] pull_request: - branches: - - develop - - master + branches: [ develop, master ] release: types: [published] jobs: build-containers: runs-on: ubuntu-latest - permissions: packages: write contents: read @@ -43,6 +47,7 @@ jobs: type=semver,pattern={{version}} type=semver,pattern={{major}}.{{minor}} type=sha + type=raw,value=${{ inputs.custom_tag }},enable=${{ inputs.custom_tag != '' }} - name: Log in to GitHub Container Registry uses: docker/login-action@v3 @@ -50,7 +55,6 @@ jobs: registry: ghcr.io username: ${{ github.actor }} password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push uses: docker/build-push-action@v6 @@ -60,3 +64,5 @@ jobs: tags: ${{ steps.meta1.outputs.tags }} labels: ${{ steps.meta1.outputs.labels }} file: Dockerfile + build-args: | + EXTRA_PIP_PACKAGES=${{ inputs.extra_pip_packages }} diff --git a/.github/workflows/code-checks.yaml b/.github/workflows/code-checks.yaml index b0e51d408..101df0be9 100644 --- a/.github/workflows/code-checks.yaml +++ b/.github/workflows/code-checks.yaml @@ -26,7 +26,9 @@ jobs: --exclude-dir='flower-client' --exclude='tests.py' --exclude='controller_cmd.py' + --exclude='api_server_cmd.py' --exclude='fedn_pb2_grpc.py' + --exclude='fedn_pb2.pyi' --exclude='combiner_cmd.py' --exclude='run_cmd.py' --exclude='README.rst' diff --git a/.github/workflows/integration-tests.yaml b/.github/workflows/integration-tests.yaml index b0274da67..a7d08b1d6 100644 --- a/.github/workflows/integration-tests.yaml +++ b/.github/workflows/integration-tests.yaml @@ -15,6 +15,7 @@ on: - '.github/**' branches: - "**" + workflow_dispatch: jobs: integration-tests: @@ -50,7 +51,15 @@ jobs: - name: print logs combiner if: failure() run: .ci/tests/examples/print_logs.sh combiner ${{ matrix.to_test }} + + - name: print logs controller + if: failure() + run: .ci/tests/examples/print_logs.sh controller ${{ matrix.to_test }} + - name: print logs hooks + if: failure() + run: .ci/tests/examples/print_logs.sh hooks ${{ matrix.to_test }} + - name: print logs client if: failure() run: .ci/tests/examples/print_logs.sh client ${{ matrix.to_test }} @@ -62,3 +71,5 @@ jobs: - name: print logs minio if: failure() run: .ci/tests/examples/print_logs.sh minio ${{ matrix.to_test }} + + diff --git a/.gitignore b/.gitignore index e75a5b9e9..7a15ef130 100644 --- a/.gitignore +++ b/.gitignore @@ -175,8 +175,8 @@ config/settings-combiner.yaml config/extra-hosts-client.yaml config/extra-hosts-reducer.yaml config/settings-client.yaml -config/settings-reducer.yaml -config/settings-combiner.yaml +config/settings-api-server.yaml +config/settings-controller.yaml config/settings-hooks.yaml ./tmp/* diff --git a/.readthedocs.yaml b/.readthedocs.yaml index f73590822..1ea29168a 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -7,11 +7,16 @@ build: jobs: pre_build: - sphinx-apidoc --ext-autodoc --module-first -o docs fedn ./*tests* ./fedn/cli* ./fedn/common* ./fedn/network/api/v1* ./fedn/network/grpc/fedn_pb2.py ./fedn/network/grpc/fedn_pb2_grpc.py ./fedn/network/api/server.py ./fedn/network/controller/controlbase.py + sphinx: configuration: docs/conf.py + fail_on_warning: false python: install: - method: pip path: . - requirements: docs/requirements.txt + +# Ensure documentation is publicly accessible +# Make sure this is not set to private in RTD settings diff --git a/Dockerfile b/Dockerfile index 7f031cad3..55cebaca0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,7 +5,7 @@ FROM $BASE_IMG AS builder ARG GRPC_HEALTH_PROBE_VERSION="" ARG REQUIREMENTS="" -ARG INSTALL_TORCH=0 +ARG EXTRA_PIP_PACKAGES="" WORKDIR /build @@ -36,9 +36,12 @@ RUN python -m venv /venv \ fi \ && rm -rf /build/requirements.txt -# only install torch when asked -RUN if [ "$INSTALL_TORCH" = "1" ]; then /venv/bin/pip install torch; fi - +RUN if [ -n "$EXTRA_PIP_PACKAGES" ]; then \ + echo "Installing extra pip packages: $EXTRA_PIP_PACKAGES" && \ + /venv/bin/pip install --no-cache-dir $EXTRA_PIP_PACKAGES; \ + else \ + echo "No EXTRA_PIP_PACKAGES provided"; \ + fi # Install grpc health probe RUN if [ ! -z "$GRPC_HEALTH_PROBE_VERSION" ]; then \ @@ -79,4 +82,3 @@ RUN set -ex \ USER appuser ENTRYPOINT [ "/venv/bin/fedn" ] - diff --git a/Dockerfile.dev b/Dockerfile.dev index b651dbea4..49182a908 100644 --- a/Dockerfile.dev +++ b/Dockerfile.dev @@ -12,7 +12,8 @@ COPY . /app COPY config/settings-client.yaml.template /app/config/settings-client.yaml COPY config/settings-combiner.yaml.template /app/config/settings-combiner.yaml COPY config/settings-hooks.yaml.template /app/config/settings-hooks.yaml -COPY config/settings-reducer.yaml.template /app/config/settings-reducer.yaml +COPY config/settings-api-server.yaml.template /app/config/settings-api-server.yaml +COPY config/settings-controller.yaml.template /app/config/settings-controller.yaml COPY $REQUIREMENTS /app/config/requirements.txt # Install developer tools (needed for psutil) diff --git a/config/reducer-settings.override.yaml b/config/reducer-settings.override.yaml index 18e499f73..d5124c125 100644 --- a/config/reducer-settings.override.yaml +++ b/config/reducer-settings.override.yaml @@ -6,4 +6,4 @@ services: reducer: volumes: - ${HOST_REPO_DIR:-.}:/app - - ${HOST_REPO_DIR:-.}/config/settings-reducer.yaml:/app/config/settings-reducer.yaml + - ${HOST_REPO_DIR:-.}/config/settings-api-server.yaml:/app/config/settings-api-server.yaml diff --git a/config/settings-reducer.yaml.template b/config/settings-api-server.yaml.template similarity index 94% rename from config/settings-reducer.yaml.template rename to config/settings-api-server.yaml.template index 62adb5aa0..e3aa0328c 100644 --- a/config/settings-reducer.yaml.template +++ b/config/settings-api-server.yaml.template @@ -1,9 +1,13 @@ network_id: fedn-network -controller: +api: host: api-server port: 8092 debug: True +controller: + host: controller + port: 12090 + statestore: # Available DB types are MongoDB, PostgreSQL, SQLite type: MongoDB diff --git a/config/settings-controller.yaml.local.template b/config/settings-controller.yaml.local.template index 6ed47c01b..9cdd8992e 100644 --- a/config/settings-controller.yaml.local.template +++ b/config/settings-controller.yaml.local.template @@ -1,5 +1,5 @@ network_id: fedn-network -controller: +api: host: localhost port: 8092 debug: True diff --git a/config/settings-controller.yaml.template b/config/settings-controller.yaml.template new file mode 100644 index 000000000..d9d1d2df0 --- /dev/null +++ b/config/settings-controller.yaml.template @@ -0,0 +1,35 @@ +network_id: fedn-network + +api: + host: api-server + port: 8092 + +controller: + host: controller + port: 12090 + debug: True + +statestore: + # Available DB types are MongoDB, PostgreSQL, SQLite + type: MongoDB + mongo_config: + username: fedn_admin + password: password + host: mongo + port: 6534 + postgres_config: + username: fedn_admin + password: password + host: fedn_postgres + port: 5432 + +storage: + storage_type: BOTO3 + storage_config: + storage_endpoint_url: http://minio:9000 + storage_access_key: fedn_admin + storage_secret_key: password + storage_bucket: fedn-models + context_bucket: fedn-context + storage_secure_mode: False + storage_verify_ssl: False diff --git a/docker-compose.dev.yaml b/docker-compose.dev.yaml index 0a8425ce8..da0ea6de5 100644 --- a/docker-compose.dev.yaml +++ b/docker-compose.dev.yaml @@ -67,8 +67,8 @@ services: - USER=test - PROJECT=project - FLASK_DEBUG=1 - - STATESTORE_CONFIG=/app/config/settings-reducer.yaml.template - - MODELSTORAGE_CONFIG=/app/config/settings-reducer.yaml.template + - STATESTORE_CONFIG=/app/config/settings-api-server.yaml.template + - MODELSTORAGE_CONFIG=/app/config/settings-api-server.yaml.template - FEDN_COMPUTE_PACKAGE_DIR=/app - TMPDIR=/app/tmp build: @@ -84,11 +84,41 @@ services: - mongo - fedn_postgres command: - - controller + - api-server - start ports: - 8092:8092 + controller: + environment: + - PYTHONUNBUFFERED=0 + - GET_HOSTS_FROM=dns + - STATESTORE_CONFIG=/app/config/settings-controller.yaml.template + - MODELSTORAGE_CONFIG=/app/config/settings-controller.yaml.template + - TMPDIR=/app/tmp + build: + context: . + args: + BASE_IMG: ${BASE_IMG:-python:3.12-slim} + working_dir: /app + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + healthcheck: + test: [ "CMD", "/app/grpc_health_probe", "-addr=localhost:12090" ] + interval: 20s + timeout: 10s + retries: 5 + depends_on: + - minio + - mongo + command: + - controller + - start + - --init + - config/settings-controller.yaml.template + ports: + - 12090:12090 + # Combiner combiner: environment: diff --git a/docker-compose.yaml b/docker-compose.yaml index 598aad0cb..392e605bc 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -58,8 +58,8 @@ services: - USER=test - PROJECT=project - FLASK_DEBUG=1 - - STATESTORE_CONFIG=/app/config/settings-reducer.yaml.template - - MODELSTORAGE_CONFIG=/app/config/settings-reducer.yaml.template + - STATESTORE_CONFIG=/app/config/settings-api-server.yaml.template + - MODELSTORAGE_CONFIG=/app/config/settings-api-server.yaml.template - FEDN_COMPUTE_PACKAGE_DIR=/app - TMPDIR=/app/tmp build: @@ -73,11 +73,43 @@ services: - minio - mongo command: - - controller + - api-server - start ports: - 8092:8092 + controller: + environment: + - PYTHONUNBUFFERED=0 + - GET_HOSTS_FROM=dns + - STATESTORE_CONFIG=/app/config/settings-controller.yaml.template + - MODELSTORAGE_CONFIG=/app/config/settings-controller.yaml.template + - TMPDIR=/app/tmp + build: + context: . + args: + BASE_IMG: ${BASE_IMG:-python:3.12-slim} + working_dir: /app + volumes: + - ${HOST_REPO_DIR:-.}/fedn:/app/fedn + healthcheck: + test: [ "CMD", "/app/grpc_health_probe", "-addr=localhost:12090" ] + interval: 20s + timeout: 10s + retries: 5 + depends_on: + - minio + - mongo + command: + - controller + - start + - --init + - config/settings-controller.yaml.template + ports: + - 12090:12090 + + + # Combiner combiner: environment: @@ -110,6 +142,8 @@ services: depends_on: - api-server - hooks + - controller + # Hooks hooks: container_name: hook diff --git a/docs/conf.py b/docs/conf.py index 7e4347029..f791a79e3 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -11,7 +11,7 @@ author = "Scaleout Systems AB" # The full version, including alpha/beta/rc tags -release = "0.30.0" +release = "0.33.0" # Add any Sphinx extension module names here, as strings extensions = [ @@ -28,6 +28,10 @@ "sphinx_copybutton", ] +# SEO configuration +html_title = "FEDn Documentation - Scalable Federated Learning Framework" +html_short_title = "FEDn Docs" + # The master toctree document. master_doc = "index" @@ -47,6 +51,14 @@ "logo_only": True, } +# SEO improvements +html_use_index = True +html_split_index = False + +# Allow search engines to index the documentation +# Remove any robots restrictions +html_extra_path = ["robots.txt"] + # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". @@ -67,6 +79,21 @@ "css/text.css", ] +html_js_files = [ + ( + "https://scripts.simpleanalyticscdn.com/sri/v11.js", + { + "async": "async", + "crossorigin": "anonymous", + "integrity": ( + "sha256-hkUzQr3zWmSDnmhw95ZmQSZ949upqD+ML9ejiN0UIIE= " + "sha384-rfv15RJy1bBYZ1Mf4xizO26jorXb2myipCvHXy4rkG0SuEET96S+m0sTzu5vfbSI " + "sha512-lQzjzTbOxHLwkZGDVMf4V0sm8v2Mrqm73IvKcXBftJ/MSZKQC4/jwKFToxT+3IVAVWQzLplSNHH8gM5d7b1BSg==" + ), + }, + ), +] + # LaTeX elements latex_elements = { # The paper size ('letterpaper' or 'a4paper'). diff --git a/docs/index.rst b/docs/index.rst index da273e2ba..397634145 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -51,5 +51,13 @@ Indices and tables .. meta:: :description lang=en: - FEDn is a federated learning platform that is secure, scalable, and easy-to-use. - :keywords: Federated Learning, Machine Learning, Federated Learning Framework, Federated Learning Platform, FEDn, Scaleout Systems + FEDn is a framework for scalable federated learning. Deploy secure, distributed machine learning models efficiently in production environments with comprehensive documentation, tutorials, and API references. + :keywords: Federated Learning, Machine Learning, Federated Learning Framework, Federated Learning Platform, FEDn, Scaleout Systems, Distributed Learning, Privacy-Preserving ML + :robots: index, follow + :author: Scaleout Systems AB + :og:title: FEDn Documentation - Scalable Federated Learning Framework + :og:description: Complete documentation for FEDn, the federated learning platform. Learn architecture, setup, deployment, and API usage. + :og:type: website + :twitter:card: summary_large_image + :twitter:title: FEDn Documentation - Scalable Federated Learning + :twitter:description: Framework for scalable federated learning with comprehensive docs and examples. diff --git a/docs/robots.txt b/docs/robots.txt new file mode 100644 index 000000000..0f337b6a2 --- /dev/null +++ b/docs/robots.txt @@ -0,0 +1,5 @@ +User-agent: * +Allow: / + +# Allow search engines to index all documentation +Sitemap: https://docs.scaleoutsystems.com/sitemap.xml diff --git a/examples/importer-client/.gitignore b/examples/importer-client/.gitignore new file mode 100644 index 000000000..ed9c07203 --- /dev/null +++ b/examples/importer-client/.gitignore @@ -0,0 +1,6 @@ +data +*.npz +*.tgz +*.tar.gz +.mnist-pytorch +client*.yaml \ No newline at end of file diff --git a/examples/importer-client/README.rst b/examples/importer-client/README.rst new file mode 100644 index 000000000..03f62d7a7 --- /dev/null +++ b/examples/importer-client/README.rst @@ -0,0 +1,47 @@ +FEDn Project: Importer Client +----------------------------- + +This is an example FEDn Project on how to design a client that imports client training code rather than running it in a separate process. +This enables the user to have access to the grpc channel to send information to thecontroller during training. + + **Note: We recommend that all new users start by taking the Quickstart Tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html** + +Prerequisites +------------- + +- `Python >=3.9, <=3.12 `__ +- `A project in FEDn Studio `__ + +Creating the compute package and seed model +------------------------------------------- + +Install fedn: + +.. code-block:: + + pip install fedn + +Clone this repository, then locate into this directory: + +.. code-block:: + + git clone https://github.com/scaleoutsystems/fedn.git + cd fedn/examples/importer-client + +Create the compute package: + +.. code-block:: + + fedn package create --path client + +This creates a file 'package.tgz' in the project folder. + +Running the project on FEDn +---------------------------- + +.. code-block:: + + fedn client start --importer --init client.yaml + + +To learn how to set up your FEDn Studio project and connect clients, take the quickstart tutorial: https://fedn.readthedocs.io/en/stable/quickstart.html. diff --git a/examples/importer-client/client/build.py b/examples/importer-client/client/build.py new file mode 100644 index 000000000..82f02a9d6 --- /dev/null +++ b/examples/importer-client/client/build.py @@ -0,0 +1,7 @@ +def main(): + print("Hello World!") + # Do the build stuff usually creating the seed.npz + + +if __name__ == "__main__": + main() diff --git a/examples/importer-client/client/fedn.yaml b/examples/importer-client/client/fedn.yaml new file mode 100644 index 000000000..5e56039dc --- /dev/null +++ b/examples/importer-client/client/fedn.yaml @@ -0,0 +1,7 @@ +# Remove the python_env tag below to handle the environment manually +python_env: python_env.yaml + +entry_points: + build: + command: python build.py + startup: startup.py \ No newline at end of file diff --git a/examples/importer-client/client/python_env.yaml b/examples/importer-client/client/python_env.yaml new file mode 100644 index 000000000..0218913bb --- /dev/null +++ b/examples/importer-client/client/python_env.yaml @@ -0,0 +1,14 @@ +name: .mnist-pytorch +build_dependencies: + - pip + - setuptools + - wheel +dependencies: + - fedn + - torch==2.4.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win32" or sys_platform == "win64" or sys_platform == "linux") + # PyTorch macOS x86 builds deprecation + - torch==2.2.2; sys_platform == "darwin" and platform_machine == "x86_64" + - torchvision==0.19.1; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win32" or sys_platform == "win64" or sys_platform == "linux") + - torchvision==0.17.2; sys_platform == "darwin" and platform_machine == "x86_64" + - numpy==2.0.2; (sys_platform == "darwin" and platform_machine == "arm64") or (sys_platform == "win32" or sys_platform == "win64" or sys_platform == "linux") + - numpy==1.26.4; (sys_platform == "darwin" and platform_machine == "x86_64") diff --git a/examples/importer-client/client/startup.py b/examples/importer-client/client/startup.py new file mode 100644 index 000000000..f75cfdc62 --- /dev/null +++ b/examples/importer-client/client/startup.py @@ -0,0 +1,38 @@ +from fedn.network.clients.fedn_client import FednClient + + +def startup(client: FednClient): + MyClient(client) + + +class MyClient: + def __init__(self, client: FednClient): + self.client = client + client.set_train_callback(self.train) + client.set_validate_callback(self.validate) + client.set_predict_callback(self.predict) + + def train(self, model_params, settings): + """Train the model with the given parameters and settings.""" + # Implement training logic here + print("Training with model parameters:", model_params) + iterations = 100 + for i in iterations: + # Do training + if i % 10 == 0: + self.client.log_metric({"training_loss": 0.1, "training_accuracy": 0.9}) + # Regularly check if the task has been aborted + self.client.check_task_abort() # Throws an exception if the task has been aborted + return model_params, {"training_metadata": {"num_examples": 1}} + + def validate(self, model_params): + """Validate the model with the given parameters.""" + # Implement validation logic here + print("Validating with model parameters:", model_params) + return {"validation_accuracy": 0.95} + + def predict(self, model_params, data): + """Make predictions with the model using the given parameters and data.""" + # Implement prediction logic here + print("Predicting with model parameters:", model_params, "and data:", data) + return {"predictions": [1, 0, 1]} # Example predictions diff --git a/examples/mnist-pytorch/.gitignore b/examples/mnist-pytorch/.gitignore index a9f01054b..ed9c07203 100644 --- a/examples/mnist-pytorch/.gitignore +++ b/examples/mnist-pytorch/.gitignore @@ -3,4 +3,4 @@ data *.tgz *.tar.gz .mnist-pytorch -client.yaml \ No newline at end of file +client*.yaml \ No newline at end of file diff --git a/examples/pytorch-keyworddetection-api/data.py b/examples/pytorch-keyworddetection-api/data.py index ee3f2124e..3e5fbb6d5 100644 --- a/examples/pytorch-keyworddetection-api/data.py +++ b/examples/pytorch-keyworddetection-api/data.py @@ -55,6 +55,11 @@ def __len__(self) -> int: return self._end_idx - self._start_idx +def sc_collate_fn(batch: tuple[int, str, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: + ys, _, spectrogram, _ = zip(*batch) + return torch.tensor(ys, dtype=torch.long), torch.stack(spectrogram) + + class FedSCDataset(Dataset): """Dataset for the Federated Speech Commands dataset.""" @@ -241,11 +246,7 @@ def get_stats(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return torch.tensor(label_mean), spectrogram_mean[:, None], spectrogram_std[:, None] def get_collate_fn(self) -> callable: - def collate_fn(batch: tuple[int, str, torch.Tensor, torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor]: - ys, _, spectrogram, _ = zip(*batch) - return torch.tensor(ys, dtype=torch.long), torch.stack(spectrogram) - - return collate_fn + return sc_collate_fn def _get_spectogram_transform(self, n_mels: int, hop_length: int, sr: int, data_augmentation: bool = False) -> torch.nn.Sequential: if data_augmentation: @@ -266,7 +267,9 @@ def get_dataloaders( ) -> tuple[DataLoader, DataLoader, DataLoader]: """Get the dataloaders for the training, validation, and testing datasets.""" dataset_train = FedSCDataset(path, keywords, "training", dataset_split_idx, dataset_total_splits, data_augmentation=True) - dataloader_train = DataLoader(dataset=dataset_train, batch_size=batchsize_train, collate_fn=dataset_train.get_collate_fn(), shuffle=True, drop_last=True) + dataloader_train = DataLoader( + dataset=dataset_train, batch_size=batchsize_train, num_workers=2, collate_fn=dataset_train.get_collate_fn(), shuffle=True, drop_last=True + ) dataset_valid = FedSCDataset(path, keywords, "validation", dataset_split_idx, dataset_total_splits) dataloader_valid = DataLoader(dataset=dataset_valid, batch_size=batchsize_valid, collate_fn=dataset_valid.get_collate_fn(), shuffle=False, drop_last=False) diff --git a/examples/pytorch-keyworddetection-api/fedn_api.py b/examples/pytorch-keyworddetection-api/fedn_api.py index d21b1aae2..a06d18e85 100644 --- a/examples/pytorch-keyworddetection-api/fedn_api.py +++ b/examples/pytorch-keyworddetection-api/fedn_api.py @@ -40,7 +40,8 @@ def main(): if args.upload_seed: init_seedmodel() - api_client.set_active_model("seed.npz") + response = api_client.set_active_model("seed.npz") + print(response) elif args.start_session: # Depending on the computer hosting the clients this round_timeout might need to increase response = api_client.start_session(name="Training", round_timeout=1200) diff --git a/examples/pytorch-keyworddetection-api/sc_client.py b/examples/pytorch-keyworddetection-api/sc_client.py index 88f82a417..29bc2e2e0 100644 --- a/examples/pytorch-keyworddetection-api/sc_client.py +++ b/examples/pytorch-keyworddetection-api/sc_client.py @@ -44,6 +44,8 @@ def train(self, model_params, settings): for epoch in range(n_epochs): sc_model.train() for idx, (y_labels, x_spectrograms) in enumerate(dataloader_train): + self.check_task_abort() + optimizer.zero_grad() _, logits = sc_model(x_spectrograms) diff --git a/examples/server-functions/server_functions.py b/examples/server-functions/server_functions.py index aabfa9a25..cb2484f1a 100644 --- a/examples/server-functions/server_functions.py +++ b/examples/server-functions/server_functions.py @@ -3,6 +3,7 @@ # See allowed_imports for what packages you can use in this class. +# Class must be named ServerFunctions class ServerFunctions(ServerFunctionsBase): # toy example to highlight functionality of ServerFunctions. def __init__(self) -> None: diff --git a/examples/server-functions/sf_incremental_aggregation.py b/examples/server-functions/sf_incremental_aggregation.py index 56d37e78d..ced24abdd 100644 --- a/examples/server-functions/sf_incremental_aggregation.py +++ b/examples/server-functions/sf_incremental_aggregation.py @@ -5,6 +5,7 @@ # Example of fedavg using memory secure running aggregation with server functions. +# Class must be named ServerFunctions class ServerFunctions(ServerFunctionsBase): def __init__(self) -> None: self.global_model = None diff --git a/examples/splitlearning_diabetes/docker-compose.override.dev.yaml b/examples/splitlearning_diabetes/docker-compose.override.dev.yaml index bb1184b69..5f1849882 100644 --- a/examples/splitlearning_diabetes/docker-compose.override.dev.yaml +++ b/examples/splitlearning_diabetes/docker-compose.override.dev.yaml @@ -13,7 +13,7 @@ services: service: combiner build: args: - INSTALL_TORCH: "1" + EXTRA_PIP_PACKAGES: "torch" environment: <<: *defaults FEDN_LABELS_PATH: /app/data/clients/labels.pt diff --git a/examples/splitlearning_diabetes/docker-compose.override.yaml b/examples/splitlearning_diabetes/docker-compose.override.yaml index 0726df600..941e68f91 100644 --- a/examples/splitlearning_diabetes/docker-compose.override.yaml +++ b/examples/splitlearning_diabetes/docker-compose.override.yaml @@ -13,7 +13,7 @@ services: service: combiner build: args: - INSTALL_TORCH: "1" + EXTRA_PIP_PACKAGES: "torch" environment: <<: *defaults FEDN_LABELS_PATH: /app/data/clients/labels.pt diff --git a/fedn/__main__.py b/fedn/__main__.py new file mode 100644 index 000000000..4900740b2 --- /dev/null +++ b/fedn/__main__.py @@ -0,0 +1,4 @@ +from fedn.cli import main + +if __name__ == "__main__": + main() diff --git a/fedn/cli/__init__.py b/fedn/cli/__init__.py index f00bb351b..3281963ec 100644 --- a/fedn/cli/__init__.py +++ b/fedn/cli/__init__.py @@ -1,3 +1,4 @@ +from .api_server_cmd import api_server_cmd # noqa: F401 from .client_cmd import client_cmd # noqa: F401 from .combiner_cmd import combiner_cmd # noqa: F401 from .config_cmd import config_cmd # noqa: F401 diff --git a/fedn/cli/api_server_cmd.py b/fedn/cli/api_server_cmd.py new file mode 100644 index 000000000..a57afec70 --- /dev/null +++ b/fedn/cli/api_server_cmd.py @@ -0,0 +1,18 @@ +import click + +from fedn.cli.main import main + + +@main.group("api-server") +@click.pass_context +def api_server_cmd(ctx): + """:param ctx:""" + pass + + +@api_server_cmd.command("start") +@click.pass_context +def api_server_cmd(ctx): + from fedn.network.api.server import start_api_server # noqa: PLC0415 + + start_api_server() diff --git a/fedn/cli/client_cmd.py b/fedn/cli/client_cmd.py index c896b4b6e..6994099e4 100644 --- a/fedn/cli/client_cmd.py +++ b/fedn/cli/client_cmd.py @@ -10,8 +10,8 @@ from fedn.cli.shared import CONTROLLER_DEFAULTS, STUDIO_DEFAULTS, apply_config, get_context, get_response, print_response from fedn.common.exceptions import InvalidClientConfig from fedn.common.log_config import set_log_level_from_string -from fedn.network.clients.client_v2 import Client as ClientV2 -from fedn.network.clients.client_v2 import ClientOptions +from fedn.network.clients.dispatcher_client import ClientOptions, DispatcherClient +from fedn.network.clients.importer_client import ImporterClient home_dir = os.path.expanduser("~") @@ -200,6 +200,8 @@ def _complement_client_params(config: dict) -> None: @click.option("-tr", "--trainer", required=False, default=None) @click.option("-hp", "--helper_type", required=False, default=None) @click.option("-in", "--init", required=False, default=None, help="Set to a filename to (re)init client from file state.") +@click.option("--importer", is_flag=True, help="Use the importer client instead of the dispatcher client.") +@click.option("--manual_env", is_flag=True, help="Use the manual environment for the client. This will not use the managed environment.") @click.pass_context def client_start_v2_cmd( ctx, @@ -217,6 +219,8 @@ def client_start_v2_cmd( trainer: bool, helper_type: str, init: str, + importer: bool, + manual_env: bool = False, ): """Start client.""" package = "local" if local_package else "remote" @@ -329,15 +333,30 @@ def client_start_v2_cmd( preferred_combiner=config["preferred_combiner"], id=config["client_id"], ) - client = ClientV2( - api_url=config["api_url"], - api_port=config["api_port"], - client_obj=client_options, - combiner_host=config["combiner"], - combiner_port=config["combiner_port"], - token=config["token"], - package_checksum=config["package_checksum"], - helper_type=config["helper_type"], - ) + if importer: + click.echo("Using ImporterClient") + client = ImporterClient( + api_url=config["api_url"], + api_port=config["api_port"], + client_obj=client_options, + combiner_host=config["combiner"], + combiner_port=config["combiner_port"], + token=config["token"], + package_checksum=config["package_checksum"], + helper_type=config["helper_type"], + manual_env=manual_env, + ) + else: + click.echo("Using DispatcherClient") + client = DispatcherClient( + api_url=config["api_url"], + api_port=config["api_port"], + client_obj=client_options, + combiner_host=config["combiner"], + combiner_port=config["combiner_port"], + token=config["token"], + package_checksum=config["package_checksum"], + helper_type=config["helper_type"], + ) client.start() diff --git a/fedn/cli/controller_cmd.py b/fedn/cli/controller_cmd.py index 33cd08bca..1c89064b7 100644 --- a/fedn/cli/controller_cmd.py +++ b/fedn/cli/controller_cmd.py @@ -1,6 +1,11 @@ import click from fedn.cli.main import main +from fedn.cli.shared import apply_config +from fedn.common.config import get_modelstorage_config, get_network_config, get_statestore_config +from fedn.network.controller.control import Control +from fedn.network.storage.dbconnection import DatabaseConnection +from fedn.network.storage.s3.repository import Repository @main.group("controller") @@ -11,8 +16,27 @@ def controller_cmd(ctx): @controller_cmd.command("start") +@click.option("-h", "--host", required=False, default="controller", help="Set hostname.") +@click.option("-i", "--port", required=False, default=12090, help="Set port.") +@click.option("-s", "--secure", is_flag=True, help="Enable SSL/TLS encrypted gRPC channels.") +@click.option("-in", "--init", required=False, default=None, help="Path to configuration file.") @click.pass_context -def controller_cmd(ctx): - from fedn.network.api.server import start_server_api # noqa: PLC0415 +def controller_cmd(ctx, host, port, secure, init): + config = { + "host": host, + "port": port, + "secure": secure, + } + + if init: + apply_config(init, config) + click.echo(f"\nController configuration loaded from file: {init}") + + network_id = get_network_config() + modelstorage_config = get_modelstorage_config() + statestore_config = get_statestore_config() - start_server_api() + db = DatabaseConnection(statestore_config, network_id) + repository = Repository(modelstorage_config["storage_config"], storage_type=modelstorage_config["storage_type"]) + controller = Control(config, network_id, repository, db) + controller.run() diff --git a/fedn/cli/run_cmd.py b/fedn/cli/run_cmd.py index 0c31721ca..0a8ef601b 100644 --- a/fedn/cli/run_cmd.py +++ b/fedn/cli/run_cmd.py @@ -11,7 +11,8 @@ from fedn.common.log_config import logger from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.repository import Repository -from fedn.utils.dispatcher import Dispatcher, _read_yaml_file +from fedn.utils.dispatcher import Dispatcher +from fedn.utils.yaml import read_yaml_file def get_statestore_config_from_file(init): @@ -46,8 +47,8 @@ def check_yaml_exists(path): return yaml_file -def delete_virtual_environment(dispatcher): - if dispatcher.python_env_path: +def delete_virtual_environment(dispatcher: Dispatcher): + if dispatcher.python_env_path and os.path.exists(dispatcher.python_env_path): logger.info(f"Removing virtualenv {dispatcher.python_env_path}") shutil.rmtree(dispatcher.python_env_path) else: @@ -77,14 +78,14 @@ def validate_cmd(ctx, path, input, output, keep_venv): path = os.path.abspath(path) yaml_file = check_yaml_exists(path) - config = _read_yaml_file(yaml_file) + config = read_yaml_file(yaml_file) # Check that validate is defined in fedn.yaml under entry_points if "validate" not in config["entry_points"]: logger.error("No validate command defined in fedn.yaml") exit(-1) dispatcher = Dispatcher(config, path) - _ = dispatcher._get_or_create_python_env() + _ = dispatcher.get_or_create_python_env() dispatcher.run_cmd("validate {} {}".format(input, output)) if not keep_venv: delete_virtual_environment(dispatcher) @@ -106,14 +107,14 @@ def train_cmd(ctx, path, input, output, keep_venv): path = os.path.abspath(path) yaml_file = check_yaml_exists(path) - config = _read_yaml_file(yaml_file) + config = read_yaml_file(yaml_file) # Check that train is defined in fedn.yaml under entry_points if "train" not in config["entry_points"]: logger.error("No train command defined in fedn.yaml") exit(-1) dispatcher = Dispatcher(config, path) - _ = dispatcher._get_or_create_python_env() + _ = dispatcher.get_or_create_python_env() dispatcher.run_cmd("train {} {}".format(input, output)) if not keep_venv: delete_virtual_environment(dispatcher) @@ -133,13 +134,13 @@ def startup_cmd(ctx, path, keep_venv): path = os.path.abspath(path) yaml_file = check_yaml_exists(path) - config = _read_yaml_file(yaml_file) + config = read_yaml_file(yaml_file) # Check that startup is defined in fedn.yaml under entry_points if "startup" not in config["entry_points"]: logger.error("No startup command defined in fedn.yaml") exit(-1) dispatcher = Dispatcher(config, path) - _ = dispatcher._get_or_create_python_env() + _ = dispatcher.get_or_create_python_env() dispatcher.run_cmd("startup") if not keep_venv: delete_virtual_environment(dispatcher) @@ -159,14 +160,14 @@ def build_cmd(ctx, path, keep_venv): path = os.path.abspath(path) yaml_file = check_yaml_exists(path) - config = _read_yaml_file(yaml_file) + config = read_yaml_file(yaml_file) # Check that build is defined in fedn.yaml under entry_points if "build" not in config["entry_points"]: logger.error("No build command defined in fedn.yaml") exit(-1) dispatcher = Dispatcher(config, path) - _ = dispatcher._get_or_create_python_env() + _ = dispatcher.get_or_create_python_env() dispatcher.run_cmd("build") if not keep_venv: delete_virtual_environment(dispatcher) diff --git a/fedn/cli/tests/tests.py b/fedn/cli/tests/tests.py index 7f348678b..10ca55a5c 100644 --- a/fedn/cli/tests/tests.py +++ b/fedn/cli/tests/tests.py @@ -8,7 +8,7 @@ from run_cmd import run_cmd,check_yaml_exists,logger import click from main import main -from fedn.network.api.server import start_server_api +from fedn.network.api.server import start_api_server from controller_cmd import main, controller_cmd import tarfile from package_cmd import create_tar_with_ignore, create_cmd, package_cmd diff --git a/fedn/common/config.py b/fedn/common/config.py index 8bb347910..f74ebc4cc 100644 --- a/fedn/common/config.py +++ b/fedn/common/config.py @@ -18,6 +18,7 @@ FEDN_PACKAGE_EXTRACT_DIR = os.environ.get("FEDN_PACKAGE_EXTRACT_DIR", "package") FEDN_COMPUTE_PACKAGE_DIR = os.environ.get("FEDN_COMPUTE_PACKAGE_DIR", "/app/client/package/") +FEDN_ARCHIVE_DIR = os.environ.get("FEDN_ARCHIVE_DIR", ".fedn") FEDN_OBJECT_STORAGE_TYPE = os.environ.get("FEDN_OBJECT_STORAGE_TYPE", "BOTO3").upper() FEDN_OBJECT_MODEL_BUCKET = os.environ.get("FEDN_OBJECT_MODEL_BUCKET", "fedn-model") @@ -108,6 +109,25 @@ def get_network_config(file=None): return settings["network_id"] +def get_api_config(file=None): + """Get the api configuration from file. + + :param file: The api configuration file (yaml) path (optional). + :type file: str + :return: The api configuration as a dict. + :rtype: dict + """ + if file is None: + get_environment_config() + file = STATESTORE_CONFIG + with open(file, "r") as config_file: + try: + settings = dict(yaml.safe_load(config_file)) + except yaml.YAMLError as e: + raise (e) + return settings["api"] + + def get_controller_config(file=None): """Get the controller configuration from file. diff --git a/fedn/common/settings-controller.yaml.template b/fedn/common/settings-controller.yaml.template index 2a0c3a405..84f00eae7 100644 --- a/fedn/common/settings-controller.yaml.template +++ b/fedn/common/settings-controller.yaml.template @@ -1,9 +1,13 @@ network_id: fedn-network -controller: +api: host: localhost port: 8092 debug: True +controller: + host: localhost + port: 12090 + statestore: type: MongoDB mongo_config: diff --git a/fedn/genprot.sh b/fedn/genprot.sh index def170de1..37a53155c 100755 --- a/fedn/genprot.sh +++ b/fedn/genprot.sh @@ -1,4 +1,4 @@ #!/bin/bash echo "Generating protocol" -python3 -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. network/grpc/*.proto +python3 -m grpc_tools.protoc -I=. --python_out=. --grpc_python_out=. --mypy_out=. network/grpc/*.proto echo "DONE" diff --git a/fedn/network/api/client.py b/fedn/network/api/client.py index 951a45c24..2938ae773 100644 --- a/fedn/network/api/client.py +++ b/fedn/network/api/client.py @@ -951,13 +951,13 @@ def get_validations( _params = {} if session_id: - _params["sessionId"] = session_id + _params["session_id"] = session_id if model_id: - _params["modelId"] = model_id + _params["model_id"] = model_id if correlation_id: - _params["correlationId"] = correlation_id + _params["correlation_id"] = correlation_id if sender_name: _params["sender.name"] = sender_name @@ -1028,10 +1028,10 @@ def get_predictions( _params = {} if model_id: - _params["modelId"] = model_id + _params["model_id"] = model_id if correlation_id: - _params["correlationId"] = correlation_id + _params["correlation_id"] = correlation_id if sender_name: _params["sender.name"] = sender_name @@ -1118,3 +1118,40 @@ def add_attributes(self, attribute: dict) -> dict: response = requests.post(url, json=attribute, headers=self.headers, verify=self.verify) response.raise_for_status() return response.json() + + ### Control Functions ### + def step_current_session(self): + """Continue a session control. + + :param session_id: The id of the session to continue. + :type session_id: str + :return: A dict with success or failure message. + :rtype: dict + """ + response = requests.post( + self._get_url_api_v1("control/continue"), + verify=self.verify, + headers=self.headers, + ) + + _json = response.json() + + return _json + + def stop_current_session(self): + """Stop a session control. + + :param session_id: The id of the session to stop. + :type session_id: str + :return: A dict with success or failure message. + :rtype: dict + """ + response = requests.post( + self._get_url_api_v1("control/stop"), + verify=self.verify, + headers=self.headers, + ) + + _json = response.json() + + return _json diff --git a/fedn/network/api/gunicorn_app.py b/fedn/network/api/gunicorn_app.py index bd9ce2d16..e484ffb2b 100644 --- a/fedn/network/api/gunicorn_app.py +++ b/fedn/network/api/gunicorn_app.py @@ -1,7 +1,5 @@ from gunicorn.app.base import BaseApplication -from fedn.network.controller.control import Control - class GunicornApp(BaseApplication): def __init__(self, app, options=None): @@ -18,21 +16,12 @@ def load(self): return self.application -def post_fork(server, worker): - """Hook to be called after the worker has forked. - - This is where we can initialize the database connection for each worker. - """ - # Initialize the database connection - Control.instance().db.initialize_connection() - - -def run_gunicorn(app, host, port, workers=4): +def run_gunicorn(app, host, port, workers=4, post_fork_func=None): bind_address = f"{host}:{port}" options = { "bind": bind_address, # Specify the bind address and port here "workers": workers, - # After forking, initialize the database connection - "post_fork": post_fork, } + if post_fork_func is not None: + options["post_fork"] = post_fork_func GunicornApp(app, options).run() diff --git a/fedn/network/api/network.py b/fedn/network/api/network.py deleted file mode 100644 index 59884af10..000000000 --- a/fedn/network/api/network.py +++ /dev/null @@ -1,87 +0,0 @@ -import os -from typing import List - -from fedn.common.log_config import logger -from fedn.network.combiner.interfaces import CombinerInterface -from fedn.network.loadbalancer.leastpacked import LeastPacked -from fedn.network.storage.dbconnection import DatabaseConnection - -__all__ = ("Network",) - - -class Network: - """FEDn network interface. This class is used to interact with the network. - Note: This class contain redundant code, which is not used in the current version of FEDn. - Some methods has been moved to :class:`fedn.network.api.interface.API`. - """ - - def __init__(self, control, network_id: str, dbconn: DatabaseConnection, load_balancer=None): - """ """ - self.control = control - self.id = network_id - self.db = dbconn - - if not load_balancer: - self.load_balancer = LeastPacked(self) - else: - self.load_balancer = load_balancer - - def get_combiner(self, name): - """Get combiner by name. - - :param name: name of combiner - :type name: str - :return: The combiner instance object - :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` - """ - combiners = self.get_combiners() - for combiner in combiners: - if name == combiner.name: - return combiner - return None - - def get_combiners(self) -> List[CombinerInterface]: - """Get all combiners in the network. - - :return: list of combiners objects - :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) - """ - result = self.db.combiner_store.list(limit=0, skip=0, sort_key=None) - combiners = [] - for combiner in result: - name = combiner.name.upper() - # General certificate handling, same for all combiners. - if os.environ.get("FEDN_GRPC_CERT_PATH"): - with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f: - cert = f.read() - # Specific certificate handling for each combiner. - elif os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}"): - cert_path = os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}") - with open(cert_path, "rb") as f: - cert = f.read() - else: - cert = None - combiners.append( - CombinerInterface(combiner.parent, combiner.name, combiner.address, combiner.fqdn, combiner.port, certificate=cert, ip=combiner.ip) - ) - - return combiners - - def find_available_combiner(self): - """Find an available combiner in the network. - - :return: The combiner instance object - :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` - """ - combiner = self.load_balancer.find_combiner() - return combiner - - def handle_unavailable_combiner(self, combiner): - """This callback is triggered if a combiner is found to be unresponsive. - - :param combiner: The combiner instance object - :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` - :return: None - """ - # TODO: Implement strategy to handle an unavailable combiner. - logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name)) diff --git a/fedn/network/api/server.py b/fedn/network/api/server.py index 1f2bb634f..a0c18e917 100644 --- a/fedn/network/api/server.py +++ b/fedn/network/api/server.py @@ -2,13 +2,14 @@ from flask import Flask, jsonify, request -from fedn.common.config import get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config +from fedn.common.config import get_api_config, get_controller_config, get_modelstorage_config, get_network_config, get_statestore_config from fedn.network.api import gunicorn_app from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import ApplicationState, get_network from fedn.network.api.v1 import _routes from fedn.network.api.v1.graphql.schema import schema -from fedn.network.controller.control import Control -from fedn.network.state import ReducerStateToString +from fedn.network.common.network import Network +from fedn.network.common.state import ReducerStateToString from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.repository import Repository @@ -62,7 +63,7 @@ def get_controller_status(): return: The status as a json object. rtype: json """ - return jsonify({"state": ReducerStateToString(Control.instance().state())}), 200 + return jsonify({"state": ReducerStateToString(get_network().get_control_state())}), 200 if custom_url_prefix: @@ -481,8 +482,8 @@ def delete_model_trail(): app.add_url_rule(f"{custom_url_prefix}/delete_model_trail", view_func=delete_model_trail, methods=["GET", "POST"]) -def start_server_api(): - config = get_controller_config() +def start_api_server(): + config = get_api_config() port = config["port"] host = "0.0.0.0" debug = config["debug"] @@ -491,19 +492,27 @@ def start_server_api(): modelstorage_config = get_modelstorage_config() statestore_config = get_statestore_config() - # TODO: Initialize database with config instead of reading it under the hood - db = DatabaseConnection(statestore_config, network_id, connect=False) - repository = Repository(modelstorage_config["storage_config"], storage_type=modelstorage_config["storage_type"]) - Control.create_instance(network_id, repository, db) + controller = get_controller_config() + + def init_globals(): + """Initialize the database connection and repository""" + state = ApplicationState() + state.db = DatabaseConnection(statestore_config, network_id) + state.repository = Repository(modelstorage_config["storage_config"], storage_type=modelstorage_config["storage_type"]) + state.network = Network(state.db, state.repository, controller_host=controller["host"], controller_port=controller["port"]) if debug: # Without gunicorn, we can initialize the database connection here - db.initialize_connection() + init_globals() app.run(debug=debug, port=port, host=host) else: + + def post_fork(server, worker): + init_globals() + workers = os.cpu_count() - gunicorn_app.run_gunicorn(app, host, port, workers) + gunicorn_app.run_gunicorn(app, host, port, workers, post_fork_func=post_fork) if __name__ == "__main__": - start_server_api() + start_api_server() diff --git a/fedn/network/api/shared.py b/fedn/network/api/shared.py index 1a7afc156..3e540d5e5 100644 --- a/fedn/network/api/shared.py +++ b/fedn/network/api/shared.py @@ -5,10 +5,44 @@ from werkzeug.security import safe_join -from fedn.network.controller.control import Control +from fedn.network.common.network import Network +from fedn.network.storage.dbconnection import DatabaseConnection +from fedn.network.storage.s3.repository import Repository from fedn.utils.checksum import sha +class ApplicationState: + """Global state to hold shared objects for the network API.""" + + _instance = None + + def _prepare(self): + self.db = None + self.repository = None + self.network = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super(ApplicationState, cls).__new__(cls) + cls._instance._prepare() + return cls._instance + + +def get_db() -> DatabaseConnection: + """Get the database connection.""" + return ApplicationState().db + + +def get_repository() -> Repository: + """Get the repository.""" + return ApplicationState().repository + + +def get_network() -> Network: + """Get the network interface.""" + return ApplicationState().network + + def get_checksum(name: str = None) -> Tuple[bool, str, str]: """Generate a checksum for a given file.""" message = None @@ -16,7 +50,7 @@ def get_checksum(name: str = None) -> Tuple[bool, str, str]: success = False if name is None: - db = Control.instance().db + db = get_db() active_package = db.package_store.get_active() if active_package is None: message = "No compute package uploaded" diff --git a/fedn/network/api/tests.py b/fedn/network/api/tests.py index 94df557d1..2529144ee 100644 --- a/fedn/network/api/tests.py +++ b/fedn/network/api/tests.py @@ -7,7 +7,8 @@ import unittest from unittest.mock import patch, MagicMock -from fedn.network.controller.control import Control +from fedn.network.api.shared import ApplicationState +from fedn.network.common.network import Network from fedn.network.storage.statestore.stores.dto.attribute import AttributeDTO from fedn.network.storage.statestore.stores.dto.metric import MetricDTO @@ -72,8 +73,10 @@ def setUp(self, mock_control): import fedn.network.api.server self.app = fedn.network.api.server.app.test_client() self.db = MockDB() - - Control.create_instance("test_network", None, self.db) + state = ApplicationState() + state.db = self.db # Set the global db variable to the mock db + state.network = Network(self.db, None) # Mock Network object + def test_health(self): @@ -90,12 +93,6 @@ def test_add_combiner(self): self.assertEqual(response.status_code, 410) - def test_get_controller_status(self): - """ Test get_models endpoint. """ - response = self.app.get('/get_controller_status') - # Assert response - self.assertEqual(response.status_code, 200) - def test_get_single_endpoints(self): """ Test get single endpoints. """ expected_return_id = "test" diff --git a/fedn/network/api/v1/__init__.py b/fedn/network/api/v1/__init__.py index dc4613df8..1ddf9e1de 100644 --- a/fedn/network/api/v1/__init__.py +++ b/fedn/network/api/v1/__init__.py @@ -1,6 +1,7 @@ from fedn.network.api.v1.attribute_routes import bp as attribute_bp from fedn.network.api.v1.client_routes import bp as client_bp from fedn.network.api.v1.combiner_routes import bp as combiner_bp +from fedn.network.api.v1.control_routes import bp as control_bp from fedn.network.api.v1.helper_routes import bp as helper_bp from fedn.network.api.v1.metric_routes import bp as metric_bp from fedn.network.api.v1.model_routes import bp as model_bp @@ -28,4 +29,5 @@ run_bp, telemetry_bp, attribute_bp, + control_bp, ] diff --git a/fedn/network/api/v1/attribute_routes.py b/fedn/network/api/v1/attribute_routes.py index f77095b53..e167628c9 100644 --- a/fedn/network/api/v1/attribute_routes.py +++ b/fedn/network/api/v1/attribute_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control from fedn.network.storage.statestore.stores.dto.attribute import AttributeDTO from fedn.network.storage.statestore.stores.shared import MissingFieldError, ValidationError @@ -14,7 +14,7 @@ @jwt_auth_required(role="admin") def get_attributes(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -32,7 +32,7 @@ def get_attributes(): @jwt_auth_required(role="admin") def list_attributes(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -50,7 +50,7 @@ def list_attributes(): @jwt_auth_required(role="admin") def get_attributes_count(): try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.attribute_store.count(**kwargs) response = count @@ -64,7 +64,7 @@ def get_attributes_count(): @jwt_auth_required(role="admin") def attributes_count(): try: - db = Control.instance().db + db = get_db() kwargs = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() count = db.attribute_store.count(**kwargs) response = count @@ -78,7 +78,7 @@ def attributes_count(): @jwt_auth_required(role="admin") def get_attribute(id: str): try: - db = Control.instance().db + db = get_db() attribute = db.attribute_store.get(id) if attribute is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -94,7 +94,7 @@ def get_attribute(id: str): @jwt_auth_required(role="admin") def add_attributes(): try: - db = Control.instance().db + db = get_db() data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() attribute = AttributeDTO().patch_with(data) @@ -158,7 +158,7 @@ def get_client_current_attributes(): type: string """ try: - db = Control.instance().db + db = get_db() json_data = request.get_json() client_ids = json_data.get("client_ids") if not client_ids: diff --git a/fedn/network/api/v1/client_routes.py b/fedn/network/api/v1/client_routes.py index 8dcc149f7..db7bf260c 100644 --- a/fedn/network/api/v1/client_routes.py +++ b/fedn/network/api/v1/client_routes.py @@ -2,12 +2,11 @@ from flask import Blueprint, jsonify, request -from fedn.common.config import get_controller_config, get_network_config +from fedn.common.config import get_api_config, get_network_config from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required -from fedn.network.api.shared import get_checksum +from fedn.network.api.shared import get_checksum, get_db, get_network from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control from fedn.network.storage.statestore.stores.dto import ClientDTO from fedn.network.storage.statestore.stores.shared import MissingFieldError, ValidationError @@ -115,7 +114,7 @@ def get_clients(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -201,7 +200,7 @@ def list_clients(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -268,7 +267,7 @@ def get_clients_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.client_store.count(**kwargs) response = count @@ -323,7 +322,7 @@ def clients_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.client_store.count(**kwargs) response = count @@ -368,7 +367,7 @@ def get_client(id: str): type: string """ try: - db = Control.instance().db + db = get_db() client = db.client_store.get(id) if client is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -412,7 +411,7 @@ def delete_client(id: str): type: string """ try: - db = Control.instance().db + db = get_db() result: bool = db.client_store.delete(id) if result is False: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -464,8 +463,8 @@ def add_client(): type: string """ try: - db = Control.instance().db - network = Control.instance().network + db = get_db() + network = get_network() json_data = request.get_json() remote_addr = request.remote_addr @@ -504,10 +503,11 @@ def add_client(): if combiner is None: return jsonify({"success": False, "message": "No combiner available."}), 400 - if db.client_store.get(client_id) is None: - logger.info("Adding client {}".format(client_id)) + existing_client = db.client_store.get(client_id) + last_seen = datetime.now() - last_seen = datetime.now() + if existing_client is None: + logger.info("Adding client {}".format(client_id)) new_client = ClientDTO( client_id=client_id, @@ -520,8 +520,27 @@ def add_client(): last_seen=last_seen, ) - added_client = db.client_store.add(new_client) - client_id = added_client.client_id + try: + added_client = db.client_store.add(new_client) + client_id = added_client.client_id + except Exception as e: + logger.error(f"Failed to add new client: {e}") + return jsonify({"success": False, "message": "Failed to add new client"}), 500 + else: + logger.info("Client {} already exists, updating client object".format(client_id)) + existing_client.name = name + existing_client.combiner = combiner.name + existing_client.combiner_preferred = preferred_combiner + existing_client.ip = remote_addr + existing_client.status = "available" + existing_client.package = package + existing_client.last_seen = last_seen + + try: + db.client_store.update(existing_client) + except Exception as e: + logger.error(f"Failed to update existing client: {e}") + return jsonify({"success": False, "message": "Failed to update existing client"}), 500 payload = { "status": "assigned", @@ -580,7 +599,7 @@ def get_client_config(): checksum_arg = request.args.get("checksum", "true") include_checksum = checksum_arg.lower() == "true" - config = get_controller_config() + config = get_api_config() network_id = get_network_config() port = config["port"] host = config["host"] @@ -643,7 +662,7 @@ def get_client_attributes(id): type: string """ try: - db = Control.instance().db + db = get_db() client = db.client_store.get(id) if client is None: diff --git a/fedn/network/api/v1/combiner_routes.py b/fedn/network/api/v1/combiner_routes.py index be3631ed4..6bad802aa 100644 --- a/fedn/network/api/v1/combiner_routes.py +++ b/fedn/network/api/v1/combiner_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("combiner", __name__, url_prefix=f"/api/{api_version}/combiners") @@ -100,7 +100,7 @@ def get_combiners(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -183,7 +183,7 @@ def list_combiners(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -237,7 +237,7 @@ def get_combiners_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.combiner_store.count(**kwargs) response = count @@ -288,7 +288,7 @@ def combiners_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.combiner_store.count(**kwargs) response = count @@ -333,7 +333,7 @@ def get_combiner(id: str): type: string """ try: - db = Control.instance().db + db = get_db() combiner = db.combiner_store.get(id) if combiner is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -376,7 +376,7 @@ def delete_combiner(id: str): type: string """ try: - db = Control.instance().db + db = get_db() result: bool = db.combiner_store.delete(id) if not result: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -420,7 +420,7 @@ def number_of_clients_connected(): type: string """ try: - db = Control.instance().db + db = get_db() data = request.get_json() combiners = data.get("combiners", "") combiners = combiners.split(",") if combiners else [] diff --git a/fedn/network/api/v1/control_routes.py b/fedn/network/api/v1/control_routes.py new file mode 100644 index 000000000..be59d0931 --- /dev/null +++ b/fedn/network/api/v1/control_routes.py @@ -0,0 +1,41 @@ +from flask import Blueprint, jsonify + +from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_network +from fedn.network.api.v1.session_routes import start_session +from fedn.network.api.v1.shared import api_version +from fedn.network.grpc.fedn_pb2 import Command + +bp = Blueprint("control", __name__, url_prefix=f"/api/{api_version}/control") + + +@bp.route("/start_session", methods=["POST"]) +@jwt_auth_required(role="admin") +def control_start_session(): + """Start a new session. + + This endpoint is identical to the one in the `session_routes.py` file. + """ + start_session() + + +@bp.route("/continue", methods=["POST"]) +@jwt_auth_required(role="admin") +def send_continue_signal(): + """Send a continue signal to the controller.""" + network = get_network() + + control = network.get_control() + control.send_command(Command.CONTINUE) + + return jsonify({"message": "Sent continue signal"}), 200 + + +@bp.route("/stop", methods=["POST"]) +@jwt_auth_required(role="admin") +def send_stop_signal(): + network = get_network() + control = network.get_control() + control.send_command(Command.STOP) + + return jsonify({"message": "Sent stop signal"}), 200 diff --git a/fedn/network/api/v1/graphql/schema.py b/fedn/network/api/v1/graphql/schema.py index 0a0bd3329..92d8de2ff 100644 --- a/fedn/network/api/v1/graphql/schema.py +++ b/fedn/network/api/v1/graphql/schema.py @@ -1,6 +1,6 @@ import graphene -from fedn.network.controller.control import Control +from fedn.network.api.shared import get_db from fedn.network.storage.statestore.stores.shared import SortOrder @@ -31,7 +31,7 @@ class StatusType(graphene.ObjectType): session = graphene.Field(lambda: SessionType) def resolve_session(self, info): - db = Control.instance().db + db = get_db() session = db.session_store.get(self["session_id"]) if session: return session.to_dict() @@ -56,7 +56,7 @@ def resolve_sender(self, info): return self["sender"] def resolve_session(self, info): - db = Control.instance().db + db = get_db() session = db.session_store.get(self["session_id"]) if session: return session.to_dict() @@ -79,7 +79,7 @@ class ModelType(graphene.ObjectType): session = graphene.Field(lambda: SessionType) def resolve_validations(self, info, limit=0, skip=0, sort_key="committed_at", sort_order="desc"): - db = Control.instance().db + db = get_db() kwargs = {"model_id": self["model_id"]} sort_order = get_sort_order_from_string(sort_order) @@ -88,7 +88,7 @@ def resolve_validations(self, info, limit=0, skip=0, sort_key="committed_at", so return result def resolve_session(self, info): - db = Control.instance().db + db = get_db() session = db.session_store.get(self["session_id"]) if session: return session.to_dict() @@ -139,7 +139,7 @@ def resolve_session_config(self, info): return self["session_config"] def resolve_models(self, info, limit=0, skip=0, sort_key="committed_at", sort_order="desc"): - db = Control.instance().db + db = get_db() kwargs = {"session_id": self["session_id"]} sort_order = get_sort_order_from_string(sort_order) @@ -150,7 +150,7 @@ def resolve_models(self, info, limit=0, skip=0, sort_key="committed_at", sort_or return result def resolve_validations(self, info, limit=0, skip=0, sort_key="committed_at", sort_order="desc"): - db = Control.instance().db + db = get_db() kwargs = {"session_id": self["session_id"]} sort_order = get_sort_order_from_string(sort_order) @@ -161,7 +161,7 @@ def resolve_validations(self, info, limit=0, skip=0, sort_key="committed_at", so return result def resolve_statuses(self, info, limit=0, skip=0, sort_key="committed_at", sort_order="desc"): - db = Control.instance().db + db = get_db() kwargs = {"session_id": self["session_id"]} sort_order = get_sort_order_from_string(sort_order) @@ -229,13 +229,13 @@ class Query(graphene.ObjectType): ) def resolve_session(root, info, id: str = None): - db = Control.instance().db + db = get_db() result = db.session_store.get(id) return result.to_dict() def resolve_sessions(root, info, name: str = None, limit: int = 25, skip: int = 0, sort_key: str = "committed_at", sort_order: str = "desc"): - db = Control.instance().db + db = get_db() if name: kwargs = {"name": name} else: @@ -249,13 +249,13 @@ def resolve_sessions(root, info, name: str = None, limit: int = 25, skip: int = return result def resolve_model(root, info, id: str = None): - db = Control.instance().db + db = get_db() result = db.model_store.get(id).to_dict() return result def resolve_models(root, info, session_id: str = None, limit: int = 25, skip: int = 0, sort_key: str = "committed_at", sort_order: str = "desc"): - db = Control.instance().db + db = get_db() if session_id: kwargs = {"session_id": session_id} else: @@ -269,13 +269,13 @@ def resolve_models(root, info, session_id: str = None, limit: int = 25, skip: in return result def resolve_validation(root, info, id: str = None): - db = Control.instance().db + db = get_db() result = db.validation_store.get(id).to_dict() return result def resolve_validations(root, info, session_id: str = None, limit: int = 25, skip: int = 0, sort_key: str = "committed_at", sort_order: str = "desc"): - db = Control.instance().db + db = get_db() if session_id: kwargs = {"session_id": session_id} @@ -290,13 +290,13 @@ def resolve_validations(root, info, session_id: str = None, limit: int = 25, ski return result def resolve_status(root, info, id: str = None): - db = Control.instance().db + db = get_db() result = db.status_store.get(id).to_dict() return result def resolve_statuses(root, info, session_id: str = None, limit: int = 25, skip: int = 0, sort_key: str = "committed_at", sort_order: str = "desc"): - db = Control.instance().db + db = get_db() if session_id: kwargs = {"session_id": session_id} diff --git a/fedn/network/api/v1/helper_routes.py b/fedn/network/api/v1/helper_routes.py index a000c08be..ac92ac64f 100644 --- a/fedn/network/api/v1/helper_routes.py +++ b/fedn/network/api/v1/helper_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version -from fedn.network.controller.control import Control bp = Blueprint("helper", __name__, url_prefix=f"/api/{api_version}/helpers") @@ -25,7 +25,7 @@ def get_active_helper(): description: An unexpected error occurred """ try: - db = Control.instance().db + db = get_db() active_package = db.package_store.get_active() if active_package is None: return jsonify({"message": "No active helper"}), 404 @@ -53,7 +53,7 @@ def set_active_helper(): description: An unexpected error occurred """ try: - db = Control.instance().db + db = get_db() data = request.get_json() helper = data["helper"] db.package_store.set_active_helper(helper) diff --git a/fedn/network/api/v1/metric_routes.py b/fedn/network/api/v1/metric_routes.py index 5854f41cc..2eaa1d3a6 100644 --- a/fedn/network/api/v1/metric_routes.py +++ b/fedn/network/api/v1/metric_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("metric", __name__, url_prefix=f"/api/{api_version}/metrics") @@ -118,7 +118,7 @@ def get_metrics(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -220,7 +220,7 @@ def list_metrics(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -269,7 +269,7 @@ def get_metric(id: str): type: string """ try: - db = Control.instance().db + db = get_db() response = db.metric_store.get(id) if response is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -354,7 +354,7 @@ def get_metrics_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.metric_store.count(**kwargs) response = count @@ -446,7 +446,7 @@ def metrics_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.metric_store.count(**kwargs) response = count diff --git a/fedn/network/api/v1/model_routes.py b/fedn/network/api/v1/model_routes.py index 29b3a2b29..348a34608 100644 --- a/fedn/network/api/v1/model_routes.py +++ b/fedn/network/api/v1/model_routes.py @@ -1,14 +1,12 @@ -import io -from io import BytesIO - import numpy as np from flask import Blueprint, jsonify, request, send_file from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db, get_network, get_repository from fedn.network.api.v1.shared import api_version, get_limit, get_post_data_to_kwargs, get_reverse, get_typed_list_headers -from fedn.network.controller.control import Control from fedn.network.storage.statestore.stores.shared import EntityNotFound, MissingFieldError, ValidationError +from fedn.utils.model import FednModel bp = Blueprint("model", __name__, url_prefix=f"/api/{api_version}/models") @@ -100,7 +98,7 @@ def get_models(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -186,7 +184,7 @@ def list_models(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -241,7 +239,7 @@ def get_models_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.model_store.count(**kwargs) response = count @@ -295,7 +293,7 @@ def models_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.model_store.count(**kwargs) response = count @@ -340,7 +338,7 @@ def get_model(id: str): type: string """ try: - db = Control.instance().db + db = get_db() model = db.model_store.get(id) if model is None: @@ -393,7 +391,7 @@ def patch_model(id: str): type: string """ try: - db = Control.instance().db + db = get_db() existing_model = db.model_store.get(id) if existing_model is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -466,7 +464,7 @@ def put_model(id: str): type: string """ try: - db = Control.instance().db + db = get_db() model = db.model_store.get(id) if model is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -537,7 +535,7 @@ def get_descendants(id: str): type: string """ try: - db = Control.instance().db + db = get_db() limit = get_limit(request.headers) descendants = db.model_store.list_descendants(id, limit or 10) @@ -605,7 +603,7 @@ def get_ancestors(id: str): type: string """ try: - db = Control.instance().db + db = get_db() limit = get_limit(request.headers) reverse = get_reverse(request.headers) include_self_param: str = request.args.get("include_self") @@ -648,7 +646,7 @@ def get_leaf_nodes(): type: string """ try: - db = Control.instance().db + db = get_db() leaf_nodes = db.model_store.get_leaf_nodes() response = [model.to_dict() for model in leaf_nodes] return jsonify(response), 200 @@ -691,16 +689,14 @@ def download(id: str): type: string """ try: - db = Control.instance().db - repository = Control.instance().repository + db = get_db() + repository = get_repository() if repository is not None: model = db.model_store.get(id) if model is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 - - file = repository.get_model_stream(model.model_id) - - return send_file(file, as_attachment=True, download_name=model.model_id) + fedn_model = repository.get_model(model.model_id) + return send_file(fedn_model.get_stream_unsafe(), as_attachment=True, download_name=model.model_id) else: return jsonify({"message": "No model storage configured"}), 500 except Exception as e: @@ -747,21 +743,15 @@ def get_parameters(id: str): type: string """ try: - db = Control.instance().db - repository = Control.instance().repository + db = get_db() + repository = get_repository() if repository is not None: model = db.model_store.get(id) if model is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 + fedn_model = repository.get_model(model.model_id) - file = repository.get_model_stream(model.model_id) - - file_bytes = io.BytesIO() - for chunk in file.stream(32 * 1024): - file_bytes.write(chunk) - file_bytes.seek(0) # Reset the pointer to the beginning of the byte array - - a = np.load(file_bytes) + a = np.load(fedn_model.get_stream_unsafe()) weights = [] for i in range(len(a.files)): @@ -803,21 +793,29 @@ def upload_model(): type: string """ try: - control = Control.instance() + network = get_network() data = request.form.to_dict() file = request.files["file"] name: str = data.get("name", None) try: - object = BytesIO() - object.seek(0, 0) - file.seek(0) - object.write(file.read()) - helper = control.get_helper() - logger.info(f"Loading model from file using helper {helper.name}") - object.seek(0) - model = helper.load(object) - control.commit(model=model, name=name) + fedn_model = FednModel.from_stream(file) + fedn_model.helper = network.get_helper() + try: + _ = fedn_model.get_model_params() + except Exception as e: + logger.error(f"Failed to extract model parameters: {e}") + status_code = 400 + return ( + jsonify( + { + "success": False, + "message": "Failed to extract model parameters. Ensure that the model is compatible with the selected helper.", + } + ), + status_code, + ) + network.commit_model(model=fedn_model, name=name) except Exception as e: logger.error(f"An unexpected error occurred: {e}") status_code = 400 diff --git a/fedn/network/api/v1/package_routes.py b/fedn/network/api/v1/package_routes.py index c1c22fa75..ae41be3df 100644 --- a/fedn/network/api/v1/package_routes.py +++ b/fedn/network/api/v1/package_routes.py @@ -8,12 +8,10 @@ from fedn.common.log_config import logger, tracer from fedn.network.api.auth import jwt_auth_required from fedn.network.api.shared import get_checksum as _get_checksum -from fedn.network.api.v1.shared import (api_version, get_post_data_to_kwargs, - get_typed_list_headers) -from fedn.network.controller.control import Control +from fedn.network.api.shared import get_db, get_network, get_repository +from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers from fedn.network.storage.statestore.stores.dto.package import PackageDTO -from fedn.network.storage.statestore.stores.shared import (MissingFieldError, - ValidationError) +from fedn.network.storage.statestore.stores.shared import MissingFieldError, ValidationError bp = Blueprint("package", __name__, url_prefix=f"/api/{api_version}/packages") @@ -123,7 +121,7 @@ def get_packages(): """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -211,7 +209,7 @@ def list_packages(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -278,7 +276,7 @@ def get_packages_count(): """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.package_store.count(**kwargs) response = count @@ -342,7 +340,7 @@ def packages_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.package_store.count(**kwargs) response = count @@ -387,7 +385,7 @@ def get_package(id: str): type: string """ try: - db = Control.instance().db + db = get_db() response = db.package_store.get(id) if response is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -427,7 +425,7 @@ def get_active_package(): type: string """ try: - db = Control.instance().db + db = get_db() response = db.package_store.get_active() if response is None: return jsonify({"message": "Entity not found"}), 404 @@ -463,7 +461,7 @@ def set_active_package(): type: string """ try: - db = Control.instance().db + db = get_db() data = request.json package_id = data["id"] response = db.package_store.set_active(package_id) @@ -509,7 +507,7 @@ def delete_active_package(): type: string """ try: - db = Control.instance().db + db = get_db() result = db.package_store.delete_active() if result is False: return jsonify({"message": "Entity not found"}), 404 @@ -567,8 +565,8 @@ def upload_package(): type: string """ try: - db = Control.instance().db - repository = Control.instance().repository + db = get_db() + repository = get_repository() data = request.form.to_dict() file = request.files["file"] @@ -648,8 +646,8 @@ def download_package(): message: type: string """ - db = Control.instance().db - control = Control.instance() + db = get_db() + network = get_network() name = request.args.get("name", None) if name is None: @@ -665,7 +663,7 @@ def download_package(): return send_from_directory(FEDN_COMPUTE_PACKAGE_DIR, name, as_attachment=True) except Exception: try: - data = control.get_compute_package(name) + data = network.get_compute_package(name) # TODO: make configurable, perhaps in config.py or package.py file_path = safe_join(FEDN_COMPUTE_PACKAGE_DIR, name) with open(file_path, "wb") as fh: diff --git a/fedn/network/api/v1/prediction_routes.py b/fedn/network/api/v1/prediction_routes.py index 893aba786..c245f5dea 100644 --- a/fedn/network/api/v1/prediction_routes.py +++ b/fedn/network/api/v1/prediction_routes.py @@ -1,11 +1,12 @@ -import threading - from flask import Blueprint, jsonify, request +from grpc import RpcError +import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db, get_network from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control +from fedn.network.common.command import CommandType bp = Blueprint("prediction", __name__, url_prefix=f"/api/{api_version}/predictions") @@ -20,8 +21,8 @@ def start_session(): type: rounds: int """ try: - db = Control.instance().db - control = Control.instance() + db = get_db() + control = get_network().get_control() data = request.get_json(silent=True) if request.is_json else request.form.to_dict() prediction_id: str = data.get("prediction_id") @@ -42,7 +43,11 @@ def start_session(): return jsonify({"message": f"Model {model_id} not found"}), 404 session_config["model_id"] = model_id - threading.Thread(target=control.prediction_session, kwargs={"config": session_config}).start() + try: + control.send_command(fedn.Command.START, CommandType.PredictionSession.value, session_config) + except RpcError as e: + logger.error(f"Failed to start prediction session: {e}") + return jsonify({"message": "Failed to start prediction session"}), 500 return jsonify({"message": "Prediction session started"}), 200 except Exception as e: @@ -169,7 +174,7 @@ def get_predictions(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -267,7 +272,7 @@ def list_predictions(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -311,7 +316,7 @@ def get_predictions_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.prediction_store.count(**kwargs) response = count @@ -380,7 +385,7 @@ def predictions_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.prediction_store.count(**kwargs) response = count @@ -427,7 +432,7 @@ def get_prediction(id: str): type: string """ try: - db = Control.instance().db + db = get_db() prediction = db.prediction_store.get(id) if prediction is None: diff --git a/fedn/network/api/v1/round_routes.py b/fedn/network/api/v1/round_routes.py index bb14496d2..8af98bceb 100644 --- a/fedn/network/api/v1/round_routes.py +++ b/fedn/network/api/v1/round_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("round", __name__, url_prefix=f"/api/{api_version}/rounds") @@ -88,7 +88,7 @@ def get_rounds(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -167,7 +167,7 @@ def list_rounds(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -215,7 +215,7 @@ def get_rounds_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.round_store.count(**kwargs) response = count @@ -262,7 +262,7 @@ def rounds_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.round_store.count(**kwargs) response = count @@ -307,7 +307,7 @@ def get_round(id: str): type: string """ try: - db = Control.instance().db + db = get_db() round = db.round_store.get(id) if round is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 diff --git a/fedn/network/api/v1/run_routes.py b/fedn/network/api/v1/run_routes.py index f861d5709..cceb222b0 100644 --- a/fedn/network/api/v1/run_routes.py +++ b/fedn/network/api/v1/run_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("run", __name__, url_prefix=f"/api/{api_version}/runs") @@ -12,7 +12,7 @@ @jwt_auth_required(role="admin") def get_runs(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -30,7 +30,7 @@ def get_runs(): @jwt_auth_required(role="admin") def list_runs(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -48,7 +48,7 @@ def list_runs(): @jwt_auth_required(role="admin") def get_runs_count(): try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.run_store.count(**kwargs) response = count @@ -62,7 +62,7 @@ def get_runs_count(): @jwt_auth_required(role="admin") def runs_count(): try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.run_store.count(**kwargs) response = count @@ -76,7 +76,7 @@ def runs_count(): @jwt_auth_required(role="admin") def get_run(id: str): try: - db = Control.instance().db + db = get_db() response = db.run_store.get(id) if response is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 diff --git a/fedn/network/api/v1/session_routes.py b/fedn/network/api/v1/session_routes.py index 3160564a7..b88f9b0d5 100644 --- a/fedn/network/api/v1/session_routes.py +++ b/fedn/network/api/v1/session_routes.py @@ -1,13 +1,13 @@ -import threading - from flask import Blueprint, jsonify, request +from grpc import RpcError +import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db, get_network from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.combiner.interfaces import CombinerUnavailableError -from fedn.network.controller.control import Control -from fedn.network.state import ReducerState +from fedn.network.common.command import CommandType +from fedn.network.common.state import ControllerState from fedn.network.storage.statestore.stores.dto.session import SessionConfigDTO, SessionDTO from fedn.network.storage.statestore.stores.shared import EntityNotFound, MissingFieldError, ValidationError @@ -87,7 +87,7 @@ def get_sessions(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -167,7 +167,7 @@ def list_sessions(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -217,7 +217,7 @@ def get_sessions_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.session_store.count(**kwargs) response = count @@ -264,7 +264,7 @@ def sessions_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.session_store.count(**kwargs) response = count @@ -309,7 +309,7 @@ def get_session(id: str): type: string """ try: - db = Control.instance().db + db = get_db() result = db.session_store.get(id) if result is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -353,7 +353,7 @@ def post(): type: string """ try: - db = Control.instance().db + db = get_db() data = request.get_json(silent=True) if request.is_json else request.form.to_dict() session_config = SessionConfigDTO() @@ -385,26 +385,6 @@ def post(): return jsonify({"message": "An unexpected error occurred"}), 500 -def _get_number_of_available_clients(client_ids: list[str]): - control = Control.instance() - - result = 0 - active_clients = None - for combiner in control.network.get_combiners(): - try: - active_clients = combiner.list_active_clients() - if active_clients is not None: - if client_ids is not None: - filtered = [item for item in active_clients if item.client_id in client_ids] - result += len(filtered) - else: - result += len(active_clients) - except CombinerUnavailableError: - return 0 - - return result - - @bp.route("/start", methods=["POST"]) @jwt_auth_required(role="admin") def start_session(): @@ -415,8 +395,9 @@ def start_session(): type: rounds: int """ try: - db = Control.instance().db - control = Control.instance() + db = get_db() + network = get_network() + control = network.get_control() data = request.get_json(silent=True) if request.is_json else request.form.to_dict() @@ -450,12 +431,12 @@ def start_session(): model_id = session_config.model_id min_clients = session_config.clients_required - if control.state() == ReducerState.monitoring: + if control.get_state() != ControllerState.idle: return jsonify({"message": "A session is already running!"}), 400 if not rounds or not isinstance(rounds, int): rounds = session_config.rounds - nr_available_clients = _get_number_of_available_clients(client_ids=client_ids) + nr_available_clients = network.get_number_of_available_clients(client_ids=client_ids) if nr_available_clients < min_clients: return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400 @@ -464,7 +445,19 @@ def start_session(): if model is None: return jsonify({"message": "Session seed model not found"}), 400 - threading.Thread(target=control.start_session, args=(session_id, rounds, round_timeout, model_name_prefix, client_ids)).start() + parameters = { + "session_id": session_id, + "rounds": rounds, + "round_timeout": round_timeout, + "model_name_prefix": model_name_prefix, + "client_ids": client_ids, + } + + try: + control.send_command(fedn.Command.START, CommandType.StandardSession.value, parameters) + except RpcError as e: + logger.error(f"Failed to send command to control: {e}") + return jsonify({"message": "Failed to start session"}), 500 return jsonify({"message": "Session started"}), 200 except Exception as e: @@ -477,8 +470,9 @@ def start_session(): def start_splitlearning_session(): """Starts a new split learning session.""" try: - db = Control.instance().db - control = Control.instance() + db = get_db() + network = get_network() + control = network.get_control() data = request.json if request.headers["Content-Type"] == "application/json" else request.form.to_dict() session_id: str = data.get("session_id") rounds: int = data.get("rounds", "") @@ -495,17 +489,22 @@ def start_splitlearning_session(): session_config = session.session_config min_clients = session_config.clients_required - if control.state() == ReducerState.monitoring: + if control.get_state() != ControllerState.idle: return jsonify({"message": "A session is already running!"}), 400 if not rounds or not isinstance(rounds, int): rounds = session_config.rounds - nr_available_clients = _get_number_of_available_clients() + nr_available_clients = network.get_number_of_available_clients() if nr_available_clients < min_clients: return jsonify({"message": f"Number of available clients is lower than the required minimum of {min_clients}"}), 400 - threading.Thread(target=control.splitlearning_session, args=(session_id, rounds, round_timeout)).start() + parameters = {"session_id": session_id, "rounds": rounds, "round_timeout": round_timeout} + try: + control.send_command(fedn.Command.START, CommandType.SplitLearningSession.value, parameters) + except RpcError as e: + logger.error(f"Failed to send command to control: {e}") + return jsonify({"message": "Failed to start split learning session"}), 500 return jsonify({"message": "Splitlearning session started"}), 200 except Exception as e: @@ -553,7 +552,7 @@ def patch_session(id: str): type: string """ try: - db = Control.instance().db + db = get_db() existing_session = db.session_store.get(id) if existing_session is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -626,7 +625,7 @@ def put_session(id: str): type: string """ try: - db = Control.instance().db + db = get_db() session = db.session_store.get(id) if session is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 diff --git a/fedn/network/api/v1/shared.py b/fedn/network/api/v1/shared.py index 7b1128b64..737ac4fe0 100644 --- a/fedn/network/api/v1/shared.py +++ b/fedn/network/api/v1/shared.py @@ -61,6 +61,10 @@ def get_post_data_to_kwargs(request: object) -> dict: except Exception: request_data = {} + # Ensure request_data is a dictionary + if not isinstance(request_data, dict): + request_data = {} + kwargs = {} for key, value in request_data.items(): if isinstance(value, str) and "," in value: diff --git a/fedn/network/api/v1/status_routes.py b/fedn/network/api/v1/status_routes.py index 3f27350df..a7100002c 100644 --- a/fedn/network/api/v1/status_routes.py +++ b/fedn/network/api/v1/status_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("status", __name__, url_prefix=f"/api/{api_version}/statuses") @@ -119,7 +119,7 @@ def get_statuses(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -214,7 +214,7 @@ def list_statuses(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -282,7 +282,7 @@ def get_statuses_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.status_store.count(**kwargs) response = count @@ -346,7 +346,7 @@ def statuses_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.status_store.count(**kwargs) response = count @@ -391,7 +391,7 @@ def get_status(id: str): type: string """ try: - db = Control.instance().db + db = get_db() status = db.status_store.get(id) if status is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 diff --git a/fedn/network/api/v1/telemetry_routes.py b/fedn/network/api/v1/telemetry_routes.py index 2a2993ba7..9e2ef5494 100644 --- a/fedn/network/api/v1/telemetry_routes.py +++ b/fedn/network/api/v1/telemetry_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control from fedn.network.storage.statestore.stores.dto.telemetry import TelemetryDTO from fedn.network.storage.statestore.stores.shared import MissingFieldError, ValidationError @@ -14,7 +14,7 @@ @jwt_auth_required(role="admin") def get_telemetries(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -32,7 +32,7 @@ def get_telemetries(): @jwt_auth_required(role="admin") def list_telemetries(): try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -50,7 +50,7 @@ def list_telemetries(): @jwt_auth_required(role="admin") def get_telemetries_count(): try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.telemetry_store.count(**kwargs) response = count @@ -64,7 +64,7 @@ def get_telemetries_count(): @jwt_auth_required(role="admin") def telemetries_count(): try: - db = Control.instance().db + db = get_db() kwargs = request.get_json(silent=True) if request.is_json else request.form.to_dict() @@ -80,7 +80,7 @@ def telemetries_count(): @jwt_auth_required(role="admin") def get_telemetry(id: str): try: - db = Control.instance().db + db = get_db() telemetry = db.telemetry_store.get(id) if telemetry is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 @@ -96,7 +96,7 @@ def get_telemetry(id: str): @jwt_auth_required(role="admin") def add_telemetries(): try: - db = Control.instance().db + db = get_db() data = request.get_json(silent=True) if request.is_json else request.form.to_dict() telemetry = TelemetryDTO().patch_with(data) diff --git a/fedn/network/api/v1/validation_routes.py b/fedn/network/api/v1/validation_routes.py index f275dfb52..4d5f7dd52 100644 --- a/fedn/network/api/v1/validation_routes.py +++ b/fedn/network/api/v1/validation_routes.py @@ -2,8 +2,8 @@ from fedn.common.log_config import logger from fedn.network.api.auth import jwt_auth_required +from fedn.network.api.shared import get_db from fedn.network.api.v1.shared import api_version, get_post_data_to_kwargs, get_typed_list_headers -from fedn.network.controller.control import Control bp = Blueprint("validation", __name__, url_prefix=f"/api/{api_version}/validations") @@ -126,7 +126,7 @@ def get_validations(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = request.args.to_dict() @@ -224,7 +224,7 @@ def list_validations(): type: string """ try: - db = Control.instance().db + db = get_db() limit, skip, sort_key, sort_order = get_typed_list_headers(request.headers) kwargs = get_post_data_to_kwargs(request) @@ -296,7 +296,7 @@ def get_validations_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = request.args.to_dict() count = db.validation_store.count(**kwargs) response = count @@ -363,7 +363,7 @@ def validations_count(): type: string """ try: - db = Control.instance().db + db = get_db() kwargs = get_post_data_to_kwargs(request) count = db.validation_store.count(**kwargs) response = count @@ -408,7 +408,7 @@ def get_validation(id: str): type: string """ try: - db = Control.instance().db + db = get_db() response = db.validation_store.get(id) if response is None: return jsonify({"message": f"Entity with id: {id} not found"}), 404 diff --git a/fedn/network/clients/connect.py b/fedn/network/clients/connect.py index d6b747b9a..8875fd54d 100644 --- a/fedn/network/clients/connect.py +++ b/fedn/network/clients/connect.py @@ -17,12 +17,7 @@ FEDN_CUSTOM_URL_PREFIX, ) from fedn.common.log_config import logger - -# Constants for HTTP status codes -HTTP_STATUS_OK = 200 -HTTP_STATUS_NO_CONTENT = 204 -HTTP_STATUS_BAD_REQUEST = 400 -HTTP_STATUS_UNAUTHORIZED = 401 +from fedn.network.clients.http_status_codes import HTTP_STATUS_BAD_REQUEST, HTTP_STATUS_NO_CONTENT, HTTP_STATUS_OK, HTTP_STATUS_UNAUTHORIZED # Default timeout for requests REQUEST_TIMEOUT = 10 # seconds diff --git a/fedn/network/clients/client_v2.py b/fedn/network/clients/dispatcher_client.py similarity index 90% rename from fedn/network/clients/client_v2.py rename to fedn/network/clients/dispatcher_client.py index 8deb0f62b..1d084919c 100644 --- a/fedn/network/clients/client_v2.py +++ b/fedn/network/clients/dispatcher_client.py @@ -10,7 +10,9 @@ from fedn.common.config import FEDN_CUSTOM_URL_PREFIX from fedn.common.log_config import logger +from fedn.network.clients.dispatcher_package_runtime import DispatcherPackageRuntime from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions +from fedn.network.clients.package_runtime import get_compute_package_dir_path from fedn.network.combiner.modelservice import get_tmp_path from fedn.utils.helpers.helpers import get_helper, save_metadata @@ -48,7 +50,7 @@ def to_json(self) -> Dict[str, Optional[str]]: } -class Client: +class DispatcherClient: """Client for interacting with the FEDn network.""" def __init__( @@ -72,6 +74,9 @@ def __init__( self.package_checksum = package_checksum self.helper_type = helper_type + package_path, archive_path = get_compute_package_dir_path() + self._package_runtime = DispatcherPackageRuntime(package_path, archive_path) + self.fedn_api_url = get_url(self.api_url, self.api_port) self.fedn_client: FednClient = FednClient() self.helper = None @@ -101,14 +106,18 @@ def start(self) -> None: if not result: return if self.client_obj.package == "remote": - result = self.fedn_client.init_remote_compute_package(url=self.fedn_api_url, token=self.token, package_checksum=self.package_checksum) + result = self._package_runtime.load_remote_compute_package(url=self.fedn_api_url, token=self.token) if not result: return else: - result = self.fedn_client.init_local_compute_package() + result = self._package_runtime.load_local_compute_package(os.path.join(os.getcwd(), "client")) if not result: return + result = self._package_runtime.run_startup() + if not result: + return + self.set_helper(combiner_config) result = self.fedn_client.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) @@ -164,7 +173,7 @@ def _process_training_request(self, in_model: BytesIO, client_settings: dict) -> outpath = self.helper.get_tmp_path() tic = time.time() - self.fedn_client.dispatcher.run_cmd(f"train {inpath} {outpath}") + self._package_runtime.dispatcher.run_cmd(f"train {inpath} {outpath}") meta["exec_training"] = time.time() - tic with open(outpath, "rb") as fr: @@ -196,7 +205,7 @@ def _process_validation_request(self, in_model: BytesIO) -> Optional[dict]: fh.write(in_model.getbuffer()) outpath = get_tmp_path() - self.fedn_client.dispatcher.run_cmd(f"validate {inpath} {outpath}") + self._package_runtime.dispatcher.run_cmd(f"validate {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.loads(fh.read()) @@ -219,7 +228,7 @@ def _process_prediction_request(self, in_model: BytesIO) -> Optional[dict]: fh.write(in_model.getbuffer()) outpath = get_tmp_path() - self.fedn_client.dispatcher.run_cmd(f"predict {inpath} {outpath}") + self._package_runtime.dispatcher.run_cmd(f"predict {inpath} {outpath}") with open(outpath, "r") as fh: metrics = json.load(fh) @@ -247,7 +256,7 @@ def _process_forward_request(self, client_id, is_sl_inference) -> Tuple[BytesIO, out_embedding_path = get_tmp_path() tic = time.time() - self.fedn_client.dispatcher.run_cmd(f"forward {client_id} {out_embedding_path} {is_sl_inference}") + self._package_runtime.dispatcher.run_cmd(f"forward {client_id} {out_embedding_path} {is_sl_inference}") meta = {} embeddings = None @@ -292,7 +301,7 @@ def _process_backward_request(self, in_gradients: BytesIO, client_id: str) -> di tic = time.time() - self.fedn_client.dispatcher.run_cmd(f"backward {inpath} {client_id}") + self._package_runtime.dispatcher.run_cmd(f"backward {inpath} {client_id}") meta["exec_training"] = time.time() - tic os.unlink(inpath) diff --git a/fedn/network/clients/dispatcher_package_runtime.py b/fedn/network/clients/dispatcher_package_runtime.py new file mode 100644 index 000000000..2dea6ca54 --- /dev/null +++ b/fedn/network/clients/dispatcher_package_runtime.py @@ -0,0 +1,67 @@ +"""Contains the PackageRuntime class, used to download, validate, and unpack compute packages.""" + +from typing import Optional + +from fedn.common.log_config import logger +from fedn.network.clients.package_runtime import PackageRuntime +from fedn.utils.dispatcher import Dispatcher + +# Default timeout for requests +REQUEST_TIMEOUT = 10 # seconds + + +class DispatcherPackageRuntime(PackageRuntime): + """PackageRuntime is used to download, validate, and unpack compute packages. + + :param package_path: Path to compute package. + :type package_path: str + """ + + def __init__(self, package_path: str, archive_path: str) -> None: + """Initialize the PackageRuntime.""" + super().__init__(package_path, archive_path) + + self.dispatcher: Optional[Dispatcher] = None + + def run_startup(self): + if self.config is None: + logger.error("Package runtime is not initialized.") + return False + + result = self.set_dispatcher() + if not result: + return False + + return self.init_dispatcher() + + def set_dispatcher(self) -> bool: + """Dispatch the compute package. + + :param run_path: Path to dispatch the compute package. + :type run_path: str + :return: Dispatcher object or None if an error occurred. + :rtype: Optional[Dispatcher] + """ + try: + self.dispatcher = Dispatcher(self.config, self._target_path) + except Exception as e: + logger.error(f"Error setting dispatcher: {e}") + return False + return True + + def init_dispatcher(self) -> bool: + """Get or set the environment.""" + try: + logger.info("Initiating Dispatcher with entrypoint set to: startup") + activate_cmd = self.dispatcher.get_or_create_python_env() + self.dispatcher.run_cmd("startup") + except KeyError: + logger.info("No startup command found in package. Continuing.") + except Exception as e: + logger.error(f"Caught exception: {type(e).__name__}") + return False + + if activate_cmd: + logger.info(f"To activate the virtual environment, run: {activate_cmd}") + + return True diff --git a/fedn/network/clients/fedn_client.py b/fedn/network/clients/fedn_client.py index 298693503..4ef0a6bd6 100644 --- a/fedn/network/clients/fedn_client.py +++ b/fedn/network/clients/fedn_client.py @@ -2,63 +2,34 @@ import enum import json -import os import threading import time import uuid from contextlib import contextmanager from io import BytesIO from typing import Any, Optional, Tuple, Union +from urllib.parse import urljoin import psutil import requests import fedn.network.grpc.fedn_pb2 as fedn -from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CONNECT_API_SECURE, FEDN_PACKAGE_EXTRACT_DIR +from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CONNECT_API_SECURE from fedn.common.log_config import logger -from fedn.network.clients.grpc_handler import GrpcHandler, RetryException -from fedn.network.clients.package_runtime import PackageRuntime -from fedn.utils.dispatcher import Dispatcher - -# Constants for HTTP status codes -HTTP_STATUS_OK = 200 -HTTP_STATUS_NO_CONTENT = 204 -HTTP_STATUS_BAD_REQUEST = 400 -HTTP_STATUS_UNAUTHORIZED = 401 -HTTP_STATUS_NOT_FOUND = 404 -HTTP_STATUS_PACKAGE_MISSING = 203 +from fedn.network.clients.grpc_handler import GrpcConnectionOptions, GrpcHandler, RetryException +from fedn.network.clients.http_status_codes import ( + HTTP_STATUS_BAD_REQUEST, + HTTP_STATUS_NOT_FOUND, + HTTP_STATUS_OK, + HTTP_STATUS_PACKAGE_MISSING, + HTTP_STATUS_UNAUTHORIZED, +) +from fedn.network.clients.logging_context import LoggingContext # Default timeout for requests REQUEST_TIMEOUT = 10 # seconds -class GrpcConnectionOptions: - """Options for configuring the GRPC connection.""" - - def __init__(self, host: str, port: int, status: str = "", fqdn: str = "", package: str = "", ip: str = "", helper_type: str = "") -> None: - """Initialize GrpcConnectionOptions.""" - self.status = status - self.host = host - self.fqdn = fqdn - self.package = package - self.ip = ip - self.port = port - self.helper_type = helper_type - - @classmethod - def from_dict(cls, config: dict) -> "GrpcConnectionOptions": - """Create a GrpcConnectionOptions instance from a JSON string.""" - return cls( - status=config.get("status", ""), - host=config.get("host", ""), - fqdn=config.get("fqdn", ""), - package=config.get("package", ""), - ip=config.get("ip", ""), - port=config.get("port", 0), - helper_type=config.get("helper_type", ""), - ) - - class ConnectToApiResult(enum.Enum): """Enum for representing the result of connecting to the FEDn API.""" @@ -70,43 +41,6 @@ class ConnectToApiResult(enum.Enum): UnknownError = 5 -def get_compute_package_dir_path() -> str: - """Get the directory path for the compute package.""" - if FEDN_PACKAGE_EXTRACT_DIR: - result = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) - else: - dirname = "compute-package-" + time.strftime("%Y%m%d-%H%M%S") - result = os.path.join(os.getcwd(), dirname) - - if not os.path.exists(result): - os.mkdir(result) - - return result - - -class LoggingContext: - """Context for keeping track of the session, model and round IDs during a dispatched call from a request.""" - - def __init__( - self, *, step: int = 0, model_id: str = None, round_id: str = None, session_id: str = None, request: Optional[fedn.TaskRequest] = None - ) -> None: - if request is not None: - if model_id is None: - model_id = request.model_id - if round_id is None: - if request.type == fedn.StatusType.MODEL_UPDATE: - config = json.loads(request.data) - round_id = config["round_id"] - if session_id is None: - session_id = request.session_id - - self.model_id = model_id - self.round_id = round_id - self.session_id = session_id - self.request = request - self.step = step - - class FednClient: """Client for interacting with the FEDn network.""" @@ -117,14 +51,12 @@ def __init__( self.train_callback = train_callback self.validate_callback = validate_callback self.predict_callback = predict_callback + self.forward_callback: Optional[callable] = None + self.backward_callback: Optional[callable] = None - path = get_compute_package_dir_path() - self._package_runtime = PackageRuntime(path) - - self.dispatcher: Optional[Dispatcher] = None self.grpc_handler: Optional[GrpcHandler] = None - self._current_context: Optional[LoggingContext] = None + self._current_logging_context: Optional[LoggingContext] = None def set_train_callback(self, callback: callable) -> None: """Set the train callback.""" @@ -146,7 +78,7 @@ def set_backward_callback(self, callback: callable): def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApiResult, Any]: """Connect to the FEDn API.""" - url_endpoint = f"{url}api/v1/clients/add" + url_endpoint = urljoin(url, "api/v1/clients/add") logger.info(f"Connecting to API endpoint: {url_endpoint}") try: @@ -190,55 +122,6 @@ def connect_to_api(self, url: str, token: str, json: dict) -> Tuple[ConnectToApi logger.warning(f"Connect to FEDn Api - Error occurred: {str(e)}") return ConnectToApiResult.UnknownError, str(e) - def download_compute_package(self, url: str, token: str, name: Optional[str] = None) -> bool: - """Download compute package from controller.""" - return self._package_runtime.download_compute_package(url, token, name) - - def set_compute_package_checksum(self, url: str, token: str, name: Optional[str] = None) -> bool: - """Get checksum of compute package from controller.""" - return self._package_runtime.set_checksum(url, token, name) - - def unpack_compute_package(self) -> Tuple[bool, str]: - """Unpack the compute package.""" - result, path = self._package_runtime.unpack_compute_package() - if result: - logger.info(f"Compute package unpacked to: {path}") - else: - logger.error("Error: Could not unpack compute package") - - return result, path - - def validate_compute_package(self, checksum: str) -> bool: - """Validate the compute package.""" - return self._package_runtime.validate(checksum) - - def set_dispatcher(self, path: str) -> bool: - """Set the dispatcher.""" - result = self._package_runtime.get_dispatcher(path) - if result: - self.dispatcher = result - return True - - logger.error("Error: Could not set dispatcher") - return False - - def get_or_set_environment(self) -> bool: - """Get or set the environment.""" - try: - logger.info("Initiating Dispatcher with entrypoint set to: startup") - activate_cmd = self.dispatcher._get_or_create_python_env() - self.dispatcher.run_cmd("startup") - except KeyError: - logger.info("No startup command found in package. Continuing.") - except Exception as e: - logger.error(f"Caught exception: {type(e).__name__}") - return False - - if activate_cmd: - logger.info(f"To activate the virtual environment, run: {activate_cmd}") - - return True - def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, token: str) -> bool: """Initialize the GRPC handler.""" try: @@ -258,11 +141,11 @@ def init_grpchandler(self, config: GrpcConnectionOptions, client_name: str, toke logger.error(f"Could not initialize GRPC connection: {e}") return False - def send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0) -> None: + def _send_heartbeats(self, client_name: str, client_id: str, update_frequency: float = 2.0) -> None: """Send heartbeats to the server.""" self.grpc_handler.send_heartbeats(client_name=client_name, client_id=client_id, update_frequency=update_frequency) - def listen_to_task_stream(self, client_name: str, client_id: str) -> None: + def _listen_to_task_stream(self, client_name: str, client_id: str) -> None: """Listen to the task stream.""" self.grpc_handler.listen_to_task_stream(client_name=client_name, client_id=client_id, callback=self._task_stream_callback) @@ -284,12 +167,12 @@ def default_telemetry_loop(self, update_frequency: float = 5.0) -> None: @contextmanager def logging_context(self, context: LoggingContext): """Set the logging context.""" - prev_context = self._current_context - self._current_context = context + prev_context = self._current_logging_context + self._current_logging_context = context try: yield finally: - self._current_context = prev_context + self._current_logging_context = prev_context def _task_stream_callback(self, request: fedn.TaskRequest) -> None: """Handle task stream callbacks.""" @@ -311,7 +194,7 @@ def update_local_model(self, request: fedn.TaskRequest) -> None: model_update_id = str(uuid.uuid4()) tic = time.time() - in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + in_model = self.get_model_from_combiner(model_id=model_id, client_id=self.client_id) if in_model is None: logger.error("Could not retrieve model from combiner. Aborting training request.") @@ -335,20 +218,31 @@ def update_local_model(self, request: fedn.TaskRequest) -> None: logger.info(f"Running train callback with model ID: {model_id}") client_settings = json.loads(request.data).get("client_settings", {}) tic = time.time() - out_model, meta = self.train_callback(in_model, client_settings) + try: + out_model, meta = self.train_callback(in_model, client_settings) + except Exception as e: + logger.error(f"Train callback failed with expection: {e}") + return meta["processing_time"] = time.time() - tic tic = time.time() - self.send_model_to_combiner(model=out_model, id=model_update_id) + self.send_model_to_combiner(model=out_model, model_id=model_update_id) meta["upload_model"] = time.time() - tic logger.info("UPLOAD_MODEL: {0}".format(meta["upload_model"])) meta["fetch_model"] = fetch_model_time meta["config"] = request.data - update = self.create_update_message(model_id=model_id, model_update_id=model_update_id, meta=meta, request=request) + update = self.grpc_handler.create_update_message( + sender_name=self.name, + model_id=model_id, + model_update_id=model_update_id, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + meta=meta, + ) - self.send_model_update(update) + self.grpc_handler.send_model_update(update) self.send_status( "Model update completed.", @@ -359,6 +253,12 @@ def update_local_model(self, request: fedn.TaskRequest) -> None: sender_name=self.name, ) + def check_task_abort(self) -> None: + # Check if the current task has been aborted + # To be implemented + """Raises an exception if the current task has been aborted. Does nothing for now.""" + pass + def validate_global_model(self, request: fedn.TaskRequest) -> None: """Validate the global model.""" with self.logging_context(LoggingContext(request=request)): @@ -372,7 +272,7 @@ def validate_global_model(self, request: fedn.TaskRequest) -> None: type=fedn.StatusType.MODEL_VALIDATION, ) - in_model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + in_model = self.get_model_from_combiner(model_id=model_id, client_id=self.client_id) if in_model is None: logger.error("Could not retrieve model from combiner. Aborting validation request.") @@ -383,13 +283,27 @@ def validate_global_model(self, request: fedn.TaskRequest) -> None: return logger.debug(f"Running validate callback with model ID: {model_id}") - metrics = self.validate_callback(in_model) + try: + metrics = self.validate_callback(in_model) + except Exception as e: + logger.error(f"Validation callback failed with expection: {e}") + return if metrics is not None: # Send validation - validation = self.create_validation_message(metrics=metrics, request=request) - result: bool = self.send_model_validation(validation) + validation = self.grpc_handler.create_validation_message( + sender_name=self.name, + sender_client_id=self.client_id, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + metrics=json.dumps(metrics), + correlation_id=request.correlation_id, + session_id=request.session_id, + ) + + result: bool = self.grpc_handler.send_model_validation(validation) if result: self.send_status( @@ -413,7 +327,7 @@ def predict_global_model(self, request: fedn.TaskRequest) -> None: """Predict using the global model.""" with self.logging_context(LoggingContext(request=request)): model_id = request.model_id - model = self.get_model_from_combiner(id=model_id, client_id=self.client_id) + model = self.get_model_from_combiner(model_id=model_id, client_id=self.client_id) if model is None: logger.error("Could not retrieve model from combiner. Aborting prediction request.") @@ -424,11 +338,23 @@ def predict_global_model(self, request: fedn.TaskRequest) -> None: return logger.info(f"Running predict callback with model ID: {model_id}") - prediction = self.predict_callback(model) + try: + prediction = self.predict_callback(model) + except Exception as e: + logger.error(f"Predict callback failed with expection: {e}") + return - prediction_message = self.create_prediction_message(prediction=prediction, request=request) + prediction_message = self.grpc_handler.create_prediction_message( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + model_id=request.model_id, + prediction_output=json.dumps(prediction), + correlation_id=request.correlation_id, + session_id=request.session_id, + ) - self.send_model_prediction(prediction_message) + self.grpc_handler.send_model_prediction(prediction_message) def log_metric(self, metrics: dict, step: int = None, commit: bool = True) -> bool: """Log the metrics to the server. @@ -444,7 +370,7 @@ def log_metric(self, metrics: dict, step: int = None, commit: bool = True) -> bo bool: True if the metrics were logged successfully, False otherwise. """ - context = self._current_context + context = self._current_logging_context if context is None: logger.error("Missing context for logging metric.") @@ -489,14 +415,21 @@ def forward_embeddings(self, request): meta["processing_time"] = time.time() - tic tic = time.time() - self.send_model_to_combiner(model=out_embeddings, id=embedding_update_id) + self.send_model_to_combiner(model=out_embeddings, model_id=embedding_update_id) meta["upload_model"] = time.time() - tic meta["config"] = request.data - update = self.create_update_message(model_id=model_id, model_update_id=embedding_update_id, meta=meta, request=request) + update = self.grpc_handler.create_update_message( + sender_name=self.name, + model_id=model_id, + model_update_id=embedding_update_id, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + meta=meta, + ) - self.send_model_update(update) + self.grpc_handler.send_model_update(update) self.send_status( "Forward pass completed.", @@ -513,7 +446,7 @@ def backward_gradients(self, request): try: tic = time.time() - in_gradients = self.get_model_from_combiner(id=model_id, client_id=self.client_id) # gets gradients + in_gradients = self.get_model_from_combiner(model_id=model_id, client_id=self.client_id) # gets gradients if in_gradients is None: logger.error("Could not retrieve gradients from combiner. Aborting backward request.") @@ -537,7 +470,16 @@ def backward_gradients(self, request): meta["status"] = "success" logger.info("Creating and sending backward completion to combiner.") - completion = self.create_backward_completion_message(gradient_id=model_id, meta=meta, request=request) + + completion = self.grpc_handler.create_backward_completion_message( + sender_name=self.name, + receiver_name=request.sender.name, + receiver_role=request.sender.role, + gradient_id=model_id, + session_id=request.session_id, + meta=meta, + ) + self.grpc_handler.send_backward_completion(completion) self.send_status( @@ -550,17 +492,6 @@ def backward_gradients(self, request): except Exception as e: logger.error(f"Error in backward pass: {str(e)}") - def create_backward_completion_message(self, gradient_id: str, meta: dict, request: fedn.TaskRequest): - """Create a backward completion message.""" - return self.grpc_handler.create_backward_completion_message( - sender_name=self.name, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - gradient_id=gradient_id, - session_id=request.session_id, - meta=meta, - ) - def log_attributes(self, attributes: dict) -> bool: """Log the attributes to the server. @@ -603,42 +534,6 @@ def log_telemetry(self, telemetry: dict) -> bool: return self.grpc_handler.send_telemetry(message) - def create_update_message(self, model_id: str, model_update_id: str, meta: dict, request: fedn.TaskRequest) -> fedn.ModelUpdate: - """Create an update message.""" - return self.grpc_handler.create_update_message( - sender_name=self.name, - model_id=model_id, - model_update_id=model_update_id, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - meta=meta, - ) - - def create_validation_message(self, metrics: dict, request: fedn.TaskRequest) -> fedn.ModelValidation: - """Create a validation message.""" - return self.grpc_handler.create_validation_message( - sender_name=self.name, - sender_client_id=self.client_id, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - model_id=request.model_id, - metrics=json.dumps(metrics), - correlation_id=request.correlation_id, - session_id=request.session_id, - ) - - def create_prediction_message(self, prediction: dict, request: fedn.TaskRequest) -> fedn.ModelPrediction: - """Create a prediction message.""" - return self.grpc_handler.create_prediction_message( - sender_name=self.name, - receiver_name=request.sender.name, - receiver_role=request.sender.role, - model_id=request.model_id, - prediction_output=json.dumps(prediction), - correlation_id=request.correlation_id, - session_id=request.session_id, - ) - def set_name(self, name: str) -> None: """Set the client name.""" logger.info(f"Setting client name to: {name}") @@ -652,21 +547,21 @@ def set_client_id(self, client_id: str) -> None: def run(self, with_telemetry=True, with_heartbeat=True) -> None: """Run the client.""" if with_heartbeat: - threading.Thread(target=self.send_heartbeats, args=(self.name, self.client_id), daemon=True).start() + threading.Thread(target=self._send_heartbeats, args=(self.name, self.client_id), daemon=True).start() if with_telemetry: threading.Thread(target=self.default_telemetry_loop, daemon=True).start() try: - self.listen_to_task_stream(client_name=self.name, client_id=self.client_id) + self._listen_to_task_stream(client_name=self.name, client_id=self.client_id) except KeyboardInterrupt: logger.info("Client stopped by user.") - def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> BytesIO: + def get_model_from_combiner(self, model_id: str, client_id: str, timeout: int = 20) -> BytesIO: """Get the model from the combiner.""" - return self.grpc_handler.get_model_from_combiner(id=id, client_id=client_id, timeout=timeout) + return self.grpc_handler.get_model_from_combiner(model_id=model_id, client_id=client_id, timeout=timeout) - def send_model_to_combiner(self, model: BytesIO, id: str) -> None: + def send_model_to_combiner(self, model: BytesIO, model_id: str) -> None: """Send the model to the combiner.""" - self.grpc_handler.send_model_to_combiner(model, id) + self.grpc_handler.send_model_to_combiner(model, model_id) def send_status( self, @@ -679,67 +574,3 @@ def send_status( ) -> None: """Send the status.""" self.grpc_handler.send_status(msg, log_level, type, request, session_id, sender_name) - - def send_model_update(self, update: fedn.ModelUpdate) -> bool: - """Send the model update.""" - return self.grpc_handler.send_model_update(update) - - def send_model_validation(self, validation: fedn.ModelValidation) -> bool: - """Send the model validation.""" - return self.grpc_handler.send_model_validation(validation) - - def send_model_prediction(self, prediction: fedn.ModelPrediction) -> bool: - """Send the model prediction.""" - return self.grpc_handler.send_model_prediction(prediction) - - def init_remote_compute_package(self, url: str, token: str, package_checksum: Optional[str] = None) -> bool: - """Initialize the remote compute package.""" - result = self.download_compute_package(url, token) - if not result: - logger.error("Could not download compute package") - return False - result = self.set_compute_package_checksum(url, token) - if not result: - logger.error("Could not set checksum") - return False - - if package_checksum: - result = self.validate_compute_package(package_checksum) - if not result: - logger.error("Could not validate compute package") - return False - - result, path = self.unpack_compute_package() - - if not result: - logger.error("Could not unpack compute package") - return False - - logger.info(f"Compute package unpacked to: {path}") - - result = self.set_dispatcher(path) - - if not result: - logger.error("Could not set dispatcher") - return False - - logger.info("Dispatcher set") - - result = self.get_or_set_environment() - - return True - - def init_local_compute_package(self) -> bool: - """Initialize the local compute package.""" - path = os.path.join(os.getcwd(), "client") - result = self.set_dispatcher(path) - - if not result: - logger.error("Could not set dispatcher") - return False - - result = self.get_or_set_environment() - - logger.info("Dispatcher set") - - return True diff --git a/fedn/network/clients/grpc_handler.py b/fedn/network/clients/grpc_handler.py index 250c81cd3..0a23be1a0 100644 --- a/fedn/network/clients/grpc_handler.py +++ b/fedn/network/clients/grpc_handler.py @@ -132,6 +132,8 @@ class GrpcHandler: def __init__(self, host: str, port: int, name: str, token: str, combiner_name: str) -> None: """Initialize the GrpcHandler.""" + os.environ["GRPC_ENABLE_FORK_SUPPORT"] = "false" # Actively disable fork support in GRPC + self.metadata = [ ("client", name), ("grpc-server", combiner_name), @@ -297,7 +299,7 @@ def send_telemetry(self, telemetry: fedn.TelemetryMessage) -> bool: return True @grpc_retry(max_retries=-1, retry_interval=5) - def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> Optional[BytesIO]: + def get_model_from_combiner(self, model_id: str, client_id: str, timeout: int = 20) -> Optional[BytesIO]: """Fetch a model from the assigned combiner. Downloads the model update object via a gRPC streaming channel. @@ -313,7 +315,7 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> """ data = BytesIO() time_start = time.time() - request = fedn.ModelRequest(id=id) + request = fedn.ModelRequest(id=model_id) request.sender.client_id = client_id request.sender.role = fedn.CLIENT @@ -323,6 +325,7 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> data.write(part.data) if part.status == fedn.ModelStatus.OK: + data.seek(0) return data if part.status == fedn.ModelStatus.FAILED: @@ -332,10 +335,11 @@ def get_model_from_combiner(self, id: str, client_id: str, timeout: int = 20) -> if time.time() - time_start >= timeout: return None continue + data.seek(0) return data @grpc_retry(max_retries=-1, retry_interval=5) - def send_model_to_combiner(self, model: BytesIO, id: str) -> Optional[BytesIO]: + def send_model_to_combiner(self, model: BytesIO, model_id: str) -> Optional[BytesIO]: """Send a model update to the assigned combiner. Uploads the model updated object via a gRPC streaming channel, Upload. @@ -358,7 +362,7 @@ def send_model_to_combiner(self, model: BytesIO, id: str) -> Optional[BytesIO]: bt.seek(0, 0) logger.info("Uploading model to combiner.") - result = self.modelStub.Upload(upload_request_generator(bt, id), metadata=self.metadata) + result = self.modelStub.Upload(upload_request_generator(bt, model_id), metadata=self.metadata) return result def create_update_message( @@ -517,3 +521,30 @@ def _reconnect(self) -> None: self._init_channel(self.host, self.port, self.token) self._init_stubs() logger.debug("GRPC channel reconnected.") + + +class GrpcConnectionOptions: + """Options for configuring the GRPC connection.""" + + def __init__(self, host: str, port: int, status: str = "", fqdn: str = "", package: str = "", ip: str = "", helper_type: str = "") -> None: + """Initialize GrpcConnectionOptions.""" + self.status = status + self.host = host + self.fqdn = fqdn + self.package = package + self.ip = ip + self.port = port + self.helper_type = helper_type + + @classmethod + def from_dict(cls, config: dict) -> "GrpcConnectionOptions": + """Create a GrpcConnectionOptions instance from a JSON string.""" + return cls( + status=config.get("status", ""), + host=config.get("host", ""), + fqdn=config.get("fqdn", ""), + package=config.get("package", ""), + ip=config.get("ip", ""), + port=config.get("port", 0), + helper_type=config.get("helper_type", ""), + ) diff --git a/fedn/network/clients/http_status_codes.py b/fedn/network/clients/http_status_codes.py new file mode 100644 index 000000000..88fb30ea0 --- /dev/null +++ b/fedn/network/clients/http_status_codes.py @@ -0,0 +1,7 @@ +# Constants for HTTP status codes +HTTP_STATUS_OK = 200 +HTTP_STATUS_NO_CONTENT = 204 +HTTP_STATUS_BAD_REQUEST = 400 +HTTP_STATUS_UNAUTHORIZED = 401 +HTTP_STATUS_NOT_FOUND = 404 +HTTP_STATUS_PACKAGE_MISSING = 203 diff --git a/fedn/network/clients/importer_client.py b/fedn/network/clients/importer_client.py new file mode 100644 index 000000000..ae4649be8 --- /dev/null +++ b/fedn/network/clients/importer_client.py @@ -0,0 +1,179 @@ +"""Client module for handling client operations in the FEDn network.""" + +import os +import sys +import time +import uuid +from pathlib import Path +from typing import Dict, Optional, Tuple + +from fedn.common.config import FEDN_CUSTOM_URL_PREFIX +from fedn.common.log_config import logger +from fedn.network.clients.fedn_client import ConnectToApiResult, FednClient, GrpcConnectionOptions +from fedn.network.clients.importer_package_runtime import ImporterPackageRuntime, get_compute_package_dir_path +from fedn.utils.process import _IS_UNIX, _join_commands + + +def get_url(api_url: str, api_port: int) -> str: + """Construct the URL for the API.""" + return f"{api_url}:{api_port}/{FEDN_CUSTOM_URL_PREFIX}" if api_port else f"{api_url}/{FEDN_CUSTOM_URL_PREFIX}" + + +class ClientOptions: + """Options for configuring the client.""" + + def __init__(self, name: str, package: str, preferred_combiner: Optional[str] = None, client_id: Optional[str] = None) -> None: + """Initialize ClientOptions with validation.""" + self._validate(name, package) + self.name = name + self.package = package + self.preferred_combiner = preferred_combiner + self.client_id = client_id if client_id else str(uuid.uuid4()) + + def _validate(self, name: str, package: str) -> None: + """Validate the name and package.""" + if not isinstance(name, str) or len(name) == 0: + raise ValueError("Name must be a string") + if not isinstance(package, str) or len(package) == 0 or package not in ["local", "remote"]: + raise ValueError("Package must be either 'local' or 'remote'") + + def to_json(self) -> Dict[str, Optional[str]]: + """Convert ClientOptions to JSON.""" + return { + "name": self.name, + "client_id": self.client_id, + "preferred_combiner": self.preferred_combiner, + "package": self.package, + } + + +class ImporterClient: + """Client for interacting with the FEDn network.""" + + def __init__( + self, + api_url: str, + api_port: int, + client_obj: ClientOptions, + combiner_host: Optional[str] = None, + combiner_port: Optional[int] = None, + token: Optional[str] = None, + package_checksum: Optional[str] = None, + helper_type: Optional[str] = None, + startup_path: Optional[str] = None, + manual_env: Optional[bool] = True, + ) -> None: + """Initialize the Client.""" + self.api_url = api_url + self.api_port = api_port + self.combiner_host = combiner_host + self.combiner_port = combiner_port + self.token = token + self.client_obj = client_obj + self.package_checksum = package_checksum + self.helper_type = helper_type + + package_path, archive_path = get_compute_package_dir_path() + self.package_runtime = ImporterPackageRuntime(package_path, archive_path) + + self.manual_env = manual_env + + self.fedn_api_url = get_url(self.api_url, self.api_port) + self.fedn_client: FednClient = FednClient() + self.helper = None + self.startup_path = startup_path + + def _connect_to_api(self) -> Tuple[bool, Optional[dict]]: + """Connect to the API and handle retries.""" + result = None + response = None + + while not result or result == ConnectToApiResult.ComputePackageMissing: + if result == ConnectToApiResult.ComputePackageMissing: + logger.info("Retrying in 3 seconds") + time.sleep(3) + result, response = self.fedn_client.connect_to_api(self.fedn_api_url, self.token, self.client_obj.to_json()) + + if result == ConnectToApiResult.Assigned: + return True, response + + return False, None + + def start(self) -> None: + """Start the client.""" + if self.combiner_host and self.combiner_port: + combiner_config = GrpcConnectionOptions(host=self.combiner_host, port=self.combiner_port) + else: + result, combiner_config = self._connect_to_api() + if not result: + return + if self.client_obj.package == "remote": + result = self.package_runtime.load_remote_compute_package(url=self.fedn_api_url, token=self.token) + if not result: + return + else: + result = self.package_runtime.load_local_compute_package(os.path.join(os.getcwd(), "client")) + if not result: + return + + result = self.fedn_client.init_grpchandler(config=combiner_config, client_name=self.client_obj.client_id, token=self.token) + if not result: + return + + self.fedn_client.set_name(self.client_obj.name) + self.fedn_client.set_client_id(self.client_obj.client_id) + + if self.manual_env: + if "python_env" not in self.package_runtime.config: + logger.error("Python environment is not specified in the package configuration. Running without managed environment.") + + logger.info("Initializing managed environment.") + self.package_runtime.init_env_runtime() + + if not self.verify_active_environment(): + self.__restart_client_with_env() + else: + logger.info("Managed environment is active and verified.") + + self.package_runtime.run_startup(self.fedn_client) + + self.fedn_client.run() + + def verify_active_environment(self) -> None: + """Verify the Python environment.""" + if self.package_runtime is None or self.package_runtime.python_env is None: + logger.error("Package runtime or Python environment is not initialized.") + return False + if self.manual_env: + venv_path = os.environ.get("VIRTUAL_ENV") + if not venv_path: + logger.warning("No virtual environment detected.") + return False + logger.info(f"Virtual environment detected at: {venv_path}") + if Path(venv_path) != Path(self.package_runtime.python_env.path): + logger.warning(f"Virtual environment path {venv_path} does not match the expected path {self.package_runtime.python_env.path}.") + return False + return True + else: + logger.info("Managed environment is disabled, skipping verification.") + return True + + def __restart_client_with_env(self) -> None: + """Restart the client.""" + # This method could be replace by letting a process manager handle the restart, i.e. a watchdog or supervisor. + # The watchdog would monitor the client process and restart it if it exits unexpectedly + # and start the client with the correct environment activated. + logger.info("Restarting client with managed environment.") + args = " ".join(sys.argv) + logger.info(f"Current command line arguments: {args}") + args_after_start = args.split("client start", 1)[1].strip() if "client start" in args else "" + + # TODO: Maybe we need to close open connections or clean up resources before restarting. + + activate_env_cmd = self.package_runtime.python_env.get_activate_cmd() + cmd = _join_commands(activate_env_cmd, "python -m fedn client start " + args_after_start) + logger.info(f"Restarting with cmd: {cmd}") + time.sleep(2) + entry_point = "/bin/bash" if _IS_UNIX else "C:\\Windows\\System32\\cmd.exe" + os.execv(entry_point, cmd) # noqa: S606 + # This line will never be reached, as os.execv replaces the current process with a new one. diff --git a/fedn/network/clients/importer_package_runtime.py b/fedn/network/clients/importer_package_runtime.py new file mode 100644 index 000000000..5a9934ab7 --- /dev/null +++ b/fedn/network/clients/importer_package_runtime.py @@ -0,0 +1,130 @@ +"""Contains the PackageRuntime class, used to download, validate, and unpack compute packages.""" + +import os +import sys +from pathlib import Path +from typing import Optional + +from fedn.common.config import FEDN_ARCHIVE_DIR, FEDN_PACKAGE_EXTRACT_DIR +from fedn.common.log_config import logger +from fedn.network.clients.package_runtime import PackageRuntime +from fedn.utils.environment import _PythonEnv + +# Default timeout for requests +REQUEST_TIMEOUT = 10 # seconds + + +def get_compute_package_dir_path() -> str: + """Get the directory path for the compute package.""" + full_package_path = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) + full_archive_path = os.path.join(os.getcwd(), FEDN_ARCHIVE_DIR) + + os.makedirs(full_package_path, exist_ok=True) + os.makedirs(full_archive_path, exist_ok=True) + + return full_package_path, full_archive_path + + +class ImporterPackageRuntime(PackageRuntime): + """ImporterPackageRuntime is used to download, validate, and unpack compute packages. + + :param package_path: Path to compute package. + :type package_path: str + """ + + def __init__(self, package_path: str, archive_path: str) -> None: + """Initialize the PackageRuntime.""" + super().__init__(package_path, archive_path) + self.python_env: Optional[_PythonEnv] = None + + def init_env_runtime(self): + if self.config is None: + logger.error("Package runtime is not loaded.") + return False + try: + python_env_yaml_path = self.config.get("python_env") + if python_env_yaml_path: + logger.info(f"Initializing Python environment from configuration: {python_env_yaml_path}") + python_env_yaml_path = Path(self._target_path).joinpath(python_env_yaml_path) + self.python_env = _PythonEnv.from_yaml(python_env_yaml_path) + self.python_env.remove_fedndependency() + self.python_env.set_base_path(self._target_path) + if not self.python_env.path.exists(): + self.python_env.create_virtualenv(capture_output=True, use_system_site_packages=True) + if not self.python_env.verify_installed_env(): + logger.error(f"Python environment at {self.python_env.path} is not valid.") + raise RuntimeError(f"Invalid Python environment at {self.python_env.path}.") + else: + logger.info("No Python environment specified in the configuration, using the system Python.") + self.python_env = None + except Exception as e: + logger.error(f"Error initializing Python environment from configuration: {e}") + self.python_env = None + + def run_startup(self, fedn_client): + """Run the client startup script.""" + if self.config is None: + logger.error("Package runtime is not initialized.") + return False + + original_sys_path = sys.path.copy() + try: + # Add the package path to sys.path + sys.path.insert(0, self._target_path) + entrypoint = self.config.get("entry_points") + if entrypoint: + startup_py = entrypoint.get("startup") + if not startup_py: + logger.info("No startup entrypoint defined in the configuration, using default 'startup.py'.") + startup_py = "startup.py" + + if not Path(self._target_path).joinpath(startup_py).exists(): + logger.error(f"Startup script {startup_py} not found in the package directory.") + raise FileNotFoundError(f"Startup script {startup_py} not found.") + + startup_module = Path(self._target_path).joinpath(startup_py).stem + logger.info(f"Running startup script from: {startup_module}") + startup = __import__(startup_module) + + startup.startup(fedn_client) + except Exception as e: + logger.error(f"Error during client startup: {e}") + return False + finally: + # Restore the original sys.path + sys.path = original_sys_path + + return True + + def run_build(self): + """Run the build script.""" + if self.config is None: + logger.error("Package runtime is not initialized.") + return False + + original_sys_path = sys.path.copy() + try: + sys.path.insert(0, self._target_path) + entrypoint = self.config.get("entry_points") + if entrypoint: + build_py = entrypoint.get("build") + if not build_py: + logger.info("No build entrypoint defined in the configuration, using default 'build.py'.") + build_py = "build.py" + + if not Path(self._target_path).joinpath(build_py).exists(): + logger.error(f"Build script {build_py} not found in the package directory.") + raise FileNotFoundError(f"Build script {build_py} not found.") + + build_module = Path(self._target_path).joinpath(build_py).stem + build = __import__(build_module) + + build.build() + except Exception as e: + logger.error(f"Error during build: {e}") + return False + finally: + # Restore the original sys.path + sys.path = original_sys_path + + return True diff --git a/fedn/network/clients/logging_context.py b/fedn/network/clients/logging_context.py new file mode 100644 index 000000000..48a8b3715 --- /dev/null +++ b/fedn/network/clients/logging_context.py @@ -0,0 +1,27 @@ +import json +from typing import Optional + +import fedn.network.grpc.fedn_pb2 as fedn + + +class LoggingContext: + """Context for keeping track of the session, model and round IDs during a dispatched call from a request.""" + + def __init__( + self, *, step: int = 0, model_id: str = None, round_id: str = None, session_id: str = None, request: Optional[fedn.TaskRequest] = None + ) -> None: + if request is not None: + if model_id is None: + model_id = request.model_id + if round_id is None: + if request.type == fedn.StatusType.MODEL_UPDATE: + config = json.loads(request.data) + round_id = config["round_id"] + if session_id is None: + session_id = request.session_id + + self.model_id = model_id + self.round_id = round_id + self.session_id = session_id + self.request = request + self.step = step diff --git a/fedn/network/clients/package_runtime.py b/fedn/network/clients/package_runtime.py index a8d8f96d8..86f546b3d 100644 --- a/fedn/network/clients/package_runtime.py +++ b/fedn/network/clients/package_runtime.py @@ -1,47 +1,46 @@ -"""Contains the PackageRuntime class, used to download, validate, and unpack compute packages.""" - import cgi import os import tarfile -from typing import Optional, Tuple +from typing import Optional import requests -from fedn.common.config import FEDN_AUTH_SCHEME, FEDN_CONNECT_API_SECURE +from fedn.common.config import FEDN_ARCHIVE_DIR, FEDN_AUTH_SCHEME, FEDN_CONNECT_API_SECURE, FEDN_PACKAGE_EXTRACT_DIR from fedn.common.log_config import logger +from fedn.network.clients.http_status_codes import HTTP_STATUS_NO_CONTENT, HTTP_STATUS_OK from fedn.utils.checksum import sha -from fedn.utils.dispatcher import Dispatcher, _read_yaml_file +from fedn.utils.yaml import read_yaml_file -# Constants for HTTP status codes -HTTP_STATUS_OK = 200 -HTTP_STATUS_NO_CONTENT = 204 +REQUEST_TIMEOUT = 10 # seconds # Default timeout for requests -# Default timeout for requests -REQUEST_TIMEOUT = 10 # seconds +def get_compute_package_dir_path() -> str: + """Get the directory path for the compute package.""" + full_package_path = os.path.join(os.getcwd(), FEDN_PACKAGE_EXTRACT_DIR) + full_archive_path = os.path.join(os.getcwd(), FEDN_ARCHIVE_DIR) -class PackageRuntime: - """PackageRuntime is used to download, validate, and unpack compute packages. + os.makedirs(full_package_path, exist_ok=True) + os.makedirs(full_archive_path, exist_ok=True) - :param package_path: Path to compute package. - :type package_path: str - """ + return full_package_path, full_archive_path - def __init__(self, package_path: str) -> None: - """Initialize the PackageRuntime.""" - self.dispatch_config = { - "entry_points": { - "predict": {"command": "python3 predict.py"}, - "train": {"command": "python3 train.py"}, - "validate": {"command": "python3 validate.py"}, - } - } +class PackageRuntime: + def __init__(self, package_path: str, archive_path: str) -> None: + """Initialize the PackageRuntime.""" self.pkg_path = package_path + self.tar_path = os.path.join(archive_path, "packages") + os.makedirs(self.tar_path, exist_ok=True) + self.pkg_name: Optional[str] = None - self.checksum: Optional[str] = None + self._checksum: Optional[str] = None + + self._target_path: Optional[str] = None + self._target_name = "fedn.yaml" - def download_compute_package(self, url: str, token: str, name: Optional[str] = None) -> bool: + self.config = None + + def _download_compute_package(self, url: str, token: str, name: Optional[str] = None) -> bool: """Download compute package from controller. :param url: URL of the controller. @@ -51,29 +50,31 @@ def download_compute_package(self, url: str, token: str, name: Optional[str] = N :rtype: bool """ try: - path = f"{url}/api/v1/packages/download?name={name}" if name else f"{url}/api/v1/packages/download" - with requests.get(path, - stream=True, - timeout=REQUEST_TIMEOUT, - headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}, - verify=FEDN_CONNECT_API_SECURE) as r: + url = f"{url}/api/v1/packages/download?name={name}" if name else f"{url}/api/v1/packages/download" + with requests.get( + url, stream=True, timeout=REQUEST_TIMEOUT, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}, verify=FEDN_CONNECT_API_SECURE + ) as r: if HTTP_STATUS_OK <= r.status_code < HTTP_STATUS_NO_CONTENT: params = cgi.parse_header(r.headers.get("Content-Disposition", ""))[-1] try: self.pkg_name = params["filename"] + r.raise_for_status() + with open(os.path.join(self.tar_path, self.pkg_name), "wb") as f: + for chunk in r.iter_content(chunk_size=8192): + f.write(chunk) + return True + except KeyError: logger.error("No package returned.") return False - r.raise_for_status() - with open(os.path.join(self.pkg_path, self.pkg_name), "wb") as f: - for chunk in r.iter_content(chunk_size=8192): - f.write(chunk) - - return True - except Exception: + else: + logger.error(f"Failed to download package: {r.status_code} {r.reason}") + return False + except Exception as e: + logger.error(f"Unknown error downloading package: {e}") return False - def set_checksum(self, url: str, token: str, name: Optional[str] = None) -> bool: + def _fetch_package_checksum(self, url: str, token: str) -> bool: """Get checksum of compute package from controller. :param url: URL of the controller. @@ -83,37 +84,42 @@ def set_checksum(self, url: str, token: str, name: Optional[str] = None) -> bool :rtype: bool """ try: - path = f"{url}/api/v1/packages/checksum?name={name}" if name else f"{url}/api/v1/packages/checksum" - with requests.get(path, - timeout=REQUEST_TIMEOUT, - headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}, - verify=FEDN_CONNECT_API_SECURE) as r: + path = f"{url}/api/v1/packages/checksum?name={self.pkg_name}" + with requests.get(path, timeout=REQUEST_TIMEOUT, headers={"Authorization": f"{FEDN_AUTH_SCHEME} {token}"}, verify=FEDN_CONNECT_API_SECURE) as r: if HTTP_STATUS_OK <= r.status_code < HTTP_STATUS_NO_CONTENT: data = r.json() try: - self.checksum = data["checksum"] + self._checksum = data["checksum"] except KeyError: logger.error("Could not extract checksum.") - return True except Exception: return False - def validate(self, expected_checksum: str) -> bool: + def validate_compute_package(self, url: str, token: str) -> bool: """Validate the package against the checksum provided by the controller. :param expected_checksum: Checksum provided by the controller. :return: True if checksums match, False otherwise. :rtype: bool """ - file_checksum = str(sha(os.path.join(self.pkg_path, self.pkg_name))) + try: + file_checksum = str(sha(os.path.join(self.tar_path, self.pkg_name))) + except FileNotFoundError: + logger.error(f"Package file {self.pkg_name} not found in {self.tar_path}.") + return False - if self.checksum == expected_checksum == file_checksum: - logger.info(f"Package validated {self.checksum}") + success = self._fetch_package_checksum(url, token) + if not success: + logger.error("Failed to fetch package checksum from controller.") + return False + + if self._checksum == file_checksum: + logger.info(f"Package validated {self._checksum}") return True return False - def unpack_compute_package(self) -> Tuple[bool, str]: + def _unpack_compute_package(self) -> Optional[str]: """Unpack the compute package. :return: Tuple containing a boolean indicating success and the path to the unpacked package. @@ -125,36 +131,93 @@ def unpack_compute_package(self) -> Tuple[bool, str]: try: if self.pkg_name.endswith(("tar.gz", ".tgz", "tar.bz2")): - with tarfile.open(os.path.join(self.pkg_path, self.pkg_name), "r:*") as f: + tar_path = os.path.join(self.tar_path, self.pkg_name) + with tarfile.open(tar_path, "r:*") as f: for member in f.getmembers(): f.extract(member, self.pkg_path) logger.info(f"Successfully extracted compute package content in {self.pkg_path}") - logger.info("Deleting temporary package tarball file.") - os.remove(os.path.join(self.pkg_path, self.pkg_name)) - - for root, _, files in os.walk(os.path.join(self.pkg_path, "")): - if "fedn.yaml" in files: - logger.info(f"Found fedn.yaml file in {root}") - return True, root - - logger.error("No fedn.yaml file found in extracted package!") - return False, "" + return self.find_target_path(self.pkg_path) + else: + return None except Exception as e: logger.error(f"Error extracting files: {e}") - os.remove(os.path.join(self.pkg_path, self.pkg_name)) - return False, "" + return None - def get_dispatcher(self, run_path: str) -> Optional[Dispatcher]: - """Dispatch the compute package. + def find_target_path(self, path) -> Optional[str]: + for root, _, files in os.walk(os.path.join(path, "")): + if self._target_name in files: + logger.info(f"Found {self._target_name} file in {root}") + return root + logger.error(f"No {self._target_name} file found in {path}!") + return None + + def load_local_compute_package(self, pkg_path) -> bool: + """Initialize the local compute package.""" + path = self.find_target_path(pkg_path) + if not path: + logger.error(f"Could not find {self._target_name} in the provided package path.") + return False - :param run_path: Path to dispatch the compute package. - :type run_path: str - :return: Dispatcher object or None if an error occurred. - :rtype: Optional[Dispatcher] - """ - try: - self.dispatch_config = _read_yaml_file(os.path.join(run_path, "fedn.yaml")) - return Dispatcher(self.dispatch_config, run_path) - except Exception as e: - logger.error(f"Error getting dispatcher: {e}") - return None + logger.info(f"Using compute package at: {path}") + self._target_path = path + if not self._load_fednyaml(): + logger.error("Failed to load fedn.yaml configuration file.") + self._target_path = None + return False + return True + + def load_remote_compute_package(self, url: str, token: str, pkg_name: Optional[str] = None, validate: bool = True) -> bool: + """Initialize the remote compute package.""" + do_download = True + if pkg_name and os.path.exists(os.path.join(self.tar_path, pkg_name)): + # Package already exists + logger.info(f"Compute package {pkg_name} already exists in {self.tar_path}.") + self.pkg_name = pkg_name + if validate: + result = self.validate_compute_package(url, token, self.pkg_name) + if not result: + logger.warning("Already downloaded compute package failed validation.") + else: + logger.info("Already downloaded compute package passed validation.") + do_download = False + else: + logger.info("Skipping validation of already downloaded compute package.") + do_download = False + + if do_download: + result = self._download_compute_package(url, token, pkg_name) + if not result: + logger.error("Could not download compute package") + return False + + if validate: + result = self.validate_compute_package(url, token) + if not result: + logger.error("Could not validate compute package") + return False + + path = self._unpack_compute_package() + + if not path: + logger.error("Could not unpack compute package") + return False + + logger.info(f"Compute package unpacked to: {path}") + self._target_path = path + if not self._load_fednyaml(): + logger.error("Failed to load fedn.yaml configuration file.") + self._target_path = None + return False + return True + + def _load_fednyaml(self): + """Load the target configuration file.""" + logger.info(f"Reading {self._target_name} configuration file.") + self.config = read_yaml_file(os.path.join(self._target_path, self._target_name)) + if not self.config: + logger.error(f"Configuration file {os.path.join(self._target_path, self._target_name)} not found or is empty.") + return False + return True + + def run_startup(self, *args, **kwargs): + raise NotImplementedError("The start method should be implemented in subclasses.") diff --git a/fedn/network/clients/task_receiver.py b/fedn/network/clients/task_receiver.py new file mode 100644 index 000000000..634898413 --- /dev/null +++ b/fedn/network/clients/task_receiver.py @@ -0,0 +1,160 @@ +import threading +import time +from typing import TYPE_CHECKING + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.common.log_config import logger + +if TYPE_CHECKING: + from fedn.network.clients.fedn_client import FednClient # not-floating-import + + +class StoppedException(Exception): + pass + + +class Task: + def __init__(self, request: fedn.TaskRequest): + self.request = request + self.runner_thread = None + self.lock = threading.Lock() + self.status = fedn.TaskStatus.TASK_PENDING + self.interrupted = False + self.interrupted_reason = None + self.response = None + self.correlation_id = request.correlation_id + self.done = False + + +class TaskReceiver: + def __init__(self, client: "FednClient", task_callback: callable, polling_interval: int = 5): + self.client = client + self.task_callback = task_callback + + self.polling_interval = polling_interval + + self.current_task: Task = None + + def start(self): + self._task_manager_thread = threading.Thread( + target=self.run_task_polling, + name="TaskReceiver", + daemon=True, + ) + self._task_manager_thread.start() + self._task_manager_stop_event = threading.Event() + + def check_abort(self): + """Check if the current task has been aborted. + + This function should be called periodically from the task callback to ensure + that the task can be interrupted if needed. + If called from another thread, this function is a no-op. + """ + if self.current_task is not None and self.current_task.runner_thread == threading.current_thread(): + with self.current_task.lock: + if self.current_task.interrupted: + raise StoppedException(self.current_task.interrupted_reason) + + def abort_current_task(self): + if self.current_task is not None: + with self.current_task.lock: + if not self.current_task.interrupted: + self.current_task.interrupted = True + self.current_task.interrupted_reason = "Aborted by client" + logger.info("TaskReceiver: Aborting current task... ") + + def run_task_polling(self): + while True: + try: + tic = time.time() + + report = fedn.ActivityReport() + + report.sender.client_id = self.client.client_id + report.sender.name = self.client.name + report.sender.role = fedn.Role.CLIENT + if self.current_task is None: + report.status = fedn.TaskStatus.TASK_NONE + else: + with self.current_task.lock: + report.status = self.current_task.status + if self.current_task.response: + report.response = self.current_task.response + report.correlation_id = self.current_task.correlation_id + report.done = self.current_task.done + if self.current_task.done: + self.current_task = None + + if report.status == fedn.TaskStatus.TASK_NONE: + logger.debug("TaskReceiver: Reporting: Polling for task") + else: + logger.debug("TaskReceiver: Reporting: Task status %s", fedn.TaskStatus.Name(report.status)) + + task_request: fedn.TaskRequest = self.client.grpc_handler.PollAndReport(report) + + if task_request.correlation_id: + if self.current_task is not None: + if self.current_task.correlation_id == task_request.correlation_id: + # Received update to current task + if task_request.task_status == fedn.TaskStatus.TASK_INTERRUPTED: + if not self.current_task.interrupted: + with self.current_task.lock: + self.current_task.interrupted = True + self.current_task.interrupted_reason = "Aborted by server" + logger.info("TaskReceiver: Received interrupt message for task %s.", self.current_task.correlation_id) + elif task_request.task_status == fedn.TaskStatus.TASK_TIMEOUT: + if not self.current_task.interrupted: + with self.current_task.lock: + self.current_task.interrupted = True + self.current_task.interrupted_reason = "Timeout by server" + logger.info("TaskReceiver: Received timeout message for task %s.", self.current_task.correlation_id) + else: + logger.warning( + "TaskReceiver: Received new task %s while processing task %s. Ignoring new task.", + task_request.correlation_id, + self.current_task.correlation_id, + ) + else: + # New task + logger.info("TaskReceiver: Got task %s", task_request.correlation_id) + self.current_task = Task(task_request) + + # Run the task in a separate thread + threading.Thread(target=self._run_task, args=(self.current_task,)).start() + + # Wait for next polling interval + toc = time.time() + if toc - tic < self.polling_interval: + time.sleep(self.polling_interval - (toc - tic)) + except Exception as e: + logger.error("TaskReceiver: Error in task polling: %s", e) + break + self._task_manager_stop_event.set() + + def _run_task(self, task: Task): + with task.lock: + task.runner_thread = threading.current_thread() + task.status = fedn.TaskStatus.TASK_RUNNING + try: + response = self.task_callback(task.request) + with task.lock: + task.response = response + task.status = fedn.TaskStatus.TASK_COMPLETED + except StoppedException as e: + with task.lock: + logger.info("TaskReceiver: Task interrupted: %s", e) + task.status = fedn.TaskStatus.TASK_INTERRUPTED + task.response = {"msg": str(e)} + except Exception as e: + with task.lock: + logger.error("TaskReceiver: Task failed: %s", e) + task.status = fedn.TaskStatus.TASK_FAILED + task.response = {"error": str(e)} + finally: + with task.lock: + task.done = True + + def wait_on_manager_thread(self): + if self._task_manager_thread is not None: + self._task_manager_stop_event.wait() diff --git a/fedn/network/combiner/aggregators/aggregatorbase.py b/fedn/network/combiner/aggregators/aggregatorbase.py index 6866c6260..6536e182b 100644 --- a/fedn/network/combiner/aggregators/aggregatorbase.py +++ b/fedn/network/combiner/aggregators/aggregatorbase.py @@ -20,9 +20,11 @@ def __init__(self, update_handler: UpdateHandler): self.update_handler = update_handler @abstractmethod - def combine_models(self, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180, delete_models=True, parameters=None): + def combine_models(self, session_id, nr_expected_models=None, nr_required_models=1, helper=None, timeout=180, delete_models=True, parameters=None): """Routine for combining model updates. Implemented in subclass. + :param session_id: The id of the session. + :type session_id: str :param nr_expected_models: Number of expected models. If None, wait for all models. :type nr_expected_models: int :param nr_required_models: Number of required models to combine. diff --git a/fedn/network/combiner/aggregators/fedavg.py b/fedn/network/combiner/aggregators/fedavg.py index 82a64b5fa..0ca8fdb1f 100644 --- a/fedn/network/combiner/aggregators/fedavg.py +++ b/fedn/network/combiner/aggregators/fedavg.py @@ -1,8 +1,10 @@ +import queue import time import traceback from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase +from fedn.utils.model import FednModel class Aggregator(AggregatorBase): @@ -19,7 +21,7 @@ def __init__(self, update_handler): self.name = "fedavg" - def combine_models(self, helper=None, delete_models=True, parameters=None): + def combine_models(self, session_id, helper=None, delete_models=True, parameters=None): """Aggregate all model updates in the queue by computing an incremental weighted average of model parameters. @@ -44,11 +46,13 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): logger.info("AGGREGATOR({}): Aggregating model updates... ".format(self.name)) - while not self.update_handler.model_updates.empty(): + while True: try: - logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) - model_update = self.update_handler.next_model_update() - + try: + model_update = self.update_handler.next_model_update(session_id) + logger.info("AGGREGATOR({}): Getting next model update from queue.".format(self.name)) + except queue.Empty: + break # Load model parameters and metadata logger.info("AGGREGATOR({}): Loading model metadata {}.".format(self.name, model_update.model_update_id)) @@ -78,6 +82,6 @@ def combine_models(self, helper=None, delete_models=True, parameters=None): logger.error(tb) data["nr_aggregated_models"] = nr_aggregated_models - + fedn_model = FednModel.from_model_params(model) logger.info("AGGREGATOR({}): Aggregation completed, aggregated {} models.".format(self.name, nr_aggregated_models)) - return model, data + return fedn_model, data diff --git a/fedn/network/combiner/aggregators/fedopt.py b/fedn/network/combiner/aggregators/fedopt.py index aba0a552a..4a3ff074d 100644 --- a/fedn/network/combiner/aggregators/fedopt.py +++ b/fedn/network/combiner/aggregators/fedopt.py @@ -1,4 +1,5 @@ import math +import queue import time import traceback from typing import Any, Dict, Optional, Tuple @@ -7,6 +8,7 @@ from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import AggregatorBase from fedn.utils.helpers.helperbase import HelperBase +from fedn.utils.model import FednModel from fedn.utils.parameters import Parameters @@ -38,7 +40,7 @@ def __init__(self, update_handler): self.m = None def combine_models( - self, helper: Optional[HelperBase] = None, delete_models: bool = True, parameters: Optional[Parameters] = None + self, session_id, helper: Optional[HelperBase] = None, delete_models: bool = True, parameters: Optional[Parameters] = None ) -> Tuple[Optional[Any], Dict[str, float]]: """Compute pseudo gradients using model updates in the queue. @@ -71,10 +73,13 @@ def combine_models( pseudo_gradient, model_old = None, None nr_aggregated_models, total_examples = 0, 0 - while not self.update_handler.model_updates.empty(): + while True: try: - logger.info(f"Aggregator {self.name}: Fetching next model update.") - model_update = self.update_handler.next_model_update() + try: + model_update = self.update_handler.next_model_update(session_id) + logger.info(f"Aggregator {self.name}: Fetching next model update.") + except queue.Empty: + break tic = time.time() model_next, metadata = self.update_handler.load_model_update(model_update, helper) @@ -87,7 +92,7 @@ def combine_models( tic = time.time() if nr_aggregated_models == 0: - model_old = self.update_handler.load_model(helper, model_update.model_id) + model_old = self.update_handler.load_model_params(helper, model_update.model_id) pseudo_gradient = helper.subtract(model_next, model_old) else: pseudo_gradient_next = helper.subtract(model_next, model_old) @@ -118,7 +123,10 @@ def combine_models( return None, data logger.info(f"Aggregator {self.name} completed. Aggregated {nr_aggregated_models} models.") - return model, data + if model is None: + return None, data + fedn_model = FednModel.from_model_params(model, helper) + return fedn_model, data def _validate_and_merge_parameters(self, parameters: Optional[Parameters], default_parameters: Dict[str, Any]) -> Dict[str, Any]: """Validate and merge default parameters.""" diff --git a/fedn/network/combiner/aggregators/splitlearningagg.py b/fedn/network/combiner/aggregators/splitlearningagg.py index 3e2d67ce9..bce400218 100644 --- a/fedn/network/combiner/aggregators/splitlearningagg.py +++ b/fedn/network/combiner/aggregators/splitlearningagg.py @@ -1,4 +1,5 @@ import os +import queue import traceback import torch @@ -53,7 +54,7 @@ def __init__(self, update_handler): self.model = None self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - def combine_models(self, helper=None, delete_models=True, is_sl_inference=False): + def combine_models(self, session_id, helper=None, delete_models=True, is_sl_inference=False): """Concatenates client embeddings in the queue by aggregating them. After all embeddings are received, the embeddings need to be sorted @@ -77,10 +78,13 @@ def combine_models(self, helper=None, delete_models=True, is_sl_inference=False) logger.info("AGGREGATOR({}): Aggregating client embeddings... ".format(self.name)) - while not self.update_handler.model_updates.empty(): + while True: try: - logger.info("AGGREGATOR({}): Getting next embedding from queue.".format(self.name)) - new_embedding = self.update_handler.next_model_update() # returns in format {client_id: embedding} + try: + new_embedding = self.update_handler.next_model_update(session_id) # returns in format {client_id: embedding} + logger.info("AGGREGATOR({}): Getting next embedding from queue.".format(self.name)) + except queue.Empty: + break # Load model parameters and metadata logger.info("AGGREGATOR({}): Loading embedding metadata.".format(self.name)) diff --git a/fedn/network/combiner/aggregators/tests/test_fedavg.py b/fedn/network/combiner/aggregators/tests/test_fedavg.py index 583a7f51a..1c1863b28 100644 --- a/fedn/network/combiner/aggregators/tests/test_fedavg.py +++ b/fedn/network/combiner/aggregators/tests/test_fedavg.py @@ -26,7 +26,7 @@ def test_fedavg_combine_models(self, *args, **kwargs): data['time_model_aggregation'] = 0.0 data['nr_aggregated_models'] = 0 - self.assertEqual(aggregator.combine_models(), (None, data)) + self.assertEqual(aggregator.combine_models(""), (None, data)) if __name__ == '__main__': diff --git a/fedn/network/combiner/clientmanager.py b/fedn/network/combiner/clientmanager.py new file mode 100644 index 000000000..871e7d42e --- /dev/null +++ b/fedn/network/combiner/clientmanager.py @@ -0,0 +1,170 @@ +import queue +from datetime import datetime +from typing import TYPE_CHECKING + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.common.log_config import logger + +# This if is needed to avoid circular imports but is crucial for type hints. +if TYPE_CHECKING: + from fedn.network.combiner.combiner import Combiner # not-floating-import + + +class ClientManager: + def __init__(self, combiner: "Combiner"): + self.combiner = combiner + + self.client_interfaces: "dict[str,ClientInterface]" = {} + + def _init_client(self, client_id: str): + if client_id not in self.client_interfaces: + self.client_interfaces[client_id] = ClientInterface(client_id) + + def update_client(self, client_id: str): + self._init_client(client_id) + client = self.client_interfaces[client_id] + client.last_seen = datetime.now() + + def get_clients(self) -> list["ClientInterface"]: + return list(self.client_interfaces.values()) + + def get_client(self, client_id: str) -> "ClientInterface": + self._init_client(client_id) + return self.client_interfaces[client_id] + + def add_tasks(self, requests: list[fedn.TaskRequest]) -> list[str]: + updated_clients = set() + for request in requests: + try: + self._init_client(request.receiver.client_id) + self.client_interfaces[request.receiver.client_id].task_queue.put(request) + except Exception as e: + logger.error("ClientManager: add_tasks: Error adding task %s to client %s: %s", request.correlation_id, request.receiver.client_id, str(e)) + continue + updated_clients.add(request.receiver.client_id) + return list(updated_clients) + + def cancel_tasks(self, correlation_ids: list[str]): + if not correlation_ids: + return + for client in self.client_interfaces.values(): + if client.current_task and client.current_task.correlation_id in correlation_ids: + logger.debug("ClientManager: cancel_tasks: Cancelling task %s for client %s", client.current_task.correlation_id, client.client_id) + client.current_task.abort_requested = True + # Remove from task queue + new_queue = queue.Queue() + while not client.task_queue.empty(): + task = client.task_queue.get() + if task.correlation_id not in correlation_ids: + new_queue.put(task) + else: + logger.debug("ClientManager: cancel_tasks: Removing task %s from queue for client %s", task.correlation_id, client.client_id) + client.task_queue = new_queue + + def timeout_tasks(self, correlation_ids: list[str]): + if not correlation_ids: + return + for client in self.client_interfaces.values(): + if client.current_task and client.current_task.correlation_id in correlation_ids: + logger.debug("ClientManager: timeout_tasks: Timing out task %s for client %s", client.current_task.correlation_id, client.client_id) + client.current_task.timeout = True + # Remove from task queue + new_queue = queue.Queue() + while not client.task_queue.empty(): + task = client.task_queue.get() + if task.correlation_id not in correlation_ids: + new_queue.put(task) + else: + logger.debug("ClientManager: timeout_tasks: Removing task %s from queue for client %s", task.correlation_id, client.client_id) + client.task_queue = new_queue + + def PollAndReport(self, report: fedn.ActivityReport) -> fedn.TaskRequest: + if report.done: + self._task_finished(report) + elif report.correlation_id: + request = self._task_update(report) + return request + + if report.status == fedn.TaskStatus.TASK_NONE or report.done: + try: + return self._poll_task(report.sender.client_id) + except queue.Empty: + pass + return fedn.TaskRequest() + + def _poll_task(self, client_id: str) -> fedn.TaskRequest: + client = self.client_interfaces[client_id] + + request: fedn.TaskRequest = self.pop_task(client_id) + if request is not None: + if client.current_task: + logger.warning( + "ClientManager: _poll_task: Client %s already has a task %s assigned. Overwriting with new task %s", + client_id, + client.current_task.correlation_id, + request.correlation_id, + ) + client.current_task = ClientTask(client_id, request.correlation_id) + logger.debug("ClientManager: PollAndReport: Sending %s to %s", request.correlation_id, client_id) + else: + request = fedn.TaskRequest() + return request + + def pop_task(self, client_id: str) -> fedn.TaskRequest: + client = self.client_interfaces[client_id] + try: + request: fedn.TaskRequest = client.task_queue.get_nowait() + return request + except queue.Empty: + return None + + def _task_finished(self, report: fedn.ActivityReport): + if report.sender.client_id in self.client_interfaces: + client = self.client_interfaces[report.sender.client_id] + if client.current_task and client.current_task.correlation_id == report.correlation_id: + client.current_task = None + logger.debug(f"ClientManager: _task_finished: {report.sender.client_id} finished task {report.correlation_id}") + else: + logger.warning( + "ClientManager: _task_finished: Received finished report for unknown task %s from %s", report.correlation_id, report.sender.client_id + ) + else: + logger.warning("ClientManager: _task_finished: Received finished report from unknown client %s", report.sender.client_id) + + def _task_update(self, report: fedn.ActivityReport) -> fedn.TaskRequest: + request = fedn.TaskRequest() + request.correlation_id = report.correlation_id + if report.sender.client_id in self.client_interfaces: + client = self.client_interfaces[report.sender.client_id] + if client.current_task and client.current_task.correlation_id == report.correlation_id: + logger.debug("ClientManager: PollAndReport: %s processing task %s", report.sender.client_id, report.correlation_id) + client.current_task.status = report.status + if client.current_task.abort_requested: + request.task_status = fedn.TaskStatus.TASK_INTERRUPTED + if client.current_task.timeout: + request.task_status = fedn.TaskStatus.TASK_TIMEOUT + else: + logger.warning( + "ClientManager: _task_update: Received status update for unknown task %s from %s", report.correlation_id, report.sender.client_id + ) + else: + logger.warning("ClientManager: _task_update: Received status update from unknown client %s", report.sender.client_id) + return request + + +class ClientInterface: + def __init__(self, client_id: str): + self.client_id = client_id + self.status = "offline" + self.last_seen = datetime.now() + self.current_task: ClientTask = None + self.task_queue: "queue.Queue[fedn.TaskRequest]" = queue.Queue() + + +class ClientTask: + def __init__(self, client_id: str, correlation_id: str): + self.client_id = client_id + self.correlation_id = correlation_id + self.status = fedn.TaskStatus.TASK_PENDING + self.abort_requested = False + self.timeout = False diff --git a/fedn/network/combiner/combiner.py b/fedn/network/combiner/combiner.py index 4ba56bf0d..90b1f25bc 100644 --- a/fedn/network/combiner/combiner.py +++ b/fedn/network/combiner/combiner.py @@ -7,7 +7,7 @@ import uuid from datetime import datetime, timedelta from enum import Enum -from typing import TypedDict +from typing import List, Tuple, TypedDict from google.protobuf.json_format import MessageToDict @@ -15,6 +15,7 @@ import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.certificate.certificate import Certificate from fedn.common.log_config import logger, set_log_level_from_string, set_log_stream +from fedn.network.combiner.clientmanager import ClientInterface, ClientManager from fedn.network.combiner.modelservice import ModelService from fedn.network.combiner.roundhandler import RoundConfig, RoundHandler from fedn.network.grpc.server import Server, ServerConfig @@ -32,6 +33,8 @@ VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" +OFFLINE_CLIENT_TIMEOUT = 30 # seconds + class Role(Enum): """Enum for combiner roles.""" @@ -154,7 +157,8 @@ def __init__(self, config, repository: Repository, db: DatabaseConnection): grpc_server_config = ServerConfig(port=config["port"], secure=False) # Set up model service - modelservice = ModelService() + modelservice = ModelService(repository) + self.client_manager = ClientManager(self) # Create gRPC server self.server = Server(self, modelservice, grpc_server_config) @@ -171,25 +175,6 @@ def __init__(self, config, repository: Repository, db: DatabaseConnection): # Start the gRPC server self.server.start() - @classmethod - def create_instance(cls, config: CombinerConfig, repository: Repository, db: DatabaseConnection): - """Create a new singleton instance of the combiner. - - :param config: configuration for the combiner - :type config: dict - :return: the instance of the combiner - :rtype: :class:`fedn.network.combiner.server.Combiner` - """ - cls._instance = cls(config, repository, db) - return cls._instance - - @classmethod - def instance(cls): - """Get the singleton instance of the combiner.""" - if cls._instance is None: - raise Exception("Combiner instance not created yet.") - return cls._instance - def __whoami(self, client, instance): """Set the client id and role in a proto message. @@ -204,21 +189,16 @@ def __whoami(self, client, instance): client.role = role_to_proto_role(instance.role) return client - def request_model_update(self, session_id, model_id, config, clients=[]): - """Ask clients to update the current global model. - - :param config: the model configuration to send to clients - :type config: dict - :param clients: the clients to send the request to - :type clients: list + def send_requests(self, requests: List[fedn.TaskRequest]) -> List[str]: + """Send requests to clients. + :param requests: the requests to send + :type requests: list + :param queue_name: the name of the queue to send the requests to + :type queue_name: str """ - clients = self._send_request_type(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients) - - if len(clients) < 20: - logger.info("Sent model update request for model {} to clients {}".format(model_id, clients)) - else: - logger.info("Sent model update request for model {} to {} clients".format(model_id, len(clients))) + clients = self.client_manager.add_tasks(requests) + return clients def request_model_validation(self, session_id, model_id, clients=[]): """Ask clients to validate the current global model. @@ -231,7 +211,8 @@ def request_model_validation(self, session_id, model_id, clients=[]): :type clients: list """ - clients = self._send_request_type(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients) + requests = self.create_requests(fedn.StatusType.MODEL_VALIDATION, session_id, model_id, clients) + self.send_requests(requests) if len(clients) < 20: logger.info("Sent model validation request for model {} to clients {}".format(model_id, clients)) @@ -249,29 +230,14 @@ def request_model_prediction(self, prediction_id: str, model_id: str, clients: l :type clients: list """ - clients = self._send_request_type(fedn.StatusType.MODEL_PREDICTION, prediction_id, model_id, {}, clients) + requests = self.create_requests(fedn.StatusType.MODEL_PREDICTION, prediction_id, model_id, {}, clients) + self.send_requests(requests) if len(clients) < 20: logger.info("Sent model prediction request for model {} to clients {}".format(model_id, clients)) else: logger.info("Sent model prediction request for model {} to {} clients".format(model_id, len(clients))) - def request_forward_pass(self, session_id: str, model_id: str, config: dict, clients=[]) -> None: - """Ask clients to perform forward pass. - - :param config: the model configuration to send to clients - :type config: dict - :param clients: the clients to send the request to - :type clients: list - - """ - clients = self._send_request_type(fedn.StatusType.FORWARD, session_id, model_id, config, clients) - - if len(clients) < 20: - logger.info("Sent forward request to clients {}".format(clients)) - else: - logger.info("Sent forward request to {} clients".format(len(clients))) - def request_backward_pass(self, session_id: str, gradient_id: str, config: dict, clients=[]) -> None: """Ask clients to perform backward pass. @@ -280,19 +246,20 @@ def request_backward_pass(self, session_id: str, gradient_id: str, config: dict, :param clients: the clients to send the request to :type clients: list """ - clients = self._send_request_type(fedn.StatusType.BACKWARD, session_id, gradient_id, config, clients) + requests = self.create_requests(fedn.StatusType.BACKWARD, session_id, gradient_id, config, clients) + self.send_requests(requests) if len(clients) < 20: logger.info("Sent backward request for gradients {} to clients {}".format(gradient_id, clients)) else: logger.info("Sent backward request for gradients {} to {} clients".format(gradient_id, len(clients))) - def _send_request_type(self, request_type, session_id, model_id=None, config=None, clients=[]): - """Send a request of a specific type to clients. + def create_requests(self, request_type, session_id, model_id=None, config=None, clients=[]) -> List[fedn.TaskRequest]: + """Create requests of a specific type to clients. :param request_type: the type of request :type request_type: :class:`fedn.network.grpc.fedn_pb2.StatusType` - :param session_id: the session id to send in the request. Obs that for prediction, this is the prediction id. + :param session_id: the session id to send in the request. Obs that for prediction, this is the prediction id.q :type session_id: str :param model_id: the model id to send in the request :type model_id: str @@ -312,6 +279,7 @@ def _send_request_type(self, request_type, session_id, model_id=None, config=Non # TODO: add prediction clients type clients = self.get_active_validators() + requests: List[Tuple[str, fedn.TaskRequest]] = [] for client in clients: request = fedn.TaskRequest() request.model_id = model_id @@ -319,11 +287,12 @@ def _send_request_type(self, request_type, session_id, model_id=None, config=Non request.timestamp = str(datetime.now()) request.type = request_type request.session_id = session_id - request.sender.name = self.id request.sender.role = fedn.COMBINER request.receiver.client_id = client request.receiver.role = fedn.CLIENT + request.task_type = fedn.StatusType.Name(request_type) + request.task_status = fedn.TaskStatus.TASK_NEW # Set the request data, not used in validation if request_type == fedn.StatusType.MODEL_PREDICTION: @@ -332,8 +301,9 @@ def _send_request_type(self, request_type, session_id, model_id=None, config=Non request.data = json.dumps({"presigned_url": presigned_url}) elif request_type in [fedn.StatusType.MODEL_UPDATE, fedn.StatusType.FORWARD, fedn.StatusType.BACKWARD]: request.data = json.dumps(config) - self._put_request_to_client_queue(request, fedn.Queue.TASK_QUEUE) - return clients + request.round_id = config.get("round_id", None) + requests.append(request) + return requests def get_active_trainers(self): """Get a list of active trainers. @@ -341,7 +311,7 @@ def get_active_trainers(self): :return: the list of active trainers :rtype: list """ - trainers = self._list_active_clients(fedn.Queue.TASK_QUEUE) + trainers = self._list_active_clients() return trainers def get_active_validators(self): @@ -350,7 +320,7 @@ def get_active_validators(self): :return: the list of active validators :rtype: list """ - validators = self._list_active_clients(fedn.Queue.TASK_QUEUE) + validators = self._list_active_clients() return validators def nr_active_trainers(self): @@ -363,60 +333,7 @@ def nr_active_trainers(self): #################################################################################################################### - def __join_client(self, client): - """Add a client to the list of active clients. - - :param client: the client to add - :type client: :class:`fedn.network.grpc.fedn_pb2.Client` - """ - if client.client_id not in self.clients.keys(): - # The status is set to offline by default, and will be updated once _list_active_clients is called. - self.clients[client.client_id] = {"last_seen": datetime.now(), "status": "offline"} - - def _subscribe_client_to_queue(self, client, queue_name): - """Subscribe a client to the queue. - - :param client: the client to subscribe - :type client: :class:`fedn.network.grpc.fedn_pb2.Client` - :param queue_name: the name of the queue to subscribe to - :type queue_name: str - """ - self.__join_client(client) - if queue_name not in self.clients[client.client_id].keys(): - self.clients[client.client_id][queue_name] = queue.Queue() - - def __get_queue(self, client, queue_name): - """Get the queue for a client. - - :param client: the client to get the queue for - :type client: :class:`fedn.network.grpc.fedn_pb2.Client` - :param queue_name: the name of the queue to get - :type queue_name: str - :return: the queue - :rtype: :class:`queue.Queue` - - :raises KeyError: if the queue does not exist - """ - try: - return self.clients[client.client_id][queue_name] - except KeyError: - raise - - def _list_subscribed_clients(self, queue_name): - """List all clients subscribed to a queue. - - :param queue_name: the name of the queue - :type queue_name: str - :return: a list of client names - :rtype: list - """ - subscribed_clients = [] - for name, client in self.clients.items(): - if queue_name in client.keys(): - subscribed_clients.append(name) - return subscribed_clients - - def _list_active_clients(self, channel): + def _list_active_clients(self): """List all clients that have sent a status message in the last 10 seconds. :param channel: the name of the channel @@ -424,61 +341,34 @@ def _list_active_clients(self, channel): :return: a list of client names :rtype: list """ - # Temporary dict to store client status - clients = { - "active_clients": [], - "update_active_clients": [], - "update_offline_clients": [], - } - for client in self._list_subscribed_clients(channel): - status = self.clients[client]["status"] + clients_to_update: List[ClientInterface] = [] + active_clients: List[ClientInterface] = [] + for client in self.client_manager.get_clients(): now = datetime.now() - then = self.clients[client]["last_seen"] - if (now - then) < timedelta(seconds=10): - clients["active_clients"].append(client) + if (now - client.last_seen) < timedelta(seconds=OFFLINE_CLIENT_TIMEOUT): + active_clients.append(client) # If client has changed status, update client queue - if status != "online": - self.clients[client]["status"] = "online" - clients["update_active_clients"].append(client) - elif status != "offline": - self.clients[client]["status"] = "offline" - clients["update_offline_clients"].append(client) + if client.status != "online": + client.status = "online" + clients_to_update.append(client) + elif client.status != "offline": + client.status = "offline" + clients_to_update.append(client) # Update statestore with client status - if len(clients["update_active_clients"]) > 0: - for client in clients["update_active_clients"]: - client_to_update = self.db.client_store.get(client) - client_to_update.status = "online" - self.db.client_store.update(client_to_update) - if len(clients["update_offline_clients"]) > 0: - for client in clients["update_offline_clients"]: - client_to_update = self.db.client_store.get(client) - client_to_update.status = "offline" - self.db.client_store.update(client_to_update) + for client in clients_to_update: + client_to_update = self.db.client_store.get(client.client_id) + client_to_update.last_seen = client.last_seen + client_to_update.status = client.status + self.db.client_store.update(client_to_update) - return clients["active_clients"] + return [client.client_id for client in active_clients] def _deamon_thread_client_status(self, timeout=5): """Deamon thread that checks for inactive clients and updates statestore.""" while True: time.sleep(timeout) # TODO: Also update validation clients - self._list_active_clients(fedn.Queue.TASK_QUEUE) - - def _put_request_to_client_queue(self, request, queue_name): - """Get a client specific queue and add a request to it. - The client is identified by the request.receiver. - - :param request: the request to send - :type request: :class:`fedn.network.grpc.fedn_pb2.Request` - :param queue_name: the name of the queue to send the request to - :type queue_name: str - """ - try: - q = self.__get_queue(request.receiver, queue_name) - q.put(request) - except Exception as e: - logger.error("Failed to put request to client queue {} for client {}: {}".format(queue_name, request.receiver.name, str(e))) - raise + self._list_active_clients() def _send_status(self, status): """Report a status to backend db. @@ -490,17 +380,13 @@ def _send_status(self, status): status = StatusDTO().populate_with(data) self.db.status_store.add(status) - def _flush_model_update_queue(self): + def _flush_model_update_queue(self, session_id: str): """Clear the model update queue (aggregator). :return: True if successful, else False """ - q = self.round_handler.aggregator.model_updates try: - with q.mutex: - q.queue.clear() - q.all_tasks_done.notify_all() - q.unfinished_tasks = 0 + self.round_handler.update_handler.flush_session(session_id) return True except Exception as e: logger.error("Failed to flush model update queue: %s", str(e)) @@ -509,34 +395,46 @@ def _flush_model_update_queue(self): ##################################################################################################################### # Controller Service + def SendCommand(self, request: fedn.CommandRequest, context): + """Send a command to the combiner. - def Start(self, control: fedn.ControlRequest, context): - """Start a round of federated learning" - - :param control: the control request - :type control: :class:`fedn.network.grpc.fedn_pb2.ControlRequest` + :param request: the command request + :type request: :class:`fedn.network.grpc.fedn_pb2.CommandRequest` :param context: the context (unused) :type context: :class:`grpc._server._Context` - :return: the control response - :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` - """ - logger.info("grpc.Combiner.Start: Starting round") - - config = RoundConfig() - for parameter in control.parameter: - config.update({parameter.key: parameter.value}) - - logger.debug("grpc.Combiner.Start: Round config {}".format(config)) - - job_id = self.round_handler.push_round_config(config) - logger.info("grcp.Combiner.Start: Pushed round config (job_id): {}".format(job_id)) - - response = fedn.ControlResponse() - p = response.parameter.add() - p.key = "job_id" - p.value = job_id - - return response + :return: the response + :rtype: :class:`fedn.network.grpc.fedn_pb2.CommandResponse` + """ + if request.command == fedn.Command.START: + logger.info("grpc.Combiner.SendCommand: Starting round") + parameters = json.loads(request.parameters) if request.parameters else {} + + logger.info("grpc.Combiner.SendCommand: Received parameters: {}".format(parameters)) + + parameters["_job_id"] = request.correlation_id or str(uuid.uuid4()) + config = RoundConfig() + config.update(parameters) + + logger.info("grpc.Combiner.SendCommand: Round config {}".format(config)) + self.round_handler.push_round_config(config) + + response = fedn.ControlResponse() + p = response.parameter.add() + p.key = "job_id" + p.value = config["_job_id"] + return response + elif request.command == fedn.Command.STOP: + logger.info("grpc.Combiner.SendCommand: Stopping current round") + self.round_handler.flow_controller.stop_event.set() + response = fedn.ControlResponse() + response.message = "Success" + return response + elif request.command == fedn.Command.CONTINUE: + logger.info("grpc.Combiner.SendCommand: Continuing current round") + self.round_handler.flow_controller.continue_event.set() + response = fedn.ControlResponse() + response.message = "Success" + return response def SetAggregator(self, control: fedn.ControlRequest, context): """Set the active aggregator. @@ -592,7 +490,17 @@ def FlushAggregationQueue(self, control: fedn.ControlRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` """ logger.debug("grpc.Combiner.FlushAggregationQueue: Called") - status = self._flush_model_update_queue() + session_id = None + for parameter in control.parameter: + if parameter.key == "session_id": + session_id = parameter.value + if session_id is None: + logger.error("grpc.Combiner.FlushAggregationQueue: session_id not provided") + response = fedn.ControlResponse() + response.message = "Failed: session_id not provided" + return response + + status = self._flush_model_update_queue(session_id) response = fedn.ControlResponse() if status: @@ -650,7 +558,7 @@ def ListActiveClients(self, request: fedn.ListClientsRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientList` """ clients = fedn.ClientList() - active_clients = self._list_active_clients(request.channel) + active_clients = self._list_active_clients() nr_active_clients = len(active_clients) if nr_active_clients < 20: logger.info("grpc.Combiner.ListActiveClients: Active clients: {}".format(active_clients)) @@ -673,7 +581,7 @@ def AcceptingClients(self, request: fedn.ConnectionRequest, context): :rtype: :class:`fedn.network.grpc.fedn_pb2.ConnectionResponse` """ response = fedn.ConnectionResponse() - active_clients = self._list_active_clients(fedn.Queue.TASK_QUEUE) + active_clients = self._list_active_clients() try: requested = int(self.max_clients) @@ -705,8 +613,7 @@ def SendHeartbeat(self, heartbeat: fedn.Heartbeat, context): logger.debug("GRPC: Received heartbeat from {}".format(heartbeat.sender.name)) # Update the clients dict with the last seen timestamp. client = heartbeat.sender - self.__join_client(client) - self.clients[client.client_id]["last_seen"] = datetime.now() + self.client_manager.update_client(client.client_id) response = fedn.Response() response.sender.name = heartbeat.sender.name @@ -736,14 +643,12 @@ def TaskStream(self, response, context): self.__whoami(status.sender, self) - # Subscribe client, this also adds the client to self.clients - self._subscribe_client_to_queue(client, fedn.Queue.TASK_QUEUE) - q = self.__get_queue(client, fedn.Queue.TASK_QUEUE) - self._send_status(status) # Set client status to online - self.clients[client.client_id]["status"] = "online" + self.client_manager.update_client(client.client_id) + client = self.client_manager.get_client(client.client_id) + client.status = "online" try: # If the client is already in the client store, update the status client_to_update = self.db.client_store.get(client.client_id) @@ -763,15 +668,20 @@ def TaskStream(self, response, context): while context.is_active(): # Check if the context has been active for more than 10 seconds if time.time() - start_time > 10: - self.clients[client.client_id]["last_seen"] = datetime.now() + self.client_manager.update_client(client.client_id) # Reset the start time start_time = time.time() try: - yield q.get(timeout=1.0) + request = self.client_manager.pop_task(client.client_id) + if request is not None: + yield request + else: + time.sleep(1.0) except queue.Empty: pass except Exception as e: logger.error("Error in ModelUpdateRequestStream: {}".format(e)) + logger.warning("Client {} disconnected from TaskStream".format(client.name)) status = fedn.Status(status="Client {} disconnected from TaskStream.".format(client.name)) status.log_level = fedn.LogLevel.INFO @@ -780,6 +690,13 @@ def TaskStream(self, response, context): self.__whoami(status.sender, self) self._send_status(status) + def PollAndReport(self, report: fedn.ActivityReport, context): + # Subscribe client, this also adds the client to self.clients + client = report.sender + # Update last_seen + self.client_manager.update_client(client.client_id) + return self.client_manager.PollAndReport(report) + def SendModelUpdate(self, request, context): """Send a model update response. @@ -864,7 +781,7 @@ def SendBackwardCompletion(self, request, context): logger.info("Received BackwardCompletion from {}".format(request.sender.name)) # Add completion to the queue - self.round_handler.update_handler.backward_completions.put(request) + self.round_handler.backward_handler.backward_completions.put(request) # Create and send status message for backward completion status = fedn.Status() diff --git a/fedn/network/combiner/hooks/grpc_wrappers.py b/fedn/network/combiner/hooks/grpc_wrappers.py new file mode 100644 index 000000000..0aff6b371 --- /dev/null +++ b/fedn/network/combiner/hooks/grpc_wrappers.py @@ -0,0 +1,58 @@ +import time + +import grpc + +from fedn.common.log_config import logger + + +def safe_unary(func_name, default_resp_factory): + def decorator(fn): + def wrapper(self, request, context): + try: + return fn(self, request, context) + except Exception as e: + self._retire_and_log(func_name, e) + # Option A: return a valid default payload (keeps channel healthy) + return default_resp_factory() + + return wrapper + + return decorator + + +def safe_streaming(func_name): + def decorator(fn): + def wrapper(self, request, context): + try: + yield from fn(self, request, context) + except Exception as e: + self._retire_and_log(func_name, e) + # Option B for streaming: signal an RPC error the client understands + context.set_code(grpc.StatusCode.FAILED_PRECONDITION) + context.set_details(f"{func_name} failed; sender should use local fallback.") + return + + return wrapper + + return decorator + + +def call_with_fallback(name, fn, *, retries=2, base_sleep=0.25, fallback_fn=None): + for i in range(retries + 1): + try: + return fn() + except grpc.RpcError as e: + code = e.code() + if code in (grpc.StatusCode.FAILED_PRECONDITION, grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.DEADLINE_EXCEEDED): + logger.warning(f"{name} rpc failed with {code.name}: {e.details()}; attempt {i + 1}/{retries}") + if i < retries: + time.sleep(base_sleep * (2**i)) + continue + break + except Exception as e: + logger.exception(f"{name} unexpected error: {e}") + break + if fallback_fn: + logger.info(f"{name}: using local fallback") + return fallback_fn() + raise RuntimeError(f"{name} failed and no fallback provided") diff --git a/fedn/network/combiner/hooks/hook_client.py b/fedn/network/combiner/hooks/hook_client.py index 16fde9eb8..f040dc51c 100644 --- a/fedn/network/combiner/hooks/hook_client.py +++ b/fedn/network/combiner/hooks/hook_client.py @@ -1,15 +1,20 @@ import json import os +import queue import grpc import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import logger -from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model from fedn.network.combiner.updatehandler import UpdateHandler +from fedn.utils.model import FednModel CHUNK_SIZE = 1024 * 1024 +# for quick functions +TIMEOUT_SHORT = 120 +# for functions which might take longer such as aggregation +TIMEOUT_LONG = 600 class CombinerHookInterface: @@ -42,7 +47,7 @@ def provided_functions(self, server_functions: str): try: request = fedn.ProvidedFunctionsRequest(function_code=server_functions) - response = self.stub.HandleProvidedFunctions(request) + response = self.stub.HandleProvidedFunctions(request, timeout=TIMEOUT_SHORT) return response.available_functions except grpc.RpcError as rpc_error: if rpc_error.code() == grpc.StatusCode.UNAVAILABLE: @@ -54,7 +59,7 @@ def provided_functions(self, server_functions: str): logger.error(f"Unexpected error communicating with hooks container: {e}") return {} - def client_settings(self, global_model) -> dict: + def client_settings(self, global_model: FednModel) -> dict: """Communicates to hook container to get a client config. :param global_model: The global model that will be distributed to clients. @@ -62,10 +67,7 @@ def client_settings(self, global_model) -> dict: :return: config that will be distributed to clients. :rtype: dict """ - request_function = fedn.ClientConfigRequest - args = {} - model = model_as_bytesIO(global_model) - response = self.stub.HandleClientConfig(bytesIO_request_generator(mdl=model, request_function=request_function, args=args)) + response = self.stub.HandleClientConfig(global_model.get_filechunk_stream(), timeout=TIMEOUT_SHORT) return json.loads(response.client_settings) def client_selection(self, clients: list) -> list: @@ -73,7 +75,7 @@ def client_selection(self, clients: list) -> list: response = self.stub.HandleClientSelection(request) return json.loads(response.client_ids) - def aggregate(self, previous_global, update_handler: UpdateHandler, helper, delete_models: bool): + def aggregate(self, session_id, previous_global: FednModel, update_handler: UpdateHandler, helper, delete_models: bool): """Aggregation call to the hook functions. Sends models in chunks, then asks for aggregation. :param global_model: The global model that will be distributed to clients. @@ -85,25 +87,33 @@ def aggregate(self, previous_global, update_handler: UpdateHandler, helper, dele data["time_model_load"] = 0.0 data["time_model_aggregation"] = 0.0 # send previous global - request_function = fedn.StoreModelRequest - args = {"id": "global_model"} - response = self.stub.HandleStoreModel(bytesIO_request_generator(mdl=previous_global, request_function=request_function, args=args)) + + response = self.stub.HandleStoreModel( + previous_global.get_filechunk_stream(), + timeout=TIMEOUT_SHORT, + metadata=[("client-id", "global_model")], + ) logger.info(f"Store model response: {response.status}") # send client models and metadata nr_updates = 0 - while not update_handler.model_updates.empty(): - logger.info("Getting next model update from queue.") - update = update_handler.next_model_update() + while True: + try: + update = update_handler.next_model_update(session_id) + logger.info("Getting next model update from queue.") + except queue.Empty: + break metadata = json.loads(update.meta)["training_metadata"] - model = update_handler.load_model_update_bytesIO(update.model_update_id) + model = update_handler.get_model(update.model_update_id) # send metadata client_id = update.sender.client_id request = fedn.ClientMetaRequest(metadata=json.dumps(metadata), client_id=client_id) - response = self.stub.HandleMetadata(request) + response = self.stub.HandleMetadata(request, timeout=TIMEOUT_SHORT) # send client model - args = {"id": client_id} - request_function = fedn.StoreModelRequest - response = self.stub.HandleStoreModel(bytesIO_request_generator(mdl=model, request_function=request_function, args=args)) + response = self.stub.HandleStoreModel( + model.get_filechunk_stream(), + timeout=TIMEOUT_SHORT, + metadata=[("client-id", client_id)], + ) logger.info(f"Store model response: {response.status}") nr_updates += 1 if delete_models: @@ -111,7 +121,7 @@ def aggregate(self, previous_global, update_handler: UpdateHandler, helper, dele update_handler.delete_model(model_update=update) # ask for aggregation request = fedn.AggregationRequest(aggregate="aggregate") - response_generator = self.stub.HandleAggregation(request) + response_generator = self.stub.HandleAggregation(request, timeout=TIMEOUT_LONG) data["nr_aggregated_models"] = nr_updates - model, _ = unpack_model(response_generator, helper) - return model, data + fedn_model = FednModel.from_filechunk_stream(response_generator) + return fedn_model, data diff --git a/fedn/network/combiner/hooks/hooks.py b/fedn/network/combiner/hooks/hooks.py index 14e3c3568..164e28c58 100644 --- a/fedn/network/combiner/hooks/hooks.py +++ b/fedn/network/combiner/hooks/hooks.py @@ -1,5 +1,8 @@ import ast import json +import linecache +import linecache as _lc +import traceback from concurrent import futures import grpc @@ -11,8 +14,9 @@ # imports for user defined code from fedn.network.combiner.hooks.allowed_import import * # noqa: F403 from fedn.network.combiner.hooks.allowed_import import ServerFunctionsBase -from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO, unpack_model +from fedn.network.combiner.hooks.grpc_wrappers import safe_streaming, safe_unary from fedn.utils.helpers.plugins.numpyhelper import Helper +from fedn.utils.model import FednModel CHUNK_SIZE = 1024 * 1024 VALID_NAME_REGEX = "^[a-zA-Z0-9_-]*$" @@ -35,25 +39,25 @@ def __init__(self) -> None: self.implemented_functions = {} logger.info("Server Functions initialized.") - def HandleClientConfig(self, request_iterator: fedn.ClientConfigRequest, context): + @safe_unary("client_settings", lambda: fedn.ClientConfigResponse(client_settings=json.dumps({}))) + def HandleClientConfig(self, request_iterator: fedn.FileChunk, context): """Distribute client configs to clients from user defined code. - :param request_iterator: the client config request - :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ClientConfigRequest` + :param request_iterator: the global model + :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.FileChunk` :param context: the context (unused) :type context: :class:`grpc._server._Context` :return: the client config response :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientConfigResponse` """ - try: - logger.info("Received client config request.") - model, _ = unpack_model(request_iterator, self.helper) - client_settings = self.server_functions.client_settings(global_model=model) - logger.info(f"Client config response: {client_settings}") - return fedn.ClientConfigResponse(client_settings=json.dumps(client_settings)) - except Exception as e: - logger.error(f"Error handling client config request: {e}") + logger.info("Received client config request.") + fedn_model = FednModel.from_filechunk_stream(request_iterator) + model = fedn_model.get_model_params(self.helper) + client_settings = self.server_functions.client_settings(global_model=model) + logger.info(f"Client config response: {client_settings}") + return fedn.ClientConfigResponse(client_settings=json.dumps(client_settings)) + @safe_unary("client_selection", lambda: fedn.ClientSelectionResponse(client_ids=json.dumps([]))) def HandleClientSelection(self, request: fedn.ClientSelectionRequest, context): """Handle client selection from user defined code. @@ -64,15 +68,13 @@ def HandleClientSelection(self, request: fedn.ClientSelectionRequest, context): :return: the client selection response :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientSelectionResponse` """ - try: - logger.info("Received client selection request.") - client_ids = json.loads(request.client_ids) - client_ids = self.server_functions.client_selection(client_ids) - logger.info(f"Clients selected: {client_ids}") - return fedn.ClientSelectionResponse(client_ids=json.dumps(client_ids)) - except Exception as e: - logger.error(f"Error handling client selection request: {e}") + logger.info("Received client selection request.") + client_ids = json.loads(request.client_ids) + client_ids = self.server_functions.client_selection(client_ids) + logger.info(f"Clients selected: {client_ids}") + return fedn.ClientSelectionResponse(client_ids=json.dumps(client_ids)) + @safe_unary("store_metadata", lambda: fedn.ClientMetaResponse(status="ERROR")) def HandleMetadata(self, request: fedn.ClientMetaRequest, context): """Store client metadata from a request. @@ -83,32 +85,33 @@ def HandleMetadata(self, request: fedn.ClientMetaRequest, context): :return: the client meta response :rtype: :class:`fedn.network.grpc.fedn_pb2.ClientMetaResponse` """ - try: - logger.info("Received metadata") - client_id = request.client_id - metadata = json.loads(request.metadata) - # dictionary contains: [model, client_metadata] in that order for each key - self.client_updates[client_id] = self.client_updates.get(client_id, []) + [metadata] - self.check_incremental_aggregate(client_id) - return fedn.ClientMetaResponse(status="Metadata stored") - except Exception as e: - logger.error(f"Error handling store metadata request: {e}") + logger.info("Received metadata") + client_id = request.client_id + metadata = json.loads(request.metadata) + # dictionary contains: [model, client_metadata] in that order for each key + self.client_updates[client_id] = self.client_updates.get(client_id, []) + [metadata] + self.check_incremental_aggregate(client_id) + return fedn.ClientMetaResponse(status="Metadata stored") + @safe_unary("store_model", lambda: fedn.StoreModelResponse(status="ERROR")) def HandleStoreModel(self, request_iterator, context): - try: - model, final_request = unpack_model(request_iterator, self.helper) - client_id = final_request.id - if client_id == "global_model": - logger.info("Received previous global model") - self.previous_global = model - else: - logger.info(f"Received client model from client {client_id}") - # dictionary contains: [model, client_metadata] in that order for each key - self.client_updates[client_id] = [model] + self.client_updates.get(client_id, []) - self.check_incremental_aggregate(client_id) - return fedn.StoreModelResponse(status=f"Received model originating from {client_id}") - except Exception as e: - logger.error(f"Error handling store model request: {e}") + metadata = dict(context.invocation_metadata()) + client_id = metadata.get("client-id") + if client_id is None: + logger.error("No client-id provided in metadata.") + context.abort(grpc.StatusCode.INVALID_ARGUMENT, "No client-id provided in metadata.") + return fedn.StoreModelResponse(status="Error: No client-id provided in metadata.") + fedn_model = FednModel.from_filechunk_stream(request_iterator) + model = fedn_model.get_model_params(self.helper) + if client_id == "global_model": + logger.info("Received previous global model") + self.previous_global = model + else: + logger.info(f"Received client model from client {client_id}") + # dictionary contains: [model, client_metadata] in that order for each key + self.client_updates[client_id] = [model] + self.client_updates.get(client_id, []) + self.check_incremental_aggregate(client_id) + return fedn.StoreModelResponse(status=f"Received model originating from {client_id}") def check_incremental_aggregate(self, client_id): # incremental aggregation (memory secure) @@ -121,6 +124,7 @@ def check_incremental_aggregate(self, client_id): self.server_functions.incremental_aggregate(client_id, client_model, client_metadata, self.previous_global) del self.client_updates[client_id] + @safe_streaming("aggregate") def HandleAggregation(self, request, context): """Receive and store models and aggregate based on user-defined code when specified in the request. @@ -131,22 +135,18 @@ def HandleAggregation(self, request, context): :return: the aggregation response (aggregated model or None) :rtype: :class:`fedn.network.grpc.fedn_pb2.AggregationResponse` """ - try: - logger.info(f"Receieved aggregation request: {request.aggregate}") - if self.implemented_functions["incremental_aggregate"]: - aggregated_model = self.server_functions.get_incremental_aggregate_model() - else: - aggregated_model = self.server_functions.aggregate(self.previous_global, self.client_updates) - - model_bytesIO = model_as_bytesIO(aggregated_model, self.helper) - request_function = fedn.AggregationResponse - self.client_updates = {} - logger.info("Returning aggregate model.") - response_generator = bytesIO_request_generator(mdl=model_bytesIO, request_function=request_function, args={}) - for response in response_generator: - yield response - except Exception as e: - logger.error(f"Error handling aggregation request: {e}") + logger.info(f"Receieved aggregation request: {request.aggregate}") + if self.implemented_functions["incremental_aggregate"]: + aggregated_model = self.server_functions.get_incremental_aggregate_model() + else: + aggregated_model = self.server_functions.aggregate(self.previous_global, self.client_updates) + + fedn_model = FednModel.from_model_params(aggregated_model, self.helper) + self.client_updates = {} + logger.info("Returning aggregate model.") + response_generator = fedn_model.get_filechunk_stream() + for response in response_generator: + yield response def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, context): """Handles the 'provided_functions' request. Sends back which functions are available. @@ -158,18 +158,19 @@ def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, conte :return: dict with str -> bool for which functions are available :rtype: :class:`fedn.network.grpc.fedn_pb2.ProvidedFunctionsResponse` """ - try: - logger.info("Receieved provided functions request.") - server_functions_code = request.function_code - # if no new code return previous - if server_functions_code == self.server_functions_code: - logger.info("No new server function code provided.") - logger.info(f"Provided functions: {self.implemented_functions}") - return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) - - self.server_functions_code = server_functions_code - self.implemented_functions = {} - self._instansiate_server_functions_code() + logger.info("Receieved provided functions request.") + server_functions_code = request.function_code + # if no new code return previous + if server_functions_code == self.server_functions_code: + logger.info("No new server function code provided.") + logger.info(f"Provided functions: {self.implemented_functions}") + return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) + + self.server_functions_code = server_functions_code + self.implemented_functions = {} + self._instansiate_server_functions_code() + + if self.implemented_functions == {}: # not defaultet due to error functions = ["client_selection", "client_settings", "aggregate", "incremental_aggregate"] # parse the entire code string into an AST tree = ast.parse(server_functions_code) @@ -185,20 +186,56 @@ def HandleProvidedFunctions(self, request: fedn.ProvidedFunctionsResponse, conte else: print(f"Function '{func}' not found.") self.implemented_functions[func] = False - logger.info(f"Provided function: {self.implemented_functions}") - return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) - except Exception as e: - logger.error(f"Error handling provided functions request: {e}") + + logger.info(f"Provided function: {self.implemented_functions}") + return fedn.ProvidedFunctionsResponse(available_functions=self.implemented_functions) def _instansiate_server_functions_code(self): - # this will create a new user defined instance of the ServerFunctions class. try: namespace = {} - exec(self.server_functions_code, globals(), namespace) # noqa: S102 + # create a stable synthetic filename to appear in tracebacks + self._server_code_filename = f"server_functions:{hash(self.server_functions_code)}" + code_obj = compile(self.server_functions_code, self._server_code_filename, "exec") + + # prime linecache so traceback can show source lines + linecache.cache[self._server_code_filename] = ( + len(self.server_functions_code), + None, + [ln if ln.endswith("\n") else ln + "\n" for ln in self.server_functions_code.splitlines()], + self._server_code_filename, + ) + + exec(code_obj, globals(), namespace) # noqa: S102 exec("server_functions = ServerFunctions()", globals(), namespace) # noqa: S102 self.server_functions = namespace.get("server_functions") except Exception as e: - logger.error(f"Exec failed with error: {str(e)}") + logger.error(f"Exec failed: {e}") + self.server_functions = None + self.implemented_functions = dict.fromkeys(["client_selection", "client_settings", "aggregate", "incremental_aggregate"], False) + + def _retire_and_log(self, func_name: str, err: Exception): + # retire the function immediately + if func_name in self.implemented_functions: + self.implemented_functions[func_name] = False + + # try to find frames that originate from the compiled user code + tb = traceback.extract_tb(err.__traceback__) + user_frames = [] + filename = getattr(self, "_server_code_filename", None) + for frame in tb: + if filename and frame.filename == filename: + user_frames.append(frame) + + if user_frames: + # deepest frame in user code (where it actually failed) + f = user_frames[-1] + # fetch the source line from linecache (primed earlier) + + src_line = (_lc.getline(f.filename, f.lineno) or "").rstrip("\n") + logger.error(f"User function '{func_name}' crashed at {f.filename}:{f.lineno} in {f.name}()\n> {src_line}\nException: {repr(err)}") + else: + # fallback: full traceback (server + user frames) if we didn't match a user frame + logger.exception(f"{func_name} failed, retiring until next code update: {err}") def serve(): diff --git a/fedn/network/combiner/hooks/serverfunctionstest.py b/fedn/network/combiner/hooks/serverfunctionstest.py index 10ceb5690..ae3dcf9de 100644 --- a/fedn/network/combiner/hooks/serverfunctionstest.py +++ b/fedn/network/combiner/hooks/serverfunctionstest.py @@ -9,7 +9,7 @@ import fedn.network.grpc.fedn_pb2 as fedn from fedn.network.combiner.hooks.hooks import FunctionServiceServicer from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctionsBase -from fedn.network.combiner.modelservice import bytesIO_request_generator, model_as_bytesIO +from fedn.utils.model import FednModel def test_server_functions(server_functions: ServerFunctionsBase, parameters_np: List[np.ndarray], client_metadata: Dict, rounds, num_clients): @@ -32,27 +32,24 @@ def test_server_functions(server_functions: ServerFunctionsBase, parameters_np: response = function_service.HandleClientSelection(request, "") selected_clients = json.loads(response.client_ids) # see output from client config request - bytesio_model = model_as_bytesIO(parameters_np) - request_function = fedn.ClientConfigRequest - args = {} - gen = bytesIO_request_generator(mdl=bytesio_model, request_function=request_function, args=args) + fedn_model = FednModel.from_model_params(parameters_np) + gen = fedn_model.get_filechunk_stream() function_service.HandleClientConfig(gen, "") # see output from aggregate request - request_function = fedn.StoreModelRequest - args = {"id": "global_model"} - bytesio_model = model_as_bytesIO(parameters_np) - gen = bytesIO_request_generator(mdl=bytesio_model, request_function=request_function, args=args) - function_service.HandleStoreModel(gen, "") + fedn_model = FednModel.from_model_params(parameters_np) + gen = fedn_model.get_filechunk_stream() + context = object() + context.invocation_metadata = lambda: [("client-id", "global_model")] + function_service.HandleStoreModel(gen, context) for k in range(len(selected_clients)): # send metadata client_id = selected_clients[k] request = fedn.ClientMetaRequest(metadata=json.dumps(client_metadata), client_id=client_id) function_service.HandleMetadata(request, "") - request_function = fedn.StoreModelRequest - args = {"id": client_id} - bytesio_model = model_as_bytesIO(parameters_np) - gen = bytesIO_request_generator(mdl=bytesio_model, request_function=request_function, args=args) - function_service.HandleStoreModel(gen, "") + fedn_model = FednModel.from_model_params(parameters_np) + gen = fedn_model.get_filechunk_stream() + context.invocation_metadata = lambda client_id=client_id: [("client-id", client_id)] + function_service.HandleStoreModel(gen, context) request = fedn.AggregationRequest(aggregate="aggregate") response_generator = function_service.HandleAggregation(request, "") for response in response_generator: diff --git a/fedn/network/combiner/modelservice.py b/fedn/network/combiner/modelservice.py index 4388f28c0..da4fd5d9e 100644 --- a/fedn/network/combiner/modelservice.py +++ b/fedn/network/combiner/modelservice.py @@ -1,105 +1,37 @@ import os import tempfile +import threading from io import BytesIO +from typing import Generator -import numpy as np +import grpc import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import logger from fedn.network.storage.models.tempmodelstorage import TempModelStorage +from fedn.network.storage.s3.repository import Repository +from fedn.utils.model import FednModel CHUNK_SIZE = 1 * 1024 * 1024 -def upload_request_generator(mdl, id): - """Generator function for model upload requests. +def upload_request_generator(model_stream: BytesIO): + """Generator function for model upload requests for the client :param mdl: The model update object. :type mdl: BytesIO :return: A model update request. - :rtype: fedn.ModelRequest + :rtype: fedn.FileChunk """ while True: - b = mdl.read(CHUNK_SIZE) + b = model_stream.read(CHUNK_SIZE) if b: - result = fedn.ModelRequest(data=b, id=id, status=fedn.ModelStatus.IN_PROGRESS) + yield fedn.FileChunk(data=b) else: - result = fedn.ModelRequest(id=id, data=None, status=fedn.ModelStatus.OK) - yield result - if not b: break -def bytesIO_request_generator(mdl, request_function, args): - """Generator function for model upload requests. - - :param mdl: The model update object. - :type mdl: BytesIO - :param request_function: Function for sending requests. - :type request_function: Function - :param args: request arguments, excluding data argument. - :type args: dict - :return: Yields grpc request for streaming. - :rtype: grpc request generator. - """ - while True: - b = mdl.read(CHUNK_SIZE) - if b: - result = request_function(data=b, **args) - else: - result = request_function(data=None, **args) - yield result - if not b: - break - - -def model_as_bytesIO(model, helper=None): - if isinstance(model, list): - bt = BytesIO() - model_dict = {str(i): w for i, w in enumerate(model)} - np.savez_compressed(bt, **model_dict) - bt.seek(0) - return bt - if not isinstance(model, BytesIO): - bt = BytesIO() - - written_total = 0 - for d in model.stream(32 * 1024): - written = bt.write(d) - written_total += written - else: - bt = model - - bt.seek(0, 0) - return bt - - -def unpack_model(request_iterator, helper): - """Unpack an incoming model sent in chunks from a request iterator. - - :param request_iterator: A streaming iterator from an gRPC service. - :return: The reconstructed model parameters. - """ - model_buffer = BytesIO() - try: - for request in request_iterator: - if request.data: - model_buffer.write(request.data) - except MemoryError as e: - logger.error(f"Memory error occured when loading model, reach out to the FEDn team if you need a solution to this. {e}") - raise - except Exception as e: - logger.error(f"Exception occured during model loading: {e}") - raise - - model_buffer.seek(0) - - model_bytes = model_buffer.getvalue() - - return load_model_from_bytes(model_bytes, helper), request - - def get_tmp_path(): """Return a temporary output path compatible with save_model, load_model.""" fd, path = tempfile.mkstemp() @@ -107,51 +39,13 @@ def get_tmp_path(): return path -def load_model_from_bytes(model_bytes, helper): - """Load a model from a bytes object. - :param model_bytesio: A bytes object containing the model. - :type model_bytes: :class:`bytes` - :param helper: The helper object for the model. - :type helper: :class:`fedn.utils.helperbase.HelperBase` - :return: The model object. - :rtype: return type of helper.load - """ - path = get_tmp_path() - with open(path, "wb") as fh: - fh.write(model_bytes) - fh.flush() - model = helper.load(path) - os.unlink(path) - return model - - -def serialize_model_to_BytesIO(model, helper): - """Serialize a model to a BytesIO object. - - :param model: The model object. - :type model: return type of helper.load - :param helper: The helper object for the model. - :type helper: :class:`fedn.utils.helperbase.HelperBase` - :return: A BytesIO object containing the model. - :rtype: :class:`io.BytesIO` - """ - outfile_name = helper.save(model) - - a = BytesIO() - a.seek(0, 0) - with open(outfile_name, "rb") as f: - a.write(f.read()) - a.seek(0) - os.unlink(outfile_name) - return a - - class ModelService(rpc.ModelServiceServicer): """Service for handling download and upload of models to the server.""" - def __init__(self): + def __init__(self, repository: Repository): """Initialize the temporary model storage.""" self.temp_model_storage = TempModelStorage() + self.repository = repository def exist(self, model_id): """Check if a model exists on the server. @@ -161,66 +55,83 @@ def exist(self, model_id): """ return self.temp_model_storage.exist(model_id) - def get_model(self, id): - """Download model with id 'id' from server. + def get_model(self, model_id): + """Get a model from the server. - :param id: The model id. - :type id: str - :return: A BytesIO object containing the model. - :rtype: :class:`io.BytesIO`, None if model does not exist. + :param model_id: The model id. + :return: The model object. + :rtype: :class:`fedn.network.storage.models.tempmodelstorage.FednModel` """ - data = BytesIO() - data.seek(0, 0) + if not self.temp_model_storage.exist(model_id): + logger.error(f"ModelServicer: Model {model_id} does not exist.") + raise ValueError(f"Model {model_id} does not exist in temporary storage.") + model = self.temp_model_storage.get(model_id) + if model is None: + # This should only occur if the model was deleted between the exist and get calls + logger.error(f"ModelServicer: Model {model_id} could not be retrieved.") + raise ValueError(f"Model {model_id} could not be retrieved.") + return model + + def model_ready(self, model_id): + """Check if a model is ready on the server. - parts = self.Download(fedn.ModelRequest(id=id), self) - for part in parts: - if part.status == fedn.ModelStatus.IN_PROGRESS: - data.write(part.data) - - if part.status == fedn.ModelStatus.OK: - return data - if part.status == fedn.ModelStatus.FAILED: - return None + :param model_id: The model id. + :return: True if the model is ready, else False. + """ + return self.temp_model_storage.is_ready(model_id) - def set_model(self, model, id): - """Upload model to server. + def fetch_model_from_repository(self, model_id, blocking: bool = False): + """Fetch model from the repository and store it in the temporary model storage. - :param model: A model object (BytesIO) - :type model: :class:`io.BytesIO` - :param id: The model id. - :type id: str + :param model_id: The model id to fetch. + :type model_id: str """ - bt = model_as_bytesIO(model) - # TODO: Check result - _ = self.Upload(upload_request_generator(bt, id), self) + logger.info(f"Fetching model {model_id} from repository.") + try: + model = self.repository.get_model_stream(model_id) + if model: + logger.info(f"Model {model_id} fetched and stored successfully.") + if blocking: + return self.temp_model_storage.set_model_from_stream(model_id, model, auto_managed=True) + else: + threading.Thread(target=lambda: self.temp_model_storage.set_model_from_stream(model_id, model, auto_managed=True)).start() + return True + else: + logger.error(f"Model {model_id} not found in repository.") + return False + except Exception as e: + logger.error(f"Error fetching model {model_id} from repository: {e}") + return False # Model Service - def Upload(self, request_iterator, context): + def Upload(self, filechunk_iterator: Generator[fedn.FileChunk, None, None], context: grpc.ServicerContext): """RPC endpoints for uploading a model. - :param request_iterator: The model request iterator. - :type request_iterator: :class:`fedn.network.grpc.fedn_pb2.ModelRequest` - :param context: The context object (unused) + :param filechunk_iterator: The model request iterator. + :type filechunk_iterator: :class:`fedn.network.grpc.fedn_pb2.FileChunk` + :param context: The context object :type context: :class:`grpc._server._Context` :return: A model response object. :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` """ logger.debug("grpc.ModelService.Upload: Called") - result = None - for request in request_iterator: - if request.status == fedn.ModelStatus.IN_PROGRESS: - self.temp_model_storage.get_ptr(request.id).write(request.data) - self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.IN_PROGRESS) - if request.status == fedn.ModelStatus.OK and not request.data: - result = fedn.ModelResponse(id=request.id, status=fedn.ModelStatus.OK, message="Got model successfully.") - # self.temp_model_storage_metadata.update({request.id: fedn.ModelStatus.OK}) - self.temp_model_storage.set_model_metadata(request.id, fedn.ModelStatus.OK) - self.temp_model_storage.get_ptr(request.id).flush() - self.temp_model_storage.get_ptr(request.id).close() - return result + # Note: Do not use underscore "_" in metadata keys, use dash "-" instead. + metadata = dict(context.invocation_metadata()) + model_id = metadata.get("model-id") + checksum = metadata.get("checksum") + + if not model_id: + logger.error("ModelServicer: Model ID not provided.") + context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Model ID not provided.") - def Download(self, request, context): + result = self.temp_model_storage.set_model_with_filechunk_stream(model_id, filechunk_iterator, checksum) + if result: + return fedn.ModelResponse(status=fedn.ModelStatus.OK, message="Got model successfully.") + else: + return fedn.ModelResponse(status=fedn.ModelStatus.FAILED, message="Failed to upload model.") + + def Download(self, request: fedn.ModelRequest, context: grpc.ServicerContext): """RPC endpoints for downloading a model. :param request: The model request object. @@ -228,29 +139,38 @@ def Download(self, request, context): :param context: The context object (unused) :type context: :class:`grpc._server._Context` :return: A model response iterator. - :rtype: :class:`fedn.network.grpc.fedn_pb2.ModelResponse` + :rtype: :class:`fedn.network.grpc.fedn_pb2.FileChunk` """ - logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.client_id} requested model {request.id}") - try: - status = self.temp_model_storage.get_model_metadata(request.id) - if status != fedn.ModelStatus.OK: - logger.error(f"model file is not ready: {request.id}, status: {status}") - yield fedn.ModelResponse(id=request.id, data=None, status=status) - except Exception: - logger.error("Error file does not exist: {}".format(request.id)) - yield fedn.ModelResponse(id=request.id, data=None, status=fedn.ModelStatus.FAILED) + logger.info(f"grpc.ModelService.Download: {request.sender.role}:{request.sender.client_id} requested model {request.model_id}") + if not request.model_id: + logger.error("ModelServicer: Model ID not provided.") + context.abort(grpc.StatusCode.UNAVAILABLE, "Model ID not provided.") + + if not self.temp_model_storage.is_ready(request.model_id): + if self.temp_model_storage.exist(request.model_id): + logger.error(f"ModelServicer: Model file is not ready: {request.model_id}") + context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Model file is not ready.") + else: + logger.error(f"ModelServicer: Model file does not exist: {request.model_id}. Trying to start automatic caching") + file_is_downloading = self.fetch_model_from_repository(request.model_id) + if not file_is_downloading: + logger.error(f"ModelServicer: Model file does not exist: {request.model_id}.") + context.abort(grpc.StatusCode.UNAVAILABLE, "Model file does not exist. ") + else: + logger.info(f"ModelServicer: Caching started: {request.model_id}.") + context.abort(grpc.StatusCode.FAILED_PRECONDITION, "Model file is not ready. Starting automatic caching.") try: - obj = self.temp_model_storage.get(request.id) - if obj is None: - raise Exception(f"File not found: {request.id}") - with obj as f: - while True: - piece = f.read(CHUNK_SIZE) - if len(piece) == 0: - yield fedn.ModelResponse(id=request.id, data=None, status=fedn.ModelStatus.OK) - return - yield fedn.ModelResponse(id=request.id, data=piece, status=fedn.ModelStatus.IN_PROGRESS) + model: FednModel = self.temp_model_storage.get(request.model_id) + stream = model.get_stream() + while True: + chunk = stream.read(CHUNK_SIZE) + if chunk: + yield fedn.FileChunk(data=chunk) + else: + break except Exception as e: - logger.error("Downloading went wrong: {} {}".format(request.id, e)) - yield fedn.ModelResponse(id=request.id, data=None, status=fedn.ModelStatus.FAILED) + logger.error("Downloading went wrong: {} {}".format(request.model_id, e)) + context.abort(grpc.StatusCode.UNKNOWN, "Download failed.") + + context.set_trailing_metadata((("checksum", self.temp_model_storage.get_checksum(request.model_id)),)) diff --git a/fedn/network/combiner/roundhandler.py b/fedn/network/combiner/roundhandler.py index eddad6539..3d7ce6ea8 100644 --- a/fedn/network/combiner/roundhandler.py +++ b/fedn/network/combiner/roundhandler.py @@ -1,19 +1,22 @@ -import ast import inspect import queue import random import time -import uuid +import traceback from typing import TYPE_CHECKING, TypedDict +import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.log_config import logger from fedn.network.combiner.aggregators.aggregatorbase import get_aggregator +from fedn.network.combiner.hooks.grpc_wrappers import call_with_fallback from fedn.network.combiner.hooks.hook_client import CombinerHookInterface from fedn.network.combiner.hooks.serverfunctionsbase import ServerFunctions -from fedn.network.combiner.modelservice import ModelService, serialize_model_to_BytesIO -from fedn.network.combiner.updatehandler import UpdateHandler +from fedn.network.combiner.modelservice import ModelService +from fedn.network.combiner.updatehandler import BackwardHandler, UpdateHandler +from fedn.network.common.flow_controller import FlowController from fedn.network.storage.s3.repository import Repository from fedn.utils.helpers.helpers import get_helper +from fedn.utils.model import FednModel from fedn.utils.parameters import Parameters # This if is needed to avoid circular imports but is crucial for type hints. @@ -103,16 +106,19 @@ def __init__(self, server: "Combiner", repository: Repository, modelservice: Mod self.server = server self.modelservice = modelservice self.server_functions = inspect.getsource(ServerFunctions) - self.update_handler = UpdateHandler(modelservice=modelservice) + self.update_handler = UpdateHandler(modelservice=modelservice, client_manager=server.client_manager) + self.backward_handler = BackwardHandler() self.hook_interface = CombinerHookInterface() + self.flow_controller = FlowController() + def set_aggregator(self, aggregator): self.aggregator = get_aggregator(aggregator, self.update_handler) def set_server_functions(self, server_functions: str): self.server_functions = server_functions - def push_round_config(self, round_config: RoundConfig) -> str: + def push_round_config(self, round_config: RoundConfig): """Add a round_config (job description) to the inbox. :param round_config: A dict containing the round configuration (from global controller). @@ -121,12 +127,10 @@ def push_round_config(self, round_config: RoundConfig) -> str: :rtype: str """ try: - round_config["_job_id"] = str(uuid.uuid4()) self.round_configs.put(round_config) except Exception: logger.error("Failed to push round config.") raise - return round_config["_job_id"] def _training_round(self, config: dict, clients: list, provided_functions: dict): """Send model update requests to clients and aggregate results. @@ -147,14 +151,30 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) session_id = config["session_id"] model_id = config["model_id"] + round_id = config["round_id"] if provided_functions.get("client_settings", False): - global_model_bytes = self.modelservice.temp_model_storage.get(model_id) - client_settings = self.hook_interface.client_settings(global_model_bytes) - config["client_settings"] = client_settings + fedn_model = self.modelservice.get_model(model_id) + + def _rpc(): + return self.hook_interface.client_settings(fedn_model) + + def _fallback(): + return {} + + client_settings = call_with_fallback("client_settings", _rpc, fallback_fn=_fallback) or {} + config["client_settings"] = {**config.get("client_settings", {}), **client_settings} # Request model updates from all active clients. - self.server.request_model_update(session_id=session_id, model_id=model_id, config=config, clients=clients) + requests = self.server.create_requests(fedn.StatusType.MODEL_UPDATE, session_id, model_id, config, clients) + session_queue = self.update_handler.get_session_queue(session_id) + session_queue.start_round_queue(round_id, [r.correlation_id for r in requests], config["accept_stragglers"]) + clients_with_requests = self.server.send_requests(requests) + + if len(clients_with_requests) < 20: + logger.info("Sent model update request for model {} to clients {}".format(model_id, clients_with_requests)) + else: + logger.info("Sent model update request for model {} to {} clients".format(model_id, len(clients_with_requests))) # If buffer_size is -1 (default), the round terminates when/if all clients have completed if int(config["buffer_size"]) == -1: @@ -163,35 +183,53 @@ def _training_round(self, config: dict, clients: list, provided_functions: dict) buffer_size = int(config["buffer_size"]) # Wait / block until the round termination policy has been met. - self.update_handler.waitforit(config, buffer_size=buffer_size) + reason = self.flow_controller.wait_until_true(lambda: session_queue.aggregation_condition(buffer_size), timeout=float(config["round_timeout"])) + tic = time.time() - model = None + fedn_model = None data = None - try: - helper = get_helper(config["helper_type"]) - logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) - if config["delete_models_storage"] == "True": - delete_models = True - else: - delete_models = False - - if "aggregator_kwargs" in config.keys(): - dict_parameters = ast.literal_eval(config["aggregator_kwargs"]) - parameters = Parameters(dict_parameters) - else: - parameters = None - if provided_functions.get("aggregate", False) or provided_functions.get("incremental_aggregate", False): - previous_model_bytes = self.modelservice.temp_model_storage.get(model_id) - model, data = self.hook_interface.aggregate(previous_model_bytes, self.update_handler, helper, delete_models=delete_models) - else: - model, data = self.aggregator.combine_models(helper=helper, delete_models=delete_models, parameters=parameters) - except Exception as e: - logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) - raise + if reason != FlowController.Reason.STOP: + try: + helper = get_helper(config["helper_type"]) + logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + if config["delete_models_storage"] == "True": + delete_models = True + else: + delete_models = False + + if "aggregator_kwargs" in config.keys(): + logger.info("Using aggregator kwargs from config: {}".format(config["aggregator_kwargs"])) + dict_parameters = config["aggregator_kwargs"] + parameters = Parameters(dict_parameters) + else: + parameters = None + if provided_functions.get("aggregate", False) or provided_functions.get("incremental_aggregate", False): + previous_model = self.modelservice.get_model(model_id) + + def _rpc(): + return self.hook_interface.aggregate(session_id, previous_model, self.update_handler, helper, delete_models=delete_models) + + def _fallback(): + return self.aggregator.combine_models(session_id=session_id, helper=helper, delete_models=delete_models, parameters=parameters) + + fedn_model, data = call_with_fallback("aggregate", _rpc, fallback_fn=_fallback) + else: + fedn_model, data = self.aggregator.combine_models(session_id=session_id, helper=helper, delete_models=delete_models, parameters=parameters) + + if not config["accept_stragglers"]: + self.server.client_manager.timeout_tasks(session_queue.get_all_outstanding_correlation_ids()) + + except Exception as e: + logger.warning("AGGREGATION FAILED AT COMBINER! {}".format(e)) + fedn_model = None + data = None + else: + self.server.client_manager.cancel_tasks(session_queue.get_all_outstanding_correlation_ids()) + logger.warning("ROUNDHANDLER: Training round terminated early, no model aggregation performed.") meta["time_combination"] = time.time() - tic meta["aggregation_time"] = data - return model, meta + return fedn_model, meta def _validation_round(self, session_id, model_id, clients): """Send model validation requests to clients. @@ -236,32 +274,41 @@ def _forward_pass(self, config: dict, clients: list): session_id = config["session_id"] model_id = config["model_id"] + round_id = config["round_id"] is_sl_inference = config[ "is_sl_inference" ] # determines whether forward pass calculates gradients ("training"), or is used for inference (e.g., for validation) # Request forward pass from all active clients. - self.server.request_forward_pass(session_id=session_id, model_id=model_id, config=config, clients=clients) + requests = self.server.create_requests(fedn.StatusType.FORWARD, session_id, model_id, config, clients) + session_queue = self.update_handler.get_session_queue(session_id) + session_queue.start_round_queue(round_id, [r.correlation_id for r in requests], config["accept_stragglers"]) + clients_with_requests = self.server.send_requests(requests) + if len(clients_with_requests) < 20: + logger.info("Sent forward pass request for model {} to clients {}".format(model_id, clients_with_requests)) + else: + logger.info("Sent forward pass request for model {} to {} clients".format(model_id, len(clients_with_requests))) # the round should terminate when all clients have completed buffer_size = len(clients) # Wait / block until the round termination policy has been met. - self.update_handler.waitforit(config, buffer_size=buffer_size) + reason = self.flow_controller.wait_until_true(lambda: session_queue.aggregation_condition(buffer_size), timeout=float(config["round_timeout"])) tic = time.time() output = None - try: - helper = get_helper(config["helper_type"]) - logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) - if config["delete_models_storage"] == "True": - delete_models = True - else: - delete_models = False + if reason != FlowController.Reason.STOP: + try: + helper = get_helper(config["helper_type"]) + logger.info("Config delete_models_storage: {}".format(config["delete_models_storage"])) + if config["delete_models_storage"] == "True": + delete_models = True + else: + delete_models = False - output = self.aggregator.combine_models(helper=helper, delete_models=delete_models, is_sl_inference=is_sl_inference) + output = self.aggregator.combine_models(session_id=session_id, helper=helper, delete_models=delete_models, is_sl_inference=is_sl_inference) - except Exception as e: - logger.warning("EMBEDDING CONCATENATION in FORWARD PASS FAILED AT COMBINER! {}".format(e)) + except Exception as e: + logger.warning("EMBEDDING CONCATENATION in FORWARD PASS FAILED AT COMBINER! {}".format(e)) meta["time_combination"] = time.time() - tic meta["aggregation_time"] = output["data"] @@ -285,7 +332,7 @@ def _backward_pass(self, config: dict, clients: list): meta["timeout"] = float(config["round_timeout"]) # Clear previous backward completions queue - self.update_handler.clear_backward_completions() + self.backward_handler.clear_backward_completions() # Request backward pass from all active clients. logger.info("ROUNDHANDLER: Requesting backward pass, gradient_id: {}".format(config["model_id"])) @@ -295,7 +342,7 @@ def _backward_pass(self, config: dict, clients: list): # the round should terminate when all clients have completed buffer_size = len(clients) - self.update_handler.waitforbackwardcompletion(config, required_backward_completions=buffer_size) + self.backward_handler.waitforbackwardcompletion(config, required_backward_completions=buffer_size) return meta @@ -310,16 +357,18 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): :type retry: int, optional """ # If the model is already in memory at the server we do not need to do anything. + logger.info("Model Staging, fetching model from storage...") + if self.modelservice.temp_model_storage.exist(model_id): logger.info("Model already exists in memory, skipping model staging.") return - logger.info("Model Staging, fetching model from storage...") + # If not, download it and stage it in memory at the combiner. tries = 0 while True: try: - model = self.storage.get_model_stream(model_id) - if model: + success = self.modelservice.fetch_model_from_repository(model_id, blocking=True) + if success: break except Exception: logger.warning("Could not fetch model from storage backend, retrying.") @@ -329,8 +378,6 @@ def stage_model(self, model_id, timeout_retry=3, retry=2): logger.error("Failed to stage model {} from storage backend!".format(model_id)) raise - self.modelservice.set_model(model, model_id) - def _assign_round_clients(self, n: int, type: str = "trainers", selected_clients: list = None): """Obtain a list of clients(trainers or validators) to ask for updates in this round. @@ -418,32 +465,37 @@ def execute_training_round(self, config): self.stage_model(config["model_id"]) # dictionary to which functions are provided - provided_functions = self.hook_interface.provided_functions(self.server_functions) + try: + provided_functions = self.hook_interface.provided_functions(self.server_functions) + except Exception: + provided_functions = {"client_selection": False, "client_settings": False, "aggregate": False, "incremental_aggregate": False} if provided_functions.get("client_selection", False): - selected = 0 - while not selected: - clients = self.hook_interface.client_selection(clients=self.server.get_active_trainers()) - selected = len(clients) - if not selected: - logger.info("No clients selected based on custom client selection implementation. Trying again in 15 seconds.") - time.sleep(15) + def _rpc(): + return self.hook_interface.client_selection(clients=self.server.get_active_trainers()) + + def _fallback(): + selected_clients = config["selected_clients"] if "selected_clients" in config and len(config["selected_clients"]) > 0 else None + + return self._assign_round_clients(n=self.server.max_clients, type="trainers", selected_clients=selected_clients) + + clients = call_with_fallback("client_selection", _rpc, fallback_fn=_fallback) + if not clients: + # Empty selection => fallback immediately (don't spin forever) + clients = _fallback() else: selected_clients = config["selected_clients"] if "selected_clients" in config and len(config["selected_clients"]) > 0 else None clients = self._assign_round_clients(n=self.server.max_clients, type="trainers", selected_clients=selected_clients) - model, meta = self._training_round(config, clients, provided_functions) + fedn_model, meta = self._training_round(config, clients, provided_functions) data["data"] = meta - if model is None: + if fedn_model is None: logger.warning("\t Failed to update global model in round {0}!".format(config["round_id"])) - if model is not None: - helper = get_helper(config["helper_type"]) - a = serialize_model_to_BytesIO(model, helper) - model_id = self.storage.set_model(a.read(), is_file=False) - a.close() + if fedn_model is not None: + model_id = self.storage.set_model(fedn_model.get_stream(), is_file=False) data["model_id"] = model_id logger.info("TRAINING ROUND COMPLETED. Aggregated model id: {}, Job id: {}".format(model_id, config["_job_id"])) @@ -481,9 +533,8 @@ def execute_forward_pass(self, config): elif output["gradients"] is not None: gradients = output["gradients"] helper = get_helper(config["helper_type"]) - a = serialize_model_to_BytesIO(gradients, helper) - gradient_id = self.storage.set_model(a.read(), is_file=False) # uploads gradients to storage - a.close() + fedn_model = FednModel.from_model_params(gradients, helper=helper) + gradient_id = self.storage.set_model(fedn_model.get_stream_unsafe(), is_file=False) # uploads gradients to storage data["model_id"] = gradient_id # intended logger.info("FORWARD PASS COMPLETED. Aggregated model id: {}, Job id: {}".format(gradient_id, config["_job_id"])) @@ -530,7 +581,11 @@ def run(self, polling_interval=1.0): while True: try: round_config = self.round_configs.get(block=False) - + except queue.Empty: + time.sleep(polling_interval) + continue + try: + self.flow_controller.stop_event.clear() # Check that the minimum allowed number of clients are connected ready = self._check_nr_round_clients(round_config) round_meta = {} @@ -605,8 +660,14 @@ def run(self, polling_interval=1.0): logger.warning("{0}".format(round_meta["reason"])) self.round_configs.task_done() - except queue.Empty: - time.sleep(polling_interval) + except Exception as e: + tb = traceback.format_exc() + logger.error("Uncought exception: {}".format(e)) + logger.error("Traceback: {}".format(tb)) + round_meta = {} + round_meta["status"] = "Failed" + round_meta["reason"] = str(e) + self.round_configs.task_done() except (KeyboardInterrupt, SystemExit): pass diff --git a/fedn/network/combiner/updatehandler.py b/fedn/network/combiner/updatehandler.py index 493991018..ac02d8612 100644 --- a/fedn/network/combiner/updatehandler.py +++ b/fedn/network/combiner/updatehandler.py @@ -1,11 +1,15 @@ import json import queue -import sys +import threading import time import traceback +from typing import Dict, List +import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.log_config import logger -from fedn.network.combiner.modelservice import ModelService, load_model_from_bytes +from fedn.network.combiner.clientmanager import ClientManager +from fedn.network.combiner.modelservice import ModelService +from fedn.utils.model import FednModel class ModelUpdateError(Exception): @@ -21,29 +25,45 @@ class UpdateHandler: :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` """ - def __init__(self, modelservice: ModelService) -> None: - self.model_updates = queue.Queue() - self.backward_completions = queue.Queue() + def __init__(self, modelservice: ModelService, client_manager: ClientManager) -> None: self.modelservice = modelservice + self.client_manager = client_manager + + self.session_queue: Dict[str, SessionQueue] = {} + + def get_session_queue(self, session_id): + """Get the session queue for the given session ID. - self.model_id_to_model_data = {} + If the session queue does not exist, create a new one. + :param session_id: The session ID + :type session_id: str + :return: The group of model updates. + :rtype: SessionQueue + """ + if session_id not in self.session_queue: + logger.info("UPDATE HANDLER: Creating new update queue for session {}".format(session_id)) + self.session_queue[session_id] = SessionQueue(self, session_id=session_id) + return self.session_queue[session_id] - def delete_model(self, model_update): + def delete_model(self, model_update: fedn.ModelUpdate): self.modelservice.temp_model_storage.delete(model_update.model_update_id) logger.info("UPDATE HANDLER: Deleted model update {} from storage.".format(model_update.model_update_id)) - def next_model_update(self): + def next_model_update(self, session_id): """Get the next model update from the queue. - :param helper: A helper object. - :type helper: object + :param session_id: The session ID + :type session_id: str :return: The model update. :rtype: fedn.network.grpc.fedn.proto.ModelUpdate + :raises: queue.Empty """ - model_update = self.model_updates.get(block=False) - return model_update + if session_id in self.session_queue: + return self.session_queue[session_id].next_model_update() + else: + raise RuntimeError("No update queue set. Please create an update queue before calling this method.") - def on_model_update(self, model_update): + def on_model_update(self, model_update: fedn.ModelUpdate): """Callback when a new client model update is recieved. Performs (optional) validation and pre-processing, @@ -60,7 +80,10 @@ def on_model_update(self, model_update): valid_update = self._validate_model_update(model_update) if valid_update: # Push the model update to the processing queue - self.model_updates.put(model_update) + if model_update.session_id in self.session_queue: + self.session_queue[model_update.session_id].add_model_update(model_update) + else: + logger.warning("UPDATE HANDLER: No session queue found for session {}, skipping.".format(model_update.session_id)) else: logger.warning("UPDATE HANDLER: Invalid model update, skipping.") except Exception as e: @@ -69,7 +92,7 @@ def on_model_update(self, model_update): logger.error(tb) pass - def _validate_model_update(self, model_update): + def _validate_model_update(self, model_update: fedn.ModelUpdate): """Validate the model update. :param model_update: A ModelUpdate message. @@ -85,9 +108,14 @@ def _validate_model_update(self, model_update): logger.error("UPDATE HANDLER: Invalid model update, missing metadata.") logger.error(tb) return False + + if not self.modelservice.exist(model_update.model_update_id): + logger.error("UPDATE HANDLER: Model update {} not found.".format(model_update.model_update_id)) + return False + return True - def load_model_update(self, model_update, helper): + def load_model_update(self, model_update: fedn.ModelUpdate, helper): """Load the memory representation of the model update. Load the model update paramters and the @@ -101,34 +129,7 @@ def load_model_update(self, model_update, helper): :rtype: tuple """ model_id = model_update.model_update_id - model = self.load_model(helper, model_id) - # Get relevant metadata - metadata = json.loads(model_update.meta) - if "config" in metadata.keys(): - # Used in Python client - config = json.loads(metadata["config"]) - else: - # Used in C++ client - config = json.loads(model_update.config) - training_metadata = metadata["training_metadata"] - if "round_id" in config: - training_metadata["round_id"] = config["round_id"] - - return model, training_metadata - - def load_model_update_byte(self, model_update): - """Load the memory representation of the model update. - - Load the model update paramters and the - associate metadata into memory. - - :param model_update: The model update. - :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate - :return: A tuple of parameters(bytes), metadata - :rtype: tuple - """ - model_id = model_update.model_update_id - model = self.load_model_update_bytesIO(model_id).getbuffer() + model_params = self.load_model_params(helper, model_id) # Get relevant metadata metadata = json.loads(model_update.meta) if "config" in metadata.keys(): @@ -141,9 +142,9 @@ def load_model_update_byte(self, model_update): if "round_id" in config: training_metadata["round_id"] = config["round_id"] - return model, training_metadata + return model_params, training_metadata - def load_model(self, helper, model_id): + def load_model_params(self, helper, model_id): """Load model update with id model_id into its memory representation. :param helper: An instance of :class: `fedn.utils.helpers.helpers.HelperBase` @@ -151,10 +152,10 @@ def load_model(self, helper, model_id): :param model_id: The ID of the model update, UUID in str format :type model_id: str """ - model_bytesIO = self.load_model_update_bytesIO(model_id) - if model_bytesIO: + fedn_model = self.get_model(model_id) + if fedn_model: try: - model = load_model_from_bytes(model_bytesIO.getbuffer(), helper) + model = fedn_model.get_model_params(helper) except IOError: logger.warning("UPDATE HANDLER: Failed to load model!") else: @@ -162,8 +163,8 @@ def load_model(self, helper, model_id): return model - def load_model_update_bytesIO(self, model_id, retry=3): - """Load model update object and return it as BytesIO. + def get_model(self, model_id) -> FednModel: + """Load model update object and return it as FednModel. :param model_id: The ID of the model :type model_id: str @@ -172,45 +173,30 @@ def load_model_update_bytesIO(self, model_id, retry=3): :return: Updated model :rtype: class: `io.BytesIO` """ - # Try reading model update from local disk/combiner memory - model_str = self.modelservice.temp_model_storage.get(model_id) - # And if we cannot access that, try downloading from the server - if model_str is None: - model_str = self.modelservice.get_model(model_id) - # TODO: use retrying library - tries = 0 - while tries < retry: - tries += 1 - if not model_str or sys.getsizeof(model_str) == 80: - logger.warning("Model download failed. retrying") - time.sleep(3) # sleep longer - model_str = self.modelservice.get_model(model_id) - - return model_str - - def waitforit(self, config, buffer_size=100, polling_interval=0.1): - """Defines the policy for how long the server should wait before starting to aggregate models. - - The policy is as follows: - 1. Wait a maximum of time_window time until the round times out. - 2. Terminate if a preset number of model updates (buffer_size) are in the queue. + return self.modelservice.get_model(model_id) - :param config: The round config object - :type config: dict - :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 - :type buffer_size: int, optional - :param polling_interval: The polling interval, defaults to 0.1 - :type polling_interval: float, optional + def flush_session(self, session_id): + """Flush the session queue for the given session ID. + + :param session_id: The session ID + :type session_id: str """ - time_window = float(config["round_timeout"]) + if session_id in self.session_queue: + logger.info("UPDATE HANDLER: Flushing update queue for session {}".format(session_id)) + self.session_queue[session_id].flush_session() - tt = 0.0 - while tt < time_window: - if self.model_updates.qsize() >= buffer_size: - break - time.sleep(polling_interval) - tt += polling_interval +class BackwardHandler: + """Backward handler. + + Handles the backward completion messages during split learning backward passes. + + :param modelservice: A handle to the model service :class: `fedn.network.combiner.modelservice.ModelService` + :type modelservice: class: `fedn.network.combiner.modelservice.ModelService` + """ + + def __init__(self) -> None: + self.backward_completions = queue.Queue() def waitforbackwardcompletion(self, config, required_backward_completions=-1, polling_interval=0.1): """Wait for backward completion messages. @@ -235,3 +221,156 @@ def clear_backward_completions(self): self.backward_completions.get_nowait() except queue.Empty: break + + +class SessionQueue: + def __init__( + self, + update_handler: UpdateHandler, + session_id: str, + accept_stragglers: bool = False, + ): + self.session_id = session_id + self.round_id: str = None + self.update_handler = update_handler + + self.model_update: queue.Queue[fedn.ModelUpdate] = queue.Queue() + self.model_update_stragglers: queue.Queue[fedn.ModelUpdate] = queue.Queue() + + self.expected_correlation_ids = [] + self.straggler_correlation_ids: List[str] = [] + + self._accept_stragglers = accept_stragglers + + self.lock = threading.RLock() + + def add_model_update(self, model_update: fedn.ModelUpdate) -> bool: + if model_update.session_id != self.session_id: + # This indicates an error in the implementation + logger.error(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to wrong session id.") + self.handle_invalid_model_update(model_update) + return False + + with self.lock: + if model_update.correlation_id in self.expected_correlation_ids: + # Expected model update + self.expected_correlation_ids.remove(model_update.correlation_id) + self.model_update.put(model_update) + return True + elif model_update.correlation_id in self.straggler_correlation_ids: + # Straggler model update + self.straggler_correlation_ids.remove(model_update.correlation_id) + if self._accept_stragglers: + self.model_update_stragglers.put(model_update) + return True + else: + logger.warning(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to late arrival.") + self.handle_ignored_model_update(model_update) + else: + # Unknown model update + logger.error(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to invalid correlation id.") + self.handle_invalid_model_update(model_update) + return False + + def get_all_outstanding_correlation_ids(self) -> List[str]: + """Get all outstanding correlation IDs. + + :return: List of outstanding correlation IDs. + :rtype: List[str] + """ + with self.lock: + return self.expected_correlation_ids + self.straggler_correlation_ids + + def handle_invalid_model_update(self, model_update: fedn.ModelUpdate): + """Handle invalid model update. + + :param model_update: The model update. + :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate + """ + # TODO: Maybe want to properly track invalid model updates somehow + # For now, just delete them + self.update_handler.delete_model(model_update) + + def handle_ignored_model_update(self, model_update: fedn.ModelUpdate): + """Handle invalid model update. + + :param model_update: The model update. + :type model_update: fedn.network.grpc.fedn.proto.ModelUpdate + """ + # TODO: Maybe want to properly track ignored model updates somehow + # For now, just delete them + self.update_handler.delete_model(model_update) + + def start_round_queue(self, round_id, expected_correlation_ids: List[str], accept_stragglers: bool = False): + """Progress to the next round transfering stragglers to the next round.""" + with self.lock: + self.round_id = round_id + self._accept_stragglers = accept_stragglers + + # Transfer stragglers to the next round + self.straggler_correlation_ids.extend(self.expected_correlation_ids) + self.expected_correlation_ids = expected_correlation_ids + + # Transfer model updates to the next round + # Model updates might contain some stragglers that was sent after the round + # was finished, so we need to transfer them to the next round + while not self.model_update.empty(): + model_update = self.model_update.get() + if self._accept_stragglers: + self.model_update_stragglers.put(model_update) + else: + logger.warning(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to late arrival.") + self.handle_ignored_model_update(model_update) + + def finish_session(self): + """Finish the session""" + with self.lock: + self.expected_correlation_ids = [] + self.straggler_correlation_ids = [] + while not self.model_update.empty(): + model_update = self.model_update.get() + logger.warning(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to session end.") + self.handle_ignored_model_update(model_update) + while not self.model_update_stragglers.empty(): + model_update = self.model_update_stragglers.get() + logger.warning(f"UPDATE HANDLER: Model update {model_update.model_update_id} is ignored due to session end.") + self.handle_ignored_model_update(model_update) + + def flush_session(self): + """Flush the session queue.""" + with self.lock: + corr_ids = self.get_all_outstanding_correlation_ids() + self.update_handler.client_manager.cancel_tasks(corr_ids) + + self.expected_correlation_ids = [] + self.straggler_correlation_ids = [] + while not self.model_update.empty(): + model_update = self.model_update.get() + self.handle_ignored_model_update(model_update) + while not self.model_update_stragglers.empty(): + model_update = self.model_update_stragglers.get() + self.handle_ignored_model_update(model_update) + + def next_model_update(self): + """Get the next model update from the queue. + + :return: The model update. + :rtype: fedn.network.grpc.fedn.proto.ModelUpdate + """ + try: + return self.model_update.get(block=False) + except queue.Empty: + if self._accept_stragglers: + return self.model_update_stragglers.get(block=False) + else: + raise queue.Empty + + def aggregation_condition(self, buffer_size=100): + """Check if the round has enough updates to continue to aggregate. + + :param buffer_size: The number of model updates to wait for before starting aggregation, defaults to 100 + :type buffer_size: int, optional + :return: True if the round is complete, False otherwise. + :rtype: bool + """ + return self.model_update.qsize() >= buffer_size diff --git a/fedn/network/common/command.py b/fedn/network/common/command.py new file mode 100644 index 000000000..c712b1eb1 --- /dev/null +++ b/fedn/network/common/command.py @@ -0,0 +1,20 @@ +from enum import Enum + + +class CommandType(Enum): + StandardSession = "Fedn_StandardSession" + SplitLearningSession = "Fedn_SplitLearningSession" + PredictionSession = "Fedn_Prediction" + + +def validate_custom_command(command: str) -> bool: + """Validate that the command is a valid custom command. + + :param command: The command to validate. + :return: True if the command is valid, False otherwise. + """ + if not isinstance(command, str): + return False + if not command.startswith("Fedn_"): + return False + return True diff --git a/fedn/network/common/flow_controller.py b/fedn/network/common/flow_controller.py new file mode 100644 index 000000000..a9736dedb --- /dev/null +++ b/fedn/network/common/flow_controller.py @@ -0,0 +1,43 @@ +import threading +import time +from enum import Enum + + +class FlowController: + class Reason(Enum): + """Reason for the flow controller resuming.""" + + STOP = "stop" + CONTINUE = "continue" + TIMEOUT = "timeout" + CONDITION = "condition" + + def __init__(self): + self.stop_event = threading.Event() + self.continue_event = threading.Event() + + def wait_until_true(self, callback, timeout=0.0, polling_rate=1.0) -> Reason: + """Wait until the callback returns True or the timeout is reached. + + :param callback: The callback function to call. + :type callback: function + :param timeout: The timeout in seconds, defaults to 0.0 which means no timeout. + :type timeout: float, optional + :param polling_rate: The polling rate in seconds, defaults to 1.0. + :type polling_rate: float, optional + :return: The reason for the flow controller resuming. + :rtype: Reason + """ + self.continue_event.clear() + start = time.time() + + while True: + if callback(): + return self.Reason.CONDITION + if self.continue_event.is_set(): + return self.Reason.CONTINUE + if self.stop_event.is_set(): + return self.Reason.STOP + if timeout > 0.0 and time.time() - start > timeout: + return self.Reason.TIMEOUT + time.sleep(polling_rate) diff --git a/fedn/network/combiner/interfaces.py b/fedn/network/common/interfaces.py similarity index 71% rename from fedn/network/combiner/interfaces.py rename to fedn/network/common/interfaces.py index 2dc485754..92a4b67aa 100644 --- a/fedn/network/combiner/interfaces.py +++ b/fedn/network/common/interfaces.py @@ -1,12 +1,14 @@ import base64 import copy import json +from typing import Dict import grpc import fedn.network.grpc.fedn_pb2 as fedn import fedn.network.grpc.fedn_pb2_grpc as rpc -from fedn.network.combiner.roundhandler import RoundConfig +from fedn.common.log_config import logger +from fedn.network.common.state import ControllerState class CombinerUnavailableError(Exception): @@ -79,8 +81,9 @@ class CombinerInterface: :type config: dict """ - def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): + def __init__(self, combiner_id, parent, name, address, fqdn, port, certificate=None, key=None, ip=None, config=None): """Initialize the combiner interface.""" + self.combiner_id = combiner_id self.parent = parent self.name = name self.address = address @@ -95,44 +98,6 @@ def __init__(self, parent, name, address, fqdn, port, certificate=None, key=None else: self.config = config - @classmethod - def from_json(combiner_config): - """Initialize the combiner config from a json document. - - :parameter combiner_config: The combiner configuration. - :type combiner_config: dict - :return: An instance of the combiner interface. - :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` - """ - return CombinerInterface(**combiner_config) - - def to_dict(self): - """Export combiner configuration to a dictionary. - - :return: A dictionary with the combiner configuration. - :rtype: dict - """ - data = { - "parent": self.parent, - "name": self.name, - "address": self.address, - "fqdn": self.fqdn, - "port": self.port, - "ip": self.ip, - "certificate": None, - "key": None, - "config": self.config, - } - return data - - def to_json(self): - """Export combiner configuration to json. - - :return: A json document with the combiner configuration. - :rtype: str - """ - return json.dumps(self.to_dict()) - def get_certificate(self): """Get combiner certificate. @@ -216,33 +181,35 @@ def set_server_functions(self, server_functions): else: raise - def submit(self, config: RoundConfig): - """Submit a compute plan to the combiner. + def submit(self, command: fedn.Command, parameters: Dict = None) -> fedn.ControlResponse: + """Send a command to the combiner. - :param config: The job configuration. - :type config: dict - :return: Server ControlResponse object. + :param command: The command to send. + :type command: :class:`fedn.network.grpc.fedn_pb2.Command` + :param parameters: The parameters for the command (optional). + :type parameters: dict + :return: The response from the combiner. :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlResponse` """ channel = Channel(self.address, self.port, self.certificate).get_channel() control = rpc.ControlStub(channel) - request = fedn.ControlRequest() - request.command = fedn.Command.START - for k, v in config.items(): - p = request.parameter.add() - p.key = str(k) - p.value = str(v) + + request = fedn.CommandRequest() + + request.command = command + + if parameters: + request.parameters = json.dumps(parameters) try: - response = control.Start(request) + response = control.SendCommand(request) + return response except grpc.RpcError as e: if e.code() == grpc.StatusCode.UNAVAILABLE: - raise CombinerUnavailableError + raise CombinerUnavailableError(f"Combiner {self.name} unavailable: {e}") else: raise - return response - def allowing_clients(self): """Check if the combiner is allowing additional client connections. @@ -290,3 +257,60 @@ def list_active_clients(self, queue=1): else: raise return response.client + + +class ControlInterface: + def __init__(self, address, port, certificate=None): + """Initialize the control interface.""" + self.address = address + self.port = port + self.certificate = certificate + + def send_command(self, command: fedn.Command, command_type: str = None, parameters: Dict = None) -> fedn.ControlRequest: + """Send a command to the control interface. + + :param command_type: The type of command to send. + :type command_type: str + :param parameters: The parameters for the command. + :type parameters: dict + :return: The response from the control interface. + :rtype: dict + """ + logger.info(f"Sending command {command} of type {command_type} to controller") + channel = Channel(self.address, self.port, self.certificate).get_channel() + control = rpc.ControlStub(channel) + + request = fedn.CommandRequest() + request.command = command + + if command_type: + request.command_type = command_type + + if parameters: + request.parameters = json.dumps(parameters) + + try: + response = control.SendCommand(request) + return response + except grpc.RpcError as e: + raise CombinerUnavailableError(f"Control interface unavailable: {e}") + + def get_state(self) -> ControllerState: + """Get the current state of the control interface. + + :return: The current state. + :rtype: :class:`fedn.network.grpc.fedn_pb2.ControlState` + """ + logger.info(f"Getting control state from {self.address}:{self.port}") + channel = Channel(self.address, self.port, self.certificate).get_channel() + control = rpc.ControlStub(channel) + + request = fedn.ControlRequest() + + try: + response = control.GetState(request) + logger.info(f"Control state response: {response.state}") + return ControllerState[response.state] + + except grpc.RpcError as e: + raise CombinerUnavailableError(f"Control interface unavailable: {e}") diff --git a/fedn/network/common/network.py b/fedn/network/common/network.py new file mode 100644 index 000000000..5009a8822 --- /dev/null +++ b/fedn/network/common/network.py @@ -0,0 +1,227 @@ +import os +from typing import List + +import fedn +from fedn.common.log_config import logger +from fedn.network.common.interfaces import CombinerInterface, CombinerUnavailableError, ControlInterface +from fedn.network.controller.shared import MisconfiguredHelper +from fedn.network.loadbalancer.leastpacked import LeastPacked +from fedn.network.storage.dbconnection import DatabaseConnection +from fedn.network.storage.s3.repository import Repository +from fedn.network.storage.statestore.stores.dto.model import ModelDTO +from fedn.network.storage.statestore.stores.shared import SortOrder +from fedn.utils.model import FednModel + +__all__ = ("Network",) + + +class Network: + """FEDn network interface. This class is used to interact with the database in a consistent way accross different containers.""" + + def __init__(self, dbconn: DatabaseConnection, repository: Repository, load_balancer=None, controller_host: str = None, controller_port: int = None): + """ """ + self.db = dbconn + self.repository = repository + + if not load_balancer: + self.load_balancer = LeastPacked(self) + else: + self.load_balancer = load_balancer + + self.controller = self._init_controller_interface(controller_host, controller_port) + + def _init_controller_interface(self, host, port) -> ControlInterface: + """Get a control instance from global config. + + :return: ControlInterface instance. + :rtype: :class:`fedn.network.common.interfaces.ControlInterface` + """ + cert = None + name = "CONTROL".upper() + # General certificate handling, same for all combiners. + if os.environ.get("FEDN_GRPC_CERT_PATH"): + with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f: + cert = f.read() + # Specific certificate handling for each combiner. + elif os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}"): + cert_path = os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}") + with open(cert_path, "rb") as f: + cert = f.read() + + # TODO: Remove hardcoded values + return ControlInterface(host, port, cert) + + def get_combiner(self, name): + """Get combiner by name. + + :param name: name of combiner + :type name: str + :return: The combiner instance object + :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` + """ + combiners = self.get_combiners() + for combiner in combiners: + if name == combiner.name: + return combiner + return None + + def get_combiners(self) -> List[CombinerInterface]: + """Get all combiners in the network. + + :return: list of combiners objects + :rtype: list(:class:`fedn.network.combiner.interfaces.CombinerInterface`) + """ + result = self.db.combiner_store.list(limit=0, skip=0, sort_key=None) + combiners = [] + for combiner in result: + name = combiner.name.upper() + # General certificate handling, same for all combiners. + if os.environ.get("FEDN_GRPC_CERT_PATH"): + with open(os.environ.get("FEDN_GRPC_CERT_PATH"), "rb") as f: + cert = f.read() + # Specific certificate handling for each combiner. + elif os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}"): + cert_path = os.environ.get(f"FEDN_GRPC_CERT_PATH_{name}") + with open(cert_path, "rb") as f: + cert = f.read() + else: + cert = None + combiners.append( + CombinerInterface( + combiner.combiner_id, combiner.parent, combiner.name, combiner.address, combiner.fqdn, combiner.port, certificate=cert, ip=combiner.ip + ) + ) + + return combiners + + def find_available_combiner(self): + """Find an available combiner in the network. + + :return: The combiner instance object + :rtype: :class:`fedn.network.combiner.interfaces.CombinerInterface` + """ + combiner = self.load_balancer.find_combiner() + return combiner + + def handle_unavailable_combiner(self, combiner): + """This callback is triggered if a combiner is found to be unresponsive. + + :param combiner: The combiner instance object + :type combiner: :class:`fedn.network.combiner.interfaces.CombinerInterface` + :return: None + """ + # TODO: Implement strategy to handle an unavailable combiner. + logger.warning("REDUCER CONTROL: Combiner {} unavailable.".format(combiner.name)) + + def get_control(self) -> ControlInterface: + """Get a control instance from global config. + + :return: ControlInterface instance. + :rtype: :class:`fedn.network.common.interfaces.ControlInterface` + """ + return self.controller + + def get_helper(self): + """Get a helper instance from global config. + + :return: Helper instance. + :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` + """ + helper_type: str = None + + try: + active_package = self.db.package_store.get_active() + helper_type = active_package.helper + except Exception: + logger.error("Failed to get active helper") + + helper = fedn.utils.helpers.helpers.get_helper(helper_type) + if not helper: + raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) + return helper + + def commit_model(self, model: FednModel = None, session_id: str = None, name: str = None) -> str: + """Commit a model to the global model trail. The model commited becomes the lastest consensus model. + + :param model_id: Unique identifier for the model to commit. + :type model_id: str (uuid) + :param model: The model object to commit + :type model: BytesIO + :param session_id: Unique identifier for the session + :type session_id: str + """ + if model is not None: + model_id = self.repository.set_model(model.get_stream(), is_file=False) + model.model_id = model_id + + logger.info("Committing model {} to global model trail in statestore...".format(model_id)) + + parent_model = None + if session_id: + last_model_of_session = self.db.model_store.list(1, 0, "committed_at", SortOrder.DESCENDING, session_id=session_id) + if len(last_model_of_session) == 1: + parent_model = last_model_of_session[0].model_id + else: + session = self.db.session_store.get(session_id) + parent_model = session.seed_model_id + + new_model = ModelDTO() + new_model.model_id = model_id + new_model.parent_model = parent_model + new_model.session_id = session_id + new_model.name = name + + try: + self.db.model_store.add(new_model) + except Exception as e: + logger.error("Failed to commit model to global model trail: {}".format(e)) + raise Exception("Failed to commit model to global model trail") + + return model_id + + def get_control_state(self): + """Get the current state of the control. + + :return: The current state. + :rtype: :class:`fedn.network.state.ReducerState` + """ + return self.get_control().get_state() + + def get_number_of_available_clients(self, client_ids: list[str]): + result = 0 + for combiner in self.get_combiners(): + try: + active_clients = combiner.list_active_clients() + if active_clients is not None: + if client_ids is not None: + filtered = [item for item in active_clients if item.client_id in client_ids] + result += len(filtered) + else: + result += len(active_clients) + except CombinerUnavailableError: + return 0 + return result + + def get_compute_package(self, compute_package=""): + """:param compute_package: + :return: + """ + if compute_package == "": + compute_package = self.get_compute_package_name() + if compute_package: + return self.repository.get_compute_package(compute_package) + else: + return None + + def get_compute_package_name(self): + """:return:""" + definition = self.db.package_store.get_active() + if definition: + try: + package_name = definition.storage_file_name + return package_name + except (IndexError, KeyError): + logger.error("No context filename set for compute context definition") + return None + else: + return None diff --git a/fedn/network/state.py b/fedn/network/common/state.py similarity index 70% rename from fedn/network/state.py rename to fedn/network/common/state.py index fcbd391eb..297dfbeb1 100644 --- a/fedn/network/state.py +++ b/fedn/network/common/state.py @@ -1,7 +1,7 @@ from enum import Enum -class ReducerState(Enum): +class ControllerState(Enum): """Enum for representing the state of a reducer.""" setup = 1 @@ -18,13 +18,13 @@ def ReducerStateToString(state): :return: The state as string. :rtype: str """ - if state == ReducerState.setup: + if state == ControllerState.setup: return "setup" - if state == ReducerState.idle: + if state == ControllerState.idle: return "idle" - if state == ReducerState.instructing: + if state == ControllerState.instructing: return "instructing" - if state == ReducerState.monitoring: + if state == ControllerState.monitoring: return "monitoring" return "UNKNOWN" @@ -39,10 +39,10 @@ def StringToReducerState(state): :rtype: :class:`fedn.network.state.ReducerState` """ if state == "setup": - return ReducerState.setup + return ControllerState.setup if state == "idle": - return ReducerState.idle + return ControllerState.idle elif state == "instructing": - return ReducerState.instructing + return ControllerState.instructing elif state == "monitoring": - return ReducerState.monitoring + return ControllerState.monitoring diff --git a/fedn/network/controller/command_runner.py b/fedn/network/controller/command_runner.py new file mode 100644 index 000000000..20782e016 --- /dev/null +++ b/fedn/network/controller/command_runner.py @@ -0,0 +1,44 @@ +import threading +from typing import TYPE_CHECKING, Callable, Dict + +from fedn.common.log_config import logger +from fedn.network.common.flow_controller import FlowController +from fedn.network.common.state import ControllerState + +if TYPE_CHECKING: + from fedn.network.controller.control import Control # not-floating-import + + +class CommandRunner: + """CommandRunner is responsible for executing commands on the controller.""" + + def __init__(self, control: "Control"): + self.flow_controller = FlowController() + self._state = ControllerState.idle + self.control = control + self.lock = threading.Lock() + + @property + def state(self) -> ControllerState: + return self._state + + def start_command(self, callback: Callable, parameters: Dict = None): + with self.lock: + if self._state != ControllerState.idle: + raise RuntimeError("CommandRunner is already running a command.") + self._state = ControllerState.instructing + threading.Thread(target=self._run_command, args=(callback, parameters)).start() + + def _run_command(self, callback, parameters: Dict = None): + """Run the command in a separate thread.""" + self.flow_controller.continue_event.clear() + self.flow_controller.stop_event.clear() + + try: + logger.info("CommandRunner: Starting command...") + callback(**parameters) + except Exception as e: + logger.error(f"CommandRunner: Failed command with error: {e}") + finally: + self._state = ControllerState.idle + logger.info("CommandRunner: Command finished.") diff --git a/fedn/network/controller/control.py b/fedn/network/controller/control.py index 765914ea1..9a237d375 100644 --- a/fedn/network/controller/control.py +++ b/fedn/network/controller/control.py @@ -1,21 +1,29 @@ import copy import datetime +import json +import signal import time -from typing import Optional +from typing import Dict, Optional, Tuple -from tenacity import retry, retry_if_exception_type, stop_after_delay, wait_random +import grpc +import fedn.network.grpc.fedn_pb2 as fedn +import fedn.network.grpc.fedn_pb2_grpc as rpc from fedn.common.log_config import logger -from fedn.network.combiner.interfaces import CombinerUnavailableError -from fedn.network.combiner.modelservice import load_model_from_bytes from fedn.network.combiner.roundhandler import RoundConfig +from fedn.network.common.command import CommandType +from fedn.network.common.flow_controller import FlowController +from fedn.network.common.interfaces import CombinerUnavailableError +from fedn.network.common.state import ControllerState +from fedn.network.controller.command_runner import CommandRunner from fedn.network.controller.controlbase import ControlBase -from fedn.network.state import ReducerState +from fedn.network.grpc.server import Server from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.repository import Repository from fedn.network.storage.statestore.stores.dto.run import RunDTO from fedn.network.storage.statestore.stores.dto.session import SessionConfigDTO from fedn.network.storage.statestore.stores.shared import SortOrder +from fedn.utils.model import FednModel class UnsupportedStorageBackend(Exception): @@ -86,17 +94,16 @@ def __init__(self, message): super().__init__(self.message) -class Control(ControlBase): +class Control(ControlBase, rpc.ControlServicer): """Controller, implementing the overall global training, validation and prediction logic. :param statestore: A StateStorage instance. :type statestore: class: `fedn.network.statestorebase.StateStorageBase` """ - _instance: "Control" - def __init__( self, + config: Dict, network_id: str, repository: Repository, db: DatabaseConnection, @@ -104,29 +111,9 @@ def __init__( """Constructor method.""" super().__init__(network_id, repository, db) self.name = "DefaultControl" + self.server = Server(self, None, config) - @classmethod - def instance(cls) -> "Control": - """Get the singleton instance of the Control class.""" - if Control._instance is None: - raise Exception("Control instance not initialized") - return Control._instance - - @classmethod - def create_instance(cls, network_id: str, repository: Repository, db: DatabaseConnection) -> "Control": - """Create a singleton instance of the Control class. - - :param network_id: The network ID. - :type network_id: str - :param repository: The repository instance. - :type repository: Repository - :param db: The database connection instance. - :type db: DatabaseConnection - :return: The Control instance. - :rtype: Control - """ - cls._instance = cls(network_id, repository, db) - return cls._instance + self.command_runner = CommandRunner(self) def _get_active_model_id(self, session_id: str) -> Optional[str]: """Get the active model for a session. @@ -147,24 +134,81 @@ def _get_active_model_id(self, session_id: str) -> Optional[str]: return None + def run(self) -> None: + # Start the gRPC server + self.server.start() + try: + while True: + signal.pause() + except (KeyboardInterrupt, SystemExit): + pass + finally: + logger.info("Shutting Controller...") + self.server.stop() + + def SendCommand(self, command_request: fedn.CommandRequest, context: grpc.ServicerContext) -> fedn.ControlResponse: + parameters = json.loads(command_request.parameters) if command_request.parameters else {} + + if command_request.command == fedn.Command.START: + logger.info("grpc.Controller.SendCommand start command") + try: + self.start_command(command_request.command_type, parameters) + except Exception as e: + logger.error(f"Failed to start command: {e}") + context.abort(grpc.StatusCode.UNKNOWN, f"Failed to start command: {e}") + + response = fedn.ControlResponse() + response.message = "Success" + return response + elif command_request.command == fedn.Command.STOP: + logger.info("grpc.Controller.SendCommand: Stopping current round") + self.command_runner.flow_controller.stop_event.set() + response = fedn.ControlResponse() + response.message = "Success" + return response + elif command_request.command == fedn.Command.CONTINUE: + logger.info("grpc.Controller.SendCommand: Continuing current round") + self.command_runner.flow_controller.continue_event.set() + response = fedn.ControlResponse() + response.message = "Success" + return response + + def GetState(self, request: fedn.ControlRequest, context: grpc.ServicerContext) -> fedn.ControlStateResponse: + """Get the current state of the control. + + :param context: The gRPC context. + :type context: grpc.ServicerContext + :return: The current state of the control. + :rtype: fedn.ControlStateResponse + """ + logger.info("grpc.Control.GetState: Getting current state of the control") + response = fedn.ControlStateResponse() + response.state = self.command_runner.state.name + return response + + def start_command(self, command_type: str, parameters: Dict) -> None: + if command_type == CommandType.StandardSession.value: + self.command_runner.start_command(self.start_session, parameters) + elif command_type == CommandType.PredictionSession.value: + self.command_runner.start_command(self.prediction_session, parameters) + elif command_type == CommandType.SplitLearningSession.value: + self.command_runner.start_command(self.splitlearning_session, parameters) + else: + # TODO: Handle custom commands + raise RuntimeError(f"Unsupported command type: {command_type}") + def start_session( self, session_id: str, rounds: int, round_timeout: int, model_name_prefix: Optional[str] = None, client_ids: Optional[list[str]] = None ) -> None: - if self._state == ReducerState.instructing: - logger.info("Controller already in INSTRUCTING state. A session is in progress.") - return - try: active_model_id = self._get_active_model_id(session_id) if not active_model_id or active_model_id in ["", " "]: logger.warning("No model in model chain, please provide a seed model!") return - except Exception: - logger.error("Failed to get latest model of session and model chain.") + except Exception as e: + logger.error("Failed to get latest model of session and model chain with error: {}".format(e)) return - self._state = ReducerState.instructing - session = self.db.session_store.get(session_id) if not session: @@ -180,8 +224,6 @@ def start_session( if round_timeout is not None: session_config.round_timeout = round_timeout - self._state = ReducerState.monitoring - last_round = self.get_latest_round_id() aggregator = session_config.aggregator @@ -240,7 +282,6 @@ def start_session( training_run_obj.completed_at_model_id = self._get_active_model_id(session_id) self.db.run_store.update(training_run_obj) logger.info("Session finished.") - self._state = ReducerState.idle self.set_session_config(session_id, session_config.to_dict()) @@ -251,10 +292,6 @@ def prediction_session(self, config: RoundConfig) -> None: :type config: PredictionConfig :return: None """ - if self._state == ReducerState.instructing: - logger.info("Controller already in INSTRUCTING state. A session is in progress.") - return - if len(self.network.get_combiners()) < 1: logger.warning("Prediction round cannot start, no combiners connected!") return @@ -275,7 +312,7 @@ def prediction_session(self, config: RoundConfig) -> None: if round_start: logger.info("Prediction round start policy met, {} participating combiners.".format(len(participating_combiners))) for combiner, _ in participating_combiners: - combiner.submit(config) + combiner.submit(fedn.Command.START, config) logger.info("Prediction round submitted to combiner {}".format(combiner)) def splitlearning_session(self, session_id: str, rounds: int, round_timeout: int) -> None: @@ -290,12 +327,6 @@ def splitlearning_session(self, session_id: str, rounds: int, round_timeout: int """ logger.info("Starting split learning session.") - if self._state == ReducerState.instructing: - logger.info("Controller already in INSTRUCTING state. A session is in progress.") - return - - self._state = ReducerState.instructing - session = self.db.session_store.get(session_id) if not session: @@ -311,8 +342,6 @@ def splitlearning_session(self, session_id: str, rounds: int, round_timeout: int if round_timeout is not None: session_config.round_timeout = round_timeout - self._state = ReducerState.monitoring - last_round = self.get_latest_round_id() for combiner in self.network.get_combiners(): @@ -343,7 +372,6 @@ def splitlearning_session(self, session_id: str, rounds: int, round_timeout: int if self.get_session_status(session_config.session_id) == "Started": self.set_session_status(session_config.session_id, "Finished") - self._state = ReducerState.idle self.set_session_config(session_id, session_config.to_dict()) @@ -394,48 +422,48 @@ def round( _ = self.request_model_updates(participating_combiners) # TODO: Check response - # Wait until participating combiners have produced an updated global model, - # or round times out. - def do_if_round_times_out(result): - logger.warning("Round timed out!") - return True - - @retry( - wait=wait_random(min=1.0, max=2.0), - stop=stop_after_delay(session_config.round_timeout), - retry_error_callback=do_if_round_times_out, - retry=retry_if_exception_type(CombinersNotDoneException), - ) - def combiners_done(): + def check_round_reported(): round = self.db.round_store.get(round_id) - session_status = self.get_session_status(session_id) - if session_status == "Terminated": - self.set_round_status(round_id, "Terminated") - return False - if len(round.combiners) < 1: - logger.info("Waiting for combiners to update model...") - raise CombinersNotDoneException("Combiners have not yet reported.") - if len(round.combiners) < len(participating_combiners): logger.info("Waiting for combiners to update model...") - raise CombinersNotDoneException("All combiners have not yet reported.") - + return False return True - combiners_are_done = combiners_done() - if not combiners_are_done: + reason = self.command_runner.flow_controller.wait_until_true(check_round_reported, session_config.round_timeout, polling_rate=2.0) + if reason == FlowController.Reason.TIMEOUT: + logger.warning("Round timed out!") + elif reason == FlowController.Reason.STOP: + self.set_session_status(session_id, "Terminated") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.STOP) + self.set_round_status(round_id, "Terminated") return None, self.db.round_store.get(round_id) + elif reason == FlowController.Reason.CONTINUE: + logger.info("Sending continue signal to combiners.") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.CONTINUE) - # Due to the distributed nature of the computation, there might be a - # delay before combiners have reported the round data to the db, - # so we need some robustness here. - @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) - def check_combiners_done_reporting(): + # Wait for the combiners to finish aggreation. The timeout might just been reached and + # the combiners need time to finish aggregate the model updates and report back to the + # controller as the timeout + + # Update method with new print + def check_round_reported(): round = self.db.round_store.get(round_id) if len(round.combiners) != len(participating_combiners): - raise KeyError("Combiners have not yet reported.") + logger.info("Waiting for combiners to finish aggregation...") + return False + return True + + # Infite loop until all combiners have reported back + reason = self.command_runner.flow_controller.wait_until_true(check_round_reported, polling_rate=2.0) - check_combiners_done_reporting() + if reason == FlowController.Reason.STOP: + self.set_session_status(session_id, "Terminated") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.STOP) + self.set_round_status(round_id, "Terminated") + return None, self.db.round_store.get(round_id) round = self.db.round_store.get(round_id) round_valid = self.evaluate_round_validity_policy(round.to_dict()) @@ -449,7 +477,7 @@ def check_combiners_done_reporting(): round_data = {} try: round = self.db.round_store.get(round_id) - model, data = self.reduce(round.combiners.to_dict()) + fedn_model, data = self.reduce(round.combiners.to_dict()) round_data["reduce"] = data logger.info("Done reducing models from combiners!") except Exception as e: @@ -459,10 +487,10 @@ def check_combiners_done_reporting(): # Commit the new global model to the model trail model_id: Optional[str] = None - if model is not None: + if fedn_model is not None: logger.info("Committing global model to model trail...") tic = time.time() - model_id = self.commit(model=model, session_id=session_id, name=model_name) + model_id = self.network.commit_model(model=fedn_model, session_id=session_id, name=model_name) round_data["time_commit"] = time.time() - tic logger.info("Done committing global model to model trail.") else: @@ -495,7 +523,7 @@ def check_combiners_done_reporting(): for combiner, combiner_config in validating_combiners: try: logger.info("Submitting validation round to combiner {}".format(combiner)) - combiner.submit(combiner_config) + combiner.submit(fedn.Command.START, combiner_config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) pass @@ -514,7 +542,6 @@ def splitlearning_round(self, session_config: SessionConfigDTO, round_id: str, s :param session_id: The session id :type session_id: str """ - # session_id = session_config.session_id self.create_round({"round_id": round_id, "status": "Pending"}) if len(self.network.get_combiners()) < 1: @@ -541,46 +568,44 @@ def splitlearning_round(self, session_config: SessionConfigDTO, round_id: str, s # Wait until participating combiners have produced an updated global model, # or round times out. - def do_if_round_times_out(result): - logger.warning("Round timed out!") - return True - - @retry( - wait=wait_random(min=1.0, max=2.0), - stop=stop_after_delay(session_config.round_timeout), - retry_error_callback=do_if_round_times_out, - retry=retry_if_exception_type(CombinersNotDoneException), - ) - def combiners_done(): + def check_round_reported(): round = self.db.round_store.get(round_id) - session_status = self.get_session_status(session_id) - if session_status == "Terminated": - self.set_round_status(round_id, "Terminated") - return False - if len(round.combiners) < 1: - logger.info("Waiting for combiners to update model...") - raise CombinersNotDoneException("Combiners have not yet reported.") - if len(round.combiners) < len(participating_combiners): logger.info("Waiting for combiners to update model...") - raise CombinersNotDoneException("All combiners have not yet reported.") - + return False return True - combiners_are_done = combiners_done() - if not combiners_are_done: + reason = self.command_runner.flow_controller.wait_until_true(check_round_reported, session_config.round_timeout, polling_rate=2.0) + if reason == FlowController.Reason.TIMEOUT: + logger.warning("Round timed out!") + elif reason == FlowController.Reason.STOP: + self.set_session_status(session_id, "Terminated") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.STOP) + self.set_round_status(round_id, "Terminated") return None, self.db.round_store.get(round_id) + elif reason == FlowController.Reason.CONTINUE: + logger.info("Sending continue signal to combiners.") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.CONTINUE) - # Due to the distributed nature of the computation, there might be a - # delay before combiners have reported the round data to the db, - # so we need some robustness here. - @retry(wait=wait_random(min=0.1, max=1.0), retry=retry_if_exception_type(KeyError)) - def check_combiners_done_reporting(): + # Update method with new print + def check_round_reported(): round = self.db.round_store.get(round_id) if len(round.combiners) != len(participating_combiners): - raise KeyError("Combiners have not yet reported.") + logger.info("Waiting for combiner to finish aggregation...") + return False + return True + + # Infite loop until all combiners have reported back + reason = self.command_runner.flow_controller.wait_until_true(check_round_reported, polling_rate=2.0) - check_combiners_done_reporting() + if reason == FlowController.Reason.STOP: + self.set_session_status(session_id, "Terminated") + for combiner, _ in participating_combiners: + combiner.submit(fedn.Command.STOP) + self.set_round_status(round_id, "Terminated") + return None, self.db.round_store.get(round_id) logger.info("CONTROLLER: Forward pass completed.") @@ -636,16 +661,16 @@ def check_combiners_done_reporting(): for combiner, config in validating_combiners: try: logger.info("Submitting validation for split learning to combiner {}".format(combiner)) - combiner.submit(config) + combiner.submit(fedn.Command.START, config) except CombinerUnavailableError: self._handle_unavailable_combiner(combiner) pass logger.info("Controller: Split Learning Validation completed") - self.set_round_status(round_id, "Finished") - return None, self.db.round_store.get(round_id) + self.set_round_status(round_id, "Success") + return model_id, self.db.round_store.get(round_id) - def reduce(self, combiners): + def reduce(self, combiners) -> Tuple[FednModel, dict]: """Combine updated models from Combiner nodes into one global model. : param combiners: dict of combiner names(key) and model IDs(value) to reduce @@ -657,7 +682,9 @@ def reduce(self, combiners): meta["time_aggregate_model"] = 0.0 i = 1 - model = None + model_params_agg = None + + helper = self.network.get_helper() for combiner in combiners: name = combiner["name"] @@ -667,30 +694,34 @@ def reduce(self, combiners): try: tic = time.time() - data = self.repository.get_model(model_id) + fedn_model = self.repository.get_model(model_id) meta["time_fetch_model"] += time.time() - tic except Exception as e: logger.error("Failed to fetch model from model repository {}: {}".format(name, e)) - data = None + fedn_model = None - if data is not None: + if fedn_model is not None: try: tic = time.time() - helper = self.get_helper() - model_next = load_model_from_bytes(data, helper) + + model_params_next = fedn_model.get_model_params(helper) meta["time_load_model"] += time.time() - tic tic = time.time() - model = helper.increment_average(model, model_next, 1.0, i) + model_params_agg = helper.increment_average(model_params_agg, model_params_next, 1.0, i) meta["time_aggregate_model"] += time.time() - tic except Exception: tic = time.time() - model = load_model_from_bytes(data, helper) + model_params_agg = fedn_model.get_model_params(helper) meta["time_aggregate_model"] += time.time() - tic i = i + 1 self.repository.delete_model(model_id) - return model, meta + if model_params_agg is not None: + fedn_model = FednModel.from_model_params(model_params_agg, helper) + return fedn_model, meta + else: + return None, meta def predict_instruct(self, config): """Main entrypoint for executing the prediction compute plan. @@ -700,17 +731,17 @@ def predict_instruct(self, config): # TODO: DEAD CODE? # Check/set instucting state - if self.__state == ReducerState.instructing: + if self.__state == ControllerState.instructing: logger.info("Already set in INSTRUCTING state") return - self.__state = ReducerState.instructing + self.__state = ControllerState.instructing # Check for a model chain if not self.statestore.latest_model(): logger.warning("No model in model chain, please set seed model.") # Set reducer in monitoring state - self.__state = ReducerState.monitoring + self.__state = ControllerState.monitoring # Start prediction round try: @@ -719,7 +750,7 @@ def predict_instruct(self, config): logger.error("Round failed.") # Set reducer in idle state - self.__state = ReducerState.idle + self.__state = ControllerState.idle def prediction_round(self, config): """Execute a prediction round. @@ -755,7 +786,7 @@ def prediction_round(self, config): # Synch combiners with latest model and trigger prediction for combiner, combiner_config in validating_combiners: try: - combiner.submit(combiner_config) + combiner.submit(fedn.Command.START, combiner_config) except CombinerUnavailableError: # It is OK if prediction fails for a combiner self._handle_unavailable_combiner(combiner) diff --git a/fedn/network/controller/controlbase.py b/fedn/network/controller/controlbase.py index bfba5c02c..3db4bcc80 100644 --- a/fedn/network/controller/controlbase.py +++ b/fedn/network/controller/controlbase.py @@ -1,33 +1,13 @@ -import os from abc import ABC, abstractmethod from typing import Any, Dict, List, Tuple -import fedn.utils.helpers.helpers -from fedn.common.log_config import logger -from fedn.network.api.network import Network -from fedn.network.combiner.interfaces import CombinerInterface, CombinerUnavailableError +import fedn.network.grpc.fedn_pb2 as fedn_proto from fedn.network.combiner.roundhandler import RoundConfig -from fedn.network.state import ReducerState +from fedn.network.common.interfaces import CombinerInterface, CombinerUnavailableError +from fedn.network.common.network import Network from fedn.network.storage.dbconnection import DatabaseConnection from fedn.network.storage.s3.repository import Repository -from fedn.network.storage.statestore.stores.dto import ModelDTO from fedn.network.storage.statestore.stores.dto.round import RoundDTO -from fedn.network.storage.statestore.stores.shared import SortOrder - -# Maximum number of tries to connect to statestore and retrieve storage configuration -MAX_TRIES_BACKEND = os.getenv("MAX_TRIES_BACKEND", 10) - - -class UnsupportedStorageBackend(Exception): - pass - - -class MisconfiguredStorageBackend(Exception): - pass - - -class MisconfiguredHelper(Exception): - pass class ControlBase(ABC): @@ -48,16 +28,11 @@ def __init__( db: DatabaseConnection, ): """Constructor.""" - self._state = ReducerState.setup - - self.network = Network(self, network_id, db) + self.network = Network(db, repository) self.repository = repository - self.db = db - self._state = ReducerState.idle - @abstractmethod def round(self, config, round_number): pass @@ -66,75 +41,13 @@ def round(self, config, round_number): def reduce(self, combiners): pass - def get_helper(self): - """Get a helper instance from global config. - - :return: Helper instance. - :rtype: :class:`fedn.utils.plugins.helperbase.HelperBase` - """ - helper_type: str = None - - try: - active_package = self.db.package_store.get_active() - helper_type = active_package.helper - except Exception: - logger.error("Failed to get active helper") - - helper = fedn.utils.helpers.helpers.get_helper(helper_type) - if not helper: - raise MisconfiguredHelper("Unsupported helper type {}, please configure compute_package.helper !".format(helper_type)) - return helper - - def get_state(self): - """Get the current state of the controller. - - :return: The current state. - :rtype: :class:`fedn.network.state.ReducerState` - """ - return self._state - - def idle(self): - """Check if the controller is idle. - - :return: True if idle, False otherwise. - :rtype: bool - """ - if self._state == ReducerState.idle: - return True - else: - return False - def get_latest_round_id(self) -> int: return self.db.round_store.get_latest_round_id() - def get_compute_package_name(self): - """:return:""" - definition = self.db.package_store.get_active() - if definition: - try: - package_name = definition.storage_file_name - return package_name - except (IndexError, KeyError): - logger.error("No context filename set for compute context definition") - return None - else: - return None - def set_compute_package(self, filename, path): """Persist the configuration for the compute package.""" self.repository.set_compute_package(filename, path) - def get_compute_package(self, compute_package=""): - """:param compute_package: - :return: - """ - if compute_package == "": - compute_package = self.get_compute_package_name() - if compute_package: - return self.repository.get_compute_package(compute_package) - else: - return None - def set_session_status(self, session_id: str, status: str) -> Tuple[bool, Any]: """Set the round round stats. @@ -212,7 +125,7 @@ def set_round_config(self, round_id: str, round_config: RoundConfig): round.round_config = round_config self.db.round_store.update(round) - def request_model_updates(self, combiners): + def request_model_updates(self, combiners: List[Tuple[CombinerInterface, Dict]]): """Ask Combiner server to produce a model update. :param combiners: A list of combiners @@ -220,55 +133,10 @@ def request_model_updates(self, combiners): """ cl = [] for combiner, combiner_round_config in combiners: - response = combiner.submit(combiner_round_config) + response = combiner.submit(fedn_proto.Command.START, combiner_round_config) cl.append((combiner, response)) return cl - def commit(self, model: dict = None, session_id: str = None, name: str = None) -> str: - """Commit a model to the global model trail. The model commited becomes the lastest consensus model. - - :param model_id: Unique identifier for the model to commit. - :type model_id: str (uuid) - :param model: The model object to commit - :type model: BytesIO - :param session_id: Unique identifier for the session - :type session_id: str - """ - helper = self.get_helper() - if model is not None: - outfile_name = helper.save(model) - logger.info("Saving model file temporarily to {}".format(outfile_name)) - logger.info("CONTROL: Uploading model to object store...") - model_id = self.repository.set_model(outfile_name, is_file=True) - - logger.info("CONTROL: Deleting temporary model file...") - os.unlink(outfile_name) - - logger.info("Committing model {} to global model trail in statestore...".format(model_id)) - - parent_model = None - if session_id: - last_model_of_session = self.db.model_store.list(1, 0, "committed_at", SortOrder.DESCENDING, session_id=session_id) - if len(last_model_of_session) == 1: - parent_model = last_model_of_session[0].model_id - else: - session = self.db.session_store.get(session_id) - parent_model = session.seed_model_id - - new_model = ModelDTO() - new_model.model_id = model_id - new_model.parent_model = parent_model - new_model.session_id = session_id - new_model.name = name - - try: - self.db.model_store.add(new_model) - except Exception as e: - logger.error("Failed to commit model to global model trail: {}".format(e)) - raise Exception("Failed to commit model to global model trail") - - return model_id - def get_combiner(self, name): for combiner in self.network.get_combiners(): if combiner.name == name: @@ -341,11 +209,3 @@ def evaluate_round_validity_policy(self, round): return False return True - - def state(self): - """Get the current state of the controller. - - :return: The state - :rype: str - """ - return self._state diff --git a/fedn/network/controller/shared.py b/fedn/network/controller/shared.py new file mode 100644 index 000000000..bbcf2f88f --- /dev/null +++ b/fedn/network/controller/shared.py @@ -0,0 +1,10 @@ +class UnsupportedStorageBackend(Exception): + pass + + +class MisconfiguredStorageBackend(Exception): + pass + + +class MisconfiguredHelper(Exception): + pass diff --git a/fedn/network/grpc/auth.py b/fedn/network/grpc/auth.py index 67bf8390e..5d09c4ece 100644 --- a/fedn/network/grpc/auth.py +++ b/fedn/network/grpc/auth.py @@ -25,6 +25,8 @@ "/fedn.Connector/ListActiveClients", "/fedn.Control/Start", "/fedn.Control/Stop", + "/fedn.Control/SendCommand", + "/fedn.Control/GetState", "/fedn.Control/FlushAggregationQueue", "/fedn.Control/SetAggregator", "/fedn.Control/SetServerFunctions", diff --git a/fedn/network/grpc/fedn.proto b/fedn/network/grpc/fedn.proto index 25a8c7490..8e68849b6 100644 --- a/fedn/network/grpc/fedn.proto +++ b/fedn/network/grpc/fedn.proto @@ -4,6 +4,7 @@ package fedn; import "google/protobuf/timestamp.proto"; import "google/protobuf/wrappers.proto"; +import "google/protobuf/struct.proto"; message Response { Client sender = 1; @@ -51,6 +52,8 @@ enum Queue { TASK_QUEUE = 1; } + + message TaskRequest { Client sender = 1; Client receiver = 2; @@ -62,6 +65,9 @@ message TaskRequest { string meta = 7; string session_id = 8; StatusType type = 9; + string round_id = 10; + string task_type = 11; + TaskStatus task_status = 12; } message ModelUpdate { @@ -73,6 +79,8 @@ message ModelUpdate { string timestamp = 6; string meta = 7; string config = 8; + string round_id = 9; + string session_id = 10; } message ModelValidation { @@ -144,32 +152,51 @@ message TelemetryElem { float value = 2; } +enum TaskStatus { + TASK_NONE = 0; + TASK_PENDING = 1; + TASK_RUNNING = 2; + TASK_COMPLETED = 3; + TASK_FAILED = 4; + TASK_INTERRUPTED = 5; + TASK_NEW = 6; + TASK_TIMEOUT = 7; +} + +message ActivityReport { + Client sender = 1; + string correlation_id = 2; + TaskStatus status = 3; + bool done = 4; + google.protobuf.Struct response = 5; +} + + enum ModelStatus { - OK = 0; + UNKNOWN = 0; IN_PROGRESS = 1; - IN_PROGRESS_OK = 2; + OK = 2; FAILED = 3; - UNKNOWN = 4; } message ModelRequest { Client sender = 1; Client receiver = 2; - bytes data = 3; - string id = 4; - ModelStatus status = 5; + string model_id = 3; } -message ModelResponse { +message FileChunk { bytes data = 1; - string id = 2; +} + +message ModelResponse { ModelStatus status = 3; string message = 4; } service ModelService { - rpc Upload(stream ModelRequest) returns (ModelResponse); - rpc Download(ModelRequest) returns (stream ModelResponse); + rpc Upload(stream FileChunk) returns (ModelResponse); + rpc Download(ModelRequest) returns (stream FileChunk); } @@ -239,6 +266,7 @@ enum Command { STOP = 3; RESET = 4; REPORT = 5; + CONTINUE = 6; } message Parameter { @@ -249,6 +277,14 @@ message Parameter { message ControlRequest { Command command = 1; repeated Parameter parameter = 2; + string command_type = 3; +} + +message CommandRequest{ + Command command = 1; + string correlation_id = 2; + string command_type = 3; + string parameters = 4; } message ControlResponse { @@ -256,12 +292,18 @@ message ControlResponse { repeated Parameter parameter = 2; } +message ControlStateResponse { + string state = 1; +} + service Control { rpc Start(ControlRequest) returns (ControlResponse); rpc Stop(ControlRequest) returns (ControlResponse); + rpc SendCommand(CommandRequest) returns (ControlResponse); rpc FlushAggregationQueue(ControlRequest) returns (ControlResponse); rpc SetAggregator(ControlRequest) returns (ControlResponse); rpc SetServerFunctions(ControlRequest) returns (ControlResponse); + rpc GetState(ControlRequest) returns (ControlStateResponse); } service Reducer { @@ -312,8 +354,11 @@ service Combiner { rpc SendModelMetric(ModelMetric) returns (Response); rpc SendAttributeMessage(AttributeMessage) returns (Response); rpc SendTelemetryMessage(TelemetryMessage) returns (Response); + + rpc PollAndReport(ActivityReport) returns (TaskRequest); } + message ProvidedFunctionsRequest { string function_code = 1; } @@ -322,10 +367,6 @@ message ProvidedFunctionsResponse { map available_functions = 1; } -message ClientConfigRequest { - bytes data = 1; -} - message ClientConfigResponse { string client_settings = 1; } @@ -347,10 +388,6 @@ message ClientMetaResponse { string status = 1; } -message StoreModelRequest { - bytes data = 1; - string id = 2; -} message StoreModelResponse { string status = 1; @@ -360,15 +397,11 @@ message AggregationRequest { string aggregate = 1; } -message AggregationResponse { - bytes data = 1; -} - service FunctionService { rpc HandleProvidedFunctions(ProvidedFunctionsRequest) returns (ProvidedFunctionsResponse); - rpc HandleClientConfig (stream ClientConfigRequest) returns (ClientConfigResponse); + rpc HandleClientConfig (stream FileChunk) returns (ClientConfigResponse); rpc HandleClientSelection (ClientSelectionRequest) returns (ClientSelectionResponse); rpc HandleMetadata (ClientMetaRequest) returns (ClientMetaResponse); - rpc HandleStoreModel (stream StoreModelRequest) returns (StoreModelResponse); - rpc HandleAggregation (AggregationRequest) returns (stream AggregationResponse); + rpc HandleStoreModel (stream FileChunk) returns (StoreModelResponse); + rpc HandleAggregation (AggregationRequest) returns (stream FileChunk); } \ No newline at end of file diff --git a/fedn/network/grpc/fedn_pb2.py b/fedn/network/grpc/fedn_pb2.py index f3fb3b3ef..e822bac54 100644 --- a/fedn/network/grpc/fedn_pb2.py +++ b/fedn/network/grpc/fedn_pb2.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # NO CHECKED-IN PROTOBUF GENCODE -# source: fedn.proto +# source: network/grpc/fedn.proto # Protobuf Python Version: 5.28.1 """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor @@ -15,7 +15,7 @@ 28, 1, '', - 'fedn.proto' + 'network/grpc/fedn.proto' ) # @@protoc_insertion_point(imports) @@ -24,125 +24,130 @@ from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2 from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2 +from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\nfedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/wrappers.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xf1\x01\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12!\n\tlog_level\x18\x03 \x01(\x0e\x32\x0e.fedn.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"\xd8\x01\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\"\xbf\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\xdb\x01\n\x0fModelPrediction\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x15\n\rprediction_id\x18\x08 \x01(\t\"\xd0\x01\n\x12\x42\x61\x63kwardCompletion\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x13\n\x0bgradient_id\x18\x03 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x04 \x01(\t\x12\x12\n\nsession_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xe1\x01\n\x0bModelMetric\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12!\n\x07metrics\x18\x02 \x03(\x0b\x32\x10.fedn.MetricElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12*\n\x04step\x18\x04 \x01(\x0b\x32\x1c.google.protobuf.UInt32Value\x12\x10\n\x08model_id\x18\x05 \x01(\t\x12\x10\n\x08round_id\x18\x06 \x01(\t\x12\x12\n\nsession_id\x18\x07 \x01(\t\"(\n\nMetricElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"\x88\x01\n\x10\x41ttributeMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\'\n\nattributes\x18\x02 \x03(\x0b\x32\x13.fedn.AttributeElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"+\n\rAttributeElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x89\x01\n\x10TelemetryMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12(\n\x0btelemetries\x18\x02 \x03(\x0b\x32\x13.fedn.TelemetryElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"+\n\rTelemetryElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"\x89\x01\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x03 \x01(\x0c\x12\n\n\x02id\x18\x04 \x01(\t\x12!\n\x06status\x18\x05 \x01(\x0e\x32\x11.fedn.ModelStatus\"]\n\rModelResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\"^\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1a\n\x12memory_utilisation\x18\x02 \x01(\x02\x12\x17\n\x0f\x63pu_utilisation\x18\x03 \x01(\x02\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"T\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus\"1\n\x18ProvidedFunctionsRequest\x12\x15\n\rfunction_code\x18\x01 \x01(\t\"\xac\x01\n\x19ProvidedFunctionsResponse\x12T\n\x13\x61vailable_functions\x18\x01 \x03(\x0b\x32\x37.fedn.ProvidedFunctionsResponse.AvailableFunctionsEntry\x1a\x39\n\x17\x41vailableFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x08:\x02\x38\x01\"#\n\x13\x43lientConfigRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"/\n\x14\x43lientConfigResponse\x12\x17\n\x0f\x63lient_settings\x18\x01 \x01(\t\",\n\x16\x43lientSelectionRequest\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"-\n\x17\x43lientSelectionResponse\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"8\n\x11\x43lientMetaRequest\x12\x10\n\x08metadata\x18\x01 \x01(\t\x12\x11\n\tclient_id\x18\x02 \x01(\t\"$\n\x12\x43lientMetaResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"-\n\x11StoreModelRequest\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\x12\n\n\x02id\x18\x02 \x01(\t\"$\n\x12StoreModelResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"\'\n\x12\x41ggregationRequest\x12\x11\n\taggregate\x18\x01 \x01(\t\"#\n\x13\x41ggregationResponse\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c*\xde\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\x14\n\x10MODEL_PREDICTION\x10\x05\x12\x0b\n\x07NETWORK\x10\x06\x12\x13\n\x0f\x46ORWARD_REQUEST\x10\x07\x12\x0b\n\x07\x46ORWARD\x10\x08\x12\x14\n\x10\x42\x41\x43KWARD_REQUEST\x10\t\x12\x0c\n\x08\x42\x41\x43KWARD\x10\n*L\n\x08LogLevel\x12\x08\n\x04NONE\x10\x00\x12\x08\n\x04INFO\x10\x01\x12\t\n\x05\x44\x45\x42UG\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\t\n\x05\x41UDIT\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*S\n\x0bModelStatus\x12\x06\n\x02OK\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x12\n\x0eIN_PROGRESS_OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03\x12\x0b\n\x07UNKNOWN\x10\x04*8\n\x04Role\x12\t\n\x05OTHER\x10\x00\x12\n\n\x06\x43LIENT\x10\x01\x12\x0c\n\x08\x43OMBINER\x10\x02\x12\x0b\n\x07REDUCER\x10\x03*J\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32z\n\x0cModelService\x12\x33\n\x06Upload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse(\x01\x12\x35\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x13.fedn.ModelResponse0\x01\x32\xbb\x02\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x41\n\x12SetServerFunctions\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xf7\x03\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Response\x12<\n\x13SendModelPrediction\x12\x15.fedn.ModelPrediction\x1a\x0e.fedn.Response\x12\x42\n\x16SendBackwardCompletion\x12\x18.fedn.BackwardCompletion\x1a\x0e.fedn.Response\x12\x34\n\x0fSendModelMetric\x12\x11.fedn.ModelMetric\x1a\x0e.fedn.Response\x12>\n\x14SendAttributeMessage\x12\x16.fedn.AttributeMessage\x1a\x0e.fedn.Response\x12>\n\x14SendTelemetryMessage\x12\x16.fedn.TelemetryMessage\x1a\x0e.fedn.Response2\xec\x03\n\x0f\x46unctionService\x12Z\n\x17HandleProvidedFunctions\x12\x1e.fedn.ProvidedFunctionsRequest\x1a\x1f.fedn.ProvidedFunctionsResponse\x12M\n\x12HandleClientConfig\x12\x19.fedn.ClientConfigRequest\x1a\x1a.fedn.ClientConfigResponse(\x01\x12T\n\x15HandleClientSelection\x12\x1c.fedn.ClientSelectionRequest\x1a\x1d.fedn.ClientSelectionResponse\x12\x43\n\x0eHandleMetadata\x12\x17.fedn.ClientMetaRequest\x1a\x18.fedn.ClientMetaResponse\x12G\n\x10HandleStoreModel\x12\x17.fedn.StoreModelRequest\x1a\x18.fedn.StoreModelResponse(\x01\x12J\n\x11HandleAggregation\x12\x18.fedn.AggregationRequest\x1a\x19.fedn.AggregationResponse0\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x17network/grpc/fedn.proto\x12\x04\x66\x65\x64n\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/wrappers.proto\x1a\x1cgoogle/protobuf/struct.proto\":\n\x08Response\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08response\x18\x02 \x01(\t\"\xf1\x01\n\x06Status\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06status\x18\x02 \x01(\t\x12!\n\tlog_level\x18\x03 \x01(\x0e\x32\x0e.fedn.LogLevel\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x1e\n\x04type\x18\x07 \x01(\x0e\x32\x10.fedn.StatusType\x12\r\n\x05\x65xtra\x18\x08 \x01(\t\x12\x12\n\nsession_id\x18\t \x01(\t\"\xa4\x02\n\x0bTaskRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\x12\x1e\n\x04type\x18\t \x01(\x0e\x32\x10.fedn.StatusType\x12\x10\n\x08round_id\x18\n \x01(\t\x12\x11\n\ttask_type\x18\x0b \x01(\t\x12%\n\x0btask_status\x18\x0c \x01(\x0e\x32\x10.fedn.TaskStatus\"\xe5\x01\n\x0bModelUpdate\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x17\n\x0fmodel_update_id\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12\x11\n\ttimestamp\x18\x06 \x01(\t\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x0e\n\x06\x63onfig\x18\x08 \x01(\t\x12\x10\n\x08round_id\x18\t \x01(\t\x12\x12\n\nsession_id\x18\n \x01(\t\"\xd8\x01\n\x0fModelValidation\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x12\n\nsession_id\x18\x08 \x01(\t\"\xdb\x01\n\x0fModelPrediction\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\x12\x15\n\rprediction_id\x18\x08 \x01(\t\"\xd0\x01\n\x12\x42\x61\x63kwardCompletion\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x13\n\x0bgradient_id\x18\x03 \x01(\t\x12\x16\n\x0e\x63orrelation_id\x18\x04 \x01(\t\x12\x12\n\nsession_id\x18\x05 \x01(\t\x12-\n\ttimestamp\x18\x06 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x0c\n\x04meta\x18\x07 \x01(\t\"\xe1\x01\n\x0bModelMetric\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12!\n\x07metrics\x18\x02 \x03(\x0b\x32\x10.fedn.MetricElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12*\n\x04step\x18\x04 \x01(\x0b\x32\x1c.google.protobuf.UInt32Value\x12\x10\n\x08model_id\x18\x05 \x01(\t\x12\x10\n\x08round_id\x18\x06 \x01(\t\x12\x12\n\nsession_id\x18\x07 \x01(\t\"(\n\nMetricElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"\x88\x01\n\x10\x41ttributeMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\'\n\nattributes\x18\x02 \x03(\x0b\x32\x13.fedn.AttributeElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"+\n\rAttributeElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x89\x01\n\x10TelemetryMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12(\n\x0btelemetries\x18\x02 \x03(\x0b\x32\x13.fedn.TelemetryElem\x12-\n\ttimestamp\x18\x03 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\"+\n\rTelemetryElem\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x02\"\xa1\x01\n\x0e\x41\x63tivityReport\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x16\n\x0e\x63orrelation_id\x18\x02 \x01(\t\x12 \n\x06status\x18\x03 \x01(\x0e\x32\x10.fedn.TaskStatus\x12\x0c\n\x04\x64one\x18\x04 \x01(\x08\x12)\n\x08response\x18\x05 \x01(\x0b\x32\x17.google.protobuf.Struct\"^\n\x0cModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\"\x19\n\tFileChunk\x12\x0c\n\x04\x64\x61ta\x18\x01 \x01(\x0c\"C\n\rModelResponse\x12!\n\x06status\x18\x03 \x01(\x0e\x32\x11.fedn.ModelStatus\x12\x0f\n\x07message\x18\x04 \x01(\t\"U\n\x15GetGlobalModelRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\"h\n\x16GetGlobalModelResponse\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x10\n\x08model_id\x18\x03 \x01(\t\"^\n\tHeartbeat\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1a\n\x12memory_utilisation\x18\x02 \x01(\x02\x12\x17\n\x0f\x63pu_utilisation\x18\x03 \x01(\x02\"W\n\x16\x43lientAvailableMessage\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\t\x12\x11\n\ttimestamp\x18\x03 \x01(\t\"P\n\x12ListClientsRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1c\n\x07\x63hannel\x18\x02 \x01(\x0e\x32\x0b.fedn.Queue\"*\n\nClientList\x12\x1c\n\x06\x63lient\x18\x01 \x03(\x0b\x32\x0c.fedn.Client\"C\n\x06\x43lient\x12\x18\n\x04role\x18\x01 \x01(\x0e\x32\n.fedn.Role\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x11\n\tclient_id\x18\x03 \x01(\t\"m\n\x0fReassignRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x0e\n\x06server\x18\x03 \x01(\t\x12\x0c\n\x04port\x18\x04 \x01(\r\"c\n\x10ReconnectRequest\x12\x1c\n\x06sender\x18\x01 \x01(\x0b\x32\x0c.fedn.Client\x12\x1e\n\x08receiver\x18\x02 \x01(\x0b\x32\x0c.fedn.Client\x12\x11\n\treconnect\x18\x03 \x01(\r\"\'\n\tParameter\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"j\n\x0e\x43ontrolRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\x12\x14\n\x0c\x63ommand_type\x18\x03 \x01(\t\"r\n\x0e\x43ommandRequest\x12\x1e\n\x07\x63ommand\x18\x01 \x01(\x0e\x32\r.fedn.Command\x12\x16\n\x0e\x63orrelation_id\x18\x02 \x01(\t\x12\x14\n\x0c\x63ommand_type\x18\x03 \x01(\t\x12\x12\n\nparameters\x18\x04 \x01(\t\"F\n\x0f\x43ontrolResponse\x12\x0f\n\x07message\x18\x01 \x01(\t\x12\"\n\tparameter\x18\x02 \x03(\x0b\x32\x0f.fedn.Parameter\"%\n\x14\x43ontrolStateResponse\x12\r\n\x05state\x18\x01 \x01(\t\"\x13\n\x11\x43onnectionRequest\"<\n\x12\x43onnectionResponse\x12&\n\x06status\x18\x01 \x01(\x0e\x32\x16.fedn.ConnectionStatus\"1\n\x18ProvidedFunctionsRequest\x12\x15\n\rfunction_code\x18\x01 \x01(\t\"\xac\x01\n\x19ProvidedFunctionsResponse\x12T\n\x13\x61vailable_functions\x18\x01 \x03(\x0b\x32\x37.fedn.ProvidedFunctionsResponse.AvailableFunctionsEntry\x1a\x39\n\x17\x41vailableFunctionsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x08:\x02\x38\x01\"/\n\x14\x43lientConfigResponse\x12\x17\n\x0f\x63lient_settings\x18\x01 \x01(\t\",\n\x16\x43lientSelectionRequest\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"-\n\x17\x43lientSelectionResponse\x12\x12\n\nclient_ids\x18\x01 \x01(\t\"8\n\x11\x43lientMetaRequest\x12\x10\n\x08metadata\x18\x01 \x01(\t\x12\x11\n\tclient_id\x18\x02 \x01(\t\"$\n\x12\x43lientMetaResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"$\n\x12StoreModelResponse\x12\x0e\n\x06status\x18\x01 \x01(\t\"\'\n\x12\x41ggregationRequest\x12\x11\n\taggregate\x18\x01 \x01(\t*\xde\x01\n\nStatusType\x12\x07\n\x03LOG\x10\x00\x12\x18\n\x14MODEL_UPDATE_REQUEST\x10\x01\x12\x10\n\x0cMODEL_UPDATE\x10\x02\x12\x1c\n\x18MODEL_VALIDATION_REQUEST\x10\x03\x12\x14\n\x10MODEL_VALIDATION\x10\x04\x12\x14\n\x10MODEL_PREDICTION\x10\x05\x12\x0b\n\x07NETWORK\x10\x06\x12\x13\n\x0f\x46ORWARD_REQUEST\x10\x07\x12\x0b\n\x07\x46ORWARD\x10\x08\x12\x14\n\x10\x42\x41\x43KWARD_REQUEST\x10\t\x12\x0c\n\x08\x42\x41\x43KWARD\x10\n*L\n\x08LogLevel\x12\x08\n\x04NONE\x10\x00\x12\x08\n\x04INFO\x10\x01\x12\t\n\x05\x44\x45\x42UG\x10\x02\x12\x0b\n\x07WARNING\x10\x03\x12\t\n\x05\x45RROR\x10\x04\x12\t\n\x05\x41UDIT\x10\x05*$\n\x05Queue\x12\x0b\n\x07\x44\x45\x46\x41ULT\x10\x00\x12\x0e\n\nTASK_QUEUE\x10\x01*\x9a\x01\n\nTaskStatus\x12\r\n\tTASK_NONE\x10\x00\x12\x10\n\x0cTASK_PENDING\x10\x01\x12\x10\n\x0cTASK_RUNNING\x10\x02\x12\x12\n\x0eTASK_COMPLETED\x10\x03\x12\x0f\n\x0bTASK_FAILED\x10\x04\x12\x14\n\x10TASK_INTERRUPTED\x10\x05\x12\x0c\n\x08TASK_NEW\x10\x06\x12\x10\n\x0cTASK_TIMEOUT\x10\x07*?\n\x0bModelStatus\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0f\n\x0bIN_PROGRESS\x10\x01\x12\x06\n\x02OK\x10\x02\x12\n\n\x06\x46\x41ILED\x10\x03*8\n\x04Role\x12\t\n\x05OTHER\x10\x00\x12\n\n\x06\x43LIENT\x10\x01\x12\x0c\n\x08\x43OMBINER\x10\x02\x12\x0b\n\x07REDUCER\x10\x03*X\n\x07\x43ommand\x12\x08\n\x04IDLE\x10\x00\x12\t\n\x05START\x10\x01\x12\t\n\x05PAUSE\x10\x02\x12\x08\n\x04STOP\x10\x03\x12\t\n\x05RESET\x10\x04\x12\n\n\x06REPORT\x10\x05\x12\x0c\n\x08\x43ONTINUE\x10\x06*I\n\x10\x43onnectionStatus\x12\x11\n\rNOT_ACCEPTING\x10\x00\x12\r\n\tACCEPTING\x10\x01\x12\x13\n\x0fTRY_AGAIN_LATER\x10\x02\x32s\n\x0cModelService\x12\x30\n\x06Upload\x12\x0f.fedn.FileChunk\x1a\x13.fedn.ModelResponse(\x01\x12\x31\n\x08\x44ownload\x12\x12.fedn.ModelRequest\x1a\x0f.fedn.FileChunk0\x01\x32\xb5\x03\n\x07\x43ontrol\x12\x34\n\x05Start\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x33\n\x04Stop\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12:\n\x0bSendCommand\x12\x14.fedn.CommandRequest\x1a\x15.fedn.ControlResponse\x12\x44\n\x15\x46lushAggregationQueue\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\rSetAggregator\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12\x41\n\x12SetServerFunctions\x12\x14.fedn.ControlRequest\x1a\x15.fedn.ControlResponse\x12<\n\x08GetState\x12\x14.fedn.ControlRequest\x1a\x1a.fedn.ControlStateResponse2V\n\x07Reducer\x12K\n\x0eGetGlobalModel\x12\x1b.fedn.GetGlobalModelRequest\x1a\x1c.fedn.GetGlobalModelResponse2\xab\x03\n\tConnector\x12\x44\n\x14\x41llianceStatusStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x0c.fedn.Status0\x01\x12*\n\nSendStatus\x12\x0c.fedn.Status\x1a\x0e.fedn.Response\x12?\n\x11ListActiveClients\x12\x18.fedn.ListClientsRequest\x1a\x10.fedn.ClientList\x12\x45\n\x10\x41\x63\x63\x65ptingClients\x12\x17.fedn.ConnectionRequest\x1a\x18.fedn.ConnectionResponse\x12\x30\n\rSendHeartbeat\x12\x0f.fedn.Heartbeat\x1a\x0e.fedn.Response\x12\x37\n\x0eReassignClient\x12\x15.fedn.ReassignRequest\x1a\x0e.fedn.Response\x12\x39\n\x0fReconnectClient\x12\x16.fedn.ReconnectRequest\x1a\x0e.fedn.Response2\xb1\x04\n\x08\x43ombiner\x12?\n\nTaskStream\x12\x1c.fedn.ClientAvailableMessage\x1a\x11.fedn.TaskRequest0\x01\x12\x34\n\x0fSendModelUpdate\x12\x11.fedn.ModelUpdate\x1a\x0e.fedn.Response\x12<\n\x13SendModelValidation\x12\x15.fedn.ModelValidation\x1a\x0e.fedn.Response\x12<\n\x13SendModelPrediction\x12\x15.fedn.ModelPrediction\x1a\x0e.fedn.Response\x12\x42\n\x16SendBackwardCompletion\x12\x18.fedn.BackwardCompletion\x1a\x0e.fedn.Response\x12\x34\n\x0fSendModelMetric\x12\x11.fedn.ModelMetric\x1a\x0e.fedn.Response\x12>\n\x14SendAttributeMessage\x12\x16.fedn.AttributeMessage\x1a\x0e.fedn.Response\x12>\n\x14SendTelemetryMessage\x12\x16.fedn.TelemetryMessage\x1a\x0e.fedn.Response\x12\x38\n\rPollAndReport\x12\x14.fedn.ActivityReport\x1a\x11.fedn.TaskRequest2\xd0\x03\n\x0f\x46unctionService\x12Z\n\x17HandleProvidedFunctions\x12\x1e.fedn.ProvidedFunctionsRequest\x1a\x1f.fedn.ProvidedFunctionsResponse\x12\x43\n\x12HandleClientConfig\x12\x0f.fedn.FileChunk\x1a\x1a.fedn.ClientConfigResponse(\x01\x12T\n\x15HandleClientSelection\x12\x1c.fedn.ClientSelectionRequest\x1a\x1d.fedn.ClientSelectionResponse\x12\x43\n\x0eHandleMetadata\x12\x17.fedn.ClientMetaRequest\x1a\x18.fedn.ClientMetaResponse\x12?\n\x10HandleStoreModel\x12\x0f.fedn.FileChunk\x1a\x18.fedn.StoreModelResponse(\x01\x12@\n\x11HandleAggregation\x12\x18.fedn.AggregationRequest\x1a\x0f.fedn.FileChunk0\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'fedn_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'network.grpc.fedn_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._loaded_options = None _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_options = b'8\001' - _globals['_STATUSTYPE']._serialized_start=4060 - _globals['_STATUSTYPE']._serialized_end=4282 - _globals['_LOGLEVEL']._serialized_start=4284 - _globals['_LOGLEVEL']._serialized_end=4360 - _globals['_QUEUE']._serialized_start=4362 - _globals['_QUEUE']._serialized_end=4398 - _globals['_MODELSTATUS']._serialized_start=4400 - _globals['_MODELSTATUS']._serialized_end=4483 - _globals['_ROLE']._serialized_start=4485 - _globals['_ROLE']._serialized_end=4541 - _globals['_COMMAND']._serialized_start=4543 - _globals['_COMMAND']._serialized_end=4617 - _globals['_CONNECTIONSTATUS']._serialized_start=4619 - _globals['_CONNECTIONSTATUS']._serialized_end=4692 - _globals['_RESPONSE']._serialized_start=85 - _globals['_RESPONSE']._serialized_end=143 - _globals['_STATUS']._serialized_start=146 - _globals['_STATUS']._serialized_end=387 - _globals['_TASKREQUEST']._serialized_start=390 - _globals['_TASKREQUEST']._serialized_end=606 - _globals['_MODELUPDATE']._serialized_start=609 - _globals['_MODELUPDATE']._serialized_end=800 - _globals['_MODELVALIDATION']._serialized_start=803 - _globals['_MODELVALIDATION']._serialized_end=1019 - _globals['_MODELPREDICTION']._serialized_start=1022 - _globals['_MODELPREDICTION']._serialized_end=1241 - _globals['_BACKWARDCOMPLETION']._serialized_start=1244 - _globals['_BACKWARDCOMPLETION']._serialized_end=1452 - _globals['_MODELMETRIC']._serialized_start=1455 - _globals['_MODELMETRIC']._serialized_end=1680 - _globals['_METRICELEM']._serialized_start=1682 - _globals['_METRICELEM']._serialized_end=1722 - _globals['_ATTRIBUTEMESSAGE']._serialized_start=1725 - _globals['_ATTRIBUTEMESSAGE']._serialized_end=1861 - _globals['_ATTRIBUTEELEM']._serialized_start=1863 - _globals['_ATTRIBUTEELEM']._serialized_end=1906 - _globals['_TELEMETRYMESSAGE']._serialized_start=1909 - _globals['_TELEMETRYMESSAGE']._serialized_end=2046 - _globals['_TELEMETRYELEM']._serialized_start=2048 - _globals['_TELEMETRYELEM']._serialized_end=2091 - _globals['_MODELREQUEST']._serialized_start=2094 - _globals['_MODELREQUEST']._serialized_end=2231 - _globals['_MODELRESPONSE']._serialized_start=2233 - _globals['_MODELRESPONSE']._serialized_end=2326 - _globals['_GETGLOBALMODELREQUEST']._serialized_start=2328 - _globals['_GETGLOBALMODELREQUEST']._serialized_end=2413 - _globals['_GETGLOBALMODELRESPONSE']._serialized_start=2415 - _globals['_GETGLOBALMODELRESPONSE']._serialized_end=2519 - _globals['_HEARTBEAT']._serialized_start=2521 - _globals['_HEARTBEAT']._serialized_end=2615 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=2617 - _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=2704 - _globals['_LISTCLIENTSREQUEST']._serialized_start=2706 - _globals['_LISTCLIENTSREQUEST']._serialized_end=2786 - _globals['_CLIENTLIST']._serialized_start=2788 - _globals['_CLIENTLIST']._serialized_end=2830 - _globals['_CLIENT']._serialized_start=2832 - _globals['_CLIENT']._serialized_end=2899 - _globals['_REASSIGNREQUEST']._serialized_start=2901 - _globals['_REASSIGNREQUEST']._serialized_end=3010 - _globals['_RECONNECTREQUEST']._serialized_start=3012 - _globals['_RECONNECTREQUEST']._serialized_end=3111 - _globals['_PARAMETER']._serialized_start=3113 - _globals['_PARAMETER']._serialized_end=3152 - _globals['_CONTROLREQUEST']._serialized_start=3154 - _globals['_CONTROLREQUEST']._serialized_end=3238 - _globals['_CONTROLRESPONSE']._serialized_start=3240 - _globals['_CONTROLRESPONSE']._serialized_end=3310 - _globals['_CONNECTIONREQUEST']._serialized_start=3312 - _globals['_CONNECTIONREQUEST']._serialized_end=3331 - _globals['_CONNECTIONRESPONSE']._serialized_start=3333 - _globals['_CONNECTIONRESPONSE']._serialized_end=3393 - _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_start=3395 - _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_end=3444 - _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_start=3447 - _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_end=3619 - _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_start=3562 - _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_end=3619 - _globals['_CLIENTCONFIGREQUEST']._serialized_start=3621 - _globals['_CLIENTCONFIGREQUEST']._serialized_end=3656 - _globals['_CLIENTCONFIGRESPONSE']._serialized_start=3658 - _globals['_CLIENTCONFIGRESPONSE']._serialized_end=3705 - _globals['_CLIENTSELECTIONREQUEST']._serialized_start=3707 - _globals['_CLIENTSELECTIONREQUEST']._serialized_end=3751 - _globals['_CLIENTSELECTIONRESPONSE']._serialized_start=3753 - _globals['_CLIENTSELECTIONRESPONSE']._serialized_end=3798 - _globals['_CLIENTMETAREQUEST']._serialized_start=3800 - _globals['_CLIENTMETAREQUEST']._serialized_end=3856 - _globals['_CLIENTMETARESPONSE']._serialized_start=3858 - _globals['_CLIENTMETARESPONSE']._serialized_end=3894 - _globals['_STOREMODELREQUEST']._serialized_start=3896 - _globals['_STOREMODELREQUEST']._serialized_end=3941 - _globals['_STOREMODELRESPONSE']._serialized_start=3943 - _globals['_STOREMODELRESPONSE']._serialized_end=3979 - _globals['_AGGREGATIONREQUEST']._serialized_start=3981 - _globals['_AGGREGATIONREQUEST']._serialized_end=4020 - _globals['_AGGREGATIONRESPONSE']._serialized_start=4022 - _globals['_AGGREGATIONRESPONSE']._serialized_end=4057 - _globals['_MODELSERVICE']._serialized_start=4694 - _globals['_MODELSERVICE']._serialized_end=4816 - _globals['_CONTROL']._serialized_start=4819 - _globals['_CONTROL']._serialized_end=5134 - _globals['_REDUCER']._serialized_start=5136 - _globals['_REDUCER']._serialized_end=5222 - _globals['_CONNECTOR']._serialized_start=5225 - _globals['_CONNECTOR']._serialized_end=5652 - _globals['_COMBINER']._serialized_start=5655 - _globals['_COMBINER']._serialized_end=6158 - _globals['_FUNCTIONSERVICE']._serialized_start=6161 - _globals['_FUNCTIONSERVICE']._serialized_end=6653 + _globals['_STATUSTYPE']._serialized_start=4394 + _globals['_STATUSTYPE']._serialized_end=4616 + _globals['_LOGLEVEL']._serialized_start=4618 + _globals['_LOGLEVEL']._serialized_end=4694 + _globals['_QUEUE']._serialized_start=4696 + _globals['_QUEUE']._serialized_end=4732 + _globals['_TASKSTATUS']._serialized_start=4735 + _globals['_TASKSTATUS']._serialized_end=4889 + _globals['_MODELSTATUS']._serialized_start=4891 + _globals['_MODELSTATUS']._serialized_end=4954 + _globals['_ROLE']._serialized_start=4956 + _globals['_ROLE']._serialized_end=5012 + _globals['_COMMAND']._serialized_start=5014 + _globals['_COMMAND']._serialized_end=5102 + _globals['_CONNECTIONSTATUS']._serialized_start=5104 + _globals['_CONNECTIONSTATUS']._serialized_end=5177 + _globals['_RESPONSE']._serialized_start=128 + _globals['_RESPONSE']._serialized_end=186 + _globals['_STATUS']._serialized_start=189 + _globals['_STATUS']._serialized_end=430 + _globals['_TASKREQUEST']._serialized_start=433 + _globals['_TASKREQUEST']._serialized_end=725 + _globals['_MODELUPDATE']._serialized_start=728 + _globals['_MODELUPDATE']._serialized_end=957 + _globals['_MODELVALIDATION']._serialized_start=960 + _globals['_MODELVALIDATION']._serialized_end=1176 + _globals['_MODELPREDICTION']._serialized_start=1179 + _globals['_MODELPREDICTION']._serialized_end=1398 + _globals['_BACKWARDCOMPLETION']._serialized_start=1401 + _globals['_BACKWARDCOMPLETION']._serialized_end=1609 + _globals['_MODELMETRIC']._serialized_start=1612 + _globals['_MODELMETRIC']._serialized_end=1837 + _globals['_METRICELEM']._serialized_start=1839 + _globals['_METRICELEM']._serialized_end=1879 + _globals['_ATTRIBUTEMESSAGE']._serialized_start=1882 + _globals['_ATTRIBUTEMESSAGE']._serialized_end=2018 + _globals['_ATTRIBUTEELEM']._serialized_start=2020 + _globals['_ATTRIBUTEELEM']._serialized_end=2063 + _globals['_TELEMETRYMESSAGE']._serialized_start=2066 + _globals['_TELEMETRYMESSAGE']._serialized_end=2203 + _globals['_TELEMETRYELEM']._serialized_start=2205 + _globals['_TELEMETRYELEM']._serialized_end=2248 + _globals['_ACTIVITYREPORT']._serialized_start=2251 + _globals['_ACTIVITYREPORT']._serialized_end=2412 + _globals['_MODELREQUEST']._serialized_start=2414 + _globals['_MODELREQUEST']._serialized_end=2508 + _globals['_FILECHUNK']._serialized_start=2510 + _globals['_FILECHUNK']._serialized_end=2535 + _globals['_MODELRESPONSE']._serialized_start=2537 + _globals['_MODELRESPONSE']._serialized_end=2604 + _globals['_GETGLOBALMODELREQUEST']._serialized_start=2606 + _globals['_GETGLOBALMODELREQUEST']._serialized_end=2691 + _globals['_GETGLOBALMODELRESPONSE']._serialized_start=2693 + _globals['_GETGLOBALMODELRESPONSE']._serialized_end=2797 + _globals['_HEARTBEAT']._serialized_start=2799 + _globals['_HEARTBEAT']._serialized_end=2893 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_start=2895 + _globals['_CLIENTAVAILABLEMESSAGE']._serialized_end=2982 + _globals['_LISTCLIENTSREQUEST']._serialized_start=2984 + _globals['_LISTCLIENTSREQUEST']._serialized_end=3064 + _globals['_CLIENTLIST']._serialized_start=3066 + _globals['_CLIENTLIST']._serialized_end=3108 + _globals['_CLIENT']._serialized_start=3110 + _globals['_CLIENT']._serialized_end=3177 + _globals['_REASSIGNREQUEST']._serialized_start=3179 + _globals['_REASSIGNREQUEST']._serialized_end=3288 + _globals['_RECONNECTREQUEST']._serialized_start=3290 + _globals['_RECONNECTREQUEST']._serialized_end=3389 + _globals['_PARAMETER']._serialized_start=3391 + _globals['_PARAMETER']._serialized_end=3430 + _globals['_CONTROLREQUEST']._serialized_start=3432 + _globals['_CONTROLREQUEST']._serialized_end=3538 + _globals['_COMMANDREQUEST']._serialized_start=3540 + _globals['_COMMANDREQUEST']._serialized_end=3654 + _globals['_CONTROLRESPONSE']._serialized_start=3656 + _globals['_CONTROLRESPONSE']._serialized_end=3726 + _globals['_CONTROLSTATERESPONSE']._serialized_start=3728 + _globals['_CONTROLSTATERESPONSE']._serialized_end=3765 + _globals['_CONNECTIONREQUEST']._serialized_start=3767 + _globals['_CONNECTIONREQUEST']._serialized_end=3786 + _globals['_CONNECTIONRESPONSE']._serialized_start=3788 + _globals['_CONNECTIONRESPONSE']._serialized_end=3848 + _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_start=3850 + _globals['_PROVIDEDFUNCTIONSREQUEST']._serialized_end=3899 + _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_start=3902 + _globals['_PROVIDEDFUNCTIONSRESPONSE']._serialized_end=4074 + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_start=4017 + _globals['_PROVIDEDFUNCTIONSRESPONSE_AVAILABLEFUNCTIONSENTRY']._serialized_end=4074 + _globals['_CLIENTCONFIGRESPONSE']._serialized_start=4076 + _globals['_CLIENTCONFIGRESPONSE']._serialized_end=4123 + _globals['_CLIENTSELECTIONREQUEST']._serialized_start=4125 + _globals['_CLIENTSELECTIONREQUEST']._serialized_end=4169 + _globals['_CLIENTSELECTIONRESPONSE']._serialized_start=4171 + _globals['_CLIENTSELECTIONRESPONSE']._serialized_end=4216 + _globals['_CLIENTMETAREQUEST']._serialized_start=4218 + _globals['_CLIENTMETAREQUEST']._serialized_end=4274 + _globals['_CLIENTMETARESPONSE']._serialized_start=4276 + _globals['_CLIENTMETARESPONSE']._serialized_end=4312 + _globals['_STOREMODELRESPONSE']._serialized_start=4314 + _globals['_STOREMODELRESPONSE']._serialized_end=4350 + _globals['_AGGREGATIONREQUEST']._serialized_start=4352 + _globals['_AGGREGATIONREQUEST']._serialized_end=4391 + _globals['_MODELSERVICE']._serialized_start=5179 + _globals['_MODELSERVICE']._serialized_end=5294 + _globals['_CONTROL']._serialized_start=5297 + _globals['_CONTROL']._serialized_end=5734 + _globals['_REDUCER']._serialized_start=5736 + _globals['_REDUCER']._serialized_end=5822 + _globals['_CONNECTOR']._serialized_start=5825 + _globals['_CONNECTOR']._serialized_end=6252 + _globals['_COMBINER']._serialized_start=6255 + _globals['_COMBINER']._serialized_end=6816 + _globals['_FUNCTIONSERVICE']._serialized_start=6819 + _globals['_FUNCTIONSERVICE']._serialized_end=7283 # @@protoc_insertion_point(module_scope) diff --git a/fedn/network/grpc/fedn_pb2.pyi b/fedn/network/grpc/fedn_pb2.pyi new file mode 100644 index 000000000..d78cee9ba --- /dev/null +++ b/fedn/network/grpc/fedn_pb2.pyi @@ -0,0 +1,1183 @@ +""" +@generated by mypy-protobuf. Do not edit manually! +isort:skip_file +""" + +import builtins +import collections.abc +import google.protobuf.descriptor +import google.protobuf.internal.containers +import google.protobuf.internal.enum_type_wrapper +import google.protobuf.message +import google.protobuf.struct_pb2 +import google.protobuf.timestamp_pb2 +import google.protobuf.wrappers_pb2 +import sys +import typing + +if sys.version_info >= (3, 10): + import typing as typing_extensions +else: + import typing_extensions + +DESCRIPTOR: google.protobuf.descriptor.FileDescriptor + +class _StatusType: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _StatusTypeEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_StatusType.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + LOG: _StatusType.ValueType # 0 + MODEL_UPDATE_REQUEST: _StatusType.ValueType # 1 + MODEL_UPDATE: _StatusType.ValueType # 2 + MODEL_VALIDATION_REQUEST: _StatusType.ValueType # 3 + MODEL_VALIDATION: _StatusType.ValueType # 4 + MODEL_PREDICTION: _StatusType.ValueType # 5 + NETWORK: _StatusType.ValueType # 6 + FORWARD_REQUEST: _StatusType.ValueType # 7 + FORWARD: _StatusType.ValueType # 8 + BACKWARD_REQUEST: _StatusType.ValueType # 9 + BACKWARD: _StatusType.ValueType # 10 + +class StatusType(_StatusType, metaclass=_StatusTypeEnumTypeWrapper): ... + +LOG: StatusType.ValueType # 0 +MODEL_UPDATE_REQUEST: StatusType.ValueType # 1 +MODEL_UPDATE: StatusType.ValueType # 2 +MODEL_VALIDATION_REQUEST: StatusType.ValueType # 3 +MODEL_VALIDATION: StatusType.ValueType # 4 +MODEL_PREDICTION: StatusType.ValueType # 5 +NETWORK: StatusType.ValueType # 6 +FORWARD_REQUEST: StatusType.ValueType # 7 +FORWARD: StatusType.ValueType # 8 +BACKWARD_REQUEST: StatusType.ValueType # 9 +BACKWARD: StatusType.ValueType # 10 +global___StatusType = StatusType + +class _LogLevel: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _LogLevelEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_LogLevel.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + NONE: _LogLevel.ValueType # 0 + INFO: _LogLevel.ValueType # 1 + DEBUG: _LogLevel.ValueType # 2 + WARNING: _LogLevel.ValueType # 3 + ERROR: _LogLevel.ValueType # 4 + AUDIT: _LogLevel.ValueType # 5 + +class LogLevel(_LogLevel, metaclass=_LogLevelEnumTypeWrapper): ... + +NONE: LogLevel.ValueType # 0 +INFO: LogLevel.ValueType # 1 +DEBUG: LogLevel.ValueType # 2 +WARNING: LogLevel.ValueType # 3 +ERROR: LogLevel.ValueType # 4 +AUDIT: LogLevel.ValueType # 5 +global___LogLevel = LogLevel + +class _Queue: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _QueueEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Queue.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + DEFAULT: _Queue.ValueType # 0 + TASK_QUEUE: _Queue.ValueType # 1 + +class Queue(_Queue, metaclass=_QueueEnumTypeWrapper): ... + +DEFAULT: Queue.ValueType # 0 +TASK_QUEUE: Queue.ValueType # 1 +global___Queue = Queue + +class _TaskStatus: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _TaskStatusEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_TaskStatus.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + TASK_NONE: _TaskStatus.ValueType # 0 + TASK_PENDING: _TaskStatus.ValueType # 1 + TASK_RUNNING: _TaskStatus.ValueType # 2 + TASK_COMPLETED: _TaskStatus.ValueType # 3 + TASK_FAILED: _TaskStatus.ValueType # 4 + TASK_INTERRUPTED: _TaskStatus.ValueType # 5 + TASK_NEW: _TaskStatus.ValueType # 6 + TASK_TIMEOUT: _TaskStatus.ValueType # 7 + +class TaskStatus(_TaskStatus, metaclass=_TaskStatusEnumTypeWrapper): ... + +TASK_NONE: TaskStatus.ValueType # 0 +TASK_PENDING: TaskStatus.ValueType # 1 +TASK_RUNNING: TaskStatus.ValueType # 2 +TASK_COMPLETED: TaskStatus.ValueType # 3 +TASK_FAILED: TaskStatus.ValueType # 4 +TASK_INTERRUPTED: TaskStatus.ValueType # 5 +TASK_NEW: TaskStatus.ValueType # 6 +TASK_TIMEOUT: TaskStatus.ValueType # 7 +global___TaskStatus = TaskStatus + +class _ModelStatus: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _ModelStatusEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_ModelStatus.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + UNKNOWN: _ModelStatus.ValueType # 0 + IN_PROGRESS: _ModelStatus.ValueType # 1 + OK: _ModelStatus.ValueType # 2 + FAILED: _ModelStatus.ValueType # 3 + +class ModelStatus(_ModelStatus, metaclass=_ModelStatusEnumTypeWrapper): ... + +UNKNOWN: ModelStatus.ValueType # 0 +IN_PROGRESS: ModelStatus.ValueType # 1 +OK: ModelStatus.ValueType # 2 +FAILED: ModelStatus.ValueType # 3 +global___ModelStatus = ModelStatus + +class _Role: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _RoleEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Role.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + OTHER: _Role.ValueType # 0 + CLIENT: _Role.ValueType # 1 + COMBINER: _Role.ValueType # 2 + REDUCER: _Role.ValueType # 3 + +class Role(_Role, metaclass=_RoleEnumTypeWrapper): ... + +OTHER: Role.ValueType # 0 +CLIENT: Role.ValueType # 1 +COMBINER: Role.ValueType # 2 +REDUCER: Role.ValueType # 3 +global___Role = Role + +class _Command: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _CommandEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_Command.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + IDLE: _Command.ValueType # 0 + START: _Command.ValueType # 1 + PAUSE: _Command.ValueType # 2 + STOP: _Command.ValueType # 3 + RESET: _Command.ValueType # 4 + REPORT: _Command.ValueType # 5 + CONTINUE: _Command.ValueType # 6 + +class Command(_Command, metaclass=_CommandEnumTypeWrapper): ... + +IDLE: Command.ValueType # 0 +START: Command.ValueType # 1 +PAUSE: Command.ValueType # 2 +STOP: Command.ValueType # 3 +RESET: Command.ValueType # 4 +REPORT: Command.ValueType # 5 +CONTINUE: Command.ValueType # 6 +global___Command = Command + +class _ConnectionStatus: + ValueType = typing.NewType("ValueType", builtins.int) + V: typing_extensions.TypeAlias = ValueType + +class _ConnectionStatusEnumTypeWrapper(google.protobuf.internal.enum_type_wrapper._EnumTypeWrapper[_ConnectionStatus.ValueType], builtins.type): + DESCRIPTOR: google.protobuf.descriptor.EnumDescriptor + NOT_ACCEPTING: _ConnectionStatus.ValueType # 0 + ACCEPTING: _ConnectionStatus.ValueType # 1 + TRY_AGAIN_LATER: _ConnectionStatus.ValueType # 2 + +class ConnectionStatus(_ConnectionStatus, metaclass=_ConnectionStatusEnumTypeWrapper): ... + +NOT_ACCEPTING: ConnectionStatus.ValueType # 0 +ACCEPTING: ConnectionStatus.ValueType # 1 +TRY_AGAIN_LATER: ConnectionStatus.ValueType # 2 +global___ConnectionStatus = ConnectionStatus + +@typing.final +class Response(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RESPONSE_FIELD_NUMBER: builtins.int + response: builtins.str + @property + def sender(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + response: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["response", b"response", "sender", b"sender"]) -> None: ... + +global___Response = Response + +@typing.final +class Status(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + STATUS_FIELD_NUMBER: builtins.int + LOG_LEVEL_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + EXTRA_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + status: builtins.str + log_level: global___LogLevel.ValueType + data: builtins.str + correlation_id: builtins.str + type: global___StatusType.ValueType + extra: builtins.str + session_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + status: builtins.str = ..., + log_level: global___LogLevel.ValueType = ..., + data: builtins.str = ..., + correlation_id: builtins.str = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + type: global___StatusType.ValueType = ..., + extra: builtins.str = ..., + session_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "data", b"data", "extra", b"extra", "log_level", b"log_level", "sender", b"sender", "session_id", b"session_id", "status", b"status", "timestamp", b"timestamp", "type", b"type"]) -> None: ... + +global___Status = Status + +@typing.final +class TaskRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + TYPE_FIELD_NUMBER: builtins.int + ROUND_ID_FIELD_NUMBER: builtins.int + TASK_TYPE_FIELD_NUMBER: builtins.int + TASK_STATUS_FIELD_NUMBER: builtins.int + model_id: builtins.str + data: builtins.str + """data is round_config when type is MODEL_UPDATE""" + correlation_id: builtins.str + timestamp: builtins.str + meta: builtins.str + session_id: builtins.str + type: global___StatusType.ValueType + round_id: builtins.str + task_type: builtins.str + task_status: global___TaskStatus.ValueType + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + data: builtins.str = ..., + correlation_id: builtins.str = ..., + timestamp: builtins.str = ..., + meta: builtins.str = ..., + session_id: builtins.str = ..., + type: global___StatusType.ValueType = ..., + round_id: builtins.str = ..., + task_type: builtins.str = ..., + task_status: global___TaskStatus.ValueType = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "data", b"data", "meta", b"meta", "model_id", b"model_id", "receiver", b"receiver", "round_id", b"round_id", "sender", b"sender", "session_id", b"session_id", "task_status", b"task_status", "task_type", b"task_type", "timestamp", b"timestamp", "type", b"type"]) -> None: ... + +global___TaskRequest = TaskRequest + +@typing.final +class ModelUpdate(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + MODEL_UPDATE_ID_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + CONFIG_FIELD_NUMBER: builtins.int + ROUND_ID_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + model_update_id: builtins.str + correlation_id: builtins.str + timestamp: builtins.str + meta: builtins.str + config: builtins.str + round_id: builtins.str + session_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + model_update_id: builtins.str = ..., + correlation_id: builtins.str = ..., + timestamp: builtins.str = ..., + meta: builtins.str = ..., + config: builtins.str = ..., + round_id: builtins.str = ..., + session_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["config", b"config", "correlation_id", b"correlation_id", "meta", b"meta", "model_id", b"model_id", "model_update_id", b"model_update_id", "receiver", b"receiver", "round_id", b"round_id", "sender", b"sender", "session_id", b"session_id", "timestamp", b"timestamp"]) -> None: ... + +global___ModelUpdate = ModelUpdate + +@typing.final +class ModelValidation(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + data: builtins.str + correlation_id: builtins.str + meta: builtins.str + session_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + data: builtins.str = ..., + correlation_id: builtins.str = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + meta: builtins.str = ..., + session_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "data", b"data", "meta", b"meta", "model_id", b"model_id", "receiver", b"receiver", "sender", b"sender", "session_id", b"session_id", "timestamp", b"timestamp"]) -> None: ... + +global___ModelValidation = ModelValidation + +@typing.final +class ModelPrediction(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + PREDICTION_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + data: builtins.str + correlation_id: builtins.str + meta: builtins.str + prediction_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + data: builtins.str = ..., + correlation_id: builtins.str = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + meta: builtins.str = ..., + prediction_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "data", b"data", "meta", b"meta", "model_id", b"model_id", "prediction_id", b"prediction_id", "receiver", b"receiver", "sender", b"sender", "timestamp", b"timestamp"]) -> None: ... + +global___ModelPrediction = ModelPrediction + +@typing.final +class BackwardCompletion(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + GRADIENT_ID_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + META_FIELD_NUMBER: builtins.int + gradient_id: builtins.str + correlation_id: builtins.str + session_id: builtins.str + meta: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + gradient_id: builtins.str = ..., + correlation_id: builtins.str = ..., + session_id: builtins.str = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + meta: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "gradient_id", b"gradient_id", "meta", b"meta", "receiver", b"receiver", "sender", b"sender", "session_id", b"session_id", "timestamp", b"timestamp"]) -> None: ... + +global___BackwardCompletion = BackwardCompletion + +@typing.final +class ModelMetric(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + METRICS_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + STEP_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + ROUND_ID_FIELD_NUMBER: builtins.int + SESSION_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + round_id: builtins.str + session_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def metrics(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___MetricElem]: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + @property + def step(self) -> google.protobuf.wrappers_pb2.UInt32Value: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + metrics: collections.abc.Iterable[global___MetricElem] | None = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + step: google.protobuf.wrappers_pb2.UInt32Value | None = ..., + model_id: builtins.str = ..., + round_id: builtins.str = ..., + session_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender", "step", b"step", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["metrics", b"metrics", "model_id", b"model_id", "round_id", b"round_id", "sender", b"sender", "session_id", b"session_id", "step", b"step", "timestamp", b"timestamp"]) -> None: ... + +global___ModelMetric = ModelMetric + +@typing.final +class MetricElem(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.float + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.float = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + +global___MetricElem = MetricElem + +@typing.final +class AttributeMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + ATTRIBUTES_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + @property + def sender(self) -> global___Client: ... + @property + def attributes(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___AttributeElem]: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + attributes: collections.abc.Iterable[global___AttributeElem] | None = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["attributes", b"attributes", "sender", b"sender", "timestamp", b"timestamp"]) -> None: ... + +global___AttributeMessage = AttributeMessage + +@typing.final +class AttributeElem(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + +global___AttributeElem = AttributeElem + +@typing.final +class TelemetryMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + TELEMETRIES_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + @property + def sender(self) -> global___Client: ... + @property + def telemetries(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TelemetryElem]: ... + @property + def timestamp(self) -> google.protobuf.timestamp_pb2.Timestamp: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + telemetries: collections.abc.Iterable[global___TelemetryElem] | None = ..., + timestamp: google.protobuf.timestamp_pb2.Timestamp | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender", "timestamp", b"timestamp"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["sender", b"sender", "telemetries", b"telemetries", "timestamp", b"timestamp"]) -> None: ... + +global___TelemetryMessage = TelemetryMessage + +@typing.final +class TelemetryElem(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.float + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.float = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + +global___TelemetryElem = TelemetryElem + +@typing.final +class ActivityReport(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + STATUS_FIELD_NUMBER: builtins.int + DONE_FIELD_NUMBER: builtins.int + RESPONSE_FIELD_NUMBER: builtins.int + correlation_id: builtins.str + status: global___TaskStatus.ValueType + done: builtins.bool + @property + def sender(self) -> global___Client: ... + @property + def response(self) -> google.protobuf.struct_pb2.Struct: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + correlation_id: builtins.str = ..., + status: global___TaskStatus.ValueType = ..., + done: builtins.bool = ..., + response: google.protobuf.struct_pb2.Struct | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["response", b"response", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["correlation_id", b"correlation_id", "done", b"done", "response", b"response", "sender", b"sender", "status", b"status"]) -> None: ... + +global___ActivityReport = ActivityReport + +@typing.final +class ModelRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["model_id", b"model_id", "receiver", b"receiver", "sender", b"sender"]) -> None: ... + +global___ModelRequest = ModelRequest + +@typing.final +class FileChunk(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DATA_FIELD_NUMBER: builtins.int + data: builtins.bytes + def __init__( + self, + *, + data: builtins.bytes = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["data", b"data"]) -> None: ... + +global___FileChunk = FileChunk + +@typing.final +class ModelResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUS_FIELD_NUMBER: builtins.int + MESSAGE_FIELD_NUMBER: builtins.int + status: global___ModelStatus.ValueType + message: builtins.str + def __init__( + self, + *, + status: global___ModelStatus.ValueType = ..., + message: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["message", b"message", "status", b"status"]) -> None: ... + +global___ModelResponse = ModelResponse + +@typing.final +class GetGlobalModelRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> None: ... + +global___GetGlobalModelRequest = GetGlobalModelRequest + +@typing.final +class GetGlobalModelResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + MODEL_ID_FIELD_NUMBER: builtins.int + model_id: builtins.str + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + model_id: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["model_id", b"model_id", "receiver", b"receiver", "sender", b"sender"]) -> None: ... + +global___GetGlobalModelResponse = GetGlobalModelResponse + +@typing.final +class Heartbeat(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + MEMORY_UTILISATION_FIELD_NUMBER: builtins.int + CPU_UTILISATION_FIELD_NUMBER: builtins.int + memory_utilisation: builtins.float + cpu_utilisation: builtins.float + @property + def sender(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + memory_utilisation: builtins.float = ..., + cpu_utilisation: builtins.float = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["cpu_utilisation", b"cpu_utilisation", "memory_utilisation", b"memory_utilisation", "sender", b"sender"]) -> None: ... + +global___Heartbeat = Heartbeat + +@typing.final +class ClientAvailableMessage(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + DATA_FIELD_NUMBER: builtins.int + TIMESTAMP_FIELD_NUMBER: builtins.int + data: builtins.str + timestamp: builtins.str + @property + def sender(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + data: builtins.str = ..., + timestamp: builtins.str = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["data", b"data", "sender", b"sender", "timestamp", b"timestamp"]) -> None: ... + +global___ClientAvailableMessage = ClientAvailableMessage + +@typing.final +class ListClientsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + CHANNEL_FIELD_NUMBER: builtins.int + channel: global___Queue.ValueType + @property + def sender(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + channel: global___Queue.ValueType = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["channel", b"channel", "sender", b"sender"]) -> None: ... + +global___ListClientsRequest = ListClientsRequest + +@typing.final +class ClientList(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_FIELD_NUMBER: builtins.int + @property + def client(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Client]: ... + def __init__( + self, + *, + client: collections.abc.Iterable[global___Client] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client", b"client"]) -> None: ... + +global___ClientList = ClientList + +@typing.final +class Client(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ROLE_FIELD_NUMBER: builtins.int + NAME_FIELD_NUMBER: builtins.int + CLIENT_ID_FIELD_NUMBER: builtins.int + role: global___Role.ValueType + name: builtins.str + client_id: builtins.str + def __init__( + self, + *, + role: global___Role.ValueType = ..., + name: builtins.str = ..., + client_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client_id", b"client_id", "name", b"name", "role", b"role"]) -> None: ... + +global___Client = Client + +@typing.final +class ReassignRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + SERVER_FIELD_NUMBER: builtins.int + PORT_FIELD_NUMBER: builtins.int + server: builtins.str + port: builtins.int + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + server: builtins.str = ..., + port: builtins.int = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["port", b"port", "receiver", b"receiver", "sender", b"sender", "server", b"server"]) -> None: ... + +global___ReassignRequest = ReassignRequest + +@typing.final +class ReconnectRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SENDER_FIELD_NUMBER: builtins.int + RECEIVER_FIELD_NUMBER: builtins.int + RECONNECT_FIELD_NUMBER: builtins.int + reconnect: builtins.int + @property + def sender(self) -> global___Client: ... + @property + def receiver(self) -> global___Client: ... + def __init__( + self, + *, + sender: global___Client | None = ..., + receiver: global___Client | None = ..., + reconnect: builtins.int = ..., + ) -> None: ... + def HasField(self, field_name: typing.Literal["receiver", b"receiver", "sender", b"sender"]) -> builtins.bool: ... + def ClearField(self, field_name: typing.Literal["receiver", b"receiver", "reconnect", b"reconnect", "sender", b"sender"]) -> None: ... + +global___ReconnectRequest = ReconnectRequest + +@typing.final +class Parameter(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + +global___Parameter = Parameter + +@typing.final +class ControlRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + COMMAND_FIELD_NUMBER: builtins.int + PARAMETER_FIELD_NUMBER: builtins.int + COMMAND_TYPE_FIELD_NUMBER: builtins.int + command: global___Command.ValueType + command_type: builtins.str + @property + def parameter(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Parameter]: ... + def __init__( + self, + *, + command: global___Command.ValueType = ..., + parameter: collections.abc.Iterable[global___Parameter] | None = ..., + command_type: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["command", b"command", "command_type", b"command_type", "parameter", b"parameter"]) -> None: ... + +global___ControlRequest = ControlRequest + +@typing.final +class CommandRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + COMMAND_FIELD_NUMBER: builtins.int + CORRELATION_ID_FIELD_NUMBER: builtins.int + COMMAND_TYPE_FIELD_NUMBER: builtins.int + PARAMETERS_FIELD_NUMBER: builtins.int + command: global___Command.ValueType + correlation_id: builtins.str + command_type: builtins.str + parameters: builtins.str + def __init__( + self, + *, + command: global___Command.ValueType = ..., + correlation_id: builtins.str = ..., + command_type: builtins.str = ..., + parameters: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["command", b"command", "command_type", b"command_type", "correlation_id", b"correlation_id", "parameters", b"parameters"]) -> None: ... + +global___CommandRequest = CommandRequest + +@typing.final +class ControlResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + MESSAGE_FIELD_NUMBER: builtins.int + PARAMETER_FIELD_NUMBER: builtins.int + message: builtins.str + @property + def parameter(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___Parameter]: ... + def __init__( + self, + *, + message: builtins.str = ..., + parameter: collections.abc.Iterable[global___Parameter] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["message", b"message", "parameter", b"parameter"]) -> None: ... + +global___ControlResponse = ControlResponse + +@typing.final +class ControlStateResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATE_FIELD_NUMBER: builtins.int + state: builtins.str + def __init__( + self, + *, + state: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["state", b"state"]) -> None: ... + +global___ControlStateResponse = ControlStateResponse + +@typing.final +class ConnectionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___ConnectionRequest = ConnectionRequest + +@typing.final +class ConnectionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUS_FIELD_NUMBER: builtins.int + status: global___ConnectionStatus.ValueType + def __init__( + self, + *, + status: global___ConnectionStatus.ValueType = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["status", b"status"]) -> None: ... + +global___ConnectionResponse = ConnectionResponse + +@typing.final +class ProvidedFunctionsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + FUNCTION_CODE_FIELD_NUMBER: builtins.int + function_code: builtins.str + def __init__( + self, + *, + function_code: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["function_code", b"function_code"]) -> None: ... + +global___ProvidedFunctionsRequest = ProvidedFunctionsRequest + +@typing.final +class ProvidedFunctionsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + @typing.final + class AvailableFunctionsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.bool + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.bool = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["key", b"key", "value", b"value"]) -> None: ... + + AVAILABLE_FUNCTIONS_FIELD_NUMBER: builtins.int + @property + def available_functions(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.bool]: ... + def __init__( + self, + *, + available_functions: collections.abc.Mapping[builtins.str, builtins.bool] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["available_functions", b"available_functions"]) -> None: ... + +global___ProvidedFunctionsResponse = ProvidedFunctionsResponse + +@typing.final +class ClientConfigResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_SETTINGS_FIELD_NUMBER: builtins.int + client_settings: builtins.str + def __init__( + self, + *, + client_settings: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client_settings", b"client_settings"]) -> None: ... + +global___ClientConfigResponse = ClientConfigResponse + +@typing.final +class ClientSelectionRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_IDS_FIELD_NUMBER: builtins.int + client_ids: builtins.str + def __init__( + self, + *, + client_ids: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client_ids", b"client_ids"]) -> None: ... + +global___ClientSelectionRequest = ClientSelectionRequest + +@typing.final +class ClientSelectionResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + CLIENT_IDS_FIELD_NUMBER: builtins.int + client_ids: builtins.str + def __init__( + self, + *, + client_ids: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client_ids", b"client_ids"]) -> None: ... + +global___ClientSelectionResponse = ClientSelectionResponse + +@typing.final +class ClientMetaRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + METADATA_FIELD_NUMBER: builtins.int + CLIENT_ID_FIELD_NUMBER: builtins.int + metadata: builtins.str + client_id: builtins.str + def __init__( + self, + *, + metadata: builtins.str = ..., + client_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["client_id", b"client_id", "metadata", b"metadata"]) -> None: ... + +global___ClientMetaRequest = ClientMetaRequest + +@typing.final +class ClientMetaResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUS_FIELD_NUMBER: builtins.int + status: builtins.str + def __init__( + self, + *, + status: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["status", b"status"]) -> None: ... + +global___ClientMetaResponse = ClientMetaResponse + +@typing.final +class StoreModelResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + STATUS_FIELD_NUMBER: builtins.int + status: builtins.str + def __init__( + self, + *, + status: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["status", b"status"]) -> None: ... + +global___StoreModelResponse = StoreModelResponse + +@typing.final +class AggregationRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + AGGREGATE_FIELD_NUMBER: builtins.int + aggregate: builtins.str + def __init__( + self, + *, + aggregate: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["aggregate", b"aggregate"]) -> None: ... + +global___AggregationRequest = AggregationRequest diff --git a/fedn/network/grpc/fedn_pb2_grpc.py b/fedn/network/grpc/fedn_pb2_grpc.py index c7d20089c..0512d13a2 100644 --- a/fedn/network/grpc/fedn_pb2_grpc.py +++ b/fedn/network/grpc/fedn_pb2_grpc.py @@ -18,7 +18,7 @@ if _version_not_supported: raise RuntimeError( f'The grpc package installed is at version {GRPC_VERSION},' - + f' but the generated code in fedn_pb2_grpc.py depends on' + + f' but the generated code in network/grpc/fedn_pb2_grpc.py depends on' + f' grpcio>={GRPC_GENERATED_VERSION}.' + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' @@ -36,13 +36,13 @@ def __init__(self, channel): """ self.Upload = channel.stream_unary( '/fedn.ModelService/Upload', - request_serializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + request_serializer=network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, _registered_method=True) self.Download = channel.unary_stream( '/fedn.ModelService/Download', request_serializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, _registered_method=True) @@ -66,13 +66,13 @@ def add_ModelServiceServicer_to_server(servicer, server): rpc_method_handlers = { 'Upload': grpc.stream_unary_rpc_method_handler( servicer.Upload, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, + request_deserializer=network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, ), 'Download': grpc.unary_stream_rpc_method_handler( servicer.Download, request_deserializer=network_dot_grpc_dot_fedn__pb2.ModelRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.ModelResponse.SerializeToString, + response_serializer=network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -100,7 +100,7 @@ def Upload(request_iterator, request_iterator, target, '/fedn.ModelService/Upload', - network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, + network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, options, channel_credentials, @@ -128,7 +128,7 @@ def Download(request, target, '/fedn.ModelService/Download', network_dot_grpc_dot_fedn__pb2.ModelRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.ModelResponse.FromString, + network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, options, channel_credentials, insecure, @@ -159,6 +159,11 @@ def __init__(self, channel): request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, _registered_method=True) + self.SendCommand = channel.unary_unary( + '/fedn.Control/SendCommand', + request_serializer=network_dot_grpc_dot_fedn__pb2.CommandRequest.SerializeToString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + _registered_method=True) self.FlushAggregationQueue = channel.unary_unary( '/fedn.Control/FlushAggregationQueue', request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, @@ -174,6 +179,11 @@ def __init__(self, channel): request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, _registered_method=True) + self.GetState = channel.unary_unary( + '/fedn.Control/GetState', + request_serializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.ControlStateResponse.FromString, + _registered_method=True) class ControlServicer(object): @@ -191,6 +201,12 @@ def Stop(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def SendCommand(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def FlushAggregationQueue(self, request, context): """Missing associated documentation comment in .proto file.""" context.set_code(grpc.StatusCode.UNIMPLEMENTED) @@ -209,6 +225,12 @@ def SetServerFunctions(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def GetState(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_ControlServicer_to_server(servicer, server): rpc_method_handlers = { @@ -222,6 +244,11 @@ def add_ControlServicer_to_server(servicer, server): request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), + 'SendCommand': grpc.unary_unary_rpc_method_handler( + servicer.SendCommand, + request_deserializer=network_dot_grpc_dot_fedn__pb2.CommandRequest.FromString, + response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, + ), 'FlushAggregationQueue': grpc.unary_unary_rpc_method_handler( servicer.FlushAggregationQueue, request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, @@ -237,6 +264,11 @@ def add_ControlServicer_to_server(servicer, server): request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.ControlResponse.SerializeToString, ), + 'GetState': grpc.unary_unary_rpc_method_handler( + servicer.GetState, + request_deserializer=network_dot_grpc_dot_fedn__pb2.ControlRequest.FromString, + response_serializer=network_dot_grpc_dot_fedn__pb2.ControlStateResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'fedn.Control', rpc_method_handlers) @@ -302,6 +334,33 @@ def Stop(request, metadata, _registered_method=True) + @staticmethod + def SendCommand(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fedn.Control/SendCommand', + network_dot_grpc_dot_fedn__pb2.CommandRequest.SerializeToString, + network_dot_grpc_dot_fedn__pb2.ControlResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + @staticmethod def FlushAggregationQueue(request, target, @@ -383,6 +442,33 @@ def SetServerFunctions(request, metadata, _registered_method=True) + @staticmethod + def GetState(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fedn.Control/GetState', + network_dot_grpc_dot_fedn__pb2.ControlRequest.SerializeToString, + network_dot_grpc_dot_fedn__pb2.ControlStateResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + class ReducerStub(object): """Missing associated documentation comment in .proto file.""" @@ -840,6 +926,11 @@ def __init__(self, channel): request_serializer=network_dot_grpc_dot_fedn__pb2.TelemetryMessage.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.Response.FromString, _registered_method=True) + self.PollAndReport = channel.unary_unary( + '/fedn.Combiner/PollAndReport', + request_serializer=network_dot_grpc_dot_fedn__pb2.ActivityReport.SerializeToString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, + _registered_method=True) class CombinerServicer(object): @@ -894,6 +985,12 @@ def SendTelemetryMessage(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def PollAndReport(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_CombinerServicer_to_server(servicer, server): rpc_method_handlers = { @@ -937,6 +1034,11 @@ def add_CombinerServicer_to_server(servicer, server): request_deserializer=network_dot_grpc_dot_fedn__pb2.TelemetryMessage.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.Response.SerializeToString, ), + 'PollAndReport': grpc.unary_unary_rpc_method_handler( + servicer.PollAndReport, + request_deserializer=network_dot_grpc_dot_fedn__pb2.ActivityReport.FromString, + response_serializer=network_dot_grpc_dot_fedn__pb2.TaskRequest.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'fedn.Combiner', rpc_method_handlers) @@ -1164,6 +1266,33 @@ def SendTelemetryMessage(request, metadata, _registered_method=True) + @staticmethod + def PollAndReport(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/fedn.Combiner/PollAndReport', + network_dot_grpc_dot_fedn__pb2.ActivityReport.SerializeToString, + network_dot_grpc_dot_fedn__pb2.TaskRequest.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + class FunctionServiceStub(object): """Missing associated documentation comment in .proto file.""" @@ -1181,7 +1310,7 @@ def __init__(self, channel): _registered_method=True) self.HandleClientConfig = channel.stream_unary( '/fedn.FunctionService/HandleClientConfig', - request_serializer=network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.SerializeToString, + request_serializer=network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.FromString, _registered_method=True) self.HandleClientSelection = channel.unary_unary( @@ -1196,13 +1325,13 @@ def __init__(self, channel): _registered_method=True) self.HandleStoreModel = channel.stream_unary( '/fedn.FunctionService/HandleStoreModel', - request_serializer=network_dot_grpc_dot_fedn__pb2.StoreModelRequest.SerializeToString, + request_serializer=network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, response_deserializer=network_dot_grpc_dot_fedn__pb2.StoreModelResponse.FromString, _registered_method=True) self.HandleAggregation = channel.unary_stream( '/fedn.FunctionService/HandleAggregation', request_serializer=network_dot_grpc_dot_fedn__pb2.AggregationRequest.SerializeToString, - response_deserializer=network_dot_grpc_dot_fedn__pb2.AggregationResponse.FromString, + response_deserializer=network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, _registered_method=True) @@ -1255,7 +1384,7 @@ def add_FunctionServiceServicer_to_server(servicer, server): ), 'HandleClientConfig': grpc.stream_unary_rpc_method_handler( servicer.HandleClientConfig, - request_deserializer=network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.FromString, + request_deserializer=network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.SerializeToString, ), 'HandleClientSelection': grpc.unary_unary_rpc_method_handler( @@ -1270,13 +1399,13 @@ def add_FunctionServiceServicer_to_server(servicer, server): ), 'HandleStoreModel': grpc.stream_unary_rpc_method_handler( servicer.HandleStoreModel, - request_deserializer=network_dot_grpc_dot_fedn__pb2.StoreModelRequest.FromString, + request_deserializer=network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, response_serializer=network_dot_grpc_dot_fedn__pb2.StoreModelResponse.SerializeToString, ), 'HandleAggregation': grpc.unary_stream_rpc_method_handler( servicer.HandleAggregation, request_deserializer=network_dot_grpc_dot_fedn__pb2.AggregationRequest.FromString, - response_serializer=network_dot_grpc_dot_fedn__pb2.AggregationResponse.SerializeToString, + response_serializer=network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, ), } generic_handler = grpc.method_handlers_generic_handler( @@ -1331,7 +1460,7 @@ def HandleClientConfig(request_iterator, request_iterator, target, '/fedn.FunctionService/HandleClientConfig', - network_dot_grpc_dot_fedn__pb2.ClientConfigRequest.SerializeToString, + network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, network_dot_grpc_dot_fedn__pb2.ClientConfigResponse.FromString, options, channel_credentials, @@ -1412,7 +1541,7 @@ def HandleStoreModel(request_iterator, request_iterator, target, '/fedn.FunctionService/HandleStoreModel', - network_dot_grpc_dot_fedn__pb2.StoreModelRequest.SerializeToString, + network_dot_grpc_dot_fedn__pb2.FileChunk.SerializeToString, network_dot_grpc_dot_fedn__pb2.StoreModelResponse.FromString, options, channel_credentials, @@ -1440,7 +1569,7 @@ def HandleAggregation(request, target, '/fedn.FunctionService/HandleAggregation', network_dot_grpc_dot_fedn__pb2.AggregationRequest.SerializeToString, - network_dot_grpc_dot_fedn__pb2.AggregationResponse.FromString, + network_dot_grpc_dot_fedn__pb2.FileChunk.FromString, options, channel_credentials, insecure, diff --git a/fedn/network/grpc/server.py b/fedn/network/grpc/server.py index b12db1182..36174905e 100644 --- a/fedn/network/grpc/server.py +++ b/fedn/network/grpc/server.py @@ -56,13 +56,13 @@ def __init__(self, servicer, modelservicer, config: ServerConfig): rpc.add_ReducerServicer_to_server(servicer, self.server) if isinstance(modelservicer, rpc.ModelServiceServicer): rpc.add_ModelServiceServicer_to_server(modelservicer, self.server) - if isinstance(servicer, rpc.CombinerServicer): + if isinstance(servicer, (rpc.CombinerServicer, rpc.ControlServicer)): rpc.add_ControlServicer_to_server(servicer, self.server) health_pb2_grpc.add_HealthServicer_to_server(self.health_servicer, self.server) if config["secure"]: - logger.info("Creating secure gRPCS server using certificate") + logger.info(f"Creating secure gRPCS server using certificate at {config['port']}") server_credentials = grpc.ssl_server_credentials( ( ( @@ -73,7 +73,7 @@ def __init__(self, servicer, modelservicer, config: ServerConfig): ) self.server.add_secure_port("[::]:" + str(config["port"]), server_credentials) else: - logger.info("Creating gRPC server") + logger.info(f"Creating gRPC server at {config['port']}") self.server.add_insecure_port("[::]:" + str(config["port"])) def start(self): diff --git a/fedn/network/loadbalancer/leastpacked.py b/fedn/network/loadbalancer/leastpacked.py index 8e793e95a..cd618b108 100644 --- a/fedn/network/loadbalancer/leastpacked.py +++ b/fedn/network/loadbalancer/leastpacked.py @@ -1,4 +1,4 @@ -from fedn.network.combiner.interfaces import CombinerUnavailableError +from fedn.network.common.interfaces import CombinerUnavailableError from fedn.network.loadbalancer.loadbalancerbase import LoadBalancerBase @@ -13,9 +13,7 @@ def __init__(self, network): super().__init__(network) def find_combiner(self): - """Find the combiner with the least number of attached clients. - - """ + """Find the combiner with the least number of attached clients.""" min_clients = -1 selected_combiner = None for combiner in self.network.get_combiners(): diff --git a/fedn/network/storage/models/memorymodelstorage.py b/fedn/network/storage/models/memorymodelstorage.py deleted file mode 100644 index 54599fb28..000000000 --- a/fedn/network/storage/models/memorymodelstorage.py +++ /dev/null @@ -1,43 +0,0 @@ -import io -from collections import defaultdict -from io import BytesIO - -from fedn.network.storage.models.modelstorage import ModelStorage - -CHUNK_SIZE = 1024 * 1024 - - -class MemoryModelStorage(ModelStorage): - """Class for in-memory storage of model artifacts. - - Models are stored as BytesIO objects in a dictionary. - - """ - - def __init__(self): - self.models = defaultdict(io.BytesIO) - self.models_metadata = {} - - def exist(self, model_id): - if model_id in self.models.keys(): - return True - return False - - def get(self, model_id): - obj = self.models[model_id] - obj.seek(0, 0) - # Have to copy object to not mix up the file pointers when sending... fix in better way. - obj = BytesIO(obj.read()) - return obj - - def get_ptr(self, model_id): - """:param model_id: - :return: - """ - return self.models[model_id] - - def get_model_metadata(self, model_id): - return self.models_metadata[model_id] - - def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) diff --git a/fedn/network/storage/models/modelstorage.py b/fedn/network/storage/models/modelstorage.py deleted file mode 100644 index 3062db36e..000000000 --- a/fedn/network/storage/models/modelstorage.py +++ /dev/null @@ -1,69 +0,0 @@ -from abc import ABC, abstractmethod - - -class ModelStorage(ABC): - @abstractmethod - def exist(self, model_id): - """Check if model exists in storage - - :param model_id: The model id - :type model_id: str - :return: True if model exists, False otherwise - :rtype: bool - """ - pass - - @abstractmethod - def get(self, model_id): - """Get model from storage - - :param model_id: The model id - :type model_id: str - :return: The model - :rtype: object - """ - pass - - @abstractmethod - def get_model_metadata(self, model_id): - """Get model metadata from storage - - :param model_id: The model id - :type model_id: str - :return: The model metadata - :rtype: dict - """ - pass - - @abstractmethod - def set_model_metadata(self, model_id, model_metadata): - """Set model metadata in storage - - :param model_id: The model id - :type model_id: str - :param model_metadata: The model metadata - :type model_metadata: dict - :return: True if successful, False otherwise - :rtype: bool - """ - pass - - @abstractmethod - def delete(self, model_id): - """Delete model from storage - - :param model_id: The model id - :type model_id: str - :return: True if successful, False otherwise - :rtype: bool - """ - pass - - @abstractmethod - def delete_all(self): - """Delete all models from storage - - :return: True if successful, False otherwise - :rtype: bool - """ - pass diff --git a/fedn/network/storage/models/tempmodelstorage.py b/fedn/network/storage/models/tempmodelstorage.py index 891f6ea07..7f0128263 100644 --- a/fedn/network/storage/models/tempmodelstorage.py +++ b/fedn/network/storage/models/tempmodelstorage.py @@ -1,23 +1,28 @@ -import os +import threading +import time from io import BytesIO +from typing import Iterator import fedn.network.grpc.fedn_pb2 as fedn from fedn.common.log_config import logger -from fedn.network.storage.models.modelstorage import ModelStorage +from fedn.utils.model import FednModel CHUNK_SIZE = 1024 * 1024 +CACHEOUT_TIME = 3600 # 1 hour -class TempModelStorage(ModelStorage): - """Class for managing local temporary models on file on combiners.""" - def __init__(self): - self.default_dir = os.environ.get("FEDN_MODEL_DIR", "/tmp/models") # set default to tmp - if not os.path.exists(self.default_dir): - os.makedirs(self.default_dir) +class TempModelStorage: + """Class for managing local temporary models on file on combiners. + + This class provides methods to store, retrieve, and manage models in a temporary directory. + Cached models are kept for one hour after they were last accessed. + Manually added models will be kept until they are deleted explicitly. + """ + def __init__(self): self.models = {} - self.models_metadata = {} + self.access_lock = threading.RLock() def exist(self, model_id): if model_id in self.models.keys(): @@ -25,69 +30,145 @@ def exist(self, model_id): return False def get(self, model_id): - try: - if self.models_metadata[model_id] != fedn.ModelStatus.OK: + with self.access_lock: + if not self.exist(model_id): + logger.error("TEMPMODELSTORAGE: model_id {} does not exist".format(model_id)) + return None + if self.models[model_id]["state"] != fedn.ModelStatus.OK: logger.warning("File not ready! Try again") return None - except KeyError: - logger.error("No such model has been made available yet!") - return None + self.models[model_id]["accessed_at"] = time.time() + return self.models[model_id]["model"] + + def _make_entry(self, model_id, model): + with self.access_lock: + now = time.time() + if model_id in self.models: + raise ValueError("Model with id {} already exists.".format(model_id)) + else: + self.models[model_id] = { + "model": model, + "state": fedn.ModelStatus.IN_PROGRESS, + "auto_managed": False, + "accessed_at": now, + } + + self._invalidate_old_models() + return model + + def _set_model(self, model_id: str, model_lambda, checksum: str = None, auto_managed: bool = False): + with self.access_lock: + try: + self._make_entry(model_id, None) + except Exception as e: + logger.error("TEMPMODELSTORAGE: Error writing model {} to disk: {}".format(model_id, e)) + return False + + # Create the model using the provided lambda function + # Do this outside the lock to avoid blocking other threads + model = model_lambda() + + with self.access_lock: + self.models[model_id]["model"] = model + if self._finalize(model_id, checksum): + self.models[model_id]["auto_managed"] = auto_managed + logger.info("TEMPMODELSTORAGE: Model {} added.".format(model_id)) + return True + else: + logger.error("TEMPMODELSTORAGE: Model {} failed.".format(model_id)) + return False - obj = BytesIO() - with open(os.path.join(self.default_dir, str(model_id)), "rb") as f: - obj.write(f.read()) + def set_model(self, model_id: str, model: FednModel, checksum: str = None, auto_managed: bool = False): + """Set model in temp storage. - obj.seek(0, 0) - return obj + :param model_id: The id of the model. + :type model_id: str + :param model: The model object. + :type model: FednModel + """ + return self._set_model(model_id, lambda: model, checksum, auto_managed) + + def set_model_from_stream(self, model_id: str, model_stream: BytesIO, checksum: str = None, auto_managed: bool = False): + """Set model in temp storage. + + :param model_id: The id of the model. + :type model_id: str + :param model_stream: The model stream. + :type model_stream: BytesIO + """ + return self._set_model(model_id, lambda: FednModel.from_stream(model_stream), checksum, auto_managed) + + def set_model_with_filechunk_stream(self, model_id: str, filechunk_stream: Iterator[fedn.FileChunk], checksum: str = None, auto_managed: bool = False): + """Set model in temp storage using a generator. - def get_ptr(self, model_id): - """:param model_id: - :return: + :param model_id: The id of the model. + :type model_id: str + :param filechunk_stream: An grpc stream of fedn.FileChunk + :type filechunk_stream: Generator[bytes, None, None] + """ + return self._set_model(model_id, lambda: FednModel.from_filechunk_stream(filechunk_stream), checksum, auto_managed) + + def _finalize(self, model_id, checksum): + """Commit the model to disk. + + :param model_id: The id of the model. + :type model_id: str + :param checksum: The checksum of the model. + :type checksum: str + """ + model: FednModel = self.models[model_id]["model"] + if not model.verify_checksum(checksum): + logger.error("TEMPMODELSTORAGE: Checksum failed! File is corrupted!") + self.delete(model_id) + return False + self.models[model_id]["state"] = fedn.ModelStatus.OK + return True + + def is_ready(self, model_id): + """Check if model is ready. + + :param model_id: The id of the model. + :type model_id: str + :return: True if model is ready, else False. + :rtype: bool """ try: - f = self.models[model_id]["file"] + return self.models[model_id]["state"] == fedn.ModelStatus.OK except KeyError: - f = open(os.path.join(self.default_dir, str(model_id)), "wb") - - self.models[model_id] = {"file": f} - return self.models[model_id]["file"] + logger.error("TEMPMODELSTORAGE: model_id {} does not exist".format(model_id)) + return False - def get_model_metadata(self, model_id): + def get_checksum(self, model_id): try: - status = self.models_metadata[model_id] + model: FednModel = self.models[model_id]["model"] except KeyError: - status = fedn.ModelStatus.UNKNOWN - return status + logger.error("TEMPMODELSTORAGE: model_id {} does not exist".format(model_id)) + return None + with self.access_lock: + return model.checksum - def set_model_metadata(self, model_id, model_metadata): - self.models_metadata.update({model_id: model_metadata}) + def _invalidate_old_models(self): + """Remove cached models that have not been accessed for more than 1 hour.""" + now = time.time() + for model_id, model_info in list(self.models.items()): + if now - model_info["accessed_at"] > CACHEOUT_TIME and model_info["auto_managed"]: + logger.info("TEMPMODELSTORAGE: Invalidating model {} due to inactivity.".format(model_id)) + self.delete(model_id) # Delete model from disk def delete(self, model_id): - try: - os.remove(os.path.join(self.default_dir, str(model_id))) - logger.info("TEMPMODELSTORAGE: Deleted model with id: {}".format(model_id)) - # Delete id from metadata and models dict - del self.models_metadata[model_id] - del self.models[model_id] - except FileNotFoundError: - logger.error("Could not delete model from disk. File not found!") - return False - return True - - # Delete all models from disk - def delete_all(self): - ids_pop = [] - for model_id in self.models.keys(): + with self.access_lock: try: - os.remove(os.path.join(self.default_dir, str(model_id))) logger.info("TEMPMODELSTORAGE: Deleted model with id: {}".format(model_id)) - # Add id to list of ids to pop/delete from metadata and models dict - ids_pop.append(model_id) + # Delete id from metadata and models dict + del self.models[model_id] except FileNotFoundError: logger.error("TEMPMODELSTORAGE: Could not delete model {} from disk. File not found!".format(model_id)) - # Remove id from metadata and models dict - for model_id in ids_pop: - del self.models_metadata[model_id] - del self.models[model_id] + return False + return True + + # Delete all models from disk + def delete_all(self): + with self.access_lock: + self.models.clear() return True diff --git a/fedn/network/storage/s3/boto3repository.py b/fedn/network/storage/s3/boto3repository.py index 88c0210f9..d133a1f83 100644 --- a/fedn/network/storage/s3/boto3repository.py +++ b/fedn/network/storage/s3/boto3repository.py @@ -1,6 +1,7 @@ """Module implementing Repository for Amazon S3 using boto3.""" import io +import os from typing import IO, List import boto3 @@ -14,31 +15,34 @@ class Boto3Repository(RepositoryBase): """Class implementing Repository for Amazon S3 using boto3.""" def __init__(self, config: dict) -> None: - """Initialize object. - - :param config: Dictionary containing configuration for credentials and bucket names. - :type config: dict - """ + """Initialize object.""" super().__init__() self.name = "Boto3Repository" + self.region = os.environ.get("AWS_REGION") or os.environ.get("AWS_DEFAULT_REGION") or config.get("storage_region") or "eu-west-1" + common_config = { - "region_name": config.get("storage_region", "eu-west-1"), - "endpoint_url": config.get("storage_endpoint", "http://minio:9000"), "use_ssl": config.get("storage_secure_mode", True), "verify": config.get("storage_verify_ssl", True), } - if "storage_access_key" in config and "storage_secret_key" in config: + access_key = config.get("storage_access_key") + secret_key = config.get("storage_secret_key") + + if access_key and secret_key: self.s3_client = boto3.client( "s3", - aws_access_key_id=config["storage_access_key"], - aws_secret_access_key=config["storage_secret_key"], + aws_access_key_id=access_key, + aws_secret_access_key=secret_key, + region_name=self.region, + endpoint_url=config.get("storage_endpoint", "http://minio:9000"), **common_config, ) else: - # Use default credentials (e.g., from a service account or environment variables) - self.s3_client = boto3.client("s3", **common_config) + # Use default credentials (IAM role via service account, environment variables, etc.) + + self.s3_client = boto3.client("s3", region_name=self.region, **common_config) + logger.info(f"Using {self.name} for S3 storage.") def set_artifact(self, instance_name: str, instance: IO, bucket: str, is_file: bool = False) -> bool: @@ -96,7 +100,7 @@ def get_artifact_stream(self, instance_name: str, bucket: str) -> io.BytesIO: """ try: response = self.s3_client.get_object(Bucket=bucket, Key=instance_name) - return io.BytesIO(response["Body"].read()) + return response["Body"] except (BotoCoreError, ClientError) as e: logger.error(f"Failed to fetch artifact stream: {instance_name} from bucket: {bucket}. Error: {e}") raise Exception(f"Could not fetch artifact stream: {e}") from e @@ -137,7 +141,10 @@ def create_bucket(self, bucket_name: str) -> None: :type bucket_name: str """ try: - self.s3_client.create_bucket(Bucket=bucket_name) + if self.region == "us-east-1": + self.s3_client.create_bucket(Bucket=bucket_name) + else: + self.s3_client.create_bucket(Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": self.region}) logger.info(f"Bucket {bucket_name} created successfully.") except self.s3_client.exceptions.BucketAlreadyExists: logger.info(f"Bucket {bucket_name} already exists. No new bucket was created.") diff --git a/fedn/network/storage/s3/miniorepository.py b/fedn/network/storage/s3/miniorepository.py index 6ac763010..523e9574a 100644 --- a/fedn/network/storage/s3/miniorepository.py +++ b/fedn/network/storage/s3/miniorepository.py @@ -60,7 +60,7 @@ def set_artifact(self, instance_name: str, instance: IO, bucket: str, is_file: b if is_file: self.client.fput_object(bucket, instance_name, instance) else: - self.client.put_object(bucket, instance_name, io.BytesIO(instance), len(instance)) + self.client.put_object(bucket, instance_name, instance, part_size=10 * 1024 * 1024) except Exception as e: logger.error(f"Failed to upload artifact: {instance_name} to bucket: {bucket}. Error: {e}") raise Exception(f"Could not load data into bytes: {e}") from e diff --git a/fedn/network/storage/s3/repository.py b/fedn/network/storage/s3/repository.py index dc83907de..058e3f7d7 100644 --- a/fedn/network/storage/s3/repository.py +++ b/fedn/network/storage/s3/repository.py @@ -3,10 +3,12 @@ import datetime import importlib import uuid -from typing import Union +from typing import IO, Union from fedn.common.config import FEDN_OBJECT_STORAGE_BUCKETS, FEDN_OBJECT_STORAGE_TYPE from fedn.common.log_config import logger +from fedn.network.storage.s3.base import RepositoryBase +from fedn.utils.model import FednModel class Repository: @@ -31,7 +33,7 @@ def __init__(self, config: dict, init_buckets: bool = True, storage_type: str = # Dynamically import the repository class based on storage_type storage_type = (storage_type or FEDN_OBJECT_STORAGE_TYPE).upper() - self.client = self._load_repository(storage_type, config) + self.client: RepositoryBase = self._load_repository(storage_type, config) if init_buckets: self.client.create_bucket(self.context_bucket) @@ -70,7 +72,7 @@ def _load_repository(self, storage_type: str, config: dict): logger.error(f"Failed to load class {class_name} from module {module_path}. Error: {e}") raise AttributeError(f"Could not load repository class {class_name}.") from e - def get_model(self, model_id: str) -> bytes: + def get_model(self, model_id: str) -> FednModel: """Retrieve a model with id model_id. :param model_id: Unique identifier for model to retrieve. @@ -79,7 +81,9 @@ def get_model(self, model_id: str) -> bytes: :rtype: bytes """ logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) - return self.client.get_artifact(model_id, self.model_bucket) + model_stream = self.client.get_artifact_stream(model_id, self.model_bucket) + fedn_model = FednModel.from_stream(model_stream) + return fedn_model def get_model_stream(self, model_id: str) -> bytes: """Retrieve a stream handle to model with id model_id. @@ -92,7 +96,7 @@ def get_model_stream(self, model_id: str) -> bytes: logger.info("Client {} trying to get model with id: {}".format(self.client.name, model_id)) return self.client.get_artifact_stream(model_id, self.model_bucket) - def set_model(self, model: Union[bytes, str], is_file: bool = True) -> str: + def set_model(self, model: Union[IO, str], is_file: bool = True) -> str: """Upload model object. :param model: The model object diff --git a/fedn/network/storage/s3/saasrepository.py b/fedn/network/storage/s3/saasrepository.py index 3c3663438..2b847939a 100644 --- a/fedn/network/storage/s3/saasrepository.py +++ b/fedn/network/storage/s3/saasrepository.py @@ -133,7 +133,7 @@ def get_artifact_stream(self, instance_name: str, bucket: str) -> io.BytesIO: try: response = self.s3_client.get_object(Bucket=bucket, Key=instance_name) - return io.BytesIO(response["Body"].read()) + return response["Body"] except (BotoCoreError, ClientError) as e: logger.error(f"Failed to fetch artifact stream: {instance_name} from bucket: {bucket}. Error: {e}") raise Exception(f"Could not fetch artifact stream: {e}") from e diff --git a/fedn/network/storage/statestore/stores/dto/session.py b/fedn/network/storage/statestore/stores/dto/session.py index ba46f2d42..73b08af8d 100644 --- a/fedn/network/storage/statestore/stores/dto/session.py +++ b/fedn/network/storage/statestore/stores/dto/session.py @@ -10,6 +10,7 @@ class SessionConfigDTO(DictDTO): aggregator: str = Field(None) aggregator_kwargs: Optional[str] = Field(None) round_timeout: int = Field(None) + accept_stragglers: Optional[bool] = Field(False) buffer_size: int = Field(None) rounds: Optional[int] = Field(None) delete_models_storage: bool = Field(None) diff --git a/fedn/network/storage/statestore/stores/dto/task.py b/fedn/network/storage/statestore/stores/dto/task.py new file mode 100644 index 000000000..024f3fec5 --- /dev/null +++ b/fedn/network/storage/statestore/stores/dto/task.py @@ -0,0 +1,43 @@ +from typing import Optional + +from google.protobuf.json_format import MessageToDict + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.network.storage.statestore.stores.dto.shared import BaseDTO, Field, PrimaryID + + +class Task(BaseDTO): + task_id: str = PrimaryID(None) + client_id: str = Field(None) + combiner_id: str = Field(None) + + type: str = Field(None) + parameters: dict = Field(None) + + model_id: Optional[str] = Field(None) + round_id: Optional[str] = Field(None) + session_id: Optional[str] = Field(None) + + +class TaskState(BaseDTO): + # Not best practive to use the same primary key for two different models + # but since it is unknown if TaskState will be moved to another store in the future this will do + task_id: str = PrimaryID(None) + + status: fedn.TaskStatus = Field(None) + response: dict = Field(None) + + def to_proto(self) -> fedn.TaskState: + task_state = fedn.TaskState() + task_state.task_id = self.task_id + task_state.status = self.status + task_state.response.update(self.response) + return task_state + + @classmethod + def from_proto(cls, proto: fedn.TaskState) -> "TaskState": + task_state = TaskState() + task_state.task_id = proto.task_id + task_state.status = proto.status + task_state.response = MessageToDict(proto.response, preserving_proto_field_name=True) + return task_state diff --git a/fedn/network/storage/statestore/stores/sql/shared.py b/fedn/network/storage/statestore/stores/sql/shared.py index d45ca5517..e4c415694 100644 --- a/fedn/network/storage/statestore/stores/sql/shared.py +++ b/fedn/network/storage/statestore/stores/sql/shared.py @@ -40,6 +40,7 @@ class SessionConfigModel(MyAbstractBase): aggregator: Mapped[str] = mapped_column(String(255)) aggregator_kwargs: Mapped[Optional[str]] round_timeout: Mapped[int] + accept_stragglers: Mapped[Optional[bool]] = mapped_column(default=False) buffer_size: Mapped[int] delete_models_storage: Mapped[bool] clients_required: Mapped[int] @@ -86,6 +87,7 @@ class RoundConfigModel(MyAbstractBase): aggregator: Mapped[str] = mapped_column(String(255)) aggregator_kwargs: Mapped[Optional[str]] round_timeout: Mapped[int] + accept_stragglers: Mapped[Optional[bool]] = mapped_column(default=False) buffer_size: Mapped[int] delete_models_storage: Mapped[bool] clients_required: Mapped[int] diff --git a/fedn/network/storage/statestore/stores/validation_store.py b/fedn/network/storage/statestore/stores/validation_store.py index b72a5e2f1..3a0586e6d 100644 --- a/fedn/network/storage/statestore/stores/validation_store.py +++ b/fedn/network/storage/statestore/stores/validation_store.py @@ -77,9 +77,9 @@ def _dto_from_orm_model(self, item: ValidationModel) -> ValidationDTO: sender_role = orm_dict.pop("sender_role") if sender_name is not None and sender_role is not None: orm_dict["sender"] = {"name": sender_name, "role": sender_role} - reciever_name = orm_dict.pop("receiver_name") + receiver_name = orm_dict.pop("receiver_name") receiver_role = orm_dict.pop("receiver_role") - if reciever_name is not None and receiver_role is not None: - orm_dict["receiver"] = {"name": reciever_name, "role": receiver_role} + if receiver_name is not None and receiver_role is not None: + orm_dict["receiver"] = {"name": receiver_name, "role": receiver_role} return ValidationDTO().populate_with(orm_dict) diff --git a/fedn/utils/checksum.py b/fedn/utils/checksum.py index 8ca678597..aa2b5bc97 100644 --- a/fedn/utils/checksum.py +++ b/fedn/utils/checksum.py @@ -14,3 +14,29 @@ def sha(fname): for chunk in iter(lambda: f.read(4096), b""): hash.update(chunk) return hash.hexdigest() + + +def compute_checksum_from_stream(stream): + """Compute the SHA256 checksum from a stream. + + :param stream: The stream to compute the checksum from. + :type stream: io.BytesIO or similar + :return: The SHA256 checksum as a string. + :rtype: str + """ + hash = hashlib.sha256() + for chunk in iter(lambda: stream.read(4096), b""): + hash.update(chunk) + return hash.hexdigest() + + +def compute_checksum_from_file(file_path): + """Compute the SHA256 checksum from a file. + + :param file_path: The path to the file. + :type file_path: str + :return: The SHA256 checksum as a string. + :rtype: str + """ + with open(file_path, "rb") as f: + return compute_checksum_from_stream(f) diff --git a/fedn/utils/dispatcher.py b/fedn/utils/dispatcher.py index d551b8053..766e40ce9 100644 --- a/fedn/utils/dispatcher.py +++ b/fedn/utils/dispatcher.py @@ -17,16 +17,9 @@ import os import shutil -import sys -import tempfile -import uuid from contextlib import contextmanager -from pathlib import Path - -import yaml from fedn.common.log_config import logger -from fedn.utils import PYTHON_VERSION from fedn.utils.environment import _PythonEnv from fedn.utils.process import _exec_cmd, _join_commands @@ -59,82 +52,16 @@ def _install_python(version, pyenv_root=None, capture_output=False): def _is_virtualenv_available(): - """Returns True if virtualenv is available, otherwise False. - """ + """Returns True if virtualenv is available, otherwise False.""" return shutil.which("virtualenv") is not None -def _validate_virtualenv_is_available(): - """Validates virtualenv is available. If not, throws an `Exception` with a brief instruction - on how to install virtualenv. - """ - if not _is_virtualenv_available(): - raise Exception("Could not find the virtualenv binary. Run `pip install virtualenv` to install " "virtualenv.") - - -def _get_virtualenv_extra_env_vars(env_root_dir=None): - extra_env = { - # PIP_NO_INPUT=1 makes pip run in non-interactive mode, - # otherwise pip might prompt "yes or no" and ask stdin input - "PIP_NO_INPUT": "1", - } - return extra_env - - -def _get_python_env(python_env_file): - """Parses a python environment file and returns a dictionary with the parsed content. - """ +def _get_python_env(python_env_file) -> _PythonEnv: + """Parses a python environment file and returns a dictionary with the parsed content.""" if os.path.exists(python_env_file): return _PythonEnv.from_yaml(python_env_file) -def _create_virtualenv(python_bin_path, env_dir, python_env, extra_env=None, capture_output=False): - # Created a command to activate the environment - paths = ("bin", "activate") if _IS_UNIX else ("Scripts", "activate.bat") - activate_cmd = env_dir.joinpath(*paths) - activate_cmd = f"source {activate_cmd}" if _IS_UNIX else str(activate_cmd) - - if env_dir.exists(): - logger.info("Environment %s already exists", env_dir) - return activate_cmd - - with remove_on_error( - env_dir, - onerror=lambda e: logger.warning( - "Encountered an unexpected error: %s while creating a virtualenv environment in %s, " "removing the environment directory...", - repr(e), - env_dir, - ), - ): - logger.info("Creating a new environment in %s with %s", env_dir, python_bin_path) - _exec_cmd( - [sys.executable, "-m", "virtualenv", "--python", python_bin_path, env_dir], - capture_output=capture_output, - ) - - logger.info("Installing dependencies") - for deps in filter(None, [python_env.build_dependencies, python_env.dependencies]): - with tempfile.TemporaryDirectory() as tmpdir: - tmp_req_file = f"requirements.{uuid.uuid4().hex}.txt" - Path(tmpdir).joinpath(tmp_req_file).write_text("\n".join(deps)) - cmd = _join_commands(activate_cmd, f"python -m pip install -r {tmp_req_file}") - _exec_cmd(cmd, capture_output=capture_output, cwd=tmpdir, extra_env=extra_env) - - return activate_cmd - - -def _read_yaml_file(file_path): - try: - cfg = None - with open(file_path, "rb") as config_file: - cfg = yaml.safe_load(config_file.read()) - - except Exception as e: - logger.error(f"Error trying to read yaml file: {file_path}") - raise e - return cfg - - class Dispatcher: """Dispatcher class for compute packages. @@ -151,55 +78,25 @@ def __init__(self, config, project_dir): self.activate_cmd = "" self.python_env_path = "" - def _get_or_create_python_env(self, capture_output=False, pip_requirements_override=None): + def get_or_create_python_env(self, capture_output=False): python_env = self.config.get("python_env", "") if not python_env: logger.info("No python_env specified in the configuration, using the system Python.") - return python_env + self.activate_cmd = "" + return self.activate_cmd else: - python_env_path = os.path.join(self.project_dir, python_env) - if not os.path.exists(python_env_path): - raise Exception("Compute package specified python_env file %s, but no such " "file was found." % python_env_path) - python_env = _get_python_env(python_env_path) - - extra_env = _get_virtualenv_extra_env_vars() - env_dir = Path(self.project_dir) / Path(python_env.name) - self.python_env_path = env_dir - try: - python_bin_path = _install_python(python_env.python, capture_output=True) - except NotImplementedError: - logger.warning("Failed to install Python: %s", python_env.python) - logger.warning("Python version installation is not implemented yet.") - logger.info(f"Using the system Python version: {PYTHON_VERSION}") - python_bin_path = Path(sys.executable) - - try: - activate_cmd = _create_virtualenv( - python_bin_path, - env_dir, - python_env, - extra_env=extra_env, - capture_output=capture_output, - ) - # Install additional dependencies specified by `requirements_override` - if pip_requirements_override: - logger.info("Installing additional dependencies specified by " f"pip_requirements_override: {pip_requirements_override}") - cmd = _join_commands( - activate_cmd, - f"python -m pip install --quiet -U {' '.join(pip_requirements_override)}", - ) - _exec_cmd(cmd, capture_output=capture_output, extra_env=extra_env) - self.activate_cmd = activate_cmd - return activate_cmd - except Exception: - logger.critical("Encountered unexpected error while creating %s", env_dir) - if env_dir.exists(): - logger.warning("Attempting to remove %s", env_dir) - shutil.rmtree(env_dir, ignore_errors=True) - msg = "Failed to remove %s" if env_dir.exists() else "Successfully removed %s" - logger.warning(msg, env_dir) - - raise + python_env_yaml_path = os.path.join(self.project_dir, python_env) + if not os.path.exists(python_env_yaml_path): + raise Exception("Compute package specified python_env file %s, but no such file was found." % python_env_yaml_path) + python_env = _get_python_env(python_env_yaml_path) + + python_env.set_base_path(self.project_dir) + if not python_env.path.exists(): + python_env.create_virtualenv(capture_output=capture_output, use_system_site_packages=True) + + self.activate_cmd = python_env.get_activate_cmd() + self.python_env_path = python_env.path + return self.activate_cmd def run_cmd(self, cmd_type, capture_output=False, extra_env=None, synchronous=True, stream_output=False): """Run a command. diff --git a/fedn/utils/environment.py b/fedn/utils/environment.py index 03d93eae7..f05caeaeb 100644 --- a/fedn/utils/environment.py +++ b/fedn/utils/environment.py @@ -15,17 +15,28 @@ limitations under the License. """ +import hashlib +import os +import shutil +import sys +import tempfile +import uuid +from contextlib import contextmanager +from pathlib import Path + import yaml +from fedn.common.log_config import logger from fedn.utils import PYTHON_VERSION +from fedn.utils.process import _exec_cmd, _join_commands _REQUIREMENTS_FILE_NAME = "requirements.txt" _PYTHON_ENV_FILE_NAME = "python_env.yaml" +_PYTHON_ENV_METADATA_FILE_NAME = "env_metadata.txt" +_IS_UNIX = os.name != "nt" class _PythonEnv: - BUILD_PACKAGES = ("pip", "setuptools", "wheel") - def __init__(self, name=None, python=None, build_dependencies=None, dependencies=None): """Represents environment information for FEDn compute packages. @@ -49,37 +60,39 @@ def __init__(self, name=None, python=None, build_dependencies=None, dependencies raise TypeError(f"`build_dependencies` must be a list but got {type(build_dependencies)}") if dependencies is not None and not isinstance(dependencies, list): raise TypeError(f"`dependencies` must be a list but got {type(dependencies)}") - self.name = name or "fedn_env" + self._name = name or "fedn_env" self.python = python or PYTHON_VERSION self.build_dependencies = build_dependencies or [] self.dependencies = dependencies or [] + self._base_path = None + + @property + def name(self): + """Name of the environment.""" + return os.path.join(self._name, self.get_sha()) + + def set_base_path(self, path): + self._base_path = path + + @property + def path(self) -> Path: + """Get the full path to the environment.""" + if not self._base_path: + raise ValueError("Base path is not set. Use `set_base_path` to set it.") + return Path(self._base_path).joinpath(self.name) def __str__(self): return str(self.to_dict()) - @classmethod - def current(cls): - return cls( - python=PYTHON_VERSION, - build_dependencies=cls.get_current_build_dependencies(), - dependencies=[f"-r {_REQUIREMENTS_FILE_NAME}"], - ) + def get_sha(self): + """Returns a SHA256 hash of the environment configuration.""" + env_str = str(self.to_dict()).encode("utf-8") + return hashlib.sha256(env_str).hexdigest() - @staticmethod - def _get_package_version(package_name): - try: - return __import__(package_name).__version__ - except (ImportError, AttributeError, AssertionError): - return None - - @staticmethod - def get_current_build_dependencies(): - build_dependencies = [] - for package in _PythonEnv.BUILD_PACKAGES: - version = _PythonEnv._get_package_version(package) - dep = (package + "==" + version) if version else package - build_dependencies.append(dep) - return build_dependencies + def remove_fedndependency(self): + """Remove 'fedn' from dependencies if it exists.""" + self.dependencies = [dep for dep in self.dependencies if dep != "fedn"] + self.build_dependencies = [dep for dep in self.build_dependencies if dep != "fedn"] def to_dict(self): return self.__dict__.copy() @@ -106,3 +119,90 @@ def get_dependencies_from_conda_yaml(path): @classmethod def from_conda_yaml(cls, path): return cls.from_dict(cls.get_dependencies_from_conda_yaml(path)) + + def get_activate_cmd(self): + """Get the command to activate the environment.""" + paths = ("bin", "activate") if _IS_UNIX else ("Scripts", "activate.bat") + activate_cmd = self.path.joinpath(*paths) + activate_cmd = f"source {activate_cmd}" if _IS_UNIX else str(activate_cmd) + return activate_cmd + + def create_virtualenv(self, capture_output=False, use_system_site_packages=False): + # Created a command to activate the environment + env_dir = self.path + + activate_cmd = self.get_activate_cmd() + + if env_dir.exists(): + logger.info("Environment %s already exists", env_dir) + return True + + with remove_on_error( + env_dir, + onerror=lambda e: logger.warning( + "Encountered an unexpected error: %s while creating a virtualenv environment in %s, removing the environment directory...", + repr(e), + env_dir, + ), + ): + os.makedirs(env_dir, exist_ok=True) + logger.info("Creating a new environment in %s with %s", env_dir, sys.executable) + _exec_cmd( + [sys.executable, "-m", "virtualenv", "--python", sys.executable] + ["--system-site-packages" if use_system_site_packages else ""] + [env_dir], + capture_output=capture_output, + ) + + extra_env = { + # PIP_NO_INPUT=1 makes pip run in non-interactive mode, + # otherwise pip might prompt "yes or no" and ask stdin input + "PIP_NO_INPUT": "1", + } + + logger.info("Installing dependencies") + for deps in filter(None, [self.build_dependencies, self.dependencies]): + with tempfile.TemporaryDirectory() as tmpdir: + tmp_req_file = f"requirements.{uuid.uuid4().hex}.txt" + Path(tmpdir).joinpath(tmp_req_file).write_text("\n".join(deps)) + cmd = _join_commands(activate_cmd, f"python -m pip install -r {tmp_req_file}") + _exec_cmd(cmd, capture_output=capture_output, cwd=tmpdir, extra_env=extra_env) + + build_deps = "\n".join(self.build_dependencies or []) + deps = "\n".join(self.dependencies or []) + Path(env_dir).joinpath(_PYTHON_ENV_METADATA_FILE_NAME).write_text(f"{self.get_sha()}\n{self.python}\n{build_deps}\n{deps}") + + return True + + def verify_installed_env(self): + """Check if the environment metadata file exists and matches the current environment.""" + metadata_file = self.path.joinpath(_PYTHON_ENV_METADATA_FILE_NAME) + if not metadata_file.exists(): + return False + + with open(metadata_file) as f: + sha, python_version, *deps = f.read().splitlines() + if sha != self.get_sha() or python_version != self.python: + return False + + # Check if dependencies match + if set(deps) != set(self.build_dependencies + self.dependencies): + return False + + return True + + +@contextmanager +def remove_on_error(path: os.PathLike, onerror=None): + """A context manager that removes a file or directory if an exception is raised during + execution. + """ + try: + yield + except Exception as e: + if onerror: + onerror(e) + if os.path.exists(path): + if os.path.isfile(path): + os.remove(path) + elif os.path.isdir(path): + shutil.rmtree(path) + raise diff --git a/fedn/utils/helpers/plugins/androidhelper.py b/fedn/utils/helpers/plugins/androidhelper.py index 119801b8c..711e81d70 100644 --- a/fedn/utils/helpers/plugins/androidhelper.py +++ b/fedn/utils/helpers/plugins/androidhelper.py @@ -70,8 +70,12 @@ def save(self, weights, path=None): path = self.get_tmp_path() byte_array = struct.pack("f" * len(weights), *weights) - with open(path, "wb") as file: - file.write(byte_array) + if isinstance(path, str): + with open(path, "wb") as file: + file.write(byte_array) + else: + # If path is a file-like object, write to it directly + path.write(byte_array) return path diff --git a/fedn/utils/helpers/plugins/splitlearninghelper.py b/fedn/utils/helpers/plugins/splitlearninghelper.py index 2140ae85f..1e7650029 100644 --- a/fedn/utils/helpers/plugins/splitlearninghelper.py +++ b/fedn/utils/helpers/plugins/splitlearninghelper.py @@ -34,8 +34,7 @@ def save(self, data_dict, path=None, file_type="npz"): # Ensure all values are numpy arrays processed_dict = {str(k): np.array(v) for k, v in data_dict.items()} - with open(path, "wb") as f: - np.savez_compressed(f, **processed_dict) + np.savez_compressed(path, **processed_dict) return path diff --git a/fedn/utils/model.py b/fedn/utils/model.py new file mode 100644 index 000000000..cb8c54a1d --- /dev/null +++ b/fedn/utils/model.py @@ -0,0 +1,123 @@ +import tempfile +import threading +from typing import BinaryIO, Iterable + +import fedn.network.grpc.fedn_pb2 as fedn +from fedn.utils.checksum import compute_checksum_from_stream +from fedn.utils.helpers.plugins.numpyhelper import Helper + +CHUNK_SIZE = 1 * 1024 * 1024 # 8 KB chunk size for reading/writing files +SPOOLED_MAX_SIZE = 10 * 1024 * 1024 # 10 MB max size for spooled temporary files + + +class FednModel: + """The FednModel class is the primary model representation in the FEDn framework. + A FednModel object contains a data object (tempfile.SpooledTemporaryFile) that holds the model parameters. + The model parameters dict can be extracted from the data object or be used to create a model object. + Unpacking of the model parameters is done by the helper which needs to be provided either to the the class or + to the method + """ + + def __init__(self): + """Initializes a FednModel object.""" + # Using SpooledTemporaryFile to handle large model data efficiently + # It will automatically store on disk if the data exceeds the specified size + self._data = tempfile.SpooledTemporaryFile(SPOOLED_MAX_SIZE) + self._data_lock = threading.RLock() + self.model_id = None + self.helper = None + self._checksum = None + + @property + def checksum(self) -> str: + """Returns the checksum of the model data.""" + if self._checksum is None: + self._checksum = compute_checksum_from_stream(self.get_stream()) + return self._checksum + + def verify_checksum(self, checksum: str) -> bool: + """Verifies the checksum of the model data. + + If no checksum is provided, it returns True. + """ + return checksum is None or self.checksum == checksum + + def get_stream(self): + """Returns a stream of the model data. + + To avoid concurrency issues, a new stream is created each time this method is called. + """ + with self._data_lock: + self._data.seek(0) + new_stream = tempfile.SpooledTemporaryFile(SPOOLED_MAX_SIZE) + while chunk := self._data.read(CHUNK_SIZE): + new_stream.write(chunk) + new_stream.seek(0) + self._data.seek(0) + return new_stream + + def get_stream_unsafe(self): + """Returns the internal stream of the model data. + + This method is not thread-safe and should be used with caution. + """ + self._data.seek(0) + return self._data + + def get_model_params(self, helper=None): + """Returns the model parameters as a dictionary.""" + stream = self.get_stream() + self.helper = helper or self.helper + if self.helper is None: + raise ValueError("No helper provided to unpack model parameters.") + return self.helper.load(stream) + + def save_to_file(self, file_path: str): + """Saves the model data to a file.""" + with open(file_path, "wb") as file: + stream = self.get_stream() + while chunk := stream.read(CHUNK_SIZE): + file.write(chunk) + + def get_filechunk_stream(self, chunk_size=CHUNK_SIZE): + """Returns a generator that yields chunks of the model data.""" + stream = self.get_stream() + while chunk := stream.read(chunk_size): + yield fedn.FileChunk(data=chunk) + + @staticmethod + def from_model_params(model_params: dict, helper=None) -> "FednModel": + """Creates a FednModel from model parameters.""" + model_reference = FednModel() + model_reference.helper = helper + if helper is None: + # No helper provided, using numpy helper as default + helper = Helper() + helper.save(model_params, model_reference._data) + model_reference._data.seek(0) + return model_reference + + @staticmethod + def from_file(file_path: str) -> "FednModel": + """Creates a FednModel from a file.""" + with open(file_path, "rb") as file: + return FednModel.from_stream(file) + + @staticmethod + def from_stream(stream: BinaryIO) -> "FednModel": + """Creates a FednModel from a stream.""" + model_reference = FednModel() + while chunk := stream.read(CHUNK_SIZE): + model_reference._data.write(chunk) + model_reference._data.seek(0) + return model_reference + + @staticmethod + def from_filechunk_stream(filechunk_stream: Iterable[fedn.FileChunk]) -> "FednModel": + """Creates a FednModel from a filechunk stream.""" + model_reference = FednModel() + for chunk in filechunk_stream: + if chunk.data: + model_reference._data.write(chunk.data) + model_reference._data.seek(0) + return model_reference diff --git a/fedn/utils/yaml.py b/fedn/utils/yaml.py new file mode 100644 index 000000000..be5dd2fcc --- /dev/null +++ b/fedn/utils/yaml.py @@ -0,0 +1,15 @@ +import yaml + +from fedn.common.log_config import logger + + +def read_yaml_file(file_path): + try: + cfg = None + with open(file_path, "rb") as config_file: + cfg = yaml.safe_load(config_file.read()) + + except Exception as e: + logger.error(f"Error trying to read yaml file: {file_path}") + raise e + return cfg diff --git a/pyproject.toml b/pyproject.toml index 47b008b05..557b07050 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" [project] name = "fedn" -version = "0.30.0" +version = "0.33.0" description = "Scaleout Federated Learning" authors = [{ name = "Scaleout Systems AB", email = "contact@scaleoutsystems.com" }] readme = "README.rst" @@ -110,6 +110,7 @@ exclude = [ ".mnist-pytorch", "fedn_pb2.py", "fedn_pb2_grpc.py", + "fedn_pb2.pyi", ".ci", "test*", "**/*.ipynb"