diff --git a/autobot-backend/services/autoresearch/routes_test.py b/autobot-backend/services/autoresearch/routes_test.py new file mode 100644 index 000000000..7639bac98 --- /dev/null +++ b/autobot-backend/services/autoresearch/routes_test.py @@ -0,0 +1,457 @@ +# AutoBot - AI-Powered Automation Platform +# Copyright (c) 2025 mrveiss +# Author: mrveiss +""" +AutoResearch Routes Integration Tests + +Issue #2637: Tests for REST API endpoints covering experiment CRUD, +baseline management, status, and cancellation. +""" + +from __future__ import annotations + +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from services.autoresearch.models import ( + Experiment, + ExperimentState, + ExperimentStats, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# App + client setup (auth middleware bypassed for unit tests) +# --------------------------------------------------------------------------- + + +def _build_app( + store: AsyncMock | None = None, + runner: MagicMock | None = None, +) -> FastAPI: + """Build a FastAPI app with the autoresearch router and mocked deps.""" + with patch("auth_middleware.check_admin_permission", return_value=True): + from services.autoresearch.routes import router + + app = FastAPI() + app.include_router(router, prefix="/autoresearch") + + if store is not None: + app.state.autoresearch_store = store + if runner is not None: + app.state.autoresearch_runner = runner + + return app + + +def _make_store() -> AsyncMock: + """Build a mock ExperimentStore.""" + store = AsyncMock() + store.save_experiment = AsyncMock() + store.get_experiment = AsyncMock(return_value=None) + store.list_experiments = AsyncMock(return_value=[]) + store.get_baseline = AsyncMock(return_value=None) + store.set_baseline = AsyncMock() + store.get_stats = AsyncMock( + return_value=ExperimentStats(total_experiments=0) + ) + return store + + +def _make_runner(is_running: bool = False) -> MagicMock: + """Build a mock ExperimentRunner.""" + runner = MagicMock() + runner.is_running = is_running + runner.run_experiment = AsyncMock() + runner.cancel = AsyncMock() + return runner + + +def _sample_experiment(**overrides) -> Experiment: + """Build a sample experiment for test responses.""" + defaults = { + "id": "exp-001", + "hypothesis": "Test hypothesis", + "state": ExperimentState.PENDING, + } + defaults.update(overrides) + return Experiment(**defaults) + + +# --------------------------------------------------------------------------- +# POST /autoresearch/experiments +# --------------------------------------------------------------------------- + + +class TestCreateExperiment: + """Tests for POST /autoresearch/experiments.""" + + def test_create_experiment_success(self): + store = _make_store() + runner = _make_runner(is_running=False) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={ + "hypothesis": "Increase learning rate", + "description": "Test higher LR", + "tags": ["lr_sweep"], + }, + ) + + assert response.status_code == 200 + data = response.json() + assert "id" in data + assert data["state"] == "pending" + store.save_experiment.assert_called_once() + + def test_create_experiment_with_hyperparams(self): + store = _make_store() + runner = _make_runner(is_running=False) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={ + "hypothesis": "Custom LR", + "hyperparams": {"learning_rate": 1e-2, "batch_size": 128}, + }, + ) + + assert response.status_code == 200 + + def test_create_experiment_minimal_payload(self): + store = _make_store() + runner = _make_runner(is_running=False) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["state"] == "pending" + + def test_create_experiment_conflict_when_running(self): + store = _make_store() + runner = _make_runner(is_running=True) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={"hypothesis": "Won't run"}, + ) + + assert response.status_code == 409 + assert "already running" in response.json()["detail"].lower() + + def test_create_experiment_hypothesis_too_long(self): + store = _make_store() + runner = _make_runner(is_running=False) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={"hypothesis": "x" * 1001}, + ) + + assert response.status_code == 422 + + def test_create_experiment_description_too_long(self): + store = _make_store() + runner = _make_runner(is_running=False) + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments", + json={"description": "x" * 5001}, + ) + + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /autoresearch/experiments +# --------------------------------------------------------------------------- + + +class TestListExperiments: + """Tests for GET /autoresearch/experiments.""" + + def test_list_empty(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments") + + assert response.status_code == 200 + data = response.json() + assert data["experiments"] == [] + assert data["count"] == 0 + assert data["offset"] == 0 + + def test_list_returns_experiments(self): + store = _make_store() + exp1 = _sample_experiment(id="exp-1") + exp2 = _sample_experiment(id="exp-2") + store.list_experiments.return_value = [exp1, exp2] + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments") + + assert response.status_code == 200 + data = response.json() + assert data["count"] == 2 + assert data["experiments"][0]["id"] == "exp-1" + + def test_list_with_state_filter(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + client.get("/autoresearch/experiments?state=kept") + + call_kwargs = store.list_experiments.call_args[1] + assert call_kwargs["state"] == ExperimentState.KEPT + + def test_list_with_invalid_state(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments?state=bogus") + + assert response.status_code == 400 + assert "Invalid state" in response.json()["detail"] + + def test_list_with_pagination(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + client.get("/autoresearch/experiments?limit=10&offset=5") + + call_kwargs = store.list_experiments.call_args[1] + assert call_kwargs["limit"] == 10 + assert call_kwargs["offset"] == 5 + + def test_list_limit_out_of_range(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments?limit=999") + assert response.status_code == 422 + + def test_list_negative_offset(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments?offset=-1") + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /autoresearch/experiments/{experiment_id} +# --------------------------------------------------------------------------- + + +class TestGetExperiment: + """Tests for GET /autoresearch/experiments/{experiment_id}.""" + + def test_get_existing_experiment(self): + store = _make_store() + exp = _sample_experiment(id="exp-001", hypothesis="found it") + store.get_experiment.return_value = exp + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments/exp-001") + + assert response.status_code == 200 + data = response.json() + assert data["id"] == "exp-001" + assert data["hypothesis"] == "found it" + + def test_get_nonexistent_experiment(self): + store = _make_store() + store.get_experiment.return_value = None + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments/nonexistent") + + assert response.status_code == 404 + assert "not found" in response.json()["detail"].lower() + + +# --------------------------------------------------------------------------- +# GET /autoresearch/experiments/stats +# --------------------------------------------------------------------------- + + +class TestGetStats: + """Tests for GET /autoresearch/experiments/stats.""" + + def test_get_stats_empty(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments/stats") + + assert response.status_code == 200 + data = response.json() + assert data["total_experiments"] == 0 + + def test_get_stats_with_data(self): + store = _make_store() + store.get_stats.return_value = ExperimentStats( + total_experiments=10, + completed=5, + kept=3, + failed=2, + best_val_bpb=5.2, + ) + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.get("/autoresearch/experiments/stats") + + assert response.status_code == 200 + data = response.json() + assert data["total_experiments"] == 10 + assert data["kept"] == 3 + assert data["best_val_bpb"] == 5.2 + + +# --------------------------------------------------------------------------- +# POST /autoresearch/experiments/baseline +# --------------------------------------------------------------------------- + + +class TestSetBaseline: + """Tests for POST /autoresearch/experiments/baseline.""" + + def test_set_baseline_success(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments/baseline", + json={"val_bpb": 6.0}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["baseline_val_bpb"] == 6.0 + store.set_baseline.assert_called_once_with(6.0) + + def test_set_baseline_missing_field(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments/baseline", + json={}, + ) + + assert response.status_code == 422 + + def test_set_baseline_invalid_type(self): + store = _make_store() + app = _build_app(store=store, runner=_make_runner()) + client = TestClient(app) + + response = client.post( + "/autoresearch/experiments/baseline", + json={"val_bpb": "not_a_number"}, + ) + + assert response.status_code == 422 + + +# --------------------------------------------------------------------------- +# GET /autoresearch/status +# --------------------------------------------------------------------------- + + +class TestGetStatus: + """Tests for GET /autoresearch/status.""" + + def test_status_idle(self): + store = _make_store() + runner = _make_runner(is_running=False) + store.get_baseline.return_value = None + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.get("/autoresearch/status") + + assert response.status_code == 200 + data = response.json() + assert data["running"] is False + assert data["baseline_val_bpb"] is None + + def test_status_running_with_baseline(self): + store = _make_store() + runner = _make_runner(is_running=True) + store.get_baseline.return_value = 6.0 + app = _build_app(store=store, runner=runner) + client = TestClient(app) + + response = client.get("/autoresearch/status") + + assert response.status_code == 200 + data = response.json() + assert data["running"] is True + assert data["baseline_val_bpb"] == 6.0 + + +# --------------------------------------------------------------------------- +# POST /autoresearch/cancel +# --------------------------------------------------------------------------- + + +class TestCancelExperiment: + """Tests for POST /autoresearch/cancel.""" + + def test_cancel_running_experiment(self): + runner = _make_runner(is_running=True) + app = _build_app(store=_make_store(), runner=runner) + client = TestClient(app) + + response = client.post("/autoresearch/cancel") + + assert response.status_code == 200 + assert response.json()["status"] == "cancelled" + runner.cancel.assert_called_once() + + def test_cancel_when_not_running(self): + runner = _make_runner(is_running=False) + app = _build_app(store=_make_store(), runner=runner) + client = TestClient(app) + + response = client.post("/autoresearch/cancel") + + assert response.status_code == 409 + assert "not currently running" in response.json()["detail"].lower() diff --git a/autobot-backend/services/autoresearch/runner_test.py b/autobot-backend/services/autoresearch/runner_test.py new file mode 100644 index 000000000..d74058506 --- /dev/null +++ b/autobot-backend/services/autoresearch/runner_test.py @@ -0,0 +1,616 @@ +# AutoBot - AI-Powered Automation Platform +# Copyright (c) 2025 mrveiss +# Author: mrveiss +""" +ExperimentRunner Unit Tests + +Issue #2637: Comprehensive tests for subprocess execution, timeout handling, +cancellation, concurrent run rejection, evaluation logic, and parameter +validation. +""" + +from __future__ import annotations + +import asyncio +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from services.autoresearch.config import AutoResearchConfig +from services.autoresearch.models import ( + Experiment, + ExperimentResult, + ExperimentState, + HyperParams, +) +from services.autoresearch.runner import ( + ExperimentRunner, + _EXTRA_KEY_PATTERN, + _RESERVED_KEYS, +) + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_config(**overrides) -> AutoResearchConfig: + """Build a test config with sensible defaults.""" + defaults = { + "default_training_timeout": 10, + "improvement_threshold": 0.01, + } + defaults.update(overrides) + return AutoResearchConfig(**defaults) + + +def _make_store() -> AsyncMock: + """Build a mock ExperimentStore.""" + store = AsyncMock() + store.save_experiment = AsyncMock() + store.get_baseline = AsyncMock(return_value=None) + store.set_baseline = AsyncMock() + return store + + +def _make_parser(result: ExperimentResult | None = None) -> MagicMock: + """Build a mock parser returning a predetermined result.""" + parser = MagicMock() + if result is None: + result = ExperimentResult(val_bpb=5.5, steps_completed=5000) + parser.parse.return_value = result + return parser + + +def _make_runner( + config: AutoResearchConfig | None = None, + store: AsyncMock | None = None, + parser: MagicMock | None = None, +) -> ExperimentRunner: + """Build an ExperimentRunner with all I/O mocked.""" + return ExperimentRunner( + config=config or _make_config(), + store=store or _make_store(), + parser=parser or _make_parser(), + ) + + +def _make_experiment(**overrides) -> Experiment: + """Build a test Experiment.""" + defaults = { + "hypothesis": "test hypothesis", + "state": ExperimentState.PENDING, + } + defaults.update(overrides) + return Experiment(**defaults) + + +# --------------------------------------------------------------------------- +# _validate_extra_params tests +# --------------------------------------------------------------------------- + + +class TestValidateExtraParams: + """Tests for ExperimentRunner._validate_extra_params.""" + + def test_valid_params_accepted(self): + ExperimentRunner._validate_extra_params( + {"custom_lr": 0.001, "warmup_ratio": 0.1} + ) + + def test_reserved_key_rejected(self): + for key in ("max_steps", "learning_rate", "batch_size"): + with pytest.raises(ValueError, match="conflicts with a built-in flag"): + ExperimentRunner._validate_extra_params({key: 42}) + + def test_uppercase_key_rejected(self): + with pytest.raises(ValueError, match="must be lowercase"): + ExperimentRunner._validate_extra_params({"BadKey": 1}) + + def test_key_starting_with_digit_rejected(self): + with pytest.raises(ValueError, match="must be lowercase"): + ExperimentRunner._validate_extra_params({"1bad": 1}) + + def test_key_with_dash_rejected(self): + with pytest.raises(ValueError, match="must be lowercase"): + ExperimentRunner._validate_extra_params({"bad-key": 1}) + + def test_key_too_long_rejected(self): + long_key = "a" * 65 + with pytest.raises(ValueError, match="must be lowercase"): + ExperimentRunner._validate_extra_params({long_key: 1}) + + def test_non_scalar_value_rejected(self): + with pytest.raises(ValueError, match="must be scalar"): + ExperimentRunner._validate_extra_params({"key": [1, 2, 3]}) + + def test_dict_value_rejected(self): + with pytest.raises(ValueError, match="must be scalar"): + ExperimentRunner._validate_extra_params({"key": {"nested": True}}) + + def test_string_with_double_dash_rejected(self): + with pytest.raises(ValueError, match="cannot contain '--'"): + ExperimentRunner._validate_extra_params({"key": "--inject"}) + + def test_string_too_long_rejected(self): + with pytest.raises(ValueError, match="256 chars"): + ExperimentRunner._validate_extra_params({"key": "x" * 257}) + + def test_bool_value_accepted(self): + ExperimentRunner._validate_extra_params({"use_flash": True}) + + def test_int_value_accepted(self): + ExperimentRunner._validate_extra_params({"seed": 42}) + + def test_empty_dict_accepted(self): + ExperimentRunner._validate_extra_params({}) + + def test_all_reserved_keys_are_lowercase(self): + """Sanity check: all reserved keys match the key pattern format.""" + for key in _RESERVED_KEYS: + assert _EXTRA_KEY_PATTERN.match(key), f"Reserved key '{key}' doesn't match pattern" + + +# --------------------------------------------------------------------------- +# _build_command tests +# --------------------------------------------------------------------------- + + +class TestBuildCommand: + """Tests for ExperimentRunner._build_command.""" + + def test_basic_command_structure(self): + runner = _make_runner() + exp = _make_experiment() + cmd = runner._build_command(exp) + + assert cmd[0] == runner.config.python_bin + assert str(runner.config.train_script) in cmd[1] + assert any("--max_steps=" in arg for arg in cmd) + assert any("--learning_rate=" in arg for arg in cmd) + + def test_extra_params_appended(self): + runner = _make_runner() + exp = _make_experiment( + hyperparams=HyperParams(extra={"seed": 42, "use_flash": True}) + ) + cmd = runner._build_command(exp) + + assert "--seed=42" in cmd + assert "--use_flash=True" in cmd + + def test_custom_python_executable(self): + config = _make_config() + config.python_executable = "/usr/bin/python3.12" + runner = _make_runner(config=config) + exp = _make_experiment() + cmd = runner._build_command(exp) + + assert cmd[0] == "/usr/bin/python3.12" + + def test_all_hyperparams_included(self): + runner = _make_runner() + hp = HyperParams() + exp = _make_experiment(hyperparams=hp) + cmd = runner._build_command(exp) + + expected_flags = [ + "max_steps", "learning_rate", "batch_size", "block_size", + "n_layer", "n_head", "n_embd", "dropout", + "warmup_steps", "weight_decay", + ] + for flag in expected_flags: + assert any(f"--{flag}=" in arg for arg in cmd), ( + f"Missing --{flag} in command" + ) + + +# --------------------------------------------------------------------------- +# _execute_training tests (subprocess mocking) +# --------------------------------------------------------------------------- + + +class TestExecuteTraining: + """Tests for ExperimentRunner._execute_training with mocked subprocess.""" + + @pytest.mark.asyncio + async def test_successful_training(self): + expected_result = ExperimentResult(val_bpb=5.5, steps_completed=5000) + parser = _make_parser(result=expected_result) + runner = _make_runner(parser=parser) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock( + return_value=(b"step 5000 | train loss 4.0 | val loss 4.1 | val_bpb 5.5\n", None) + ) + mock_process.returncode = 0 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + result = await runner._execute_training(exp) + + assert result.val_bpb == 5.5 + parser.parse.assert_called_once() + + @pytest.mark.asyncio + async def test_nonzero_exit_code(self): + runner = _make_runner() + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock( + return_value=(b"segfault", None) + ) + mock_process.returncode = 1 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + result = await runner._execute_training(exp) + + assert not result.success + assert "exited with code 1" in result.error_message + assert result.raw_output == "segfault" + + @pytest.mark.asyncio + async def test_timeout_kills_process(self): + config = _make_config(default_training_timeout=1) + runner = _make_runner(config=config) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock( + side_effect=asyncio.TimeoutError() + ) + mock_process.kill = AsyncMock() + mock_process.wait = AsyncMock() + mock_process.returncode = None + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + runner._current_process = mock_process + result = await runner._execute_training(exp) + + assert not result.success + assert "timed out" in result.error_message + mock_process.kill.assert_called_once() + + @pytest.mark.asyncio + async def test_empty_stdout_handled(self): + parser = _make_parser(result=ExperimentResult(error_message="Empty training output")) + runner = _make_runner(parser=parser) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock(return_value=(b"", None)) + mock_process.returncode = 0 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + result = await runner._execute_training(exp) + + # Parser called with empty string + parser.parse.assert_called_once() + call_args = parser.parse.call_args + assert call_args[0][0] == "" + + @pytest.mark.asyncio + async def test_none_stdout_handled(self): + parser = _make_parser(result=ExperimentResult(error_message="Empty training output")) + runner = _make_runner(parser=parser) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock(return_value=(None, None)) + mock_process.returncode = 0 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + result = await runner._execute_training(exp) + + parser.parse.assert_called_once() + call_args = parser.parse.call_args + assert call_args[0][0] == "" + + +# --------------------------------------------------------------------------- +# run_experiment tests (full flow) +# --------------------------------------------------------------------------- + + +class TestRunExperiment: + """Tests for ExperimentRunner.run_experiment end-to-end flow.""" + + @pytest.mark.asyncio + async def test_successful_run_sets_completed(self): + store = _make_store() + result = ExperimentResult(val_bpb=5.5, steps_completed=5000) + parser = _make_parser(result=result) + runner = _make_runner(store=store, parser=parser) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock(return_value=(b"ok", None)) + mock_process.returncode = 0 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + updated = await runner.run_experiment(exp) + + assert updated.result is not None + assert updated.result.val_bpb == 5.5 + assert updated.started_at is not None + assert updated.completed_at is not None + # Store should be called twice: once for RUNNING, once for final state + assert store.save_experiment.call_count == 2 + + @pytest.mark.asyncio + async def test_failed_run_sets_failed_state(self): + store = _make_store() + runner = _make_runner(store=store) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock(return_value=(b"error", None)) + mock_process.returncode = 1 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + updated = await runner.run_experiment(exp) + + assert updated.state == ExperimentState.FAILED + assert updated.result.error_message is not None + + @pytest.mark.asyncio + async def test_exception_during_training_sets_failed(self): + store = _make_store() + runner = _make_runner(store=store) + exp = _make_experiment() + + with patch( + "asyncio.create_subprocess_exec", + side_effect=OSError("No such file"), + ): + updated = await runner.run_experiment(exp) + + assert updated.state == ExperimentState.FAILED + assert "No such file" in updated.result.error_message + + @pytest.mark.asyncio + async def test_running_flag_cleared_after_completion(self): + store = _make_store() + result = ExperimentResult(val_bpb=5.5) + runner = _make_runner(store=store, parser=_make_parser(result=result)) + exp = _make_experiment() + + mock_process = AsyncMock() + mock_process.communicate = AsyncMock(return_value=(b"ok", None)) + mock_process.returncode = 0 + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + await runner.run_experiment(exp) + + assert runner.is_running is False + assert runner._current_process is None + + @pytest.mark.asyncio + async def test_running_flag_cleared_after_failure(self): + store = _make_store() + runner = _make_runner(store=store) + exp = _make_experiment() + + with patch( + "asyncio.create_subprocess_exec", + side_effect=RuntimeError("boom"), + ): + await runner.run_experiment(exp) + + assert runner.is_running is False + + @pytest.mark.asyncio + async def test_concurrent_run_rejected(self): + store = _make_store() + runner = _make_runner(store=store) + # Simulate already running + runner._running = True + + exp = _make_experiment() + with pytest.raises(RuntimeError, match="already running"): + await runner.run_experiment(exp) + + @pytest.mark.asyncio + async def test_cancellation_sets_failed_state(self): + store = _make_store() + runner = _make_runner(store=store) + exp = _make_experiment() + + async def _slow_communicate(): + await asyncio.sleep(10) + return (b"", None) + + mock_process = AsyncMock() + mock_process.communicate = _slow_communicate + mock_process.returncode = None + mock_process.kill = MagicMock() + mock_process.wait = AsyncMock() + + with patch("asyncio.create_subprocess_exec", return_value=mock_process): + task = asyncio.create_task(runner.run_experiment(exp)) + # Give the task a moment to start + await asyncio.sleep(0.05) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + assert exp.state == ExperimentState.FAILED + assert exp.result is not None + assert "cancelled" in exp.result.error_message.lower() + assert runner.is_running is False + + +# --------------------------------------------------------------------------- +# _evaluate_result tests +# --------------------------------------------------------------------------- + + +class TestEvaluateResult: + """Tests for ExperimentRunner._evaluate_result decision logic.""" + + @pytest.mark.asyncio + async def test_first_experiment_becomes_baseline(self): + store = _make_store() + store.get_baseline.return_value = None + runner = _make_runner(store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=5.5), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.KEPT + assert exp.baseline_val_bpb == 5.5 + store.set_baseline.assert_called_once_with(5.5) + + @pytest.mark.asyncio + async def test_improvement_above_threshold_kept(self): + store = _make_store() + store.get_baseline.return_value = 6.0 + config = _make_config(improvement_threshold=0.01) + runner = _make_runner(config=config, store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=5.5), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.KEPT + assert exp.baseline_val_bpb == 6.0 + store.set_baseline.assert_called_once_with(5.5) + + @pytest.mark.asyncio + async def test_improvement_below_threshold_discarded(self): + store = _make_store() + store.get_baseline.return_value = 6.0 + config = _make_config(improvement_threshold=1.0) + runner = _make_runner(config=config, store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=5.5), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.DISCARDED + assert exp.baseline_val_bpb == 6.0 + store.set_baseline.assert_not_called() + + @pytest.mark.asyncio + async def test_worse_result_discarded(self): + store = _make_store() + store.get_baseline.return_value = 5.0 + runner = _make_runner(store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=5.5), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.DISCARDED + + @pytest.mark.asyncio + async def test_no_result_skips_evaluation(self): + store = _make_store() + runner = _make_runner(store=store) + + exp = _make_experiment(state=ExperimentState.COMPLETED, result=None) + await runner._evaluate_result(exp) + + # State should remain COMPLETED (not changed to KEPT or DISCARDED) + assert exp.state == ExperimentState.COMPLETED + store.get_baseline.assert_not_called() + + @pytest.mark.asyncio + async def test_no_val_bpb_skips_evaluation(self): + store = _make_store() + runner = _make_runner(store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=None), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.COMPLETED + store.get_baseline.assert_not_called() + + @pytest.mark.asyncio + async def test_exact_threshold_kept(self): + """Improvement exactly equal to threshold should be KEPT.""" + store = _make_store() + store.get_baseline.return_value = 6.0 + config = _make_config(improvement_threshold=0.5) + runner = _make_runner(config=config, store=store) + + exp = _make_experiment( + state=ExperimentState.COMPLETED, + result=ExperimentResult(val_bpb=5.5), + ) + await runner._evaluate_result(exp) + + assert exp.state == ExperimentState.KEPT + + +# --------------------------------------------------------------------------- +# cancel tests +# --------------------------------------------------------------------------- + + +class TestCancel: + """Tests for ExperimentRunner.cancel.""" + + @pytest.mark.asyncio + async def test_cancel_kills_running_process(self): + runner = _make_runner() + mock_process = MagicMock() + mock_process.returncode = None + mock_process.kill = MagicMock() + runner._current_process = mock_process + + await runner.cancel() + mock_process.kill.assert_called_once() + + @pytest.mark.asyncio + async def test_cancel_noop_when_no_process(self): + runner = _make_runner() + runner._current_process = None + # Should not raise + await runner.cancel() + + @pytest.mark.asyncio + async def test_cancel_noop_when_process_finished(self): + runner = _make_runner() + mock_process = MagicMock() + mock_process.returncode = 0 # already finished + mock_process.kill = MagicMock() + runner._current_process = mock_process + + await runner.cancel() + mock_process.kill.assert_not_called() + + +# --------------------------------------------------------------------------- +# is_running property tests +# --------------------------------------------------------------------------- + + +class TestIsRunning: + """Tests for ExperimentRunner.is_running property.""" + + def test_initially_not_running(self): + runner = _make_runner() + assert runner.is_running is False + + def test_reflects_internal_flag(self): + runner = _make_runner() + runner._running = True + assert runner.is_running is True diff --git a/autobot-backend/services/autoresearch/store_chromadb_test.py b/autobot-backend/services/autoresearch/store_chromadb_test.py new file mode 100644 index 000000000..e146c6abf --- /dev/null +++ b/autobot-backend/services/autoresearch/store_chromadb_test.py @@ -0,0 +1,373 @@ +# AutoBot - AI-Powered Automation Platform +# Copyright (c) 2025 mrveiss +# Author: mrveiss +""" +ExperimentStore ChromaDB Indexing Tests + +Issue #2637: Tests for _index_in_chromadb, _build_document, _build_metadata, +and the conditional indexing logic in save_experiment. +""" + +from __future__ import annotations + +import logging +from unittest.mock import AsyncMock, patch + +import pytest + +from services.autoresearch.config import AutoResearchConfig +from services.autoresearch.models import ( + Experiment, + ExperimentResult, + ExperimentState, +) +from services.autoresearch.store import ExperimentStore + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _make_store( + mock_redis: AsyncMock | None = None, + mock_collection: AsyncMock | None = None, +) -> ExperimentStore: + """Build an ExperimentStore with mocked Redis and ChromaDB.""" + store = ExperimentStore(AutoResearchConfig()) + if mock_redis is None: + mock_redis = AsyncMock() + mock_redis.hset = AsyncMock() + mock_redis.hget = AsyncMock(return_value=None) + mock_redis.zadd = AsyncMock() + mock_redis.sadd = AsyncMock() + mock_redis.srem = AsyncMock() + mock_redis.get = AsyncMock(return_value=None) + mock_redis.set = AsyncMock() + store._redis = mock_redis + + if mock_collection is not None: + store._chromadb_collection = mock_collection + + return store + + +def _make_experiment( + state: ExperimentState = ExperimentState.COMPLETED, + val_bpb: float | None = 5.5, + hypothesis: str = "Test higher learning rate", + description: str = "Increase LR from 3e-4 to 1e-3", + code_diff: str = "", + tags: list | None = None, + baseline: float | None = 6.0, +) -> Experiment: + """Build a test experiment with optional result.""" + exp = Experiment( + hypothesis=hypothesis, + description=description, + code_diff=code_diff, + tags=tags or [], + state=state, + baseline_val_bpb=baseline, + ) + if val_bpb is not None: + exp.result = ExperimentResult(val_bpb=val_bpb) + return exp + + +# --------------------------------------------------------------------------- +# _build_document tests +# --------------------------------------------------------------------------- + + +class TestBuildDocument: + """Tests for ExperimentStore._build_document.""" + + def test_includes_hypothesis(self): + store = _make_store() + exp = _make_experiment(hypothesis="Increase dropout") + doc = store._build_document(exp) + assert "Increase dropout" in doc + + def test_includes_description(self): + store = _make_store() + exp = _make_experiment(description="Raise dropout from 0.2 to 0.3") + doc = store._build_document(exp) + assert "Raise dropout" in doc + + def test_includes_val_bpb(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5) + doc = store._build_document(exp) + assert "5.5" in doc + + def test_includes_improvement_when_available(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5, baseline=6.0) + doc = store._build_document(exp) + assert "Improvement" in doc + assert "0.5000" in doc + + def test_no_improvement_without_baseline(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5, baseline=None) + doc = store._build_document(exp) + assert "Improvement" not in doc + + def test_includes_truncated_code_diff(self): + store = _make_store() + long_diff = "x" * 600 + exp = _make_experiment(code_diff=long_diff) + doc = store._build_document(exp) + assert "Code change:" in doc + # Should be truncated to 500 chars + assert len(doc.split("Code change:\n")[1]) == 500 + + def test_no_code_diff_section_when_empty(self): + store = _make_store() + exp = _make_experiment(code_diff="") + doc = store._build_document(exp) + assert "Code change:" not in doc + + def test_no_result_omits_val_bpb(self): + store = _make_store() + exp = _make_experiment(val_bpb=None) + exp.result = None + doc = store._build_document(exp) + assert "val_bpb" not in doc + + +# --------------------------------------------------------------------------- +# _build_metadata tests +# --------------------------------------------------------------------------- + + +class TestBuildMetadata: + """Tests for ExperimentStore._build_metadata.""" + + def test_includes_state(self): + store = _make_store() + exp = _make_experiment(state=ExperimentState.KEPT) + meta = store._build_metadata(exp) + assert meta["state"] == "kept" + + def test_includes_created_at(self): + store = _make_store() + exp = _make_experiment() + meta = store._build_metadata(exp) + assert "created_at" in meta + assert isinstance(meta["created_at"], float) + + def test_includes_val_bpb_when_available(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5) + meta = store._build_metadata(exp) + assert meta["val_bpb"] == 5.5 + + def test_no_val_bpb_when_no_result(self): + store = _make_store() + exp = _make_experiment(val_bpb=None) + exp.result = None + meta = store._build_metadata(exp) + assert "val_bpb" not in meta + + def test_includes_improvement(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5, baseline=6.0) + meta = store._build_metadata(exp) + assert meta["improvement"] == 0.5 + + def test_no_improvement_without_baseline(self): + store = _make_store() + exp = _make_experiment(val_bpb=5.5, baseline=None) + meta = store._build_metadata(exp) + assert "improvement" not in meta + + def test_includes_tags_as_csv(self): + store = _make_store() + exp = _make_experiment(tags=["lr_sweep", "dropout"]) + meta = store._build_metadata(exp) + assert meta["tags"] == "lr_sweep,dropout" + + def test_no_tags_key_when_empty(self): + store = _make_store() + exp = _make_experiment(tags=[]) + meta = store._build_metadata(exp) + assert "tags" not in meta + + +# --------------------------------------------------------------------------- +# _index_in_chromadb tests +# --------------------------------------------------------------------------- + + +class TestIndexInChromadb: + """Tests for ExperimentStore._index_in_chromadb.""" + + @pytest.mark.asyncio + async def test_upserts_to_collection(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(val_bpb=5.5) + + await store._index_in_chromadb(exp) + + collection.upsert.assert_called_once() + call_kwargs = collection.upsert.call_args[1] + assert call_kwargs["ids"] == [exp.id] + assert len(call_kwargs["documents"]) == 1 + assert len(call_kwargs["metadatas"]) == 1 + + @pytest.mark.asyncio + async def test_document_contains_hypothesis(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(hypothesis="Reduce block size") + + await store._index_in_chromadb(exp) + + doc = collection.upsert.call_args[1]["documents"][0] + assert "Reduce block size" in doc + + @pytest.mark.asyncio + async def test_metadata_contains_state(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.KEPT) + + await store._index_in_chromadb(exp) + + meta = collection.upsert.call_args[1]["metadatas"][0] + assert meta["state"] == "kept" + + @pytest.mark.asyncio + async def test_chromadb_error_logged_not_raised(self): + """ChromaDB failures should be logged but not propagate.""" + collection = AsyncMock() + collection.upsert = AsyncMock( + side_effect=RuntimeError("ChromaDB down") + ) + store = _make_store(mock_collection=collection) + exp = _make_experiment() + + # Should not raise + await store._index_in_chromadb(exp) + + @pytest.mark.asyncio + async def test_lazy_init_chromadb_on_first_call(self): + """When _chromadb_collection is None, _get_chromadb is called.""" + store = _make_store() + store._chromadb_collection = None # force lazy init + + mock_collection = AsyncMock() + mock_collection.upsert = AsyncMock() + mock_client = AsyncMock() + mock_client.get_or_create_collection = AsyncMock( + return_value=mock_collection + ) + + with patch( + "utils.chromadb_client.get_async_chromadb_client", + new_callable=AsyncMock, + return_value=mock_client, + ) as mock_get_client: + exp = _make_experiment() + await store._index_in_chromadb(exp) + + mock_get_client.assert_called_once() + mock_client.get_or_create_collection.assert_called_once() + mock_collection.upsert.assert_called_once() + + +# --------------------------------------------------------------------------- +# save_experiment conditional indexing tests +# --------------------------------------------------------------------------- + + +class TestSaveExperimentIndexing: + """Tests for the ChromaDB indexing trigger in save_experiment.""" + + @pytest.mark.asyncio + async def test_completed_experiment_indexed(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.COMPLETED, val_bpb=5.5) + + await store.save_experiment(exp) + + collection.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_kept_experiment_indexed(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.KEPT, val_bpb=5.5) + + await store.save_experiment(exp) + + collection.upsert.assert_called_once() + + @pytest.mark.asyncio + async def test_failed_experiment_not_indexed(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.FAILED, val_bpb=None) + exp.result = None + + await store.save_experiment(exp) + + collection.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_pending_experiment_not_indexed(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.PENDING, val_bpb=None) + exp.result = None + + await store.save_experiment(exp) + + collection.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_discarded_experiment_not_indexed(self): + collection = AsyncMock() + collection.upsert = AsyncMock() + store = _make_store(mock_collection=collection) + exp = _make_experiment(state=ExperimentState.DISCARDED, val_bpb=5.5) + + await store.save_experiment(exp) + + collection.upsert.assert_not_called() + + @pytest.mark.asyncio + async def test_state_transition_cleans_old_index(self): + """When old_state differs from current, srem is called.""" + store = _make_store() + exp = _make_experiment(state=ExperimentState.KEPT, val_bpb=5.5) + + await store.save_experiment(exp, old_state=ExperimentState.RUNNING) + + store._redis.srem.assert_called_once() + call_args = store._redis.srem.call_args[0] + assert "running" in call_args[0] + + @pytest.mark.asyncio + async def test_same_state_no_srem(self): + """When old_state equals current state, srem should not be called.""" + store = _make_store() + exp = _make_experiment(state=ExperimentState.PENDING, val_bpb=None) + exp.result = None + + await store.save_experiment(exp, old_state=ExperimentState.PENDING) + + store._redis.srem.assert_not_called()