diff --git a/nemoguardrails/embeddings/basic.py b/nemoguardrails/embeddings/basic.py index cbd48ec62..a4e497762 100644 --- a/nemoguardrails/embeddings/basic.py +++ b/nemoguardrails/embeddings/basic.py @@ -15,9 +15,9 @@ import asyncio import logging -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union, cast -from annoy import AnnoyIndex +from annoy import AnnoyIndex # type: ignore from nemoguardrails.embeddings.cache import cache_embeddings from nemoguardrails.embeddings.index import EmbeddingsIndex, IndexItem @@ -45,26 +45,14 @@ class BasicEmbeddingsIndex(EmbeddingsIndex): max_batch_hold: The maximum time a batch is held before being processed """ - embedding_model: str - embedding_engine: str - embedding_params: Dict[str, Any] - index: AnnoyIndex - embedding_size: int - cache_config: EmbeddingsCacheConfig - embeddings: List[List[float]] - search_threshold: float - use_batching: bool - max_batch_size: int - max_batch_hold: float - def __init__( self, - embedding_model=None, - embedding_engine=None, - embedding_params=None, - index=None, - cache_config: Union[EmbeddingsCacheConfig, Dict[str, Any]] = None, - search_threshold: float = None, + embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + embedding_engine: str = "SentenceTransformers", + embedding_params: Optional[Dict[str, Any]] = None, + index: Optional[AnnoyIndex] = None, + cache_config: Optional[Union[EmbeddingsCacheConfig, Dict[str, Any]]] = None, + search_threshold: float = float("inf"), use_batching: bool = False, max_batch_size: int = 10, max_batch_hold: float = 0.01, @@ -72,22 +60,23 @@ def __init__( """Initialize the BasicEmbeddingsIndex. Args: - embedding_model (str, optional): The model for computing embeddings. Defaults to None. - embedding_engine (str, optional): The engine for computing embeddings. Defaults to None. - index (AnnoyIndex, optional): The pre-existing index. Defaults to None. - cache_config (EmbeddingsCacheConfig | Dict[str, Any], optional): The cache configuration. Defaults to None. + embedding_model: The model for computing embeddings. + embedding_engine: The engine for computing embeddings. + index: The pre-existing index. + cache_config: The cache configuration. + search_threshold: The threshold for filtering search results. use_batching: Whether to batch requests when computing the embeddings. max_batch_size: The maximum size of a batch. max_batch_hold: The maximum time a batch is held before being processed """ self._model: Optional[EmbeddingModel] = None - self._items = [] - self._embeddings = [] + self._items: List[IndexItem] = [] + self._embeddings: List[List[float]] = [] self.embedding_model = embedding_model self.embedding_engine = embedding_engine self.embedding_params = embedding_params or {} self._embedding_size = 0 - self.search_threshold = search_threshold or float("inf") + self.search_threshold = search_threshold if isinstance(cache_config, Dict): self._cache_config = EmbeddingsCacheConfig(**cache_config) else: @@ -95,12 +84,12 @@ def __init__( self._index = index # Data structures for batching embedding requests - self._req_queue = {} - self._req_results = {} - self._req_idx = 0 - self._current_batch_finished_event = None - self._current_batch_full_event = None - self._current_batch_submitted = asyncio.Event() + self._req_queue: Dict[int, str] = {} + self._req_results: Dict[int, List[float]] = {} + self._req_idx: int = 0 + self._current_batch_finished_event: Optional[asyncio.Event] = None + self._current_batch_full_event: Optional[asyncio.Event] = None + self._current_batch_submitted: asyncio.Event = asyncio.Event() # Initialize the batching configuration self.use_batching = use_batching @@ -112,6 +101,11 @@ def embeddings_index(self): """Get the current embedding index""" return self._index + @embeddings_index.setter + def embeddings_index(self, index): + """Setter to allow replacing the index dynamically.""" + self._index = index + @property def cache_config(self): """Get the cache configuration.""" @@ -127,16 +121,14 @@ def embeddings(self): """Get the computed embeddings.""" return self._embeddings - @embeddings_index.setter - def embeddings_index(self, index): - """Setter to allow replacing the index dynamically.""" - self._index = index - def _init_model(self): """Initialize the model used for computing the embeddings.""" + model = self.embedding_model + engine = self.embedding_engine + self._model = init_embedding_model( - embedding_model=self.embedding_model, - embedding_engine=self.embedding_engine, + embedding_model=model, + embedding_engine=engine, embedding_params=self.embedding_params, ) @@ -153,7 +145,9 @@ async def _get_embeddings(self, texts: List[str]) -> List[List[float]]: if self._model is None: self._init_model() - embeddings = await self._model.encode_async(texts) + # self._model can't be None here, or self._init_model() would throw a ValueError + model: EmbeddingModel = cast(EmbeddingModel, self._model) + embeddings = await model.encode_async(texts) return embeddings async def add_item(self, item: IndexItem): @@ -199,6 +193,12 @@ async def _run_batch(self): """Runs the current batch of embeddings.""" # Wait up to `max_batch_hold` time or until `max_batch_size` is reached. + if ( + self._current_batch_full_event is None + or self._current_batch_finished_event is None + ): + raise RuntimeError("Batch events not initialized. This should not happen.") + done, pending = await asyncio.wait( [ asyncio.create_task(asyncio.sleep(self.max_batch_hold)), @@ -244,7 +244,10 @@ async def _batch_get_embeddings(self, text: str) -> List[float]: self._req_idx += 1 self._req_queue[req_id] = text - if self._current_batch_finished_event is None: + if ( + self._current_batch_finished_event is None + or self._current_batch_full_event is None + ): self._current_batch_finished_event = asyncio.Event() self._current_batch_full_event = asyncio.Event() self._current_batch_submitted.clear() diff --git a/nemoguardrails/embeddings/cache.py b/nemoguardrails/embeddings/cache.py index 9abeb1de2..cdef48c27 100644 --- a/nemoguardrails/embeddings/cache.py +++ b/nemoguardrails/embeddings/cache.py @@ -20,7 +20,12 @@ from abc import ABC, abstractmethod from functools import singledispatchmethod from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Optional + +try: + import redis # type: ignore +except ImportError: + redis = None # type: ignore from nemoguardrails.rails.llm.config import EmbeddingsCacheConfig @@ -30,6 +35,8 @@ class KeyGenerator(ABC): """Abstract class for key generators.""" + name: str # Class attribute that should be defined in subclasses + @abstractmethod def generate_key(self, text: str) -> str: pass @@ -76,6 +83,8 @@ def generate_key(self, text: str) -> str: class CacheStore(ABC): """Abstract class for cache stores.""" + name: str + @abstractmethod def get(self, key): """Get a value from the cache.""" @@ -147,7 +156,7 @@ class FilesystemCacheStore(CacheStore): name = "filesystem" - def __init__(self, cache_dir: str = None): + def __init__(self, cache_dir: Optional[str] = None): self._cache_dir = Path(cache_dir or ".cache/embeddings") self._cache_dir.mkdir(parents=True, exist_ok=True) @@ -190,8 +199,10 @@ class RedisCacheStore(CacheStore): name = "redis" def __init__(self, host: str = "localhost", port: int = 6379, db: int = 0): - import redis - + if redis is None: + raise ImportError( + "Could not import redis, please install it with `pip install redis`." + ) self._redis = redis.Redis(host=host, port=port, db=db) def get(self, key): @@ -207,9 +218,9 @@ def clear(self): class EmbeddingsCache: def __init__( self, - key_generator: KeyGenerator = None, - cache_store: CacheStore = None, - store_config: dict = None, + key_generator: KeyGenerator, + cache_store: CacheStore, + store_config: Optional[dict] = None, ): self._key_generator = key_generator self._cache_store = cache_store @@ -218,7 +229,10 @@ def __init__( @classmethod def from_dict(cls, d: Dict[str, str]): key_generator = KeyGenerator.from_name(d.get("key_generator"))() - store_config = d.get("store_config") + store_config_raw = d.get("store_config") + store_config: dict = ( + store_config_raw if isinstance(store_config_raw, dict) else {} + ) cache_store = CacheStore.from_name(d.get("store"))(**store_config) return cls(key_generator=key_generator, cache_store=cache_store) @@ -239,7 +253,7 @@ def get_config(self): def get(self, texts): raise NotImplementedError - @get.register + @get.register(str) def _(self, text: str): key = self._key_generator.generate_key(text) log.info(f"Fetching key {key} for text '{text[:20]}...' from cache") @@ -248,7 +262,7 @@ def _(self, text: str): return result - @get.register + @get.register(list) def _(self, texts: list): cached = {} @@ -266,13 +280,13 @@ def _(self, texts: list): def set(self, texts): raise NotImplementedError - @set.register + @set.register(str) def _(self, text: str, value: List[float]): key = self._key_generator.generate_key(text) log.info(f"Cache miss for text '{text}'. Storing key {key} in cache.") self._cache_store.set(key, value) - @set.register + @set.register(list) def _(self, texts: list, values: List[List[float]]): for text, value in zip(texts, values): self.set(text, value) diff --git a/nemoguardrails/embeddings/providers/azureopenai.py b/nemoguardrails/embeddings/providers/azureopenai.py index 5c5906d5d..e77ab481a 100644 --- a/nemoguardrails/embeddings/providers/azureopenai.py +++ b/nemoguardrails/embeddings/providers/azureopenai.py @@ -46,17 +46,16 @@ class AzureEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str): try: - from openai import AzureOpenAI + from openai import AzureOpenAI # type: ignore except ImportError: raise ImportError( - "Could not import openai, please install it with " - "`pip install openai`." + "Could not import openai, please install it with `pip install openai`." ) # Set Azure OpenAI API credentials self.client = AzureOpenAI( api_key=os.getenv("AZURE_OPENAI_API_KEY"), api_version=os.getenv("AZURE_OPENAI_API_VERSION"), - azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), + azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"), # type: ignore ) self.embedding_model = embedding_model diff --git a/nemoguardrails/embeddings/providers/cohere.py b/nemoguardrails/embeddings/providers/cohere.py index 34cee4156..704e0bcd7 100644 --- a/nemoguardrails/embeddings/providers/cohere.py +++ b/nemoguardrails/embeddings/providers/cohere.py @@ -14,7 +14,7 @@ # limitations under the License. import asyncio from contextvars import ContextVar -from typing import List +from typing import TYPE_CHECKING, List from .base import EmbeddingModel @@ -23,6 +23,10 @@ # is changed, it will fail. async_client_var: ContextVar = ContextVar("async_client", default=None) +if TYPE_CHECKING: + import cohere + from cohere import AsyncClient, Client + class CohereEmbeddingModel(EmbeddingModel): """ @@ -64,7 +68,7 @@ def __init__( self.model = embedding_model self.input_type = input_type - self.client = cohere.Client(**kwargs) + self.client = cohere.Client(**kwargs) # type: ignore[reportCallIssue] self.embedding_size_dict = { "embed-v4.0": 1536, @@ -120,6 +124,9 @@ def encode(self, documents: List[str]) -> List[List[float]]: """ # Make embedding request to Cohere API - return self.client.embed( + # Since we don't pass embedding_types parameter, the response should be + # EmbeddingsFloatsEmbedResponse with embeddings as List[List[float]] + response = self.client.embed( texts=documents, model=self.model, input_type=self.input_type - ).embeddings + ) + return response.embeddings # type: ignore[return-value] diff --git a/nemoguardrails/embeddings/providers/fastembed.py b/nemoguardrails/embeddings/providers/fastembed.py index 1062e566f..1359f7ab5 100644 --- a/nemoguardrails/embeddings/providers/fastembed.py +++ b/nemoguardrails/embeddings/providers/fastembed.py @@ -42,7 +42,7 @@ class FastEmbedEmbeddingModel(EmbeddingModel): engine_name = "FastEmbed" def __init__(self, embedding_model: str, **kwargs): - from fastembed import TextEmbedding as Embedding + from fastembed import TextEmbedding as Embedding # type: ignore # Enabling a short form model name for all-MiniLM-L6-v2. if embedding_model == "all-MiniLM-L6-v2": diff --git a/nemoguardrails/embeddings/providers/google.py b/nemoguardrails/embeddings/providers/google.py index cf55399af..1f78974e6 100644 --- a/nemoguardrails/embeddings/providers/google.py +++ b/nemoguardrails/embeddings/providers/google.py @@ -46,7 +46,7 @@ class GoogleEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from google import genai + from google import genai # type: ignore[import] except ImportError: raise ImportError( diff --git a/nemoguardrails/embeddings/providers/nim.py b/nemoguardrails/embeddings/providers/nim.py index dd5690a4d..8ea9c1d0f 100644 --- a/nemoguardrails/embeddings/providers/nim.py +++ b/nemoguardrails/embeddings/providers/nim.py @@ -35,7 +35,7 @@ class NIMEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings + from langchain_nvidia_ai_endpoints import NVIDIAEmbeddings # type: ignore self.model = embedding_model self.document_embedder = NVIDIAEmbeddings(model=embedding_model, **kwargs) diff --git a/nemoguardrails/embeddings/providers/openai.py b/nemoguardrails/embeddings/providers/openai.py index 83f83f8c2..bd12f2333 100644 --- a/nemoguardrails/embeddings/providers/openai.py +++ b/nemoguardrails/embeddings/providers/openai.py @@ -46,14 +46,14 @@ def __init__( **kwargs, ): try: - import openai - from openai import AsyncOpenAI, OpenAI + import openai # type: ignore + from openai import AsyncOpenAI, OpenAI # type: ignore except ImportError: raise ImportError( "Could not import openai, please install it with " "`pip install openai`." ) - if openai.__version__ < "1.0.0": + if openai.__version__ < "1.0.0": # type: ignore raise RuntimeError( "`openai<1.0.0` is no longer supported. " "Please upgrade using `pip install openai>=1.0.0`." diff --git a/nemoguardrails/embeddings/providers/sentence_transformers.py b/nemoguardrails/embeddings/providers/sentence_transformers.py index 7ffcec712..cc7ce7be8 100644 --- a/nemoguardrails/embeddings/providers/sentence_transformers.py +++ b/nemoguardrails/embeddings/providers/sentence_transformers.py @@ -43,7 +43,7 @@ class SentenceTransformerEmbeddingModel(EmbeddingModel): def __init__(self, embedding_model: str, **kwargs): try: - from sentence_transformers import SentenceTransformer + from sentence_transformers import SentenceTransformer # type: ignore except ImportError: raise ImportError( "Could not import sentence-transformers, please install it with " @@ -51,7 +51,7 @@ def __init__(self, embedding_model: str, **kwargs): ) try: - from torch import cuda + from torch import cuda # type: ignore except ImportError: raise ImportError( "Could not import torch, please install it with `pip install torch`." diff --git a/pyproject.toml b/pyproject.toml index 6be833997..b31308b9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,6 +157,7 @@ pyright = "^1.1.405" include = [ "nemoguardrails/rails/**", "nemoguardrails/actions/**", + "nemoguardrails/embeddings/**", "nemoguardrails/cli/**", "nemoguardrails/kb/**", "nemoguardrails/logging/**",