From 0019374708cc6ac5d0752bc67fddd38d8d63dec4 Mon Sep 17 00:00:00 2001 From: Boburmirzo Date: Wed, 16 Jul 2025 22:07:38 +0300 Subject: [PATCH 1/3] add gibsonai tool --- toolkits/gibsonai/.pre-commit-config.yaml | 18 ++ toolkits/gibsonai/.ruff.toml | 47 ++++ toolkits/gibsonai/Makefile | 55 +++++ toolkits/gibsonai/arcade_gibsonai/__init__.py | 5 + .../gibsonai/arcade_gibsonai/api_client.py | 73 +++++++ .../gibsonai/arcade_gibsonai/constants.py | 8 + .../arcade_gibsonai/tools/__init__.py | 5 + .../gibsonai/arcade_gibsonai/tools/query.py | 48 ++++ toolkits/gibsonai/evals/eval_gibsonai.py | 206 ++++++++++++++++++ toolkits/gibsonai/pyproject.toml | 59 +++++ toolkits/gibsonai/tests/__init__.py | 0 toolkits/gibsonai/tests/test_gibsonai.py | 191 ++++++++++++++++ 12 files changed, 715 insertions(+) create mode 100644 toolkits/gibsonai/.pre-commit-config.yaml create mode 100644 toolkits/gibsonai/.ruff.toml create mode 100644 toolkits/gibsonai/Makefile create mode 100644 toolkits/gibsonai/arcade_gibsonai/__init__.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/api_client.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/constants.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/tools/__init__.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/tools/query.py create mode 100644 toolkits/gibsonai/evals/eval_gibsonai.py create mode 100644 toolkits/gibsonai/pyproject.toml create mode 100644 toolkits/gibsonai/tests/__init__.py create mode 100644 toolkits/gibsonai/tests/test_gibsonai.py diff --git a/toolkits/gibsonai/.pre-commit-config.yaml b/toolkits/gibsonai/.pre-commit-config.yaml new file mode 100644 index 000000000..f63210621 --- /dev/null +++ b/toolkits/gibsonai/.pre-commit-config.yaml @@ -0,0 +1,18 @@ +files: ^.*/gibsonai/.* +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: "v4.4.0" + hooks: + - id: check-case-conflict + - id: check-merge-conflict + - id: check-toml + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.6.7 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format \ No newline at end of file diff --git a/toolkits/gibsonai/.ruff.toml b/toolkits/gibsonai/.ruff.toml new file mode 100644 index 000000000..fc0d98a2f --- /dev/null +++ b/toolkits/gibsonai/.ruff.toml @@ -0,0 +1,47 @@ +target-version = "py310" +line-length = 100 +fix = true + +[lint] +select = [ + # flake8-2020 + "YTT", + # flake8-bandit + "S", + # flake8-bugbear + "B", + # flake8-builtins + "A", + # flake8-comprehensions + "C4", + # flake8-debugger + "T10", + # flake8-simplify + "SIM", + # isort + "I", + # mccabe + "C90", + # pycodestyle + "E", "W", + # pyflakes + "F", + # pygrep-hooks + "PGH", + # pyupgrade + "UP", + # ruff + "RUF", + # tryceratops + "TRY", +] + +[lint.per-file-ignores] +"*" = ["TRY003", "B904"] +"**/tests/*" = ["S101", "E501"] +"**/evals/*" = ["S101", "E501"] + + +[format] +preview = true +skip-magic-trailing-comma = false \ No newline at end of file diff --git a/toolkits/gibsonai/Makefile b/toolkits/gibsonai/Makefile new file mode 100644 index 000000000..95f08428f --- /dev/null +++ b/toolkits/gibsonai/Makefile @@ -0,0 +1,55 @@ +.PHONY: help + +help: + @echo "🛠️ github Commands:\n" + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' + +.PHONY: install +install: ## Install the uv environment and install all packages with dependencies + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras --no-sources + @if [ -f .pre-commit-config.yaml ]; then uv run --no-sources pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: install-local +install-local: ## Install the uv environment and install all packages with dependencies with local Arcade sources + @echo "🚀 Creating virtual environment and installing all packages using uv" + @uv sync --active --all-extras + @if [ -f .pre-commit-config.yaml ]; then uv run pre-commit install; fi + @echo "✅ All packages and dependencies installed via uv" + +.PHONY: build +build: clean-build ## Build wheel file using poetry + @echo "🚀 Creating wheel file" + uv build + +.PHONY: clean-build +clean-build: ## clean build artifacts + @echo "🗑️ Cleaning dist directory" + rm -rf dist + +.PHONY: test +test: ## Test the code with pytest + @echo "🚀 Testing code: Running pytest" + @uv run --no-sources pytest -W ignore -v --cov --cov-config=pyproject.toml --cov-report=xml + +.PHONY: coverage +coverage: ## Generate coverage report + @echo "coverage report" + @uv run --no-sources coverage report + @echo "Generating coverage report" + @uv run --no-sources coverage html + +.PHONY: bump-version +bump-version: ## Bump the version in the pyproject.toml file by a patch version + @echo "🚀 Bumping version in pyproject.toml" + uv version --no-sources --bump patch + +.PHONY: check +check: ## Run code quality tools. + @if [ -f .pre-commit-config.yaml ]; then\ + echo "🚀 Linting code: Running pre-commit";\ + uv run --no-sources pre-commit run -a;\ + fi + @echo "🚀 Static type checking: Running mypy" + @uv run --no-sources mypy --config-file=pyproject.toml \ No newline at end of file diff --git a/toolkits/gibsonai/arcade_gibsonai/__init__.py b/toolkits/gibsonai/arcade_gibsonai/__init__.py new file mode 100644 index 000000000..6415fc414 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/__init__.py @@ -0,0 +1,5 @@ +"""GibsonAI Database Tools for Arcade.""" + +from arcade_gibsonai.tools.query import execute_query + +__all__ = ["execute_query"] diff --git a/toolkits/gibsonai/arcade_gibsonai/api_client.py b/toolkits/gibsonai/arcade_gibsonai/api_client.py new file mode 100644 index 000000000..89aa8ac20 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/api_client.py @@ -0,0 +1,73 @@ +from typing import Any, Dict, List, Optional +import httpx +from pydantic import BaseModel + +from .constants import API_BASE_URL, API_VERSION, MAX_ROWS_RETURNED + + +class GibsonAIResponse(BaseModel): + """Response model for GibsonAI API.""" + + data: List[Dict[str, Any]] + success: bool + error: Optional[str] = None + + +class GibsonAIClient: + """Client for interacting with GibsonAI Data API.""" + + def __init__(self, api_key: str): + self.api_key = api_key + self.base_url = f"{API_BASE_URL}/{API_VERSION}" + self.headers = {"Content-Type": "application/json", "X-Gibson-API-Key": api_key} + + async def execute_query( + self, query: str, params: Optional[List[Any]] = None + ) -> List[str]: + """Execute a query against GibsonAI database.""" + if params is None: + params = [] + + payload = {"array_mode": False, "params": params, "query": query} + + try: + async with httpx.AsyncClient() as client: + response = await client.post( + f"{self.base_url}/-/query", + headers=self.headers, + json=payload, + timeout=30.0, + ) + + if response.status_code != 200: + error_msg = f"HTTP {response.status_code}: {response.text}" + raise Exception(f"GibsonAI API error: {error_msg}") + + result = response.json() + + # Handle different response formats + if isinstance(result, dict): + if "error" in result and result["error"]: + raise Exception(f"GibsonAI query error: {result['error']}") + elif "data" in result: + results = [str(row) for row in result["data"]] + else: + results = [str(result)] + elif isinstance(result, list): + results = [str(row) for row in result] + else: + results = [str(result)] + + # Limit results to avoid memory issues + return results[:MAX_ROWS_RETURNED] + + except httpx.TimeoutException: + raise Exception("Request timeout - GibsonAI API took too long to respond") + except httpx.RequestError as e: + raise Exception(f"Network error connecting to GibsonAI API: {e}") + except Exception as e: + # Re-raise if it's already our custom exception + if "GibsonAI" in str(e): + raise + else: + raise Exception(f"Unexpected error: {e}") diff --git a/toolkits/gibsonai/arcade_gibsonai/constants.py b/toolkits/gibsonai/arcade_gibsonai/constants.py new file mode 100644 index 000000000..f5ed625dd --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/constants.py @@ -0,0 +1,8 @@ +"""Constants for GibsonAI API configuration.""" + +# API configuration +API_BASE_URL = "https://api.gibsonai.com" +API_VERSION = "v1" + +# Maximum number of rows to return from queries +MAX_ROWS_RETURNED = 1000 diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py b/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py new file mode 100644 index 000000000..73c7a6cf8 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py @@ -0,0 +1,5 @@ +"""GibsonAI database query tools.""" + +from arcade_gibsonai.tools.query import execute_query + +__all__ = ["execute_query"] diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/query.py b/toolkits/gibsonai/arcade_gibsonai/tools/query.py new file mode 100644 index 000000000..768970764 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/tools/query.py @@ -0,0 +1,48 @@ +from typing import Annotated + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import RetryableToolError + +from ..api_client import GibsonAIClient + + +@tool(requires_secrets=["GIBSONAI_API_KEY"]) +async def execute_query( + context: ToolContext, + query: Annotated[ + str, + "The SQL query to execute against GibsonAI project database. Supports all SQL operations including SELECT, INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, etc.", + ], +) -> list[str]: + """ + Execute a SQL query and return the results from the GibsonAI project relational database. + + This tool supports all SQL operations including: + * SELECT queries for data retrieval + * INSERT, UPDATE, DELETE for data manipulation + * CREATE, ALTER, DROP for schema management + * Any other valid SQL statements + + When running queries, follow these rules which will help avoid errors: + * First discover the database schema in the GibsonAI project database when schema is not known. + * Discover all the tables in the database when the list of tables is not known. + * Always use case-insensitive queries to match strings in the query. + * Always trim strings in the query. + * Prefer LIKE queries over direct string matches or regex queries. + * Only join on columns that are indexed or the primary key. + + For SELECT queries, unless otherwise specified, ensure that query has a LIMIT of 100 for all results. + """ + api_key = context.get_secret("GIBSONAI_API_KEY") + client = GibsonAIClient(api_key) + + try: + results = await client.execute_query(query) + return results + except Exception as e: + raise RetryableToolError( + f"Query failed: {e}", + developer_message=f"Query '{query}' failed against GibsonAI database.", + additional_prompt_content="Please check your query syntax and try again.", + retry_after_ms=10, + ) from e diff --git a/toolkits/gibsonai/evals/eval_gibsonai.py b/toolkits/gibsonai/evals/eval_gibsonai.py new file mode 100644 index 000000000..c3e756dac --- /dev/null +++ b/toolkits/gibsonai/evals/eval_gibsonai.py @@ -0,0 +1,206 @@ +from arcade_tdk import ToolCatalog +from arcade_evals import ( + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_evals.critic import SimilarityCritic + +import arcade_gibsonai +from arcade_gibsonai.tools.query import execute_query + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.85, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_module(arcade_gibsonai) + + +@tool_eval() +def gibsonai_eval_suite() -> EvalSuite: + suite = EvalSuite( + name="GibsonAI Database Tools Evaluation", + system_message=( + "You are an AI assistant with access to GibsonAI database tools. " + "Use them to help the user execute queries against the GibsonAI database. " + "You can perform all SQL operations including SELECT, INSERT, UPDATE, DELETE, " + "CREATE, ALTER, DROP, and other schema management operations." + ), + catalog=catalog, + rubric=rubric, + ) + + # SELECT query test + suite.add_case( + name="Execute SELECT query", + user_message="Can you run a simple SELECT query to get the current timestamp?", + expected_tool_calls=[ + ExpectedToolCall(func=execute_query, args={"query": "SELECT NOW()"}) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + {"role": "user", "content": "I need to test the database connection."}, + { + "role": "assistant", + "content": "I'll help you test the database connection by running a simple query.", + }, + ], + ) + + # INSERT query test + suite.add_case( + name="Execute INSERT query", + user_message="Insert a new user with name 'John Doe' and email 'john@example.com' into the users table.", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, + args={ + "query": "INSERT INTO users (name, email) VALUES ('John Doe', 'john@example.com')" + }, + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + {"role": "user", "content": "I need to add a new user to the database."}, + { + "role": "assistant", + "content": "I'll help you insert a new user into the users table.", + }, + ], + ) + + # UPDATE query test + suite.add_case( + name="Execute UPDATE query", + user_message="Update the user with ID 1 to change their email to 'newemail@example.com'.", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, + args={ + "query": "UPDATE users SET email = 'newemail@example.com' WHERE id = 1" + }, + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + {"role": "user", "content": "I need to update a user's email address."}, + { + "role": "assistant", + "content": "I'll help you update the user's email in the database.", + }, + ], + ) + + # DELETE query test + suite.add_case( + name="Execute DELETE query", + user_message="Delete the user with ID 5 from the users table.", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, args={"query": "DELETE FROM users WHERE id = 5"} + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + {"role": "user", "content": "I need to remove a user from the database."}, + { + "role": "assistant", + "content": "I'll help you delete the user from the users table.", + }, + ], + ) + + # CREATE TABLE query test + suite.add_case( + name="Execute CREATE TABLE query", + user_message="Create a new table called 'products' with columns: id (integer primary key), name (varchar), price (decimal).", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, + args={ + "query": "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255), price DECIMAL(10,2))" + }, + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + { + "role": "user", + "content": "I need to create a new table for storing product information.", + }, + { + "role": "assistant", + "content": "I'll help you create a products table with the specified columns.", + }, + ], + ) + + # ALTER TABLE query test + suite.add_case( + name="Execute ALTER TABLE query", + user_message="Add a new column 'description' of type TEXT to the products table.", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, + args={"query": "ALTER TABLE products ADD COLUMN description TEXT"}, + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + { + "role": "user", + "content": "I need to add a description column to the products table.", + }, + { + "role": "assistant", + "content": "I'll help you alter the products table to add a description column.", + }, + ], + ) + + # DROP TABLE query test + suite.add_case( + name="Execute DROP TABLE query", + user_message="Drop the temporary_data table as it's no longer needed.", + expected_tool_calls=[ + ExpectedToolCall( + func=execute_query, args={"query": "DROP TABLE temporary_data"} + ) + ], + rubric=rubric, + critics=[ + SimilarityCritic(critic_field="query", weight=0.8), + ], + additional_messages=[ + {"role": "user", "content": "I need to remove the temporary_data table."}, + { + "role": "assistant", + "content": "I'll help you drop the temporary_data table.", + }, + ], + ) + + return suite diff --git a/toolkits/gibsonai/pyproject.toml b/toolkits/gibsonai/pyproject.toml new file mode 100644 index 000000000..e8249eb4a --- /dev/null +++ b/toolkits/gibsonai/pyproject.toml @@ -0,0 +1,59 @@ +[build-system] +requires = [ "hatchling",] +build-backend = "hatchling.build" + +[project] +name = "arcade_gibsonai" +version = "0.1.0" +description = "Enable agents evolve database schemas and run SQL queries against multiple relational databases." +requires-python = ">=3.10" +dependencies = [ + "arcade-tdk>=2.0.0,<3.0.0", + "httpx>=0.24.0,<1.0.0", + "pydantic>=2.0.0,<3.0.0", +] +[[project.authors]] +name = "Boburmirzo" +email = "boburmirzo.umurzokov@gmail.com" + + +[project.optional-dependencies] +dev = [ + "arcade-ai[evals]>=2.0.6,<3.0.0", + "arcade-serve>=2.0.0,<3.0.0", + "pytest>=8.3.0,<8.4.0", + "pytest-cov>=4.0.0,<4.1.0", + "pytest-mock>=3.11.1,<3.12.0", + "pytest-asyncio>=0.24.0,<0.25.0", + "mypy>=1.5.1,<1.6.0", + "pre-commit>=3.4.0,<3.5.0", + "tox>=4.11.1,<4.12.0", + "ruff>=0.7.4,<0.8.0", +] + +# Use local path sources for arcade libs when working locally +[tool.uv.sources] +arcade-ai = { path = "../../", editable = true } +arcade-serve = { path = "../../libs/arcade-serve/", editable = true } +arcade-tdk = { path = "../../libs/arcade-tdk/", editable = true } + +[tool.mypy] +files = [ "arcade_gibsonai/**/*.py",] +python_version = "3.10" +disallow_untyped_defs = "True" +disallow_any_unimported = "True" +no_implicit_optional = "True" +check_untyped_defs = "True" +warn_return_any = "True" +warn_unused_ignores = "True" +show_error_codes = "True" +ignore_missing_imports = "True" + +[tool.pytest.ini_options] +testpaths = [ "tests",] + +[tool.coverage.report] +skip_empty = true + +[tool.hatch.build.targets.wheel] +packages = [ "arcade_gibsonai",] diff --git a/toolkits/gibsonai/tests/__init__.py b/toolkits/gibsonai/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/toolkits/gibsonai/tests/test_gibsonai.py b/toolkits/gibsonai/tests/test_gibsonai.py new file mode 100644 index 000000000..1797e3367 --- /dev/null +++ b/toolkits/gibsonai/tests/test_gibsonai.py @@ -0,0 +1,191 @@ +import pytest +from unittest.mock import AsyncMock, patch +from arcade_tdk import ToolContext +from arcade_tdk.errors import RetryableToolError + +from arcade_gibsonai.tools.query import execute_query + + +@pytest.mark.asyncio +async def test_execute_select_query(): + """Test successful SELECT query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["result1", "result2"] + mock_client_class.return_value = mock_client + + result = await execute_query(mock_context, "SELECT * FROM users LIMIT 10") + + assert result == ["result1", "result2"] + mock_client.execute_query.assert_called_once_with( + "SELECT * FROM users LIMIT 10" + ) + + +@pytest.mark.asyncio +async def test_execute_insert_query(): + """Test successful INSERT query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["1 row inserted"] + mock_client_class.return_value = mock_client + + result = await execute_query( + mock_context, + "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')", + ) + + assert result == ["1 row inserted"] + mock_client.execute_query.assert_called_once_with( + "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')" + ) + + +@pytest.mark.asyncio +async def test_execute_update_query(): + """Test successful UPDATE query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["1 row updated"] + mock_client_class.return_value = mock_client + + result = await execute_query( + mock_context, "UPDATE users SET email = 'new@example.com' WHERE id = 1" + ) + + assert result == ["1 row updated"] + mock_client.execute_query.assert_called_once_with( + "UPDATE users SET email = 'new@example.com' WHERE id = 1" + ) + + +@pytest.mark.asyncio +async def test_execute_delete_query(): + """Test successful DELETE query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["1 row deleted"] + mock_client_class.return_value = mock_client + + result = await execute_query(mock_context, "DELETE FROM users WHERE id = 5") + + assert result == ["1 row deleted"] + mock_client.execute_query.assert_called_once_with( + "DELETE FROM users WHERE id = 5" + ) + + +@pytest.mark.asyncio +async def test_execute_create_table_query(): + """Test successful CREATE TABLE query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["Table created successfully"] + mock_client_class.return_value = mock_client + + result = await execute_query( + mock_context, + "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255))", + ) + + assert result == ["Table created successfully"] + mock_client.execute_query.assert_called_once_with( + "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255))" + ) + + +@pytest.mark.asyncio +async def test_execute_alter_table_query(): + """Test successful ALTER TABLE query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["Table altered successfully"] + mock_client_class.return_value = mock_client + + result = await execute_query( + mock_context, "ALTER TABLE products ADD COLUMN description TEXT" + ) + + assert result == ["Table altered successfully"] + mock_client.execute_query.assert_called_once_with( + "ALTER TABLE products ADD COLUMN description TEXT" + ) + + +@pytest.mark.asyncio +async def test_execute_drop_table_query(): + """Test successful DROP TABLE query execution.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["Table dropped successfully"] + mock_client_class.return_value = mock_client + + result = await execute_query(mock_context, "DROP TABLE temporary_data") + + assert result == ["Table dropped successfully"] + mock_client.execute_query.assert_called_once_with("DROP TABLE temporary_data") + + +@pytest.mark.asyncio +async def test_execute_query_failure(): + """Test query execution failure.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.side_effect = Exception("Database error") + mock_client_class.return_value = mock_client + + with pytest.raises(RetryableToolError) as exc_info: + await execute_query(mock_context, "SELECT * FROM users") + + assert "Query failed: Database error" in str(exc_info.value) + + +@pytest.mark.asyncio +async def test_execute_complex_query(): + """Test complex query with joins and conditions.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + mock_client = AsyncMock() + mock_client.execute_query.return_value = ["complex_result"] + mock_client_class.return_value = mock_client + + complex_query = """ + SELECT u.name, p.title, o.total + FROM users u + JOIN orders o ON u.id = o.user_id + JOIN products p ON o.product_id = p.id + WHERE u.active = true + ORDER BY o.created_at DESC + LIMIT 50 + """ + + result = await execute_query(mock_context, complex_query) + + assert result == ["complex_result"] + mock_client.execute_query.assert_called_once_with(complex_query) From 891a4257bd936557808f6e265f3eb8ae28184cda Mon Sep 17 00:00:00 2001 From: Boburmirzo Date: Mon, 21 Jul 2025 00:48:29 +0300 Subject: [PATCH 2/3] make query read only, enable add, delete and insert. No schem update --- toolkits/gibsonai/.pre-commit-config.yaml | 2 +- toolkits/gibsonai/.ruff.toml | 2 +- toolkits/gibsonai/Makefile | 2 +- toolkits/gibsonai/arcade_gibsonai/__init__.py | 12 +- .../gibsonai/arcade_gibsonai/api_client.py | 112 ++++++-- .../arcade_gibsonai/tools/__init__.py | 12 +- .../gibsonai/arcade_gibsonai/tools/delete.py | 165 ++++++++++++ .../gibsonai/arcade_gibsonai/tools/insert.py | 147 +++++++++++ .../gibsonai/arcade_gibsonai/tools/query.py | 60 ++++- .../gibsonai/arcade_gibsonai/tools/update.py | 176 +++++++++++++ toolkits/gibsonai/evals/eval_gibsonai.py | 122 +++++---- toolkits/gibsonai/tests/test_gibsonai.py | 248 +++++++++++------- 12 files changed, 864 insertions(+), 196 deletions(-) create mode 100644 toolkits/gibsonai/arcade_gibsonai/tools/delete.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/tools/insert.py create mode 100644 toolkits/gibsonai/arcade_gibsonai/tools/update.py diff --git a/toolkits/gibsonai/.pre-commit-config.yaml b/toolkits/gibsonai/.pre-commit-config.yaml index f63210621..66b637c05 100644 --- a/toolkits/gibsonai/.pre-commit-config.yaml +++ b/toolkits/gibsonai/.pre-commit-config.yaml @@ -15,4 +15,4 @@ repos: hooks: - id: ruff args: [--fix] - - id: ruff-format \ No newline at end of file + - id: ruff-format diff --git a/toolkits/gibsonai/.ruff.toml b/toolkits/gibsonai/.ruff.toml index fc0d98a2f..19364180c 100644 --- a/toolkits/gibsonai/.ruff.toml +++ b/toolkits/gibsonai/.ruff.toml @@ -44,4 +44,4 @@ select = [ [format] preview = true -skip-magic-trailing-comma = false \ No newline at end of file +skip-magic-trailing-comma = false diff --git a/toolkits/gibsonai/Makefile b/toolkits/gibsonai/Makefile index 95f08428f..0a8969beb 100644 --- a/toolkits/gibsonai/Makefile +++ b/toolkits/gibsonai/Makefile @@ -52,4 +52,4 @@ check: ## Run code quality tools. uv run --no-sources pre-commit run -a;\ fi @echo "🚀 Static type checking: Running mypy" - @uv run --no-sources mypy --config-file=pyproject.toml \ No newline at end of file + @uv run --no-sources mypy --config-file=pyproject.toml diff --git a/toolkits/gibsonai/arcade_gibsonai/__init__.py b/toolkits/gibsonai/arcade_gibsonai/__init__.py index 6415fc414..4ed44d730 100644 --- a/toolkits/gibsonai/arcade_gibsonai/__init__.py +++ b/toolkits/gibsonai/arcade_gibsonai/__init__.py @@ -1,5 +1,13 @@ """GibsonAI Database Tools for Arcade.""" -from arcade_gibsonai.tools.query import execute_query +from arcade_gibsonai.tools.delete import delete_records +from arcade_gibsonai.tools.insert import insert_records +from arcade_gibsonai.tools.query import execute_read_query +from arcade_gibsonai.tools.update import update_records -__all__ = ["execute_query"] +__all__ = [ + "delete_records", + "execute_read_query", + "insert_records", + "update_records", +] diff --git a/toolkits/gibsonai/arcade_gibsonai/api_client.py b/toolkits/gibsonai/arcade_gibsonai/api_client.py index 89aa8ac20..d0f702c15 100644 --- a/toolkits/gibsonai/arcade_gibsonai/api_client.py +++ b/toolkits/gibsonai/arcade_gibsonai/api_client.py @@ -1,16 +1,88 @@ -from typing import Any, Dict, List, Optional +from typing import Any, NoReturn + import httpx from pydantic import BaseModel from .constants import API_BASE_URL, API_VERSION, MAX_ROWS_RETURNED +class GibsonAIError(Exception): + """Base exception for GibsonAI API errors.""" + + pass + + +class GibsonAIHTTPError(GibsonAIError): + """HTTP-related errors from GibsonAI API.""" + + pass + + +class GibsonAIQueryError(GibsonAIError): + """Query execution errors from GibsonAI API.""" + + pass + + +class GibsonAITimeoutError(GibsonAIError): + """Timeout errors from GibsonAI API.""" + + pass + + +class GibsonAINetworkError(GibsonAIError): + """Network-related errors when connecting to GibsonAI API.""" + + pass + + class GibsonAIResponse(BaseModel): """Response model for GibsonAI API.""" - data: List[Dict[str, Any]] + data: list[dict[str, Any]] success: bool - error: Optional[str] = None + error: str | None = None + + +def _raise_http_error(status_code: int, response_text: str) -> NoReturn: + """Raise an HTTP error with formatted message.""" + error_msg = f"HTTP {status_code}: {response_text}" + raise GibsonAIHTTPError(f"GibsonAI API error: {error_msg}") + + +def _raise_query_error(error_message: str) -> NoReturn: + """Raise a query error with formatted message.""" + raise GibsonAIQueryError(f"GibsonAI query error: {error_message}") + + +def _raise_timeout_error() -> NoReturn: + """Raise a timeout error.""" + raise GibsonAITimeoutError("Request timeout - GibsonAI API took too long to respond") + + +def _raise_network_error(error: Exception) -> NoReturn: + """Raise a network error with original exception details.""" + raise GibsonAINetworkError(f"Network error connecting to GibsonAI API: {error}") + + +def _raise_unexpected_error(error: Exception) -> NoReturn: + """Raise an unexpected error.""" + raise GibsonAIError(f"Unexpected error: {error}") + + +def _process_response_data(result: Any) -> list[str]: + """Process the API response data into a list of strings.""" + if isinstance(result, dict): + if result.get("error"): + _raise_query_error(result["error"]) + elif "data" in result: + return [str(row) for row in result["data"]] + else: + return [str(result)] + elif isinstance(result, list): + return [str(row) for row in result] + else: + return [str(result)] class GibsonAIClient: @@ -21,9 +93,7 @@ def __init__(self, api_key: str): self.base_url = f"{API_BASE_URL}/{API_VERSION}" self.headers = {"Content-Type": "application/json", "X-Gibson-API-Key": api_key} - async def execute_query( - self, query: str, params: Optional[List[Any]] = None - ) -> List[str]: + async def execute_query(self, query: str, params: list[Any] | None = None) -> list[str]: """Execute a query against GibsonAI database.""" if params is None: params = [] @@ -40,34 +110,20 @@ async def execute_query( ) if response.status_code != 200: - error_msg = f"HTTP {response.status_code}: {response.text}" - raise Exception(f"GibsonAI API error: {error_msg}") + _raise_http_error(response.status_code, response.text) result = response.json() - - # Handle different response formats - if isinstance(result, dict): - if "error" in result and result["error"]: - raise Exception(f"GibsonAI query error: {result['error']}") - elif "data" in result: - results = [str(row) for row in result["data"]] - else: - results = [str(result)] - elif isinstance(result, list): - results = [str(row) for row in result] - else: - results = [str(result)] + results = _process_response_data(result) # Limit results to avoid memory issues return results[:MAX_ROWS_RETURNED] except httpx.TimeoutException: - raise Exception("Request timeout - GibsonAI API took too long to respond") + _raise_timeout_error() except httpx.RequestError as e: - raise Exception(f"Network error connecting to GibsonAI API: {e}") + _raise_network_error(e) + except GibsonAIError: + # Re-raise our custom exceptions as-is + raise except Exception as e: - # Re-raise if it's already our custom exception - if "GibsonAI" in str(e): - raise - else: - raise Exception(f"Unexpected error: {e}") + _raise_unexpected_error(e) diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py b/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py index 73c7a6cf8..8db789f87 100644 --- a/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py +++ b/toolkits/gibsonai/arcade_gibsonai/tools/__init__.py @@ -1,5 +1,13 @@ """GibsonAI database query tools.""" -from arcade_gibsonai.tools.query import execute_query +from arcade_gibsonai.tools.delete import delete_records +from arcade_gibsonai.tools.insert import insert_records +from arcade_gibsonai.tools.query import execute_read_query +from arcade_gibsonai.tools.update import update_records -__all__ = ["execute_query"] +__all__ = [ + "delete_records", + "execute_read_query", + "insert_records", + "update_records", +] diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/delete.py b/toolkits/gibsonai/arcade_gibsonai/tools/delete.py new file mode 100644 index 000000000..c66b948e6 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/tools/delete.py @@ -0,0 +1,165 @@ +import json +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import RetryableToolError + +from ..api_client import GibsonAIClient + + +def _validate_delete_conditions(conditions: list[dict[str, Any]]) -> None: + """Validate that all delete conditions have required keys.""" + if not conditions: + raise ValueError("Delete operations require at least one WHERE condition for safety") + + required_keys = {"column", "operator", "value"} + valid_operators = { + "=", + "!=", + "<>", + "<", + "<=", + ">", + ">=", + "LIKE", + "NOT LIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + } + + for i, condition in enumerate(conditions): + if not isinstance(condition, dict): + raise TypeError(f"Condition {i} must be a dictionary") + + missing_keys = required_keys - set(condition.keys()) + if missing_keys: + raise ValueError(f"Condition {i} missing required keys: {missing_keys}") + + if condition["operator"] not in valid_operators: + raise ValueError( + f"Condition {i} has invalid operator '{condition['operator']}'. " + f"Valid operators: {', '.join(sorted(valid_operators))}" + ) + + +def _build_delete_query( + table_name: str, conditions: list[dict[str, Any]], limit: int +) -> tuple[str, list[Any]]: + """Build DELETE query with parameterized values.""" + # Build WHERE clause + where_parts = [] + values: list[Any] = [] + + for condition in conditions: + column = condition["column"] + operator = condition["operator"] + value = condition["value"] + + if operator in ("IS NULL", "IS NOT NULL"): + where_parts.append(f"{column} {operator}") + elif operator in ("IN", "NOT IN"): + if isinstance(value, list | tuple): + placeholders = ", ".join("?" * len(value)) + where_parts.append(f"{column} {operator} ({placeholders})") + values.extend(value) + else: + raise ValueError(f"Value for {operator} must be a list or tuple") + else: + where_parts.append(f"{column} {operator} ?") + values.append(value) + + where_clause = "WHERE " + " AND ".join(where_parts) + + # Build complete query - use parameterized query for safety + # Note: table_name is validated above, not user-controlled + query_parts = ["DELETE FROM", table_name, where_clause] + if limit > 0: + query_parts.extend(["LIMIT", str(limit)]) + + query = " ".join(query_parts) + return query, values + + +def _validate_delete_inputs( + table_name: str, + parsed_conditions: list, + limit: int, + confirm_deletion: bool, +) -> None: + """Validate delete inputs and raise appropriate errors.""" + if not table_name or not isinstance(table_name, str): + raise ValueError("table_name must be a non-empty string") + + if not parsed_conditions: + raise TypeError("conditions must be a non-empty list") + + if limit < 0: + raise ValueError("limit must be non-negative (0 = no limit)") + + if not confirm_deletion: + raise ValueError("confirm_deletion must be explicitly set to True to proceed") + + +@tool(requires_secrets=["GIBSONAI_API_KEY"]) +async def delete_records( + context: ToolContext, + table_name: Annotated[str, "Name of the table to delete records from"], + conditions: Annotated[ + str, + "JSON string containing list of WHERE conditions. Each condition should have " + "'column', 'operator', and 'value' keys. " + 'Example: \'[{"column": "id", "operator": "=", "value": 1}]\'', + ], + limit: Annotated[int, "Optional LIMIT for safety. Set to 0 for no limit"] = 0, + confirm_deletion: Annotated[ + bool, "Explicit confirmation required (must be True to proceed)" + ] = False, +) -> str: + """Delete records from a table with specified conditions. + + This tool safely deletes records from the specified table. It requires at least one + WHERE condition and explicit confirmation to prevent accidental deletions. + + Args: + table_name: Name of the table to delete records from + conditions: List of WHERE conditions for safety + limit: Optional LIMIT clause for additional safety (0 = no limit) + confirm_deletion: Must be set to True to proceed with deletion + + Returns: + A message indicating the number of records deleted + + Raises: + ValueError: If no conditions provided, invalid conditions, or confirmation not given + RetryableToolError: If the database operation fails + """ + try: + # Parse JSON conditions + try: + parsed_conditions = json.loads(conditions) + if not isinstance(parsed_conditions, list): + raise TypeError("Conditions must be a JSON array") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format for conditions: {e}") from e + + # Validate inputs + _validate_delete_inputs(table_name, parsed_conditions, limit, confirm_deletion) + + _validate_delete_conditions(parsed_conditions) + + # Build query with parameterized values + query, values = _build_delete_query(table_name, parsed_conditions, limit) + + # Execute delete + client = GibsonAIClient(context.get_secret("GIBSONAI_API_KEY")) + await client.execute_query(query, values) + + except ValueError as e: + raise ValueError(f"Delete validation error: {e!s}") + except Exception as e: + raise RetryableToolError(f"Failed to delete records from table '{table_name}': {e!s}") + else: + # If we reach here, the delete was successful + return f"Successfully deleted records from table '{table_name}'" diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/insert.py b/toolkits/gibsonai/arcade_gibsonai/tools/insert.py new file mode 100644 index 000000000..207b5cc21 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/tools/insert.py @@ -0,0 +1,147 @@ +import json +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import RetryableToolError + +from ..api_client import GibsonAIClient + + +def _validate_record_columns_simple(records: list[dict[str, Any]]) -> list[str]: + """Validate that all records have the same columns and return column names.""" + if not records: + raise ValueError("At least one record is required") + + columns = list(records[0].keys()) + for i, record in enumerate(records[1:], 1): + if set(record.keys()) != set(columns): + raise ValueError(f"Record {i + 1} has different columns than the first record") + + return columns + + +def _build_insert_query_simple( + table_name: str, records: list[dict[str, Any]], on_conflict: str, columns: list[str] +) -> str: + """Build the INSERT SQL query from simple parameters.""" + columns_str = ", ".join(f"`{col}`" for col in columns) + + conflict_clause = "" + if on_conflict.strip(): + if on_conflict.upper() == "IGNORE": + conflict_clause = " ON DUPLICATE KEY UPDATE id=id" + elif on_conflict.upper() == "REPLACE": + conflict_clause = " ON DUPLICATE KEY UPDATE " + ", ".join( + f"`{col}`=VALUES(`{col}`)" for col in columns + ) + + # Build value groups + value_groups = [] + for record in records: + values = [] + for col in columns: + val = record[col] + if val is None: + values.append("NULL") + elif isinstance(val, str): + # Escape single quotes for SQL safety + escaped_val = val.replace("'", "''") + values.append(f"'{escaped_val}'") + else: + values.append(str(val)) + value_groups.append(f"({', '.join(values)})") + + # Build complete query using proper SQL construction + # Note: table_name is validated above, not user-controlled + query_parts = ["INSERT INTO", f"`{table_name}`", f"({columns_str})", "VALUES"] + query_parts.append(", ".join(value_groups)) + if conflict_clause: + query_parts.append(conflict_clause) + + query = " ".join(query_parts) + return query + + +def _validate_insert_inputs(table_name: str, parsed_records: list) -> None: + """Validate insert inputs and raise appropriate errors.""" + # Validate table name + if not table_name.strip(): + raise ValueError("Table name cannot be empty") + dangerous_keywords = [";", "--", "/*", "*/", "drop", "delete", "truncate"] + if any(keyword in table_name.lower() for keyword in dangerous_keywords): + raise ValueError("Invalid characters in table name") + + # Validate records + if not parsed_records: + raise ValueError("At least one record is required") + + +@tool(requires_secrets=["GIBSONAI_API_KEY"]) +async def insert_records( + context: ToolContext, + table_name: Annotated[str, "Name of the table to insert data into"], + records: Annotated[ + str, + "JSON string containing a list of records to insert. Each record should be an object " + 'with column names as keys. Example: \'[{"name": "John", "age": 30}]\'', + ], + on_conflict: Annotated[ + str, "How to handle conflicts (e.g., 'IGNORE', 'REPLACE', 'UPDATE'). Leave empty for none" + ] = "", +) -> list[str]: + """ + Insert records into a GibsonAI database table with type validation and safety checks. + + This tool provides a safe way to insert data with: + * Input validation and type checking + * SQL injection protection + * Consistent data formatting + * Conflict resolution options + + Examples of usage: + * Insert single record: table_name="users", + records='[{"name": "John", "email": "john@example.com"}]' + * Insert multiple records with conflict handling + * Batch inserts with validation + + The tool automatically generates properly formatted INSERT statements + based on the validated input data. + """ + api_key = context.get_secret("GIBSONAI_API_KEY") + client = GibsonAIClient(api_key) + + try: + # Parse JSON records + try: + parsed_records = json.loads(records) + if not isinstance(parsed_records, list): + raise TypeError("Records must be a JSON array") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format: {e}") from e + + # Validate table name and records + _validate_insert_inputs(table_name, parsed_records) + + # Validate columns consistency across all records + columns = _validate_record_columns_simple(parsed_records) + + # Build and execute the INSERT query + query = _build_insert_query_simple(table_name, parsed_records, on_conflict, columns) + results = await client.execute_query(query) + + except ValueError as e: + raise RetryableToolError( + f"Validation error: {e}", + developer_message=f"Invalid data provided for insert: {e}", + additional_prompt_content="Please check your data format and try again.", + retry_after_ms=0, + ) from e + except Exception as e: + raise RetryableToolError( + f"Insert failed: {e}", + developer_message=f"Insert operation failed for table '{table_name}': {e}", + additional_prompt_content="Please check your table name and data format.", + retry_after_ms=10, + ) from e + else: + return results diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/query.py b/toolkits/gibsonai/arcade_gibsonai/tools/query.py index 768970764..8f5ca5594 100644 --- a/toolkits/gibsonai/arcade_gibsonai/tools/query.py +++ b/toolkits/gibsonai/arcade_gibsonai/tools/query.py @@ -1,3 +1,4 @@ +import re from typing import Annotated from arcade_tdk import ToolContext, tool @@ -6,39 +7,74 @@ from ..api_client import GibsonAIClient +def _is_read_only_query(query: str) -> bool: + """Check if a query is read-only (only SELECT, SHOW, DESCRIBE, EXPLAIN operations).""" + # Remove comments and normalize whitespace + normalized_query = re.sub(r"--.*?$|/\*.*?\*/", "", query, flags=re.MULTILINE | re.DOTALL) + normalized_query = " ".join(normalized_query.strip().split()) + + # Check if query starts with read-only operations + read_only_patterns = [ + r"^\s*SELECT\s", + r"^\s*SHOW\s", + r"^\s*DESCRIBE\s", + r"^\s*DESC\s", + r"^\s*EXPLAIN\s", + r"^\s*WITH\s.*SELECT\s", # CTE with SELECT + ] + + return any(re.match(pattern, normalized_query, re.IGNORECASE) for pattern in read_only_patterns) + + @tool(requires_secrets=["GIBSONAI_API_KEY"]) -async def execute_query( +async def execute_read_query( context: ToolContext, query: Annotated[ str, - "The SQL query to execute against GibsonAI project database. Supports all SQL operations including SELECT, INSERT, UPDATE, DELETE, CREATE, ALTER, DROP, etc.", + "The read-only SQL query to execute against GibsonAI project database. " + "Only SELECT, SHOW, DESCRIBE, and EXPLAIN operations are permitted.", ], ) -> list[str]: """ - Execute a SQL query and return the results from the GibsonAI project relational database. + Execute a read-only SQL query and return the results from the GibsonAI + project relational database. - This tool supports all SQL operations including: + This tool supports only read operations including: * SELECT queries for data retrieval - * INSERT, UPDATE, DELETE for data manipulation - * CREATE, ALTER, DROP for schema management - * Any other valid SQL statements + * SHOW commands for metadata inspection + * DESCRIBE/DESC commands for table structure + * EXPLAIN commands for query analysis + * WITH clauses (Common Table Expressions) that contain SELECT operations When running queries, follow these rules which will help avoid errors: - * First discover the database schema in the GibsonAI project database when schema is not known. - * Discover all the tables in the database when the list of tables is not known. + * First discover the database schema in the GibsonAI project database when + schema is not known. + * Discover all the tables in the database when the list of tables is not + known. * Always use case-insensitive queries to match strings in the query. * Always trim strings in the query. * Prefer LIKE queries over direct string matches or regex queries. * Only join on columns that are indexed or the primary key. - For SELECT queries, unless otherwise specified, ensure that query has a LIMIT of 100 for all results. + For SELECT queries, unless otherwise specified, ensure that query has a + LIMIT of 100 for all results. """ + # Validate that the query is read-only + if not _is_read_only_query(query): + raise RetryableToolError( + "Only read-only queries (SELECT, SHOW, DESCRIBE, EXPLAIN) are permitted", + developer_message=f"Query '{query}' contains write operations which are not allowed.", + additional_prompt_content=( + "Please use the appropriate DML/DDL tools for data modification operations." + ), + retry_after_ms=0, + ) + api_key = context.get_secret("GIBSONAI_API_KEY") client = GibsonAIClient(api_key) try: results = await client.execute_query(query) - return results except Exception as e: raise RetryableToolError( f"Query failed: {e}", @@ -46,3 +82,5 @@ async def execute_query( additional_prompt_content="Please check your query syntax and try again.", retry_after_ms=10, ) from e + else: + return results diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/update.py b/toolkits/gibsonai/arcade_gibsonai/tools/update.py new file mode 100644 index 000000000..e7a8024e2 --- /dev/null +++ b/toolkits/gibsonai/arcade_gibsonai/tools/update.py @@ -0,0 +1,176 @@ +import json +from typing import Annotated, Any + +from arcade_tdk import ToolContext, tool +from arcade_tdk.errors import RetryableToolError + +from ..api_client import GibsonAIClient + + +def _validate_update_conditions(conditions: list[dict[str, Any]]) -> None: + """Validate that all update conditions have required keys.""" + if not conditions: + raise ValueError("Update operations require at least one WHERE condition for safety") + + required_keys = {"column", "operator", "value"} + valid_operators = { + "=", + "!=", + "<>", + "<", + "<=", + ">", + ">=", + "LIKE", + "NOT LIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + } + + for i, condition in enumerate(conditions): + if not isinstance(condition, dict): + raise TypeError(f"Condition {i} must be a dictionary") + + missing_keys = required_keys - set(condition.keys()) + if missing_keys: + raise ValueError(f"Condition {i} missing required keys: {missing_keys}") + + if condition["operator"] not in valid_operators: + raise ValueError( + f"Condition {i} has invalid operator '{condition['operator']}'. " + f"Valid operators: {', '.join(sorted(valid_operators))}" + ) + + +def _build_update_query( + table_name: str, updates: dict[str, Any], conditions: list[dict[str, Any]], limit: int +) -> tuple[str, list[Any]]: + """Build UPDATE query with parameterized values.""" + # Build SET clause + set_parts = [] + values: list[Any] = [] + for column, value in updates.items(): + set_parts.append(f"{column} = ?") + values.append(value) + + set_clause = "SET " + ", ".join(set_parts) + + # Build WHERE clause + where_parts = [] + for condition in conditions: + column = condition["column"] + operator = condition["operator"] + value = condition["value"] + + if operator in ("IS NULL", "IS NOT NULL"): + where_parts.append(f"{column} {operator}") + elif operator in ("IN", "NOT IN"): + if isinstance(value, list | tuple): + placeholders = ", ".join("?" * len(value)) + where_parts.append(f"{column} {operator} ({placeholders})") + values.extend(value) + else: + raise ValueError(f"Value for {operator} must be a list or tuple") + else: + where_parts.append(f"{column} {operator} ?") + values.append(value) + + where_clause = "WHERE " + " AND ".join(where_parts) + + # Build complete query + query = f"UPDATE {table_name} {set_clause} {where_clause}" + if limit > 0: + query += f" LIMIT {limit}" + + return query, values + + +def _validate_update_inputs( + table_name: str, parsed_updates: dict, parsed_conditions: list, limit: int +) -> None: + """Validate update inputs and raise appropriate errors.""" + if not table_name or not isinstance(table_name, str): + raise ValueError("table_name must be a non-empty string") + + if not parsed_updates: + raise ValueError("updates must be a non-empty dictionary") + + if not parsed_conditions: + raise TypeError("conditions must be a non-empty list") + + if limit < 0: + raise ValueError("limit must be non-negative (0 = no limit)") + + +@tool(requires_secrets=["GIBSONAI_API_KEY"]) +async def update_records( + context: ToolContext, + table_name: Annotated[str, "Name of the table to update records in"], + updates: Annotated[ + str, + "JSON string containing column-value pairs to update. " + 'Example: \'{"name": "John", "age": 30}\'', + ], + conditions: Annotated[ + str, + "JSON string containing list of WHERE conditions. Each condition should have " + "'column', 'operator', and 'value' keys. " + 'Example: \'[{"column": "id", "operator": "=", "value": 1}]\'', + ], + limit: Annotated[int, "Optional LIMIT for safety. Set to 0 for no limit"] = 0, +) -> str: + """Update records in a table with specified conditions. + + This tool safely updates records in the specified table. It requires at least one + WHERE condition to prevent accidental updates to all records. + + Args: + table_name: Name of the table to update records in + updates: Dictionary of column names to new values + conditions: List of WHERE conditions for safety + limit: Optional LIMIT clause for additional safety (0 = no limit) + + Returns: + A message indicating the number of records updated + + Raises: + ValueError: If no conditions provided or invalid conditions + RetryableToolError: If the database operation fails + """ + try: + # Parse JSON parameters + try: + parsed_updates = json.loads(updates) + if not isinstance(parsed_updates, dict): + raise TypeError("Updates must be a JSON object") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format for updates: {e}") from e + + try: + parsed_conditions = json.loads(conditions) + if not isinstance(parsed_conditions, list): + raise TypeError("Conditions must be a JSON array") + except json.JSONDecodeError as e: + raise ValueError(f"Invalid JSON format for conditions: {e}") from e + + # Validate inputs + _validate_update_inputs(table_name, parsed_updates, parsed_conditions, limit) + + _validate_update_conditions(parsed_conditions) + + # Build query with parameterized values + query, values = _build_update_query(table_name, parsed_updates, parsed_conditions, limit) + + # Execute update + client = GibsonAIClient(context.get_secret("GIBSONAI_API_KEY")) + await client.execute_query(query, values) + + except ValueError as e: + raise RetryableToolError(f"Update validation error: {e!s}") + except Exception as e: + raise RetryableToolError(f"Failed to update records in table '{table_name}': {e!s}") + else: + # If we reach here, the update was successful + return f"Successfully updated records in table '{table_name}'" diff --git a/toolkits/gibsonai/evals/eval_gibsonai.py b/toolkits/gibsonai/evals/eval_gibsonai.py index c3e756dac..2baca2cbf 100644 --- a/toolkits/gibsonai/evals/eval_gibsonai.py +++ b/toolkits/gibsonai/evals/eval_gibsonai.py @@ -1,4 +1,3 @@ -from arcade_tdk import ToolCatalog from arcade_evals import ( EvalRubric, EvalSuite, @@ -6,9 +5,13 @@ tool_eval, ) from arcade_evals.critic import SimilarityCritic +from arcade_tdk import ToolCatalog import arcade_gibsonai -from arcade_gibsonai.tools.query import execute_query +from arcade_gibsonai.tools.delete import delete_records +from arcade_gibsonai.tools.insert import insert_records +from arcade_gibsonai.tools.query import execute_read_query +from arcade_gibsonai.tools.update import update_records # Evaluation rubric rubric = EvalRubric( @@ -16,7 +19,6 @@ warn_threshold=0.95, ) - catalog = ToolCatalog() catalog.add_module(arcade_gibsonai) @@ -27,20 +29,21 @@ def gibsonai_eval_suite() -> EvalSuite: name="GibsonAI Database Tools Evaluation", system_message=( "You are an AI assistant with access to GibsonAI database tools. " - "Use them to help the user execute queries against the GibsonAI database. " - "You can perform all SQL operations including SELECT, INSERT, UPDATE, DELETE, " - "CREATE, ALTER, DROP, and other schema management operations." + "Use them to help the user execute queries and database operations. " + "For read operations, use execute_read_query. For data modifications, " + "use the specific parameterized tools: insert_records, update_records, " + "and delete_records with proper validation." ), catalog=catalog, rubric=rubric, ) - # SELECT query test + # SELECT query test (read-only) suite.add_case( name="Execute SELECT query", user_message="Can you run a simple SELECT query to get the current timestamp?", expected_tool_calls=[ - ExpectedToolCall(func=execute_query, args={"query": "SELECT NOW()"}) + ExpectedToolCall(func=execute_read_query, args={"query": "SELECT NOW()"}) ], rubric=rubric, critics=[ @@ -55,139 +58,144 @@ def gibsonai_eval_suite() -> EvalSuite: ], ) - # INSERT query test + # INSERT query test (using parameterized tool) suite.add_case( - name="Execute INSERT query", + name="Execute INSERT operation", user_message="Insert a new user with name 'John Doe' and email 'john@example.com' into the users table.", expected_tool_calls=[ ExpectedToolCall( - func=execute_query, + func=insert_records, args={ - "query": "INSERT INTO users (name, email) VALUES ('John Doe', 'john@example.com')" + "table_name": "users", + "records": '[{"name": "John Doe", "email": "john@example.com"}]', + "on_conflict": "", }, ) ], rubric=rubric, critics=[ - SimilarityCritic(critic_field="query", weight=0.8), + SimilarityCritic(critic_field="table_name", weight=0.4), + SimilarityCritic(critic_field="records", weight=0.6), ], additional_messages=[ {"role": "user", "content": "I need to add a new user to the database."}, { "role": "assistant", - "content": "I'll help you insert a new user into the users table.", + "content": "I'll help you insert a new user into the users table using the parameterized insert tool.", }, ], ) - # UPDATE query test + # UPDATE query test (using parameterized tool) suite.add_case( - name="Execute UPDATE query", + name="Execute UPDATE operation", user_message="Update the user with ID 1 to change their email to 'newemail@example.com'.", expected_tool_calls=[ ExpectedToolCall( - func=execute_query, + func=update_records, args={ - "query": "UPDATE users SET email = 'newemail@example.com' WHERE id = 1" + "table_name": "users", + "updates": '{"email": "newemail@example.com"}', + "conditions": '[{"column": "id", "operator": "=", "value": 1}]', + "limit": 0, }, ) ], rubric=rubric, critics=[ - SimilarityCritic(critic_field="query", weight=0.8), + SimilarityCritic(critic_field="table_name", weight=0.3), + SimilarityCritic(critic_field="updates", weight=0.4), + SimilarityCritic(critic_field="conditions", weight=0.3), ], additional_messages=[ {"role": "user", "content": "I need to update a user's email address."}, { "role": "assistant", - "content": "I'll help you update the user's email in the database.", + "content": "I'll help you update the user's email using the parameterized update tool.", }, ], ) - # DELETE query test + # DELETE query test (using parameterized tool) suite.add_case( - name="Execute DELETE query", + name="Execute DELETE operation", user_message="Delete the user with ID 5 from the users table.", expected_tool_calls=[ ExpectedToolCall( - func=execute_query, args={"query": "DELETE FROM users WHERE id = 5"} + func=delete_records, + args={ + "table_name": "users", + "conditions": '[{"column": "id", "operator": "=", "value": 5}]', + "limit": 0, + "confirm_deletion": True, + }, ) ], rubric=rubric, critics=[ - SimilarityCritic(critic_field="query", weight=0.8), + SimilarityCritic(critic_field="table_name", weight=0.3), + SimilarityCritic(critic_field="conditions", weight=0.4), + SimilarityCritic(critic_field="confirm_deletion", weight=0.3), ], additional_messages=[ {"role": "user", "content": "I need to remove a user from the database."}, { "role": "assistant", - "content": "I'll help you delete the user from the users table.", + "content": "I'll help you delete the user using the parameterized delete tool with safety confirmation.", }, ], ) - # CREATE TABLE query test + # SHOW TABLES test (read-only) suite.add_case( - name="Execute CREATE TABLE query", - user_message="Create a new table called 'products' with columns: id (integer primary key), name (varchar), price (decimal).", + name="Execute SHOW TABLES query", + user_message="Show me all the tables in the database.", expected_tool_calls=[ - ExpectedToolCall( - func=execute_query, - args={ - "query": "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255), price DECIMAL(10,2))" - }, - ) + ExpectedToolCall(func=execute_read_query, args={"query": "SHOW TABLES"}) ], rubric=rubric, critics=[ SimilarityCritic(critic_field="query", weight=0.8), ], additional_messages=[ - { - "role": "user", - "content": "I need to create a new table for storing product information.", - }, + {"role": "user", "content": "I need to see what tables exist in the database."}, { "role": "assistant", - "content": "I'll help you create a products table with the specified columns.", + "content": "I'll show you all the tables using a SHOW TABLES query.", }, ], ) - # ALTER TABLE query test + # DESCRIBE test (read-only) suite.add_case( - name="Execute ALTER TABLE query", - user_message="Add a new column 'description' of type TEXT to the products table.", + name="Execute DESCRIBE query", + user_message="Describe the structure of the users table.", expected_tool_calls=[ - ExpectedToolCall( - func=execute_query, - args={"query": "ALTER TABLE products ADD COLUMN description TEXT"}, - ) + ExpectedToolCall(func=execute_read_query, args={"query": "DESCRIBE users"}) ], rubric=rubric, critics=[ SimilarityCritic(critic_field="query", weight=0.8), ], additional_messages=[ - { - "role": "user", - "content": "I need to add a description column to the products table.", - }, + {"role": "user", "content": "I need to understand the structure of the users table."}, { "role": "assistant", - "content": "I'll help you alter the products table to add a description column.", + "content": "I'll describe the users table structure for you.", }, ], ) - # DROP TABLE query test + # Complex SELECT with JOIN (read-only) suite.add_case( - name="Execute DROP TABLE query", - user_message="Drop the temporary_data table as it's no longer needed.", + name="Execute complex SELECT with JOIN", + user_message="Get all users with their order totals, joining users and orders tables.", expected_tool_calls=[ ExpectedToolCall( - func=execute_query, args={"query": "DROP TABLE temporary_data"} + func=execute_read_query, + args={ + "query": "SELECT u.name, u.email, SUM(o.total) as total_orders FROM users u LEFT JOIN orders o ON u.id = o.user_id GROUP BY u.id, u.name, u.email LIMIT 100" + }, ) ], rubric=rubric, @@ -195,10 +203,10 @@ def gibsonai_eval_suite() -> EvalSuite: SimilarityCritic(critic_field="query", weight=0.8), ], additional_messages=[ - {"role": "user", "content": "I need to remove the temporary_data table."}, + {"role": "user", "content": "I need to analyze user order data."}, { "role": "assistant", - "content": "I'll help you drop the temporary_data table.", + "content": "I'll create a query that joins users with their orders to show the totals.", }, ], ) diff --git a/toolkits/gibsonai/tests/test_gibsonai.py b/toolkits/gibsonai/tests/test_gibsonai.py index 1797e3367..d14038237 100644 --- a/toolkits/gibsonai/tests/test_gibsonai.py +++ b/toolkits/gibsonai/tests/test_gibsonai.py @@ -1,9 +1,15 @@ -import pytest +"""Tests for GibsonAI toolkit with simple parameter interfaces.""" + from unittest.mock import AsyncMock, patch + +import pytest +from arcade_core.errors import RetryableToolError, ToolExecutionError from arcade_tdk import ToolContext -from arcade_tdk.errors import RetryableToolError -from arcade_gibsonai.tools.query import execute_query +from arcade_gibsonai.tools.delete import delete_records +from arcade_gibsonai.tools.insert import insert_records +from arcade_gibsonai.tools.query import execute_read_query +from arcade_gibsonai.tools.update import update_records @pytest.mark.asyncio @@ -14,178 +20,234 @@ async def test_execute_select_query(): with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["result1", "result2"] + mock_client.execute_query.return_value = [ + "id,name,email", + "1,John,john@example.com", + "2,Jane,jane@example.com", + ] mock_client_class.return_value = mock_client - result = await execute_query(mock_context, "SELECT * FROM users LIMIT 10") + result = await execute_read_query(context=mock_context, query="SELECT * FROM users") - assert result == ["result1", "result2"] - mock_client.execute_query.assert_called_once_with( - "SELECT * FROM users LIMIT 10" - ) + assert len(result) == 3 + assert "John" in result[1] + assert "Jane" in result[2] + mock_client.execute_query.assert_called_once_with("SELECT * FROM users") @pytest.mark.asyncio -async def test_execute_insert_query(): - """Test successful INSERT query execution.""" +async def test_execute_read_query_with_conditions(): + """Test SELECT query with WHERE conditions.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["1 row inserted"] + mock_client.execute_query.return_value = ["id,name", "1,John"] mock_client_class.return_value = mock_client - result = await execute_query( - mock_context, - "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')", + result = await execute_read_query( + context=mock_context, query="SELECT * FROM users WHERE id = 1" ) - assert result == ["1 row inserted"] - mock_client.execute_query.assert_called_once_with( - "INSERT INTO users (name, email) VALUES ('John', 'john@example.com')" + assert len(result) == 2 + mock_client.execute_query.assert_called_once_with("SELECT * FROM users WHERE id = 1") + + +@pytest.mark.asyncio +async def test_execute_non_read_query_raises_error(): + """Test that non-read queries raise an error.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + with pytest.raises(RetryableToolError, match="Only read-only queries"): + await execute_read_query(context=mock_context, query="DELETE FROM users WHERE id = 1") + + with pytest.raises(RetryableToolError, match="Only read-only queries"): + await execute_read_query( + context=mock_context, query="UPDATE users SET name = 'Bob' WHERE id = 1" ) @pytest.mark.asyncio -async def test_execute_update_query(): - """Test successful UPDATE query execution.""" +async def test_execute_query_failure(): + """Test query execution failure.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["1 row updated"] + mock_client.execute_query.side_effect = Exception("Database connection failed") mock_client_class.return_value = mock_client - result = await execute_query( - mock_context, "UPDATE users SET email = 'new@example.com' WHERE id = 1" - ) - - assert result == ["1 row updated"] - mock_client.execute_query.assert_called_once_with( - "UPDATE users SET email = 'new@example.com' WHERE id = 1" - ) + with pytest.raises(RetryableToolError): + await execute_read_query(context=mock_context, query="SELECT * FROM users") @pytest.mark.asyncio -async def test_execute_delete_query(): - """Test successful DELETE query execution.""" +async def test_insert_records_success(): + """Test successful record insertion.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + with patch("arcade_gibsonai.tools.insert.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["1 row deleted"] + mock_client.execute_query.return_value = ["1 row inserted"] mock_client_class.return_value = mock_client - result = await execute_query(mock_context, "DELETE FROM users WHERE id = 5") - - assert result == ["1 row deleted"] - mock_client.execute_query.assert_called_once_with( - "DELETE FROM users WHERE id = 5" + result = await insert_records( + context=mock_context, + table_name="users", + records='[{"name": "John", "email": "john@example.com"}]', + on_conflict="IGNORE", ) + # Result should be the raw response from GibsonAI + assert result == ["1 row inserted"] + @pytest.mark.asyncio -async def test_execute_create_table_query(): - """Test successful CREATE TABLE query execution.""" +async def test_insert_records_multiple(): + """Test inserting multiple records.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + with patch("arcade_gibsonai.tools.insert.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["Table created successfully"] + mock_client.execute_query.return_value = ["2 rows affected"] mock_client_class.return_value = mock_client - result = await execute_query( - mock_context, - "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255))", + result = await insert_records( + context=mock_context, + table_name="users", + records='[{"name": "John", "email": "john@example.com"}, {"name": "Jane", "email": "jane@example.com"}]', + on_conflict="REPLACE", ) - assert result == ["Table created successfully"] - mock_client.execute_query.assert_called_once_with( - "CREATE TABLE products (id INTEGER PRIMARY KEY, name VARCHAR(255))" - ) + assert result == ["2 rows affected"] @pytest.mark.asyncio -async def test_execute_alter_table_query(): - """Test successful ALTER TABLE query execution.""" +async def test_insert_records_validation_errors(): + """Test various validation errors.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.execute_query.return_value = ["Table altered successfully"] - mock_client_class.return_value = mock_client + # Test empty table name + with pytest.raises(RetryableToolError, match="Table name cannot be empty"): + await insert_records( + context=mock_context, table_name="", records='[{"name": "John"}]', on_conflict="IGNORE" + ) - result = await execute_query( - mock_context, "ALTER TABLE products ADD COLUMN description TEXT" + # Test empty records + with pytest.raises(RetryableToolError, match="At least one record is required"): + await insert_records( + context=mock_context, table_name="users", records="[]", on_conflict="IGNORE" ) - assert result == ["Table altered successfully"] - mock_client.execute_query.assert_called_once_with( - "ALTER TABLE products ADD COLUMN description TEXT" + # Test invalid JSON format (not an array) + with pytest.raises(RetryableToolError, match="Records must be a JSON array"): + await insert_records( + context=mock_context, + table_name="users", + records='{"name": "John"}', + on_conflict="IGNORE", + ) + + # Test malformed JSON + with pytest.raises(RetryableToolError, match="Invalid JSON format"): + await insert_records( + context=mock_context, + table_name="users", + records='[{"name": "John"', + on_conflict="IGNORE", ) @pytest.mark.asyncio -async def test_execute_drop_table_query(): - """Test successful DROP TABLE query execution.""" +async def test_update_records_success(): + """Test successful record update.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + with patch("arcade_gibsonai.tools.update.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["Table dropped successfully"] + mock_client.execute_query.return_value = ["1 row affected"] mock_client_class.return_value = mock_client - result = await execute_query(mock_context, "DROP TABLE temporary_data") + result = await update_records( + context=mock_context, + table_name="users", + updates='{"name": "Johnny"}', + conditions='[{"column": "id", "operator": "=", "value": 1}]', + ) - assert result == ["Table dropped successfully"] - mock_client.execute_query.assert_called_once_with("DROP TABLE temporary_data") + assert "Successfully updated records in table 'users'" in result @pytest.mark.asyncio -async def test_execute_query_failure(): - """Test query execution failure.""" +async def test_update_records_validation_errors(): + """Test update validation errors.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: - mock_client = AsyncMock() - mock_client.execute_query.side_effect = Exception("Database error") - mock_client_class.return_value = mock_client - - with pytest.raises(RetryableToolError) as exc_info: - await execute_query(mock_context, "SELECT * FROM users") + # Test missing conditions + with pytest.raises(RetryableToolError, match="conditions must be a non-empty list"): + await update_records( + context=mock_context, table_name="users", updates='{"name": "Johnny"}', conditions="[]" + ) - assert "Query failed: Database error" in str(exc_info.value) + # Test invalid table name + with pytest.raises(RetryableToolError, match="table_name must be a non-empty string"): + await update_records( + context=mock_context, + table_name="", + updates='{"name": "Johnny"}', + conditions='[{"column": "id", "operator": "=", "value": 1}]', + ) @pytest.mark.asyncio -async def test_execute_complex_query(): - """Test complex query with joins and conditions.""" +async def test_delete_records_success(): + """Test successful record deletion.""" mock_context = AsyncMock(spec=ToolContext) mock_context.get_secret.return_value = "test_api_key" - with patch("arcade_gibsonai.tools.query.GibsonAIClient") as mock_client_class: + with patch("arcade_gibsonai.tools.delete.GibsonAIClient") as mock_client_class: mock_client = AsyncMock() - mock_client.execute_query.return_value = ["complex_result"] + mock_client.execute_query.return_value = ["1 row affected"] mock_client_class.return_value = mock_client - complex_query = """ - SELECT u.name, p.title, o.total - FROM users u - JOIN orders o ON u.id = o.user_id - JOIN products p ON o.product_id = p.id - WHERE u.active = true - ORDER BY o.created_at DESC - LIMIT 50 - """ + result = await delete_records( + context=mock_context, + table_name="users", + conditions='[{"column": "id", "operator": "=", "value": 1}]', + confirm_deletion=True, + ) + + assert "Successfully deleted records from table 'users'" in result - result = await execute_query(mock_context, complex_query) - assert result == ["complex_result"] - mock_client.execute_query.assert_called_once_with(complex_query) +@pytest.mark.asyncio +async def test_delete_records_validation_errors(): + """Test delete validation errors.""" + mock_context = AsyncMock(spec=ToolContext) + mock_context.get_secret.return_value = "test_api_key" + + # Test missing confirmation + with pytest.raises(ToolExecutionError, match="Error in execution of DeleteRecords"): + await delete_records( + context=mock_context, + table_name="users", + conditions='[{"column": "id", "operator": "=", "value": 5}]', + confirm_deletion=False, + ) + + # Test empty table name + with pytest.raises(ToolExecutionError, match="Error in execution of DeleteRecords"): + await delete_records( + context=mock_context, + table_name="", + conditions='[{"column": "id", "operator": "=", "value": 5}]', + confirm_deletion=True, + ) From 0ac753ec797fdbb1ee0e7ecc047be64445d0cec6 Mon Sep 17 00:00:00 2001 From: Boburmirzo Date: Mon, 21 Jul 2025 01:03:31 +0300 Subject: [PATCH 3/3] introduce pydantic --- .../gibsonai/arcade_gibsonai/tools/delete.py | 176 ++++++++++------- .../gibsonai/arcade_gibsonai/tools/insert.py | 124 ++++++++---- .../gibsonai/arcade_gibsonai/tools/update.py | 185 +++++++++++------- toolkits/gibsonai/tests/test_gibsonai.py | 7 +- 4 files changed, 313 insertions(+), 179 deletions(-) diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/delete.py b/toolkits/gibsonai/arcade_gibsonai/tools/delete.py index c66b948e6..954293340 100644 --- a/toolkits/gibsonai/arcade_gibsonai/tools/delete.py +++ b/toolkits/gibsonai/arcade_gibsonai/tools/delete.py @@ -3,59 +3,89 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.errors import RetryableToolError +from pydantic import BaseModel, Field, field_validator from ..api_client import GibsonAIClient -def _validate_delete_conditions(conditions: list[dict[str, Any]]) -> None: - """Validate that all delete conditions have required keys.""" - if not conditions: - raise ValueError("Delete operations require at least one WHERE condition for safety") - - required_keys = {"column", "operator", "value"} - valid_operators = { - "=", - "!=", - "<>", - "<", - "<=", - ">", - ">=", - "LIKE", - "NOT LIKE", - "IN", - "NOT IN", - "IS NULL", - "IS NOT NULL", - } - - for i, condition in enumerate(conditions): - if not isinstance(condition, dict): - raise TypeError(f"Condition {i} must be a dictionary") - - missing_keys = required_keys - set(condition.keys()) - if missing_keys: - raise ValueError(f"Condition {i} missing required keys: {missing_keys}") - - if condition["operator"] not in valid_operators: - raise ValueError( - f"Condition {i} has invalid operator '{condition['operator']}'. " - f"Valid operators: {', '.join(sorted(valid_operators))}" - ) - - -def _build_delete_query( - table_name: str, conditions: list[dict[str, Any]], limit: int -) -> tuple[str, list[Any]]: - """Build DELETE query with parameterized values.""" +class DeleteCondition(BaseModel): + """Pydantic model for delete WHERE conditions.""" + + column: str = Field(..., min_length=1, description="Column name for the condition") + operator: str = Field(..., description="SQL operator for the condition") + value: Any = Field(..., description="Value for the condition") + + @field_validator("operator") + @classmethod + def validate_operator(cls, v: str) -> str: + """Validate SQL operator.""" + valid_operators = { + "=", + "!=", + "<>", + "<", + "<=", + ">", + ">=", + "LIKE", + "NOT LIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + } + if v not in valid_operators: + operators_str = ", ".join(sorted(valid_operators)) + raise ValueError(f"Invalid operator '{v}'. Valid operators: {operators_str}") + return v + + +class DeleteRequest(BaseModel): + """Pydantic model for validating delete requests.""" + + table_name: str = Field( + ..., min_length=1, description="Name of the table to delete records from" + ) + conditions: list[DeleteCondition] = Field( + ..., min_length=1, description="List of WHERE conditions for safety" + ) + limit: int = Field(default=0, ge=0, description="Optional LIMIT for safety") + confirm_deletion: bool = Field( + ..., description="Explicit confirmation required (must be True to proceed)" + ) + + @field_validator("table_name") + @classmethod + def validate_table_name(cls, v: str) -> str: + """Validate table name for security.""" + if not v.strip(): + raise ValueError("Table name cannot be empty") + + dangerous_keywords = [";", "--", "/*", "*/", "drop", "delete", "truncate"] + if any(keyword in v.lower() for keyword in dangerous_keywords): + raise ValueError("Invalid characters in table name") + + return v.strip() + + @field_validator("confirm_deletion") + @classmethod + def validate_confirmation(cls, v: bool) -> bool: + """Validate deletion confirmation.""" + if not v: + raise ValueError("confirm_deletion must be explicitly set to True to proceed") + return v + + +def _build_delete_query(request: DeleteRequest) -> tuple[str, list[Any]]: + """Build DELETE query with parameterized values from validated request.""" # Build WHERE clause where_parts = [] values: list[Any] = [] - for condition in conditions: - column = condition["column"] - operator = condition["operator"] - value = condition["value"] + for condition in request.conditions: + column = condition.column + operator = condition.operator + value = condition.value if operator in ("IS NULL", "IS NOT NULL"): where_parts.append(f"{column} {operator}") @@ -74,32 +104,46 @@ def _build_delete_query( # Build complete query - use parameterized query for safety # Note: table_name is validated above, not user-controlled - query_parts = ["DELETE FROM", table_name, where_clause] - if limit > 0: - query_parts.extend(["LIMIT", str(limit)]) + query_parts = ["DELETE FROM", request.table_name, where_clause] + if request.limit > 0: + query_parts.extend(["LIMIT", str(request.limit)]) query = " ".join(query_parts) return query, values -def _validate_delete_inputs( +def _create_delete_request( table_name: str, parsed_conditions: list, limit: int, confirm_deletion: bool, -) -> None: - """Validate delete inputs and raise appropriate errors.""" - if not table_name or not isinstance(table_name, str): - raise ValueError("table_name must be a non-empty string") - - if not parsed_conditions: - raise TypeError("conditions must be a non-empty list") - - if limit < 0: - raise ValueError("limit must be non-negative (0 = no limit)") - - if not confirm_deletion: - raise ValueError("confirm_deletion must be explicitly set to True to proceed") +) -> DeleteRequest: + """Create and validate DeleteRequest from parsed data.""" + try: + # Convert conditions to DeleteCondition models + condition_models = [ + DeleteCondition(column=cond["column"], operator=cond["operator"], value=cond["value"]) + for cond in parsed_conditions + ] + + return DeleteRequest( + table_name=table_name, + conditions=condition_models, + limit=limit, + confirm_deletion=confirm_deletion, + ) + except Exception as e: + # Convert Pydantic validation errors to more readable messages + error_msg = str(e) + if "String should have at least 1 character" in error_msg: + raise ValueError("Table name cannot be empty") from e + elif "List should have at least 1 item" in error_msg: + msg = "Delete operations require at least one WHERE condition for safety" + raise ValueError(msg) from e + elif "confirm_deletion must be explicitly set to True" in error_msg: + raise ValueError("confirm_deletion must be explicitly set to True to proceed") from e + else: + raise ValueError(f"Validation error: {error_msg}") from e @tool(requires_secrets=["GIBSONAI_API_KEY"]) @@ -144,13 +188,11 @@ async def delete_records( except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format for conditions: {e}") from e - # Validate inputs - _validate_delete_inputs(table_name, parsed_conditions, limit, confirm_deletion) - - _validate_delete_conditions(parsed_conditions) + # Create and validate request using Pydantic + request = _create_delete_request(table_name, parsed_conditions, limit, confirm_deletion) # Build query with parameterized values - query, values = _build_delete_query(table_name, parsed_conditions, limit) + query, values = _build_delete_query(request) # Execute delete client = GibsonAIClient(context.get_secret("GIBSONAI_API_KEY")) diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/insert.py b/toolkits/gibsonai/arcade_gibsonai/tools/insert.py index 207b5cc21..a08aa2a32 100644 --- a/toolkits/gibsonai/arcade_gibsonai/tools/insert.py +++ b/toolkits/gibsonai/arcade_gibsonai/tools/insert.py @@ -3,41 +3,86 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.errors import RetryableToolError +from pydantic import BaseModel, Field, field_validator from ..api_client import GibsonAIClient -def _validate_record_columns_simple(records: list[dict[str, Any]]) -> list[str]: - """Validate that all records have the same columns and return column names.""" - if not records: - raise ValueError("At least one record is required") +class InsertRequest(BaseModel): + """Pydantic model for validating insert requests.""" - columns = list(records[0].keys()) - for i, record in enumerate(records[1:], 1): - if set(record.keys()) != set(columns): - raise ValueError(f"Record {i + 1} has different columns than the first record") + table_name: str = Field( + ..., min_length=1, description="Name of the table to insert records into" + ) + records: list[dict[str, Any]] = Field( + ..., min_length=1, description="List of records to insert" + ) + on_conflict: str = Field(default="", description="Conflict resolution strategy") - return columns + @field_validator("table_name") + @classmethod + def validate_table_name(cls, v: str) -> str: + """Validate table name for security.""" + if not v.strip(): + raise ValueError("Table name cannot be empty") + dangerous_keywords = [";", "--", "/*", "*/", "drop", "delete", "truncate"] + if any(keyword in v.lower() for keyword in dangerous_keywords): + raise ValueError("Invalid characters in table name") -def _build_insert_query_simple( - table_name: str, records: list[dict[str, Any]], on_conflict: str, columns: list[str] -) -> str: - """Build the INSERT SQL query from simple parameters.""" + return v.strip() + + @field_validator("records") + @classmethod + def validate_records_consistency(cls, v: list[dict[str, Any]]) -> list[dict[str, Any]]: + """Validate that all records have the same columns.""" + if not v: + raise ValueError("At least one record is required") + + if not all(isinstance(record, dict) for record in v): + raise ValueError("All records must be dictionaries") + + # Check column consistency + expected_columns = set(v[0].keys()) + for i, record in enumerate(v[1:], 1): + if set(record.keys()) != expected_columns: + msg = f"Record {i + 1} has different columns than the first record" + raise ValueError(msg) + + return v + + @field_validator("on_conflict") + @classmethod + def validate_on_conflict(cls, v: str) -> str: + """Validate on_conflict strategy.""" + if not v: + return "" + + valid_strategies = {"ignore", "replace", "update"} + if v.lower() not in valid_strategies: + strategies_str = ", ".join(valid_strategies) + raise ValueError(f"Invalid on_conflict strategy. Must be one of: {strategies_str}") + + return v.upper() + + +def _build_insert_query(request: InsertRequest) -> str: + """Build the INSERT SQL query from validated request.""" + columns = list(request.records[0].keys()) columns_str = ", ".join(f"`{col}`" for col in columns) conflict_clause = "" - if on_conflict.strip(): - if on_conflict.upper() == "IGNORE": + if request.on_conflict: + if request.on_conflict == "IGNORE": conflict_clause = " ON DUPLICATE KEY UPDATE id=id" - elif on_conflict.upper() == "REPLACE": + elif request.on_conflict == "REPLACE": conflict_clause = " ON DUPLICATE KEY UPDATE " + ", ".join( f"`{col}`=VALUES(`{col}`)" for col in columns ) # Build value groups value_groups = [] - for record in records: + for record in request.records: values = [] for col in columns: val = record[col] @@ -53,7 +98,7 @@ def _build_insert_query_simple( # Build complete query using proper SQL construction # Note: table_name is validated above, not user-controlled - query_parts = ["INSERT INTO", f"`{table_name}`", f"({columns_str})", "VALUES"] + query_parts = ["INSERT INTO", f"`{request.table_name}`", f"({columns_str})", "VALUES"] query_parts.append(", ".join(value_groups)) if conflict_clause: query_parts.append(conflict_clause) @@ -62,20 +107,6 @@ def _build_insert_query_simple( return query -def _validate_insert_inputs(table_name: str, parsed_records: list) -> None: - """Validate insert inputs and raise appropriate errors.""" - # Validate table name - if not table_name.strip(): - raise ValueError("Table name cannot be empty") - dangerous_keywords = [";", "--", "/*", "*/", "drop", "delete", "truncate"] - if any(keyword in table_name.lower() for keyword in dangerous_keywords): - raise ValueError("Invalid characters in table name") - - # Validate records - if not parsed_records: - raise ValueError("At least one record is required") - - @tool(requires_secrets=["GIBSONAI_API_KEY"]) async def insert_records( context: ToolContext, @@ -107,9 +138,6 @@ async def insert_records( The tool automatically generates properly formatted INSERT statements based on the validated input data. """ - api_key = context.get_secret("GIBSONAI_API_KEY") - client = GibsonAIClient(api_key) - try: # Parse JSON records try: @@ -119,14 +147,28 @@ async def insert_records( except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format: {e}") from e - # Validate table name and records - _validate_insert_inputs(table_name, parsed_records) - - # Validate columns consistency across all records - columns = _validate_record_columns_simple(parsed_records) + # Create and validate request using Pydantic + try: + request = InsertRequest( + table_name=table_name, records=parsed_records, on_conflict=on_conflict + ) + except Exception as e: + # Convert Pydantic validation errors to more readable messages + error_msg = str(e) + if "String should have at least 1 character" in error_msg: + raise ValueError("Table name cannot be empty") from e + elif "List should have at least 1 item" in error_msg: + raise ValueError("At least one record is required") from e + elif "Invalid on_conflict strategy" in error_msg: + raise ValueError(error_msg) from e + else: + raise ValueError(f"Validation error: {error_msg}") from e # Build and execute the INSERT query - query = _build_insert_query_simple(table_name, parsed_records, on_conflict, columns) + query = _build_insert_query(request) + + api_key = context.get_secret("GIBSONAI_API_KEY") + client = GibsonAIClient(api_key) results = await client.execute_query(query) except ValueError as e: diff --git a/toolkits/gibsonai/arcade_gibsonai/tools/update.py b/toolkits/gibsonai/arcade_gibsonai/tools/update.py index e7a8024e2..d6ed1ceb1 100644 --- a/toolkits/gibsonai/arcade_gibsonai/tools/update.py +++ b/toolkits/gibsonai/arcade_gibsonai/tools/update.py @@ -3,55 +3,90 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.errors import RetryableToolError +from pydantic import BaseModel, Field, field_validator from ..api_client import GibsonAIClient -def _validate_update_conditions(conditions: list[dict[str, Any]]) -> None: - """Validate that all update conditions have required keys.""" - if not conditions: - raise ValueError("Update operations require at least one WHERE condition for safety") - - required_keys = {"column", "operator", "value"} - valid_operators = { - "=", - "!=", - "<>", - "<", - "<=", - ">", - ">=", - "LIKE", - "NOT LIKE", - "IN", - "NOT IN", - "IS NULL", - "IS NOT NULL", - } - - for i, condition in enumerate(conditions): - if not isinstance(condition, dict): - raise TypeError(f"Condition {i} must be a dictionary") - - missing_keys = required_keys - set(condition.keys()) - if missing_keys: - raise ValueError(f"Condition {i} missing required keys: {missing_keys}") - - if condition["operator"] not in valid_operators: - raise ValueError( - f"Condition {i} has invalid operator '{condition['operator']}'. " - f"Valid operators: {', '.join(sorted(valid_operators))}" - ) - - -def _build_update_query( - table_name: str, updates: dict[str, Any], conditions: list[dict[str, Any]], limit: int -) -> tuple[str, list[Any]]: - """Build UPDATE query with parameterized values.""" +class UpdateCondition(BaseModel): + """Pydantic model for update WHERE conditions.""" + + column: str = Field(..., min_length=1, description="Column name for the condition") + operator: str = Field(..., description="SQL operator for the condition") + value: Any = Field(..., description="Value for the condition") + + @field_validator("operator") + @classmethod + def validate_operator(cls, v: str) -> str: + """Validate SQL operator.""" + valid_operators = { + "=", + "!=", + "<>", + "<", + "<=", + ">", + ">=", + "LIKE", + "NOT LIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + } + if v not in valid_operators: + operators_str = ", ".join(sorted(valid_operators)) + raise ValueError(f"Invalid operator '{v}'. Valid operators: {operators_str}") + return v + + +class UpdateRequest(BaseModel): + """Pydantic model for validating update requests.""" + + table_name: str = Field(..., min_length=1, description="Name of the table to update records in") + updates: dict[str, Any] = Field( + ..., min_length=1, description="Dictionary of column-value pairs to update" + ) + conditions: list[UpdateCondition] = Field( + ..., min_length=1, description="List of WHERE conditions for safety" + ) + limit: int = Field(default=0, ge=0, description="Optional LIMIT for safety") + + @field_validator("table_name") + @classmethod + def validate_table_name(cls, v: str) -> str: + """Validate table name for security.""" + if not v.strip(): + raise ValueError("Table name cannot be empty") + + dangerous_keywords = [";", "--", "/*", "*/", "drop", "delete", "truncate"] + if any(keyword in v.lower() for keyword in dangerous_keywords): + raise ValueError("Invalid characters in table name") + + return v.strip() + + @field_validator("updates") + @classmethod + def validate_updates(cls, v: dict[str, Any]) -> dict[str, Any]: + """Validate updates dictionary.""" + if not v: + raise ValueError("Updates must be a non-empty dictionary") + + # Check for dangerous column names + for column in v: + if not isinstance(column, str) or not column.strip(): + raise ValueError("Column names must be non-empty strings") + + return v + + +def _build_update_query(request: UpdateRequest) -> tuple[str, list[Any]]: + """Build UPDATE query with parameterized values from validated request.""" # Build SET clause set_parts = [] values: list[Any] = [] - for column, value in updates.items(): + + for column, value in request.updates.items(): set_parts.append(f"{column} = ?") values.append(value) @@ -59,10 +94,10 @@ def _build_update_query( # Build WHERE clause where_parts = [] - for condition in conditions: - column = condition["column"] - operator = condition["operator"] - value = condition["value"] + for condition in request.conditions: + column = condition.column + operator = condition.operator + value = condition.value if operator in ("IS NULL", "IS NOT NULL"): where_parts.append(f"{column} {operator}") @@ -80,28 +115,42 @@ def _build_update_query( where_clause = "WHERE " + " AND ".join(where_parts) # Build complete query - query = f"UPDATE {table_name} {set_clause} {where_clause}" - if limit > 0: - query += f" LIMIT {limit}" + query = f"UPDATE {request.table_name} {set_clause} {where_clause}" + if request.limit > 0: + query += f" LIMIT {request.limit}" return query, values -def _validate_update_inputs( +def _create_update_request( table_name: str, parsed_updates: dict, parsed_conditions: list, limit: int -) -> None: - """Validate update inputs and raise appropriate errors.""" - if not table_name or not isinstance(table_name, str): - raise ValueError("table_name must be a non-empty string") - - if not parsed_updates: - raise ValueError("updates must be a non-empty dictionary") - - if not parsed_conditions: - raise TypeError("conditions must be a non-empty list") - - if limit < 0: - raise ValueError("limit must be non-negative (0 = no limit)") +) -> UpdateRequest: + """Create and validate UpdateRequest from parsed data.""" + try: + # Convert conditions to UpdateCondition models + condition_models = [ + UpdateCondition(column=cond["column"], operator=cond["operator"], value=cond["value"]) + for cond in parsed_conditions + ] + + return UpdateRequest( + table_name=table_name, + updates=parsed_updates, + conditions=condition_models, + limit=limit, + ) + except Exception as e: + # Convert Pydantic validation errors to more readable messages + error_msg = str(e) + if "String should have at least 1 character" in error_msg: + raise ValueError("Table name cannot be empty") from e + elif "List should have at least 1 item" in error_msg: + msg = "Update operations require at least one WHERE condition for safety" + raise ValueError(msg) from e + elif "Updates must be a non-empty dictionary" in error_msg: + raise ValueError("Updates must be a non-empty dictionary") from e + else: + raise ValueError(f"Validation error: {error_msg}") from e @tool(requires_secrets=["GIBSONAI_API_KEY"]) @@ -133,7 +182,7 @@ async def update_records( limit: Optional LIMIT clause for additional safety (0 = no limit) Returns: - A message indicating the number of records updated + A message indicating successful update Raises: ValueError: If no conditions provided or invalid conditions @@ -155,13 +204,11 @@ async def update_records( except json.JSONDecodeError as e: raise ValueError(f"Invalid JSON format for conditions: {e}") from e - # Validate inputs - _validate_update_inputs(table_name, parsed_updates, parsed_conditions, limit) - - _validate_update_conditions(parsed_conditions) + # Create and validate request using Pydantic + request = _create_update_request(table_name, parsed_updates, parsed_conditions, limit) # Build query with parameterized values - query, values = _build_update_query(table_name, parsed_updates, parsed_conditions, limit) + query, values = _build_update_query(request) # Execute update client = GibsonAIClient(context.get_secret("GIBSONAI_API_KEY")) diff --git a/toolkits/gibsonai/tests/test_gibsonai.py b/toolkits/gibsonai/tests/test_gibsonai.py index d14038237..c9fdcd3a9 100644 --- a/toolkits/gibsonai/tests/test_gibsonai.py +++ b/toolkits/gibsonai/tests/test_gibsonai.py @@ -192,13 +192,16 @@ async def test_update_records_validation_errors(): mock_context.get_secret.return_value = "test_api_key" # Test missing conditions - with pytest.raises(RetryableToolError, match="conditions must be a non-empty list"): + with pytest.raises( + RetryableToolError, + match="Update operations require at least one WHERE condition for safety", + ): await update_records( context=mock_context, table_name="users", updates='{"name": "Johnny"}', conditions="[]" ) # Test invalid table name - with pytest.raises(RetryableToolError, match="table_name must be a non-empty string"): + with pytest.raises(RetryableToolError, match="Table name cannot be empty"): await update_records( context=mock_context, table_name="",