Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions tests/unit/test_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import tempfile
import os
import csv
import shutil
import json


@dataclass
Expand Down Expand Up @@ -3973,6 +3975,215 @@ def _test_filter_first_exploration(self, filter_first_exploration: float):
self.assertGreaterEqual(filter_first_exploration, 0.20)
self.assertLessEqual(filter_first_exploration, 0.40)

def test_checkpoint_and_resume(self):
"""Test that VespaNNParameterOptimizer can save and resume progress."""
temp_dir = tempfile.mkdtemp()
run_name = "test_checkpoint_run"
queries = [q for q, _ in self.queries_with_hitratios]

try:
# Create and distribute queries to buckets
optimizer1 = VespaNNParameterOptimizer(
app=self.mock_app,
queries=queries,
hits=100,
buckets_per_percent=2,
run_name=run_name,
resume=False,
)
optimizer1.state_dir = temp_dir
optimizer1.run_state_file = os.path.join(temp_dir, f"{run_name}_state.json")

# Distribute queries to buckets
optimizer1.determine_hit_ratios_and_distribute_to_buckets(
optimizer1.queries
)

# Create bucket report matching run() method format
bucket_report = optimizer1.get_bucket_report()

# Save the stage (should only save bucket_report, not full queries)
optimizer1._save_stage("bucket_distribution", bucket_report)

# Verify state file structure
self.assertTrue(os.path.exists(optimizer1.run_state_file))
with open(optimizer1.run_state_file, "r") as f:
state = json.load(f)

# Verify state file has correct structure
self.assertIn("run_name", state)
self.assertIn("created_at", state)
self.assertIn("completed_stages", state)
self.assertIn("metadata", state)
self.assertIn("bucket_indices", state)

# Verify bucket_indices contains only integers (not full query objects)
for bucket in state["bucket_indices"]:
self.assertIsInstance(bucket, list)
for item in bucket:
self.assertIsInstance(
item,
int,
f"Bucket should contain query indices (ints), not queries. Found: {type(item)}",
)

# Verify completed_stages structure matches report format
bucket_stage = state["completed_stages"]["bucket_distribution"]
self.assertIn("completed_at", bucket_stage)
self.assertIn("data", bucket_stage)

# Verify the saved data matches the bucket_report structure (not containing queries)
saved_bucket_data = bucket_stage["data"]
self.assertEqual(
saved_bucket_data["buckets_per_percent"],
bucket_report["buckets_per_percent"],
)
self.assertEqual(
saved_bucket_data["bucket_interval_width"],
bucket_report["bucket_interval_width"],
)
self.assertEqual(
saved_bucket_data["non_empty_buckets"],
bucket_report["non_empty_buckets"],
)
self.assertEqual(
saved_bucket_data["filtered_out_ratios"],
bucket_report["filtered_out_ratios"],
)
self.assertEqual(
saved_bucket_data["hit_ratios"], bucket_report["hit_ratios"]
)
self.assertEqual(
saved_bucket_data["query_distribution"],
bucket_report["query_distribution"],
)

# Test resuming from saved state
optimizer2 = VespaNNParameterOptimizer(
app=self.mock_app,
queries=queries,
hits=100,
buckets_per_percent=2,
run_name=run_name,
resume=True,
)
optimizer2.state_dir = temp_dir
optimizer2.run_state_file = os.path.join(temp_dir, f"{run_name}_state.json")
optimizer2.load_state_if_exists()

# Verify bucket_distribution stage is complete
self.assertTrue(optimizer2._is_stage_complete("bucket_distribution"))

# Verify loaded data matches saved data
loaded_data = optimizer2._get_stage_data("bucket_distribution")
self.assertEqual(
loaded_data["buckets_per_percent"], bucket_report["buckets_per_percent"]
)
self.assertEqual(
loaded_data["non_empty_buckets"], bucket_report["non_empty_buckets"]
)
self.assertEqual(
loaded_data["filtered_out_ratios"], bucket_report["filtered_out_ratios"]
)

# Verify buckets were reconstructed correctly (same structure, query indices)
self.assertEqual(len(optimizer2.buckets), len(optimizer1.buckets))
for i in range(len(optimizer1.buckets)):
self.assertEqual(
optimizer2.buckets[i],
optimizer1.buckets[i],
f"Bucket {i} indices should match after resume",
)

# Verify bucket distribution metrics are preserved
self.assertEqual(
optimizer2.get_non_empty_buckets(), optimizer1.get_non_empty_buckets()
)
self.assertEqual(
optimizer2.get_filtered_out_ratios(),
optimizer1.get_filtered_out_ratios(),
)
self.assertEqual(
optimizer2.get_query_distribution(), optimizer1.get_query_distribution()
)

# Test that other stages are not yet complete
self.assertFalse(optimizer2._is_stage_complete("filter_first_exploration"))
self.assertFalse(optimizer2._is_stage_complete("filter_first_threshold"))
self.assertFalse(optimizer2._is_stage_complete("approximate_threshold"))
self.assertFalse(optimizer2._is_stage_complete("post_filter_threshold"))

# Test saving and loading additional stages with report-like structure
filter_exploration_report = {
"suggestion": 0.42,
"benchmarks": {0.0: [1.5, 2.0], 1.0: [2.5, 3.0]},
"recall_measurements": {0.0: [0.85, 0.90], 1.0: [0.95, 0.98]},
}
optimizer2._save_stage(
"filter_first_exploration", filter_exploration_report
)

# Reload state and verify structure
with open(optimizer2.run_state_file, "r") as f:
state = json.load(f)

filter_stage = state["completed_stages"]["filter_first_exploration"]
self.assertEqual(filter_stage["data"]["suggestion"], 0.42)
self.assertIn("benchmarks", filter_stage["data"])
self.assertIn("recall_measurements", filter_stage["data"])

# Verify retrieved report matches saved report
retrieved_report = optimizer2._get_stage_data("filter_first_exploration")
self.assertEqual(retrieved_report, filter_exploration_report)

finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)

def test_checkpoint_resume_false(self):
"""Test that resume=False creates a new run even if state file exists."""
temp_dir = tempfile.mkdtemp()
run_name = "test_no_resume_run"

try:
# Create a state file manually
state_file = os.path.join(temp_dir, f"{run_name}_state.json")
initial_state = {
"run_name": run_name,
"created_at": "2025-12-24T00:00:00",
"completed_stages": {
"bucket_distribution": {
"completed_at": "2025-12-24T01:00:00",
"data": {"buckets": [], "bucket_report": {}},
}
},
"metadata": {},
}
with open(state_file, "w") as f:
json.dump(initial_state, f)

# Create optimizer with resume=False
queries = [q for q, _ in self.queries_with_hitratios]
optimizer = VespaNNParameterOptimizer(
app=self.mock_app,
queries=queries,
hits=100,
run_name=run_name,
resume=False,
)

# Override state_dir
optimizer.state_dir = temp_dir
optimizer.run_state_file = state_file
optimizer.load_state_if_exists()

# Verify that the stage is NOT marked as complete (fresh start)
self.assertFalse(optimizer._is_stage_complete("bucket_distribution"))

finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)


class TestBucketedMetricResults(unittest.TestCase):
"""Test the BucketedMetricResults class."""
Expand Down
Loading
Loading