diff --git a/memori/_config.py b/memori/_config.py index ea8dcfcf..c41fac41 100644 --- a/memori/_config.py +++ b/memori/_config.py @@ -26,11 +26,18 @@ def __init__(self): self.cockroachdb = False +class Embeddings: + def __init__(self): + self.model = "all-MiniLM-L6-v2" + self.fallback_dimension = 768 + + class Config: def __init__(self): self.api_key = None self.augmentation = None self.cache = Cache() + self.embeddings = Embeddings() self.enterprise = False self.llm = Llm() self.framework = Framework() diff --git a/memori/_search.py b/memori/_search.py index 3539ffb2..de8fb690 100644 --- a/memori/_search.py +++ b/memori/_search.py @@ -54,12 +54,18 @@ def find_similar_embeddings( if not embeddings: return [] + query_dim = len(query_embedding) + if query_dim == 0: + return [] + embeddings_list = [] id_list = [] for fact_id, raw in embeddings: try: parsed = parse_embedding(raw) + if parsed.ndim != 1 or parsed.shape[0] != query_dim: + continue embeddings_list.append(parsed) id_list.append(fact_id) except Exception: @@ -68,7 +74,10 @@ def find_similar_embeddings( if not embeddings_list: return [] - embeddings_array = np.stack(embeddings_list, axis=0) + try: + embeddings_array = np.stack(embeddings_list, axis=0) + except ValueError: + return [] faiss.normalize_L2(embeddings_array) query_array = np.asarray([query_embedding], dtype=np.float32) diff --git a/memori/llm/_embeddings.py b/memori/llm/_embeddings.py index a7214d42..f0082e3c 100644 --- a/memori/llm/_embeddings.py +++ b/memori/llm/_embeddings.py @@ -11,6 +11,7 @@ import asyncio import os import struct +from collections.abc import Iterable from typing import Any os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") @@ -18,7 +19,6 @@ from sentence_transformers import SentenceTransformer _MODEL_CACHE: dict[str, SentenceTransformer] = {} -_DEFAULT_DIMENSION = 768 def _get_model(model_name: str) -> SentenceTransformer: @@ -27,18 +27,25 @@ def _get_model(model_name: str) -> SentenceTransformer: return _MODEL_CACHE[model_name] -def format_embedding_for_db(embedding: list[float], dialect: str) -> Any: - """Format embedding for database storage. +def _prepare_text_inputs(texts: str | Iterable[str]) -> list[str]: + if isinstance(texts, str): + return [texts] + return [t for t in texts if t] + + +def _embedding_dimension(model: Any, default: int) -> int: + try: + dim_value = model.get_sentence_embedding_dimension() + return int(dim_value) if dim_value is not None else default + except (RuntimeError, ValueError, AttributeError, TypeError): + return default + - Args: - embedding: List of floats representing the embedding vector - dialect: Database dialect (postgresql, mysql, sqlite, mongodb) +def _zero_vectors(count: int, dim: int) -> list[list[float]]: + return [[0.0] * dim for _ in range(count)] - Returns: - Formatted embedding optimized for the target database: - - PostgreSQL/CockroachDB/MySQL/SQLite: Binary (BYTEA/BLOB) - compact & fast - - MongoDB: Binary (BinData) - compact & fast - """ + +def format_embedding_for_db(embedding: list[float], dialect: str) -> Any: binary_data = struct.pack(f"<{len(embedding)}f", *embedding) if dialect == "mongodb": @@ -48,36 +55,57 @@ def format_embedding_for_db(embedding: list[float], dialect: str) -> Any: return bson.Binary(binary_data) except ImportError: return binary_data - else: - return binary_data + return binary_data def embed_texts( - texts: str | list[str], model: str = "all-mpnet-base-v2" + texts: str | list[str], + model: str, + fallback_dimension: int, ) -> list[list[float]]: - inputs = [texts] if isinstance(texts, str) else [t for t in texts if t] + inputs = _prepare_text_inputs(texts) if not inputs: return [] try: encoder = _get_model(model) except (OSError, RuntimeError, ValueError): - return [[0.0] * _DEFAULT_DIMENSION for _ in inputs] + return _zero_vectors(len(inputs), fallback_dimension) try: embeddings = encoder.encode(inputs, convert_to_numpy=True) return embeddings.tolist() - except (RuntimeError, ValueError): + except ValueError as e: + # Some models can raise "all input arrays must have the same shape" when + # encoding batches. Retry one-by-one to avoid internal stacking. + if "same shape" not in str(e): + raise + try: - dim_value = encoder.get_sentence_embedding_dimension() - dim = int(dim_value) if dim_value is not None else _DEFAULT_DIMENSION - except (RuntimeError, ValueError, AttributeError, TypeError): - dim = _DEFAULT_DIMENSION - return [[0.0] * dim for _ in inputs] + vectors: list[list[float]] = [] + for text in inputs: + single = encoder.encode([text], convert_to_numpy=True) + vectors.append(single[0].tolist()) + + dim_set = {len(v) for v in vectors} + if len(dim_set) != 1: + raise ValueError("all input arrays must have the same shape") from e + + return vectors + except Exception: + dim = _embedding_dimension(encoder, default=fallback_dimension) + return _zero_vectors(len(inputs), dim) + except RuntimeError: + dim = _embedding_dimension(encoder, default=fallback_dimension) + return _zero_vectors(len(inputs), dim) async def embed_texts_async( - texts: str | list[str], model: str = "all-mpnet-base-v2" + texts: str | list[str], + model: str, + fallback_dimension: int, ) -> list[list[float]]: loop = asyncio.get_event_loop() - return await loop.run_in_executor(None, embed_texts, texts, model) + return await loop.run_in_executor( + None, embed_texts, texts, model, fallback_dimension + ) diff --git a/memori/memory/augmentation/augmentations/memori/_augmentation.py b/memori/memory/augmentation/augmentations/memori/_augmentation.py index 366e6d5b..680793c6 100644 --- a/memori/memory/augmentation/augmentations/memori/_augmentation.py +++ b/memori/memory/augmentation/augmentations/memori/_augmentation.py @@ -140,7 +140,12 @@ async def _process_api_response(self, api_response: dict) -> Memories: ] if facts: - fact_embeddings = await embed_texts_async(facts) + embeddings_config = self.config.embeddings + fact_embeddings = await embed_texts_async( + facts, + model=embeddings_config.model, + fallback_dimension=embeddings_config.fallback_dimension, + ) api_response["entity"]["fact_embeddings"] = fact_embeddings return Memories().configure_from_advanced_augmentation(api_response) @@ -167,7 +172,12 @@ async def _schedule_entity_writes( ] if facts_from_triples: - embeddings_from_triples = await embed_texts_async(facts_from_triples) + embeddings_config = self.config.embeddings + embeddings_from_triples = await embed_texts_async( + facts_from_triples, + model=embeddings_config.model, + fallback_dimension=embeddings_config.fallback_dimension, + ) facts_to_write = (facts_to_write or []) + facts_from_triples embeddings_to_write = ( embeddings_to_write or [] diff --git a/memori/memory/recall.py b/memori/memory/recall.py index c171fb9c..2da68a1b 100644 --- a/memori/memory/recall.py +++ b/memori/memory/recall.py @@ -41,7 +41,12 @@ def search_facts( if limit is None: limit = self.config.recall_facts_limit - query_embedding = embed_texts(query)[0] + embeddings_config = self.config.embeddings + query_embedding = embed_texts( + query, + model=embeddings_config.model, + fallback_dimension=embeddings_config.fallback_dimension, + )[0] facts = [] for attempt in range(MAX_RETRIES): diff --git a/pyproject.toml b/pyproject.toml index 0ee2c04d..ff961f60 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -93,6 +93,7 @@ asyncio_mode = "auto" addopts = [ "-v", "--strict-markers", + "--ignore=tests/benchmarks", "-m", "not benchmark", "--cov=memori", diff --git a/tests/benchmarks/_results.py b/tests/benchmarks/_results.py new file mode 100644 index 00000000..fc187536 --- /dev/null +++ b/tests/benchmarks/_results.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import csv +from pathlib import Path +from typing import Any + + +def repo_root() -> Path: + return Path(__file__).resolve().parents[2] + + +def results_dir() -> Path: + path = repo_root() / "results" + path.mkdir(parents=True, exist_ok=True) + return path + + +def append_csv_row(path: str | Path, *, header: list[str], row: dict[str, Any]) -> None: + out_path = Path(path) + out_path.parent.mkdir(parents=True, exist_ok=True) + file_exists = out_path.exists() + with out_path.open("a", newline="") as f: + writer = csv.DictWriter(f, fieldnames=header) + if not file_exists: + writer.writeheader() + writer.writerow(row) diff --git a/tests/benchmarks/conftest.py b/tests/benchmarks/conftest.py index a2d56d72..3bda695d 100644 --- a/tests/benchmarks/conftest.py +++ b/tests/benchmarks/conftest.py @@ -40,7 +40,7 @@ def postgres_db_connection(): postgres_uri, pool_pre_ping=True, pool_recycle=300, - connect_args=connect_args if connect_args else None, + connect_args=connect_args, ) try: @@ -156,7 +156,11 @@ def entity_with_n_facts(memori_instance, fact_content_size, request): memori_instance.attribution(entity_id=entity_id, process_id="benchmark-process") facts = generate_facts_with_size(fact_count, fact_content_size) - fact_embeddings = embed_texts(facts) + fact_embeddings = embed_texts( + facts, + model=memori_instance.config.embeddings.model, + fallback_dimension=memori_instance.config.embeddings.fallback_dimension, + ) entity_db_id = memori_instance.config.storage.driver.entity.create(entity_id) memori_instance.config.storage.driver.entity_fact.create( diff --git a/tests/benchmarks/fixtures/sample_facts.py b/tests/benchmarks/fixtures/sample_facts.py new file mode 100644 index 00000000..ec120c08 --- /dev/null +++ b/tests/benchmarks/fixtures/sample_facts.py @@ -0,0 +1,394 @@ +# 1,000 synthetic user facts as semantic triples (subject, predicate, object) +# Copy/paste into your project. Produces: user_data = {"user": "John", "facts": [(s,p,o), ...]} + + +def build_user_data() -> dict: + facts: list[tuple[str, str, str]] = [] + + def add(s: str, p: str, o: str) -> None: + facts.append((s, p, o)) + + # ----------------------------- + # Core profile facts (diverse) + # ----------------------------- + add("John", "type", "Person") + add("John", "has_given_name", "John") + add("John", "prefers_language", "English") + add("John", "prefers_time_format", "12-hour") + add("John", "prefers_temperature_unit", "Fahrenheit") + add("John", "prefers_distance_unit", "miles") + add("John", "has_home_city", "St_Louis") + add("John", "has_home_state", "Missouri") + add("John", "has_home_country", "United_States") + add("John", "has_role", "Software_Engineer") + add("John", "works_in", "Technology") + add("John", "interested_in", "AI") + add("John", "interested_in", "cloud_infrastructure") + add("John", "interested_in", "personal_finance") + add("John", "interested_in", "fitness") + add("John", "interested_in", "cooking") + add("John", "interested_in", "travel") + add("John", "interested_in", "productivity") + add("John", "interested_in", "home_improvement") + add("John", "uses_llm_for", "writing_assistance") + add("John", "uses_llm_for", "coding_help") + add("John", "uses_llm_for", "travel_planning") + add("John", "uses_llm_for", "learning_new_topics") + add("John", "uses_llm_for", "recipe_ideas") + add("John", "uses_llm_for", "career_advice") + add("John", "uses_llm_for", "data_analysis") + add("John", "uses_llm_for", "brainstorming") + add("John", "prefers_response_style", "structured") + add("John", "prefers_response_style", "actionable") + add("John", "concerned_about", "latency") + add("John", "concerned_about", "cost") + add("John", "concerned_about", "privacy") + add("John", "primary_os", "macOS") + add("John", "primary_browser", "Chrome") + add("John", "primary_editor", "VS_Code") + add("John", "primary_shell", "zsh") + add("John", "primary_email_provider", "Gmail") + add("John", "primary_calendar", "Google_Calendar") + add("John", "primary_messaging_app", "Slack") + add("John", "primary_code_host", "GitHub") + add("John", "prefers_document_format", "Markdown") + add("John", "prefers_spreadsheet_tool", "Google_Sheets") + add("John", "uses_password_manager", "1Password") + add("John", "uses_cloud_storage", "Google_Drive") + add("John", "uses_issue_tracker", "GitHub_Issues") + add("John", "uses_ci_cd", "GitHub_Actions") + add("John", "prefers_container_runtime", "Docker") + add("John", "prefers_database", "PostgreSQL") + add("John", "prefers_cache", "Redis") + add("John", "prefers_language_for_backend", "Python") + add("John", "knows_language", "JavaScript") + add("John", "knows_language", "SQL") + add("John", "knows_language", "Bash") + add("John", "learning_language", "Rust") + add("John", "prefers_testing_framework", "pytest") + add("John", "prefers_package_manager", "uv") + add("John", "uses_llm_provider", "OpenAI") + add("John", "uses_llm_provider", "Anthropic") + add("John", "uses_model_family", "GPT") + add("John", "uses_model_family", "Claude") + add("John", "has_hobby", "fantasy_football") + add("John", "follows_sport", "NFL") + add("John", "follows_sport", "NHL") + add("John", "prefers_coffee_drink", "latte") + add("John", "prefers_breakfast", "oatmeal") + add("John", "prefers_lunch", "salad") + add("John", "prefers_dinner", "grilled_chicken") + add("John", "commutes_by", "car") + add("John", "prefers_meeting_platform", "Zoom") + + # ----------------------------- + # Family / household facts + # ----------------------------- + family = [ + ("Spouse_01", "spouse"), + ("Child_01", "child"), + ("Child_02", "child"), + ("Parent_01", "parent"), + ("Parent_02", "parent"), + ("Sibling_01", "sibling"), + ("Sibling_02", "sibling"), + ("Pet_01", "pet"), + ] + for i, (entity, rel) in enumerate(family, start=1): + add(entity, "type", "Person" if "Pet" not in entity else "Animal") + add("John", f"has_{rel}", entity) + add(entity, "has_first_name", entity.split("_")[0]) + add(entity, "located_in", f"City_{(i % 25) + 1:03d}") + add( + entity, + "preferred_contact_method", + ["text", "call", "email"][i % 3] if "Pet" not in entity else "n/a", + ) + + # ----------------------------- + # Friends (varied interests) + # ----------------------------- + friend_interests = [ + "music", + "travel", + "tech", + "food", + "sports", + "finance", + "fitness", + "gaming", + "books", + "photography", + ] + contact_methods = ["email", "text", "slack", "signal", "call"] + for i in range(1, 61): # 60 friends + f = f"Friend_{i:03d}" + add(f, "type", "Person") + add("John", "has_friend", f) + add(f, "located_in", f"City_{(i % 80) + 1:03d}") + add(f, "interested_in", friend_interests[i % len(friend_interests)]) + add(f, "preferred_contact_method", contact_methods[i % len(contact_methods)]) + + # ----------------------------- + # Coworkers (teams, roles, tools) + # ----------------------------- + roles = [ + "Backend_Engineer", + "Frontend_Engineer", + "DevOps_Engineer", + "Product_Manager", + "Designer", + "Data_Engineer", + ] + teams = ["Platform", "Infra", "Product", "Data", "Security", "Growth"] + tools = ["Jira", "Linear", "Confluence", "Notion", "Figma", "Datadog", "Grafana"] + for i in range(1, 41): # 40 coworkers + c = f"Coworker_{i:03d}" + add(c, "type", "Person") + add("John", "works_with", c) + add(c, "has_role", roles[i % len(roles)]) + add(c, "member_of_team", teams[i % len(teams)]) + add(c, "uses_tool", tools[i % len(tools)]) + + # ----------------------------- + # Places: cities, venues, travel + # ----------------------------- + for i in range(1, 81): # 80 cities + city = f"City_{i:03d}" + add(city, "type", "City") + add(city, "in_country", "United_States" if i <= 60 else "International") + add("John", "has_visited", city if i <= 45 else f"Planned_{city}") + add( + city, + "has_timezone", + "America/Chicago" + if i % 3 == 0 + else "America/New_York" + if i % 3 == 1 + else "America/Los_Angeles", + ) + + venue_types = [ + "Restaurant", + "Coffee_Shop", + "Gym", + "Airport", + "Hotel", + "Park", + "Museum", + "Stadium", + ] + for i in range(1, 81): # 80 venues + v = f"Venue_{i:03d}" + add(v, "type", venue_types[i % len(venue_types)]) + add(v, "located_in", f"City_{(i % 80) + 1:03d}") + add("John", "likes_place", v if i % 4 != 0 else f"Neutral_{v}") + add(v, "has_price_tier", ["$", "$$", "$$$"][i % 3]) + + # ----------------------------- + # Devices, apps, services, accounts + # ----------------------------- + devices = [ + "MacBook_Pro", + "iPhone", + "iPad", + "AirPods", + "Smart_TV", + "Router", + "NAS", + "Mechanical_Keyboard", + "Gaming_PC", + "Monitor_34inch", + "Standing_Desk", + "Ergonomic_Chair", + "Fitness_Tracker", + "Smart_Speaker", + "Kindle", + "External_SSD", + "Webcam", + "Microphone", + "Printer", + "Smart_Thermostat", + ] + for d in devices: + add(d, "type", "Device") + add("John", "owns_device", d) + add(d, "used_for", "work" if "MacBook" in d or "Monitor" in d else "personal") + add(d, "has_status", "active") + + apps = [ + "Notion", + "Google_Drive", + "Google_Calendar", + "Gmail", + "Slack", + "Zoom", + "VS_Code", + "Docker_Desktop", + "GitHub", + "GitHub_Desktop", + "Postman", + "TablePlus", + "DBeaver", + "Obsidian", + "Todoist", + "Spotify", + "YouTube", + "Netflix", + "Hulu", + "Amazon_Prime", + "Strava", + "MyFitnessPal", + "Pocket", + "Kindle_App", + "Signal", + "Discord", + "Reddit", + "X", + "LinkedIn", + "Google_Maps", + ] + for a in apps: + add(a, "type", "Application") + add("John", "uses_app", a) + add( + a, + "category", + "productivity" + if a + in { + "Notion", + "Todoist", + "Obsidian", + "Google_Calendar", + "Google_Drive", + "Gmail", + } + else "communication" + if a in {"Slack", "Zoom", "Signal", "Discord"} + else "media", + ) + add(a, "access_method", "mobile_and_desktop" if a not in {"Smart_TV"} else "tv") + + services = [ + "Banking_Service", + "Credit_Card", + "Mortgage_Lender", + "Insurance_Auto", + "Insurance_Home", + "Electric_Utility", + "Water_Utility", + "Internet_ISP", + "Mobile_Carrier", + "Cloud_Provider", + ] + for i, s in enumerate(services, start=1): + add(s, "type", "Service") + add("John", "has_service_account", s) + add(s, "billing_cycle", "monthly") + add(s, "has_priority", "high" if i <= 3 else "medium") + + # ----------------------------- + # Projects, repos, documents, artifacts + # ----------------------------- + project_domains = [ + "LLM_Tooling", + "Backend_API", + "Data_Pipeline", + "Infra_As_Code", + "Personal_Website", + "Home_Budget", + ] + for i in range(1, 51): # 50 projects + p = f"Project_{i:03d}" + add(p, "type", "Project") + add("John", "owns_project", p) + add(p, "domain", project_domains[i % len(project_domains)]) + add(p, "has_status", "active" if i % 5 != 0 else "paused") + add(p, "uses_language", ["Python", "TypeScript", "SQL", "Go"][i % 4]) + add(p, "uses_platform", ["AWS", "GCP", "Vercel", "DigitalOcean"][i % 4]) + + doc_types = [ + "Spec", + "Design_Doc", + "Runbook", + "Postmortem", + "Resume", + "Budget_Spreadsheet", + "Travel_Itinerary", + "Grocery_List", + ] + for i in range(1, 81): # 80 docs + doc = f"Doc_{i:03d}" + add(doc, "type", "Document") + add("John", "created_document", doc) + add(doc, "document_type", doc_types[i % len(doc_types)]) + add(doc, "stored_in", "Google_Drive" if i % 2 == 0 else "Notion") + + # ----------------------------- + # Goals, tasks, routines, preferences (LLM-relevant) + # ----------------------------- + goal_areas = ["Career", "Health", "Finance", "Learning", "Home", "Travel"] + for i in range(1, 121): # 120 goals + g = f"Goal_{i:03d}" + add(g, "type", "Goal") + add("John", "has_goal", g) + add(g, "goal_area", goal_areas[i % len(goal_areas)]) + add(g, "target_timeframe", ["this_quarter", "this_year", "next_year"][i % 3]) + + routines = [ + "Morning_Routine", + "Workout_Routine", + "Meal_Prep_Routine", + "Weekly_Planning", + "Budget_Review", + "Code_Review_Habit", + ] + for i in range(1, 101): # 100 routines/habits + r = f"Routine_{i:03d}" + add(r, "type", "Routine") + add("John", "follows_routine", r) + add(r, "routine_template", routines[i % len(routines)]) + add(r, "frequency", ["daily", "weekly", "monthly"][i % 3]) + + # ----------------------------- + # LLM conversation topics / entities John might discuss + # ----------------------------- + topics = [ + "Trip_Planning", + "Interview_Prep", + "System_Design", + "API_Debugging", + "SQL_Optimization", + "Docker_Troubleshooting", + "CI_CD_Failures", + "Cost_Optimization", + "Latency_Tuning", + "Meal_Planning", + "Workout_Programming", + "Budgeting", + "Insurance_Comparison", + "Home_Repairs", + "Writing_Emails", + "Writing_Docs", + "Refactoring_Code", + "Testing_Strategy", + "Monitoring_Alerts", + "Productivity_Systems", + ] + for i in range(1, 201): # 200 topic facts + t = f"Topic_{i:03d}" + add(t, "type", "Topic") + add("John", "asks_llm_about", t) + add(t, "topic_name", topics[i % len(topics)]) + add(t, "priority", ["high", "medium", "low"][i % 3]) + + # ----------------------------- + # Normalize to exactly 1,000 facts + # ----------------------------- + facts_1000 = facts[:1000] + return {"user": "John", "facts": facts_1000} + + +user_data = build_user_data() +assert len(user_data["facts"]) == 1000 diff --git a/tests/benchmarks/test_recall_accuracy.py b/tests/benchmarks/test_recall_accuracy.py index cdf38a09..098cc3be 100644 --- a/tests/benchmarks/test_recall_accuracy.py +++ b/tests/benchmarks/test_recall_accuracy.py @@ -1,45 +1,640 @@ +import datetime +import os import random +import statistics +from math import sqrt +from typing import TypedDict +from uuid import uuid4 +import pytest + +from memori._config import Config +from memori.llm import _embeddings as embeddings_mod +from memori.llm._embeddings import embed_texts from memori.memory.recall import Recall +from tests.benchmarks._results import append_csv_row, results_dir +from tests.benchmarks.fixtures.sample_facts import build_user_data +from tests.benchmarks.semantic_accuracy_dataset import DATASET as CURATED_DATASET +from tests.benchmarks.semantic_accuracy_metrics import ( + mrr, +) + + +def _embeddings_available() -> bool: + # If the embedding model can't load, Memori falls back to all-zeros embeddings. + # That makes semantic accuracy meaningless, so we skip instead of failing. + cfg = Config() + vec = embed_texts( + "sanity check", + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + )[0] + return any(v != 0.0 for v in vec) + + +def _generate_hard_distractors( + count: int, *, rng: random.Random, forbidden: set[str] +) -> list[str]: + cities = ["London", "Berlin", "Rome", "Madrid", "Lisbon", "Dublin", "Vienna"] + colors = ["red", "green", "yellow", "purple", "orange", "black", "white"] + foods = ["sushi", "tacos", "ramen", "burgers", "pasta", "salad", "ice cream"] + drinks = ["tea", "sparkling water", "matcha", "hot chocolate", "juice"] + companies = ["Acme Corp", "Globex", "Initech", "Hooli", "Soylent", "Umbrella"] + activities = ["running", "swimming", "reading", "gaming", "cycling", "yoga"] + themes = ["light mode", "system theme", "high contrast mode"] + birthdays = ["April 1st", "May 20th", "June 7th", "July 30th", "Oct 12th"] + pets = ["1 cat", "3 cats", "2 dogs", "a dog", "a cat", "no pets"] + + templates = [ + lambda v: f"User lives in {v}", + lambda v: f"User's favorite color is {v}", + lambda v: f"User likes {v}", + lambda v: f"User works at {v}", + lambda v: f"User enjoys {v}", + lambda v: f"User prefers {v}", + lambda v: f"User's birthday is {v}", + lambda v: f"User has {v}", + ] + values = [ + cities, + colors, + foods + drinks, + companies, + activities, + themes, + birthdays, + pets, + ] + + distractors: list[str] = [] + for i in range(count): + idx = i % len(templates) + base = templates[idx](rng.choice(values[idx])) + candidate = f"{base} (id: d{i})" + if candidate in forbidden: + candidate = f"{base} (note: alt) (id: d{i})" + distractors.append(candidate) + + return distractors + + +def _strip_id_suffix(text: str) -> str: + idx = text.rfind(" (id:") + if idx == -1 or not text.endswith(")"): + return text + return text[:idx] + + +def _fact_to_text(subject: str, predicate: str, obj: str, *, fact_id: int) -> str: + subj = subject.replace("_", " ") + pred = predicate.replace("_", " ") + obj_text = obj.replace("_", " ") + return f"{subj} {pred} {obj_text} (id: {fact_id})" + + +class _SemanticAccuracyDataset(TypedDict): + corpus_facts: list[str] + queries: dict[str, list[str]] + + +def _t_critical_975(df: int) -> float: + # 97.5% quantiles for Student-t (two-sided 95% CI), df 1..10 + # If df is larger, normal approximation is fine for our benchmark reporting. + table = { + 1: 12.706, + 2: 4.303, + 3: 3.182, + 4: 2.776, + 5: 2.571, + 6: 2.447, + 7: 2.365, + 8: 2.306, + 9: 2.262, + 10: 2.228, + } + return table.get(df, 1.96) + + +def _mean_ci_95(values: list[float]) -> tuple[float, float, float]: + if not values: + return 0.0, 0.0, 0.0 + if len(values) == 1: + return values[0], values[0], values[0] + + mean = statistics.fmean(values) + stdev = statistics.stdev(values) + df = len(values) - 1 + half_width = _t_critical_975(df) * (stdev / sqrt(len(values))) + return mean, mean - half_width, mean + half_width + + +def _default_semantic_accuracy_csv_path() -> str: + return str(results_dir() / "semantic_accuracy.csv") + + +def _default_semantic_accuracy_curated_csv_path() -> str: + return str(results_dir() / "semantic_accuracy_curated.csv") + + +def _build_semantic_accuracy_dataset_from_sample_facts() -> _SemanticAccuracyDataset: + triples: list[tuple[str, str, str]] = build_user_data()["facts"] + fact_texts = [ + _fact_to_text(s, p, o, fact_id=i) for i, (s, p, o) in enumerate(triples) + ] + + def _facts(subject: str, pred: str) -> list[str]: + results: list[str] = [] + for i, (s, p, _) in enumerate(triples): + if s == subject and p == pred: + results.append(fact_texts[i]) + return results + + # We generate queries later (with varying seeds / sizes). + return {"corpus_facts": fact_texts, "queries": {}} + + +def _subject_variants(subject: str) -> list[str]: + if subject == "John": + return ["John", "I", "me", "the user", "this user"] + + subj = subject.replace("_", " ") + variants = {subj, subj.lower()} + + if subject.startswith("Coworker_") or subject.startswith("Friend_"): + _, num = subject.split("_", 1) + label = subject.split("_", 1)[0].lower() + variants.update( + { + f"{label} {num}", + f"{label} #{num}", + f"my {label} {num}", + f"My {label} {num}", + f"{label.title()} {num}", + } + ) + + return sorted(variants) + + +def _predicate_question_variants(subject: str, predicate: str) -> list[str]: + pred_words = predicate.replace("_", " ") + + templates: list[str] + if predicate == "member_of_team": + templates = [ + "Which team is {subj} a member of?", + "What team is {subj} on?", + "Which team does {subj} belong to?", + "What is {subj}'s team?", + ] + elif predicate == "located_in": + templates = [ + "Where is {subj} located?", + "What city is {subj} in?", + "Where can I find {subj}?", + ] + elif predicate == "in_country": + templates = [ + "What country is {subj} in?", + "Which country is {subj} in?", + ] + elif predicate == "has_timezone": + templates = [ + "What timezone does {subj} have?", + "What is {subj}'s timezone?", + ] + elif predicate == "uses_llm_for": + templates = [ + "What does {subj} use LLM for?", + "What does {subj} use an LLM for?", + "Why does {subj} use LLMs?", + ] + elif predicate == "uses_llm_provider": + templates = [ + "Which LLM providers does {subj} use?", + "What LLM providers does {subj} use?", + ] + elif predicate == "uses_model_family": + templates = [ + "Which model families does {subj} use?", + "What model families does {subj} use?", + ] + elif predicate == "type": + templates = [ + "What type is {subj}?", + "What kind of entity is {subj}?", + ] + elif predicate.startswith("has_"): + rest = predicate.removeprefix("has_").replace("_", " ") + templates = [ + f"What {rest} does {{subj}} have?", + f"Which {rest} does {{subj}} have?", + ] + else: + templates = [ + f"What is {{subj}} {pred_words}?", + f"What does {{subj}} {pred_words}?", + ] + + questions: list[str] = [] + for subj in _subject_variants(subject): + for t in templates: + questions.append(t.format(subj=subj)) + + # Deduplicate but preserve a stable order + seen: set[str] = set() + out: list[str] = [] + for q in questions: + if q in seen: + continue + out.append(q) + seen.add(q) + return out + + +def _generate_query_set( + *, + triples: list[tuple[str, str, str]], + fact_texts: list[str], + rng: random.Random, + query_count: int, + expected_cap: int = 5, +) -> dict[str, list[str]]: + # Group facts by (subject, predicate) so we can accept any of the values. + grouped: dict[tuple[str, str], list[str]] = {} + for i, (s, p, _) in enumerate(triples): + grouped.setdefault((s, p), []).append(_strip_id_suffix(fact_texts[i])) + + items = list(grouped.items()) + rng.shuffle(items) + + queries: dict[str, list[str]] = {} + + # Always include the coworker-035 query if present (user requested). + key = ("Coworker_035", "member_of_team") + if key in grouped: + variants = _predicate_question_variants(*key) + queries[rng.choice(variants)] = grouped[key][:1] + + for (s, p), expected in items: + if len(queries) >= query_count: + break + variants = _predicate_question_variants(s, p) + rng.shuffle(variants) + for question in variants: + if question in queries: + continue + queries[question] = expected[:expected_cap] + break + + return queries -def test_recall_accuracy_topk(memori_instance, entity_with_n_facts): +@pytest.mark.skipif(not _embeddings_available(), reason="Embedding model unavailable") +@pytest.mark.parametrize( + "total_records", + [10, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000], + ids=lambda n: f"n{n}", +) +def test_semantic_recall_accuracy(memori_instance, total_records): """ - Accuracy proxy: for a sample of stored facts, querying with the exact fact text - should retrieve that fact in top-k (ideally top-1). + Semantic accuracy evaluation (the "right way"): + - seed a labeled dataset of facts + - run a labeled set of queries + - report hit@k (a.k.a. recall-as-hit), plus MRR + """ + dataset = _build_semantic_accuracy_dataset_from_sample_facts() + corpus_facts = dataset["corpus_facts"] + triples: list[tuple[str, str, str]] = build_user_data()["facts"] + + query_count = int(os.environ.get("SEMANTIC_ACCURACY_QUERY_COUNT", "150")) + repeats = int(os.environ.get("SEMANTIC_ACCURACY_REPEATS", "5")) + base_seed = int(os.environ.get("SEMANTIC_ACCURACY_BASE_SEED", "123")) + + if query_count < 1: + pytest.skip("SEMANTIC_ACCURACY_QUERY_COUNT must be >= 1") + if repeats < 1: + pytest.skip("SEMANTIC_ACCURACY_REPEATS must be >= 1") + + if os.environ.get("SEMANTIC_ACCURACY_DUMP_FACTS") == "1": + limit_raw = os.environ.get("SEMANTIC_ACCURACY_DUMP_FACTS_LIMIT") + limit = int(limit_raw) if limit_raw else None + facts_to_print = corpus_facts[:limit] if limit else corpus_facts + for fact in facts_to_print: + print(fact) + pytest.skip("Dumped semantic accuracy facts") + + # Repeat full evaluation with different seeds/query sets to report variance. + per_run: dict[str, list[float]] = { + "hit@1": [], + "hit@3": [], + "hit@5": [], + "mrr": [], + } + + for rep in range(repeats): + rng = random.Random(base_seed + rep) + queries = _generate_query_set( + triples=triples, + fact_texts=corpus_facts, + rng=rng, + query_count=query_count, + ) + labeled_norm_set = {fact for expected in queries.values() for fact in expected} + if not labeled_norm_set: + pytest.skip("No labeled facts available from generated query set") + + if total_records < len(labeled_norm_set): + pytest.skip( + f"total_records={total_records} is smaller than labeled fact count={len(labeled_norm_set)}" + ) + + distractor_count = total_records - len(labeled_norm_set) + distractor_pool = [ + f for f in corpus_facts if _strip_id_suffix(f) not in labeled_norm_set + ] + forbidden = set(labeled_norm_set) + + # Nested distractors per (total_records, rep): deterministic shuffle + prefix. + distractor_pool_shuffled = list(distractor_pool) + rng.shuffle(distractor_pool_shuffled) + + if distractor_count <= len(distractor_pool_shuffled): + distractors = distractor_pool_shuffled[:distractor_count] + else: + distractors = list(distractor_pool_shuffled) + remaining = distractor_count - len(distractor_pool_shuffled) + distractors.extend( + _generate_hard_distractors(remaining, rng=rng, forbidden=forbidden) + ) + + rng.shuffle(distractors) + + labeled_facts_with_ids: list[str] = [] + labeled_with_ids_set: set[str] = set() + for full_fact in corpus_facts: + norm = _strip_id_suffix(full_fact) + if norm in labeled_norm_set and full_fact not in labeled_with_ids_set: + labeled_facts_with_ids.append(full_fact) + labeled_with_ids_set.add(full_fact) + + facts = list(labeled_facts_with_ids) + rng.shuffle(facts) + facts.extend(distractors) + + # IMPORTANT: use a unique entity per run to avoid accumulating facts + # across reruns (especially if embedding dimensions/models change). + entity_id = f"semantic-accuracy-entity-{total_records}-rep{rep}-{uuid4()}" + memori_instance.attribution(entity_id=entity_id, process_id="semantic-accuracy") + entity_db_id = memori_instance.config.storage.driver.entity.create(entity_id) + + fact_embeddings = embed_texts( + facts, + model=memori_instance.config.embeddings.model, + fallback_dimension=memori_instance.config.embeddings.fallback_dimension, + ) + memori_instance.config.storage.driver.entity_fact.create( + entity_db_id, facts, fact_embeddings + ) + + memori_instance.config.recall_embeddings_limit = total_records + recall = Recall(memori_instance.config) + + scores = { + "hit@1": [], + "hit@3": [], + "hit@5": [], + "mrr": [], + } + + debug_limit = int(os.environ.get("SEMANTIC_ACCURACY_DEBUG_LIMIT", "50")) + debug_printed = 0 + + for query, expected in queries.items(): + results = recall.search_facts(query=query, limit=5, entity_id=entity_db_id) + retrieved = [r.get("content", "") for r in results] + retrieved_norm = [_strip_id_suffix(r) for r in retrieved] + + relevant = set(expected) + scores["hit@1"].append( + 1.0 if any(f in retrieved_norm[:1] for f in relevant) else 0.0 + ) + scores["hit@3"].append( + 1.0 if any(f in retrieved_norm[:3] for f in relevant) else 0.0 + ) + scores["hit@5"].append( + 1.0 if any(f in retrieved_norm[:5] for f in relevant) else 0.0 + ) + scores["mrr"].append(mrr(relevant, retrieved_norm)) + + if ( + os.environ.get("SEMANTIC_ACCURACY_DEBUG") == "1" + and debug_printed < debug_limit + ): + hit_rank: int | None = None + for i, item in enumerate(retrieved_norm, start=1): + if item in relevant: + hit_rank = i + break + + tag = "HIT" if hit_rank is not None else "MISS" + extra = f"rank={hit_rank} " if hit_rank is not None else "" + print( + f"[semantic-accuracy][debug][{tag}] total={total_records} rep={rep} " + f"query={query!r} {extra}" + f"expected={expected!r} retrieved={retrieved_norm!r}" + ) + debug_printed += 1 - This validates the end-to-end recall pipeline returns the correct row given an - exact-match query (embedding + DB pull + FAISS + content fetch). + per_run["hit@1"].append(statistics.fmean(scores["hit@1"])) + per_run["hit@3"].append(statistics.fmean(scores["hit@3"])) + per_run["hit@5"].append(statistics.fmean(scores["hit@5"])) + per_run["mrr"].append(statistics.fmean(scores["mrr"])) + + db_type = getattr(memori_instance, "_benchmark_db_type", "unknown") + + hit5_mean, hit5_lo, hit5_hi = _mean_ci_95(per_run["hit@5"]) + hit1_mean, hit1_lo, hit1_hi = _mean_ci_95(per_run["hit@1"]) + hit3_mean, hit3_lo, hit3_hi = _mean_ci_95(per_run["hit@3"]) + mrr_mean, mrr_lo, mrr_hi = _mean_ci_95(per_run["mrr"]) + + hit5_min = min(per_run["hit@5"]) + hit5_max = max(per_run["hit@5"]) + + print( + f"[semantic-accuracy] db={db_type} total={total_records} " + f"queries={query_count} repeats={repeats} " + f"hit@5(min/mean/max)={hit5_min:.3f}/{hit5_mean:.3f}/{hit5_max:.3f} " + f"hit@5_ci95=({hit5_lo:.3f},{hit5_hi:.3f}) " + f"hit@3_mean={hit3_mean:.3f} ci95=({hit3_lo:.3f},{hit3_hi:.3f}) " + f"hit@1_mean={hit1_mean:.3f} ci95=({hit1_lo:.3f},{hit1_hi:.3f}) " + f"mrr_mean={mrr_mean:.3f} ci95=({mrr_lo:.3f},{mrr_hi:.3f})" + ) + + csv_path = ( + os.environ.get("SEMANTIC_ACCURACY_CSV_PATH") + or _default_semantic_accuracy_csv_path() + ) + run_id = str(uuid4()) + ts = datetime.datetime.now(datetime.UTC).isoformat() + header = [ + "timestamp_utc", + "run_id", + "db", + "total_records", + "query_count", + "repeats", + "base_seed", + "embedding_model", + "embedding_default_dim", + "hit1_mean", + "hit1_ci_lo", + "hit1_ci_hi", + "hit3_mean", + "hit3_ci_lo", + "hit3_ci_hi", + "hit5_min", + "hit5_mean", + "hit5_max", + "hit5_ci_lo", + "hit5_ci_hi", + "mrr_mean", + "mrr_ci_lo", + "mrr_ci_hi", + ] + append_csv_row( + csv_path, + header=header, + row={ + "timestamp_utc": ts, + "run_id": run_id, + "db": db_type, + "total_records": total_records, + "query_count": query_count, + "repeats": repeats, + "base_seed": base_seed, + "embedding_model": getattr(embeddings_mod, "_DEFAULT_MODEL", ""), + "embedding_default_dim": getattr(embeddings_mod, "_DEFAULT_DIMENSION", ""), + "hit1_mean": hit1_mean, + "hit1_ci_lo": hit1_lo, + "hit1_ci_hi": hit1_hi, + "hit3_mean": hit3_mean, + "hit3_ci_lo": hit3_lo, + "hit3_ci_hi": hit3_hi, + "hit5_min": hit5_min, + "hit5_mean": hit5_mean, + "hit5_max": hit5_max, + "hit5_ci_lo": hit5_lo, + "hit5_ci_hi": hit5_hi, + "mrr_mean": mrr_mean, + "mrr_ci_lo": mrr_lo, + "mrr_ci_hi": mrr_hi, + }, + ) + + # We intentionally don't hard-fail on aggressive thresholds here because the goal + # is to *benchmark* accuracy as N grows. The printed metrics are the artifact. + + +@pytest.mark.skipif(not _embeddings_available(), reason="Embedding model unavailable") +@pytest.mark.parametrize( + "distractor_count", [0, 200], ids=["no_distractors", "plus200_distractors"] +) +def test_semantic_recall_accuracy_curated(memori_instance, distractor_count): + """ + Semantic accuracy benchmark on a small curated dataset. + + This provides a stable baseline (fixed facts + fixed queries) that is easier to + defend over time than purely generated query sets. """ - entity_db_id = entity_with_n_facts["entity_db_id"] - facts = entity_with_n_facts["facts"] + facts = list(CURATED_DATASET["facts"]) + queries = CURATED_DATASET["queries"] - rng = random.Random(42) - sample_size = min(10, len(facts)) - sampled = rng.sample(facts, k=sample_size) + rng = random.Random(123) + forbidden = set(facts) + distractors = _generate_hard_distractors( + distractor_count, rng=rng, forbidden=forbidden + ) + rng.shuffle(distractors) + facts.extend(distractors) + rng.shuffle(facts) + + entity_id = f"semantic-accuracy-curated-{distractor_count}-{uuid4()}" + memori_instance.attribution( + entity_id=entity_id, process_id="semantic-accuracy-curated" + ) + entity_db_id = memori_instance.config.storage.driver.entity.create(entity_id) + fact_embeddings = embed_texts( + facts, + model=memori_instance.config.embeddings.model, + fallback_dimension=memori_instance.config.embeddings.fallback_dimension, + ) + memori_instance.config.storage.driver.entity_fact.create( + entity_db_id, facts, fact_embeddings + ) + + memori_instance.config.recall_embeddings_limit = len(facts) recall = Recall(memori_instance.config) - top1_hits = 0 - top5_hits = 0 + k = 5 + hit1: list[float] = [] + hit3: list[float] = [] + hit5: list[float] = [] + mrr_scores: list[float] = [] - for fact in sampled: - results = recall.search_facts(query=fact, limit=5, entity_id=entity_db_id) - contents = [r.get("content") for r in results] + for query, expected in queries.items(): + results = recall.search_facts(query=query, limit=k, entity_id=entity_db_id) + retrieved = [r.get("content", "") for r in results] + retrieved_norm = [_strip_id_suffix(r) for r in retrieved] - if contents and contents[0] == fact: - top1_hits += 1 - if fact in contents: - top5_hits += 1 + relevant_norm = {_strip_id_suffix(e) for e in expected} + hit1.append(1.0 if any(f in retrieved_norm[:1] for f in relevant_norm) else 0.0) + hit3.append(1.0 if any(f in retrieved_norm[:3] for f in relevant_norm) else 0.0) + hit5.append(1.0 if any(f in retrieved_norm[:5] for f in relevant_norm) else 0.0) + mrr_scores.append(mrr(relevant_norm, retrieved_norm)) - # Print a small summary if running with -s - db_type = entity_with_n_facts["db_type"] - n = entity_with_n_facts["fact_count"] - size = entity_with_n_facts["content_size"] + db_type = getattr(memori_instance, "_benchmark_db_type", "unknown") print( - f"[recall-accuracy] db={db_type} n={n} size={size} " - f"top1={top1_hits}/{sample_size} top5={top5_hits}/{sample_size}" + f"[semantic-accuracy-curated] db={db_type} total={len(facts)} " + f"distractors={distractor_count} " + f"hit@1={statistics.fmean(hit1):.3f} " + f"hit@3={statistics.fmean(hit3):.3f} " + f"hit@5={statistics.fmean(hit5):.3f} " + f"mrr={statistics.fmean(mrr_scores):.3f}" ) - # Hard assertions: exact-match should always be in top-5 for this pipeline. - assert top5_hits == sample_size + curated_csv_path = ( + os.environ.get("SEMANTIC_ACCURACY_CURATED_CSV_PATH") + or _default_semantic_accuracy_curated_csv_path() + ) + curated_header = [ + "timestamp_utc", + "run_id", + "db", + "total_records", + "distractor_count", + "embedding_model", + "embedding_default_dim", + "hit1_mean", + "hit3_mean", + "hit5_mean", + "mrr_mean", + ] + append_csv_row( + curated_csv_path, + header=curated_header, + row={ + "timestamp_utc": datetime.datetime.now(datetime.UTC).isoformat(), + "run_id": str(uuid4()), + "db": db_type, + "total_records": len(facts), + "distractor_count": distractor_count, + "embedding_model": getattr(embeddings_mod, "_DEFAULT_MODEL", ""), + "embedding_default_dim": getattr(embeddings_mod, "_DEFAULT_DIMENSION", ""), + "hit1_mean": statistics.fmean(hit1), + "hit3_mean": statistics.fmean(hit3), + "hit5_mean": statistics.fmean(hit5), + "mrr_mean": statistics.fmean(mrr_scores), + }, + ) diff --git a/tests/benchmarks/test_recall_benchmarks.py b/tests/benchmarks/test_recall_benchmarks.py index c455dfd6..1d639a88 100644 --- a/tests/benchmarks/test_recall_benchmarks.py +++ b/tests/benchmarks/test_recall_benchmarks.py @@ -1,13 +1,66 @@ """Performance benchmarks for Memori recall functionality.""" +import datetime +import os +from time import perf_counter + import pytest +from memori._config import Config from memori._search import find_similar_embeddings from memori.llm._embeddings import embed_texts from memori.memory.recall import Recall +from tests.benchmarks._results import append_csv_row, results_dir from tests.benchmarks.memory_utils import measure_peak_rss_bytes +def _default_benchmark_csv_path() -> str: + return str(results_dir() / "recall_benchmarks.csv") + + +def _write_benchmark_row(*, benchmark, row: dict[str, object]) -> None: + csv_path = ( + os.environ.get("BENCHMARK_RESULTS_CSV_PATH") or _default_benchmark_csv_path() + ) + stats = getattr(benchmark, "stats", None) + row_out: dict[str, object] = dict(row) + row_out["timestamp_utc"] = datetime.datetime.now(datetime.UTC).isoformat() + + for key in ( + "mean", + "stddev", + "median", + "min", + "max", + "rounds", + "iterations", + "ops", + ): + value = getattr(stats, key, None) if stats is not None else None + if value is not None: + row_out[key] = value + + header = [ + "timestamp_utc", + "test", + "db", + "fact_count", + "query_size", + "retrieval_limit", + "one_shot_seconds", + "peak_rss_bytes", + "mean", + "stddev", + "median", + "min", + "max", + "rounds", + "iterations", + "ops", + ] + append_csv_row(csv_path, header=header, row=row_out) + + @pytest.mark.benchmark class TestQueryEmbeddingBenchmarks: """Benchmarks for query embedding generation.""" @@ -15,46 +68,122 @@ class TestQueryEmbeddingBenchmarks: def test_benchmark_query_embedding_short(self, benchmark, sample_queries): """Benchmark embedding generation for short queries.""" query = sample_queries["short"][0] + cfg = Config() def _embed(): - return embed_texts(query) + return embed_texts( + query, + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) + start = perf_counter() result = benchmark(_embed) + one_shot_seconds = perf_counter() - start assert len(result) > 0 assert len(result[0]) > 0 + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "query_embedding_short", + "db": "", + "fact_count": "", + "query_size": "short", + "retrieval_limit": "", + "one_shot_seconds": one_shot_seconds, + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) def test_benchmark_query_embedding_medium(self, benchmark, sample_queries): """Benchmark embedding generation for medium-length queries.""" query = sample_queries["medium"][0] + cfg = Config() def _embed(): - return embed_texts(query) + return embed_texts( + query, + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) + start = perf_counter() result = benchmark(_embed) + one_shot_seconds = perf_counter() - start assert len(result) > 0 assert len(result[0]) > 0 + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "query_embedding_medium", + "db": "", + "fact_count": "", + "query_size": "medium", + "retrieval_limit": "", + "one_shot_seconds": one_shot_seconds, + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) def test_benchmark_query_embedding_long(self, benchmark, sample_queries): """Benchmark embedding generation for long queries.""" query = sample_queries["long"][0] + cfg = Config() def _embed(): - return embed_texts(query) + return embed_texts( + query, + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) + start = perf_counter() result = benchmark(_embed) + one_shot_seconds = perf_counter() - start assert len(result) > 0 assert len(result[0]) > 0 + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "query_embedding_long", + "db": "", + "fact_count": "", + "query_size": "long", + "retrieval_limit": "", + "one_shot_seconds": one_shot_seconds, + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) def test_benchmark_query_embedding_batch(self, benchmark, sample_queries): """Benchmark embedding generation for multiple queries at once.""" queries = sample_queries["short"][:5] + cfg = Config() def _embed(): - return embed_texts(queries) + return embed_texts( + queries, + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) + start = perf_counter() result = benchmark(_embed) + one_shot_seconds = perf_counter() - start assert len(result) == len(queries) assert all(len(emb) > 0 for emb in result) + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "query_embedding_batch", + "db": "", + "fact_count": "", + "query_size": "batch", + "retrieval_limit": "", + "one_shot_seconds": one_shot_seconds, + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) @pytest.mark.benchmark @@ -79,6 +208,18 @@ def _retrieve(): result = benchmark(_retrieve) assert len(result) == fact_count assert all("id" in row and "content_embedding" in row for row in result) + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "db_embedding_retrieval", + "db": entity_with_n_facts["db_type"], + "fact_count": fact_count, + "query_size": "", + "retrieval_limit": "", + "one_shot_seconds": "", + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) @pytest.mark.benchmark @@ -120,6 +261,18 @@ def _retrieve(): result = benchmark(_retrieve) assert len(result) == len(fact_ids) assert all("id" in row and "content" in row for row in result) + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "db_fact_content_retrieval", + "db": entity_with_n_facts["db_type"], + "fact_count": entity_with_n_facts["fact_count"], + "query_size": "", + "retrieval_limit": retrieval_limit, + "one_shot_seconds": "", + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) @pytest.mark.benchmark @@ -140,7 +293,11 @@ def test_benchmark_semantic_search( # Pre-generate query embedding (not part of benchmark) query = sample_queries["short"][0] - query_embedding = embed_texts(query)[0] + query_embedding = embed_texts( + query, + model=memori_instance.config.embeddings.model, + fallback_dimension=memori_instance.config.embeddings.fallback_dimension, + )[0] # Benchmark only the similarity search def _search(): @@ -156,17 +313,39 @@ def _search(): assert all( isinstance(item[0], int) and isinstance(item[1], float) for item in result ) + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "semantic_search_faiss", + "db": entity_with_n_facts["db_type"], + "fact_count": fact_count, + "query_size": "short", + "retrieval_limit": "", + "one_shot_seconds": "", + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) @pytest.mark.benchmark class TestEndToEndRecallBenchmarks: """Benchmarks for end-to-end recall (embed query + DB + FAISS + content fetch).""" + @pytest.mark.parametrize( + "query_size", + ["short", "medium", "long"], + ids=["short_query", "medium_query", "long_query"], + ) def test_benchmark_end_to_end_recall( - self, benchmark, memori_instance, entity_with_n_facts, sample_queries + self, + benchmark, + memori_instance, + entity_with_n_facts, + sample_queries, + query_size, ): entity_db_id = entity_with_n_facts["entity_db_id"] - query = sample_queries["short"][0] + query = sample_queries[query_size][0] recall = Recall(memori_instance.config) @@ -177,6 +356,20 @@ def _recall(): if peak_rss is not None: benchmark.extra_info["peak_rss_bytes"] = peak_rss + start = perf_counter() result = benchmark(_recall) + one_shot_seconds = perf_counter() - start assert isinstance(result, list) assert len(result) <= 5 + _write_benchmark_row( + benchmark=benchmark, + row={ + "test": "end_to_end_recall", + "db": entity_with_n_facts["db_type"], + "fact_count": entity_with_n_facts["fact_count"], + "query_size": query_size, + "retrieval_limit": "", + "one_shot_seconds": one_shot_seconds, + "peak_rss_bytes": benchmark.extra_info.get("peak_rss_bytes", ""), + }, + ) diff --git a/tests/benchmarks/test_recall_semantic_accuracy.py b/tests/benchmarks/test_recall_semantic_accuracy.py deleted file mode 100644 index ad89e0d6..00000000 --- a/tests/benchmarks/test_recall_semantic_accuracy.py +++ /dev/null @@ -1,144 +0,0 @@ -import random - -import pytest - -from memori.llm._embeddings import embed_texts -from memori.memory.recall import Recall -from tests.benchmarks.semantic_accuracy_dataset import DATASET -from tests.benchmarks.semantic_accuracy_metrics import ( - mrr, - ndcg_at_k, - precision_at_k, - recall_at_k, -) - - -def _embeddings_available() -> bool: - # If the embedding model can't load, Memori falls back to all-zeros embeddings. - # That makes semantic accuracy meaningless, so we skip instead of failing. - vec = embed_texts("sanity check")[0] - return any(v != 0.0 for v in vec) - - -def _generate_hard_distractors( - count: int, *, rng: random.Random, forbidden: set[str] -) -> list[str]: - cities = ["London", "Berlin", "Rome", "Madrid", "Lisbon", "Dublin", "Vienna"] - colors = ["red", "green", "yellow", "purple", "orange", "black", "white"] - foods = ["sushi", "tacos", "ramen", "burgers", "pasta", "salad", "ice cream"] - drinks = ["tea", "sparkling water", "matcha", "hot chocolate", "juice"] - companies = ["Acme Corp", "Globex", "Initech", "Hooli", "Soylent", "Umbrella"] - activities = ["running", "swimming", "reading", "gaming", "cycling", "yoga"] - themes = ["light mode", "system theme", "high contrast mode"] - birthdays = ["April 1st", "May 20th", "June 7th", "July 30th", "Oct 12th"] - pets = ["1 cat", "3 cats", "2 dogs", "a dog", "a cat", "no pets"] - - templates = [ - lambda v: f"User lives in {v}", - lambda v: f"User's favorite color is {v}", - lambda v: f"User likes {v}", - lambda v: f"User works at {v}", - lambda v: f"User enjoys {v}", - lambda v: f"User prefers {v}", - lambda v: f"User's birthday is {v}", - lambda v: f"User has {v}", - ] - values = [ - cities, - colors, - foods + drinks, - companies, - activities, - themes, - birthdays, - pets, - ] - - distractors: list[str] = [] - for i in range(count): - idx = i % len(templates) - base = templates[idx](rng.choice(values[idx])) - candidate = f"{base} (id: d{i})" - if candidate in forbidden: - candidate = f"{base} (note: alt) (id: d{i})" - distractors.append(candidate) - - return distractors - - -@pytest.mark.skipif(not _embeddings_available(), reason="Embedding model unavailable") -@pytest.mark.parametrize( - "total_records", [10, 100, 500, 1000, 5000], ids=lambda n: f"n{n}" -) -def test_semantic_recall_accuracy(memori_instance, total_records): - """ - Semantic accuracy evaluation (the "right way"): - - seed a labeled dataset of facts - - run a labeled set of queries - - compute standard IR metrics (Recall@k, Precision@k, MRR, nDCG@k) - """ - # Seed dataset facts + distractors into a fresh entity - facts = list(DATASET["facts"]) - queries = DATASET["queries"] - - # Expand to the requested total size by adding distractors. - # This lets us evaluate how accuracy changes as the number of stored records grows. - if total_records < len(facts): - pytest.skip( - f"total_records={total_records} is smaller than labeled fact count={len(facts)}" - ) - - distractor_count = total_records - len(facts) - rng = random.Random(123) - forbidden = set(facts) - distractors = _generate_hard_distractors( - distractor_count, rng=rng, forbidden=forbidden - ) - rng.shuffle(distractors) - facts.extend(distractors) - - entity_id = f"semantic-accuracy-entity-{total_records}" - memori_instance.attribution(entity_id=entity_id, process_id="semantic-accuracy") - entity_db_id = memori_instance.config.storage.driver.entity.create(entity_id) - - fact_embeddings = embed_texts(facts) - memori_instance.config.storage.driver.entity_fact.create( - entity_db_id, facts, fact_embeddings - ) - - # Make the evaluation honest: search across the full corpus for this N. - # Otherwise recall will only consider the first `recall_embeddings_limit` rows (default 1000). - memori_instance.config.recall_embeddings_limit = total_records - - recall = Recall(memori_instance.config) - - k = 5 - scores = { - "recall@5": [], - "precision@5": [], - "mrr": [], - "ndcg@5": [], - } - - for query, expected in queries.items(): - relevant = set(expected) - results = recall.search_facts(query=query, limit=k, entity_id=entity_db_id) - retrieved = [r.get("content", "") for r in results] - - scores["recall@5"].append(recall_at_k(relevant, retrieved, k)) - scores["precision@5"].append(precision_at_k(relevant, retrieved, k)) - scores["mrr"].append(mrr(relevant, retrieved)) - scores["ndcg@5"].append(ndcg_at_k(relevant, retrieved, k)) - - # Aggregate (mean) metrics - mean_scores = {k: sum(v) / len(v) for k, v in scores.items()} - - db_type = getattr(memori_instance, "_benchmark_db_type", "unknown") - print( - f"[semantic-accuracy] db={db_type} total={total_records} " - f"labeled={len(DATASET['facts'])} distractors={distractor_count} " - f"embeddings_limit={memori_instance.config.recall_embeddings_limit} {mean_scores}" - ) - - # We intentionally don't hard-fail on aggressive thresholds here because the goal - # is to *benchmark* accuracy as N grows. The printed metrics are the artifact. diff --git a/tests/llm/test_llm_embeddings.py b/tests/llm/test_llm_embeddings.py index 67d5afb4..21cbf592 100644 --- a/tests/llm/test_llm_embeddings.py +++ b/tests/llm/test_llm_embeddings.py @@ -14,6 +14,7 @@ import numpy as np import pytest +from memori._config import Config from memori.llm._embeddings import ( _get_model, embed_texts, @@ -131,13 +132,18 @@ def test_get_model_different_models(): def test_embed_texts_single_string(): + cfg = Config() with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_model = Mock() mock_embeddings = np.array([[0.1, 0.2, 0.3]]) mock_model.encode.return_value = mock_embeddings mock_get_model.return_value = mock_model - result = embed_texts("Hello world") + result = embed_texts( + "Hello world", + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 1 assert result[0] == pytest.approx([0.1, 0.2, 0.3]) @@ -145,13 +151,18 @@ def test_embed_texts_single_string(): def test_embed_texts_list_of_strings(): + cfg = Config() with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_model = Mock() mock_embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) mock_model.encode.return_value = mock_embeddings mock_get_model.return_value = mock_model - result = embed_texts(["Hello", "World"]) + result = embed_texts( + ["Hello", "World"], + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 2 assert result[0] == pytest.approx([0.1, 0.2, 0.3]) @@ -159,31 +170,46 @@ def test_embed_texts_list_of_strings(): def test_embed_texts_empty_list(): - result = embed_texts([]) + cfg = Config() + result = embed_texts( + [], + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert result == [] def test_embed_texts_empty_string(): + cfg = Config() with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_model = Mock() mock_embeddings = np.array([[0.1, 0.2, 0.3]]) mock_model.encode.return_value = mock_embeddings mock_get_model.return_value = mock_model - result = embed_texts("") + result = embed_texts( + "", + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 1 mock_model.encode.assert_called_once_with([""], convert_to_numpy=True) def test_embed_texts_filters_empty_strings(): + cfg = Config() with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_model = Mock() mock_embeddings = np.array([[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]) mock_model.encode.return_value = mock_embeddings mock_get_model.return_value = mock_model - result = embed_texts(["Hello", "", "World", ""]) + result = embed_texts( + ["Hello", "", "World", ""], + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 2 mock_model.encode.assert_called_once_with( @@ -198,7 +224,7 @@ def test_embed_texts_custom_model(): mock_model.encode.return_value = mock_embeddings mock_get_model.return_value = mock_model - result = embed_texts("test", model="custom-model") + result = embed_texts("test", model="custom-model", fallback_dimension=1024) mock_get_model.assert_called_once_with("custom-model") assert len(result) == 1 @@ -208,11 +234,13 @@ def test_embed_texts_model_load_failure(): with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_get_model.side_effect = OSError("Model not found") - result = embed_texts(["Hello", "World"]) + result = embed_texts( + ["Hello", "World"], model="missing-model", fallback_dimension=7 + ) assert len(result) == 2 - assert result[0] == [0.0] * 768 - assert result[1] == [0.0] * 768 + assert result[0] == [0.0] * 7 + assert result[1] == [0.0] * 7 def test_embed_texts_encode_failure(): @@ -222,12 +250,30 @@ def test_embed_texts_encode_failure(): mock_model.get_sentence_embedding_dimension.return_value = 384 mock_get_model.return_value = mock_model - result = embed_texts(["Hello"]) + result = embed_texts(["Hello"], model="test-model", fallback_dimension=1024) assert len(result) == 1 assert result[0] == [0.0] * 384 +def test_embed_texts_shape_error_retries_and_pools(mocker): + mock_model = mocker.Mock() + # First call (convert_to_numpy=True) fails like numpy stack error + # Then we retry per-text (convert_to_numpy=True) which succeeds. + mock_model.encode.side_effect = [ + ValueError("all input arrays must have the same shape"), + np.ones((1, 4), dtype=np.float32), + np.zeros((1, 4), dtype=np.float32), + ] + mock_model.get_sentence_embedding_dimension.return_value = 4 + mocker.patch("memori.llm._embeddings._get_model", return_value=mock_model) + + out = embed_texts(["a", "b"], model="test-model", fallback_dimension=1024) + + assert out == [[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0]] + assert mock_model.encode.call_count == 3 + + def test_embed_texts_encode_failure_with_dimension_error(): with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_model = Mock() @@ -237,34 +283,35 @@ def test_embed_texts_encode_failure_with_dimension_error(): ) mock_get_model.return_value = mock_model - result = embed_texts(["Hello"]) + result = embed_texts(["Hello"], model="test-model", fallback_dimension=11) assert len(result) == 1 - assert result[0] == [0.0] * 768 + assert result[0] == [0.0] * 11 def test_embed_texts_model_load_runtime_error(): with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_get_model.side_effect = RuntimeError("Runtime error") - result = embed_texts("test") + result = embed_texts("test", model="test-model", fallback_dimension=9) assert len(result) == 1 - assert result[0] == [0.0] * 768 + assert result[0] == [0.0] * 9 def test_embed_texts_model_load_value_error(): with patch("memori.llm._embeddings._get_model") as mock_get_model: mock_get_model.side_effect = ValueError("Value error") - result = embed_texts("test") + result = embed_texts("test", model="test-model", fallback_dimension=9) assert len(result) == 1 - assert result[0] == [0.0] * 768 + assert result[0] == [0.0] * 9 @pytest.mark.asyncio async def test_embed_texts_async_single_string(): + cfg = Config() mock_result = [[0.1, 0.2, 0.3]] async def mock_run_in_executor(executor, func, *args): @@ -273,7 +320,11 @@ async def mock_run_in_executor(executor, func, *args): with patch("asyncio.get_event_loop") as mock_loop: mock_loop.return_value.run_in_executor = mock_run_in_executor - result = await embed_texts_async("Hello world") + result = await embed_texts_async( + "Hello world", + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 1 assert result[0] == pytest.approx([0.1, 0.2, 0.3]) @@ -281,6 +332,7 @@ async def mock_run_in_executor(executor, func, *args): @pytest.mark.asyncio async def test_embed_texts_async_list(): + cfg = Config() mock_result = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]] async def mock_run_in_executor(executor, func, *args): @@ -289,7 +341,11 @@ async def mock_run_in_executor(executor, func, *args): with patch("asyncio.get_event_loop") as mock_loop: mock_loop.return_value.run_in_executor = mock_run_in_executor - result = await embed_texts_async(["Hello", "World"]) + result = await embed_texts_async( + ["Hello", "World"], + model=cfg.embeddings.model, + fallback_dimension=cfg.embeddings.fallback_dimension, + ) assert len(result) == 2 assert result[0] == pytest.approx([0.1, 0.2, 0.3]) @@ -306,7 +362,9 @@ async def mock_run_in_executor(executor, func, *args): with patch("asyncio.get_event_loop") as mock_loop: mock_loop.return_value.run_in_executor = mock_run_in_executor - result = await embed_texts_async("test", model="custom-model") + result = await embed_texts_async( + "test", model="custom-model", fallback_dimension=1024 + ) assert len(result) == 1 assert result[0] == pytest.approx([0.1, 0.2, 0.3]) diff --git a/tests/memory/test_recall.py b/tests/memory/test_recall.py index d888bc09..f2baa6f6 100644 --- a/tests/memory/test_recall.py +++ b/tests/memory/test_recall.py @@ -114,7 +114,11 @@ def test_search_facts_success(): assert result[0]["content"] == "User likes pizza" assert result[1]["content"] == "User lives in NYC" - mock_embed.assert_called_once_with("What do I like?") + mock_embed.assert_called_once_with( + "What do I like?", + model=config.embeddings.model, + fallback_dimension=config.embeddings.fallback_dimension, + ) mock_search.assert_called_once_with( config.storage.driver.entity_fact, 1, @@ -271,7 +275,11 @@ def test_search_facts_embeds_query_correctly(): recall.search_facts("My test query", entity_id=1) - mock_embed.assert_called_once_with("My test query") + mock_embed.assert_called_once_with( + "My test query", + model=config.embeddings.model, + fallback_dimension=config.embeddings.fallback_dimension, + ) mock_search.assert_called_once() assert mock_search.call_args[0][2] == [0.1, 0.2, 0.3, 0.4, 0.5] diff --git a/tests/test_search.py b/tests/test_search.py index 2cd2a60b..eb28cd28 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -149,6 +149,19 @@ def test_find_similar_embeddings_dimension_mismatch(): assert result == [] +def test_find_similar_embeddings_mixed_dimensions(): + embeddings = [ + (1, [1.0, 0.0, 0.0]), # 3D + (2, [0.0, 1.0]), # 2D + (3, [0.0, 0.0, 1.0]), # 3D + ] + query = [1.0, 0.0, 0.0] + result = find_similar_embeddings(embeddings, query, limit=5) + + # Should ignore mismatched dimensions, not crash. + assert [fact_id for fact_id, _ in result] == [1, 3] + + def test_find_similar_embeddings_mixed_formats(): embeddings = [ (1, json.dumps([1.0, 0.0, 0.0])),