Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions memori/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,18 @@ def __init__(self):
self.cockroachdb = False


class Embeddings:
def __init__(self):
self.model = "all-MiniLM-L6-v2"
self.fallback_dimension = 768


class Config:
def __init__(self):
self.api_key = None
self.augmentation = None
self.cache = Cache()
self.embeddings = Embeddings()
self.enterprise = False
self.llm = Llm()
self.framework = Framework()
Expand Down
11 changes: 10 additions & 1 deletion memori/_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,18 @@ def find_similar_embeddings(
if not embeddings:
return []

query_dim = len(query_embedding)
if query_dim == 0:
return []

embeddings_list = []
id_list = []

for fact_id, raw in embeddings:
try:
parsed = parse_embedding(raw)
if parsed.ndim != 1 or parsed.shape[0] != query_dim:
continue
embeddings_list.append(parsed)
id_list.append(fact_id)
except Exception:
Expand All @@ -68,7 +74,10 @@ def find_similar_embeddings(
if not embeddings_list:
return []

embeddings_array = np.stack(embeddings_list, axis=0)
try:
embeddings_array = np.stack(embeddings_list, axis=0)
except ValueError:
return []

faiss.normalize_L2(embeddings_array)
query_array = np.asarray([query_embedding], dtype=np.float32)
Expand Down
76 changes: 52 additions & 24 deletions memori/llm/_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
import asyncio
import os
import struct
from collections.abc import Iterable
from typing import Any

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")

from sentence_transformers import SentenceTransformer

_MODEL_CACHE: dict[str, SentenceTransformer] = {}
_DEFAULT_DIMENSION = 768


def _get_model(model_name: str) -> SentenceTransformer:
Expand All @@ -27,18 +27,25 @@ def _get_model(model_name: str) -> SentenceTransformer:
return _MODEL_CACHE[model_name]


def format_embedding_for_db(embedding: list[float], dialect: str) -> Any:
"""Format embedding for database storage.
def _prepare_text_inputs(texts: str | Iterable[str]) -> list[str]:
if isinstance(texts, str):
return [texts]
return [t for t in texts if t]


def _embedding_dimension(model: Any, default: int) -> int:
try:
dim_value = model.get_sentence_embedding_dimension()
return int(dim_value) if dim_value is not None else default
except (RuntimeError, ValueError, AttributeError, TypeError):
return default


Args:
embedding: List of floats representing the embedding vector
dialect: Database dialect (postgresql, mysql, sqlite, mongodb)
def _zero_vectors(count: int, dim: int) -> list[list[float]]:
return [[0.0] * dim for _ in range(count)]

Returns:
Formatted embedding optimized for the target database:
- PostgreSQL/CockroachDB/MySQL/SQLite: Binary (BYTEA/BLOB) - compact & fast
- MongoDB: Binary (BinData) - compact & fast
"""

def format_embedding_for_db(embedding: list[float], dialect: str) -> Any:
binary_data = struct.pack(f"<{len(embedding)}f", *embedding)

if dialect == "mongodb":
Expand All @@ -48,36 +55,57 @@ def format_embedding_for_db(embedding: list[float], dialect: str) -> Any:
return bson.Binary(binary_data)
except ImportError:
return binary_data
else:
return binary_data
return binary_data


def embed_texts(
texts: str | list[str], model: str = "all-mpnet-base-v2"
texts: str | list[str],
model: str,
fallback_dimension: int,
) -> list[list[float]]:
inputs = [texts] if isinstance(texts, str) else [t for t in texts if t]
inputs = _prepare_text_inputs(texts)
if not inputs:
return []

try:
encoder = _get_model(model)
except (OSError, RuntimeError, ValueError):
return [[0.0] * _DEFAULT_DIMENSION for _ in inputs]
return _zero_vectors(len(inputs), fallback_dimension)

try:
embeddings = encoder.encode(inputs, convert_to_numpy=True)
return embeddings.tolist()
except (RuntimeError, ValueError):
except ValueError as e:
# Some models can raise "all input arrays must have the same shape" when
# encoding batches. Retry one-by-one to avoid internal stacking.
if "same shape" not in str(e):
raise

try:
dim_value = encoder.get_sentence_embedding_dimension()
dim = int(dim_value) if dim_value is not None else _DEFAULT_DIMENSION
except (RuntimeError, ValueError, AttributeError, TypeError):
dim = _DEFAULT_DIMENSION
return [[0.0] * dim for _ in inputs]
vectors: list[list[float]] = []
for text in inputs:
single = encoder.encode([text], convert_to_numpy=True)
vectors.append(single[0].tolist())

dim_set = {len(v) for v in vectors}
if len(dim_set) != 1:
raise ValueError("all input arrays must have the same shape") from e

return vectors
except Exception:
dim = _embedding_dimension(encoder, default=fallback_dimension)
return _zero_vectors(len(inputs), dim)
except RuntimeError:
dim = _embedding_dimension(encoder, default=fallback_dimension)
return _zero_vectors(len(inputs), dim)


async def embed_texts_async(
texts: str | list[str], model: str = "all-mpnet-base-v2"
texts: str | list[str],
model: str,
fallback_dimension: int,
) -> list[list[float]]:
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, embed_texts, texts, model)
return await loop.run_in_executor(
None, embed_texts, texts, model, fallback_dimension
)
14 changes: 12 additions & 2 deletions memori/memory/augmentation/augmentations/memori/_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,12 @@ async def _process_api_response(self, api_response: dict) -> Memories:
]

if facts:
fact_embeddings = await embed_texts_async(facts)
embeddings_config = self.config.embeddings
fact_embeddings = await embed_texts_async(
facts,
model=embeddings_config.model,
fallback_dimension=embeddings_config.fallback_dimension,
)
api_response["entity"]["fact_embeddings"] = fact_embeddings

return Memories().configure_from_advanced_augmentation(api_response)
Expand All @@ -167,7 +172,12 @@ async def _schedule_entity_writes(
]

if facts_from_triples:
embeddings_from_triples = await embed_texts_async(facts_from_triples)
embeddings_config = self.config.embeddings
embeddings_from_triples = await embed_texts_async(
facts_from_triples,
model=embeddings_config.model,
fallback_dimension=embeddings_config.fallback_dimension,
)
facts_to_write = (facts_to_write or []) + facts_from_triples
embeddings_to_write = (
embeddings_to_write or []
Expand Down
7 changes: 6 additions & 1 deletion memori/memory/recall.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def search_facts(
if limit is None:
limit = self.config.recall_facts_limit

query_embedding = embed_texts(query)[0]
embeddings_config = self.config.embeddings
query_embedding = embed_texts(
query,
model=embeddings_config.model,
fallback_dimension=embeddings_config.fallback_dimension,
)[0]

facts = []
for attempt in range(MAX_RETRIES):
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ asyncio_mode = "auto"
addopts = [
"-v",
"--strict-markers",
"--ignore=tests/benchmarks",
"-m",
"not benchmark",
"--cov=memori",
Expand Down
26 changes: 26 additions & 0 deletions tests/benchmarks/_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from __future__ import annotations

import csv
from pathlib import Path
from typing import Any


def repo_root() -> Path:
return Path(__file__).resolve().parents[2]


def results_dir() -> Path:
path = repo_root() / "results"
path.mkdir(parents=True, exist_ok=True)
return path


def append_csv_row(path: str | Path, *, header: list[str], row: dict[str, Any]) -> None:
out_path = Path(path)
out_path.parent.mkdir(parents=True, exist_ok=True)
file_exists = out_path.exists()
with out_path.open("a", newline="") as f:
writer = csv.DictWriter(f, fieldnames=header)
if not file_exists:
writer.writeheader()
writer.writerow(row)
8 changes: 6 additions & 2 deletions tests/benchmarks/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def postgres_db_connection():
postgres_uri,
pool_pre_ping=True,
pool_recycle=300,
connect_args=connect_args if connect_args else None,
connect_args=connect_args,
)

try:
Expand Down Expand Up @@ -156,7 +156,11 @@ def entity_with_n_facts(memori_instance, fact_content_size, request):
memori_instance.attribution(entity_id=entity_id, process_id="benchmark-process")

facts = generate_facts_with_size(fact_count, fact_content_size)
fact_embeddings = embed_texts(facts)
fact_embeddings = embed_texts(
facts,
model=memori_instance.config.embeddings.model,
fallback_dimension=memori_instance.config.embeddings.fallback_dimension,
)

entity_db_id = memori_instance.config.storage.driver.entity.create(entity_id)
memori_instance.config.storage.driver.entity_fact.create(
Expand Down
Loading
Loading