diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4a5268ed5..48d882884 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,14 @@ repos: args: - --license-filepath - LICENSE.md - + - repo: local + hooks: + - id: pyright + name: pyright + entry: poetry run pyright + language: system + types: [python] + pass_filenames: false # Deactivating this for now. # - repo: https://github.com/pycqa/pylint # rev: v2.17.0 diff --git a/nemoguardrails/actions/llm/generation.py b/nemoguardrails/actions/llm/generation.py index 2a57e1c26..cd11e70a7 100644 --- a/nemoguardrails/actions/llm/generation.py +++ b/nemoguardrails/actions/llm/generation.py @@ -137,7 +137,7 @@ async def init(self): self._init_flows_index(), ) - def _extract_user_message_example(self, flow: Flow): + def _extract_user_message_example(self, flow: Flow) -> None: """Heuristic to extract user message examples from a flow.""" elements = [ item diff --git a/nemoguardrails/actions_server/actions_server.py b/nemoguardrails/actions_server/actions_server.py index 58d49437b..e45131a46 100644 --- a/nemoguardrails/actions_server/actions_server.py +++ b/nemoguardrails/actions_server/actions_server.py @@ -16,7 +16,7 @@ import logging from typing import Dict, Optional -from fastapi import FastAPI +from fastapi import Depends, FastAPI from pydantic import BaseModel, Field from nemoguardrails.actions.action_dispatcher import ActionDispatcher @@ -34,7 +34,12 @@ # Create action dispatcher object to communicate with actions -app.action_dispatcher = ActionDispatcher(load_all_actions=True) +_action_dispatcher = ActionDispatcher(load_all_actions=True) + + +def get_action_dispatcher() -> ActionDispatcher: + """Dependency to provide the action dispatcher instance.""" + return _action_dispatcher class RequestBody(BaseModel): @@ -58,22 +63,26 @@ class ResponseBody(BaseModel): summary="Execute action", response_model=ResponseBody, ) -async def run_action(body: RequestBody): +async def run_action( + body: RequestBody, + action_dispatcher: ActionDispatcher = Depends(get_action_dispatcher), +): """Execute the specified action and return the result. Args: body (RequestBody): The request body containing action_name and action_parameters. + action_dispatcher (ActionDispatcher): The action dispatcher dependency. Returns: dict: The response containing the execution status and result. """ - log.info(f"Request body: {body}") - result, status = await app.action_dispatcher.execute_action( + log.info("Request body: %s", body) + result, status = await action_dispatcher.execute_action( body.action_name, body.action_parameters ) resp = {"status": status, "result": result} - log.info(f"Response: {resp}") + log.info("Response: %s", resp) return resp @@ -81,7 +90,9 @@ async def run_action(body: RequestBody): "/v1/actions/list", summary="List available actions", ) -async def get_actions_list(): +async def get_actions_list( + action_dispatcher: ActionDispatcher = Depends(get_action_dispatcher), +): """Returns the list of available actions.""" - return app.action_dispatcher.get_registered_actions() + return action_dispatcher.get_registered_actions() diff --git a/poetry.lock b/poetry.lock index 6942217f3..0ac9c487b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -4050,6 +4050,26 @@ files = [ [package.extras] dev = ["build", "flake8", "mypy", "pytest", "twine"] +[[package]] +name = "pyright" +version = "1.1.405" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.405-py3-none-any.whl", hash = "sha256:a2cb13700b5508ce8e5d4546034cb7ea4aedb60215c6c33f56cec7f53996035a"}, + {file = "pyright-1.1.405.tar.gz", hash = "sha256:5c2a30e1037af27eb463a1cc0b9f6d65fec48478ccf092c1ac28385a15c55763"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + [[package]] name = "pytest" version = "8.3.4" @@ -6190,4 +6210,4 @@ tracing = ["aiofiles", "opentelemetry-api"] [metadata] lock-version = "2.0" python-versions = ">=3.9,!=3.9.7,<3.14" -content-hash = "6654d6115d5142024695ff1a736cc3d133842421b1282f5c3ba413b6a0250118" +content-hash = "313705d475a9cb177efa633c193da9315388aa99832b9c5b429fafb5b3da44b0" diff --git a/pyproject.toml b/pyproject.toml index 6200d0ca3..4be691eff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -151,8 +151,13 @@ pytest-profiling = "^1.7.0" yara-python = "^4.5.1" opentelemetry-api = "^1.34.1" opentelemetry-sdk = "^1.34.1" +pyright = "^1.1.405" +# Directories in which to run Pyright type-checking +[tool.pyright] +include = [] + [tool.poetry.group.docs] optional = true