diff --git a/chromadb/api/types.py b/chromadb/api/types.py index ad519bc0990..7f80f3fa9dd 100644 --- a/chromadb/api/types.py +++ b/chromadb/api/types.py @@ -1,7 +1,7 @@ from typing import Optional, Union, TypeVar, List, Dict, Any, Tuple, cast from numpy.typing import NDArray import numpy as np -from typing_extensions import Literal, TypedDict, Protocol, runtime_checkable +from typing_extensions import Literal, TypedDict, Protocol import chromadb.errors as errors from chromadb.types import ( Metadata, @@ -56,7 +56,7 @@ def maybe_cast_one_to_many_ids(target: OneOrMany[ID]) -> IDs: def maybe_cast_one_to_many_embedding( - target: Union[OneOrMany[Embedding], OneOrMany[np.ndarray]] # type: ignore[type-arg] + target: Union[OneOrMany[Embedding], OneOrMany[np.ndarray]] ) -> Embeddings: if isinstance(target, List): # One Embedding @@ -101,7 +101,7 @@ def maybe_cast_one_to_many_document(target: OneOrMany[Document]) -> Documents: # Images -ImageDType = Union[np.uint, np.int_, np.float_] # type: ignore[name-defined] +ImageDType = Union[np.uint, np.int_, np.float_] Image = NDArray[ImageDType] Images = List[Image] @@ -184,7 +184,6 @@ class IndexMetadata(TypedDict): time_created: float -@runtime_checkable class EmbeddingFunction(Protocol[D]): def __call__(self, input: D) -> Embeddings: ... @@ -200,10 +199,8 @@ def __call__(self: EmbeddingFunction[D], input: D) -> Embeddings: setattr(cls, "__call__", __call__) - def embed_with_retries( - self, input: D, **retry_kwargs: Dict[str, Any] - ) -> Embeddings: - return cast(Embeddings, retry(**retry_kwargs)(self.__call__)(input)) + def embed_with_retries(self, input: D, **retry_kwargs: Dict) -> Embeddings: + return retry(**retry_kwargs)(self.__call__)(input) def validate_embedding_function( diff --git a/chromadb/test/ef/test_default_ef.py b/chromadb/test/ef/test_default_ef.py index a80ccd2813b..6d8fb623698 100644 --- a/chromadb/test/ef/test_default_ef.py +++ b/chromadb/test/ef/test_default_ef.py @@ -7,10 +7,7 @@ import pytest from hypothesis import given, settings -from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( - ONNXMiniLM_L6_V2, - _verify_sha256, -) +from chromadb.utils.embedding_functions import ONNXMiniLM_L6_V2, _verify_sha256 def unique_by(x: Hashable) -> Hashable: diff --git a/chromadb/test/ef/test_ef.py b/chromadb/test/ef/test_ef.py deleted file mode 100644 index c93502e3fc8..00000000000 --- a/chromadb/test/ef/test_ef.py +++ /dev/null @@ -1,53 +0,0 @@ -from chromadb.utils import embedding_functions -from chromadb.api.types import EmbeddingFunction - - -def test_get_builtins_holds() -> None: - """ - Ensure that `get_builtins` is consistent after the ef migration. - - This test is intended to be temporary until the ef migration is complete as - these expected builtins are likely to grow as long as users add new - embedding functions. - - REMOVE ME ON THE NEXT EF ADDITION - """ - expected_builtins = { - "AmazonBedrockEmbeddingFunction", - "CohereEmbeddingFunction", - "GoogleGenerativeAiEmbeddingFunction", - "GooglePalmEmbeddingFunction", - "GoogleVertexEmbeddingFunction", - "HuggingFaceEmbeddingFunction", - "HuggingFaceEmbeddingServer", - "InstructorEmbeddingFunction", - "JinaEmbeddingFunction", - "ONNXMiniLM_L6_V2", - "OllamaEmbeddingFunction", - "OpenAIEmbeddingFunction", - "OpenCLIPEmbeddingFunction", - "RoboflowEmbeddingFunction", - "SentenceTransformerEmbeddingFunction", - "Text2VecEmbeddingFunction", - "ChromaLangchainEmbeddingFunction", - } - - assert expected_builtins == embedding_functions.get_builtins() - - -def test_default_ef_exists() -> None: - assert hasattr(embedding_functions, "DefaultEmbeddingFunction") - default_ef = embedding_functions.DefaultEmbeddingFunction() - - assert default_ef is not None - assert isinstance(default_ef, EmbeddingFunction) - - -def test_ef_imports() -> None: - for ef in embedding_functions.get_builtins(): - # Langchain embedding function is a special snowflake - if ef == "ChromaLangchainEmbeddingFunction": - continue - assert hasattr(embedding_functions, ef) - assert isinstance(getattr(embedding_functions, ef), type) - assert issubclass(getattr(embedding_functions, ef), EmbeddingFunction) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py new file mode 100644 index 00000000000..3b0aeff13fb --- /dev/null +++ b/chromadb/utils/embedding_functions.py @@ -0,0 +1,1029 @@ +import hashlib +import logging +from functools import cached_property + +from tenacity import stop_after_attempt, wait_random, retry, retry_if_exception + +from chromadb.api.types import ( + Document, + Documents, + Embedding, + Image, + Images, + EmbeddingFunction, + Embeddings, + is_image, + is_document, +) + +from io import BytesIO +from pathlib import Path +import os +import tarfile +import httpx +from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Union, cast +import numpy as np +import numpy.typing as npt +import importlib +import inspect +import json +import sys +import base64 + +try: + from chromadb.is_thin_client import is_thin_client +except ImportError: + is_thin_client = False + +if TYPE_CHECKING: + from onnxruntime import InferenceSession + from tokenizers import Tokenizer + +logger = logging.getLogger(__name__) + + +def _verify_sha256(fname: str, expected_sha256: str) -> bool: + sha256_hash = hashlib.sha256() + with open(fname, "rb") as f: + # Read and update hash in chunks to avoid using too much memory + for byte_block in iter(lambda: f.read(4096), b""): + sha256_hash.update(byte_block) + + return sha256_hash.hexdigest() == expected_sha256 + + +class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): + # Since we do dynamic imports we have to type this as Any + models: Dict[str, Any] = {} + + # If you have a beefier machine, try "gtr-t5-large". + # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html + def __init__( + self, + model_name: str = "all-MiniLM-L6-v2", + device: str = "cpu", + normalize_embeddings: bool = False, + **kwargs: Any, + ): + """Initialize SentenceTransformerEmbeddingFunction. + + Args: + model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" + device (str, optional): Device used for computation, defaults to "cpu" + normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False + **kwargs: Additional arguments to pass to the SentenceTransformer model. + """ + if model_name not in self.models: + try: + from sentence_transformers import SentenceTransformer + except ImportError: + raise ValueError( + "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" + ) + self.models[model_name] = SentenceTransformer( + model_name, device=device, **kwargs + ) + self._model = self.models[model_name] + self._normalize_embeddings = normalize_embeddings + + def __call__(self, input: Documents) -> Embeddings: + return cast( + Embeddings, + self._model.encode( + list(input), + convert_to_numpy=True, + normalize_embeddings=self._normalize_embeddings, + ).tolist(), + ) + + +class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): + try: + from text2vec import SentenceModel + except ImportError: + raise ValueError( + "The text2vec python package is not installed. Please install it with `pip install text2vec`" + ) + self._model = SentenceModel(model_name_or_path=model_name) + + def __call__(self, input: Documents) -> Embeddings: + return cast( + Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() + ) # noqa E501 + + +class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + api_key: Optional[str] = None, + model_name: str = "text-embedding-ada-002", + organization_id: Optional[str] = None, + api_base: Optional[str] = None, + api_type: Optional[str] = None, + api_version: Optional[str] = None, + deployment_id: Optional[str] = None, + default_headers: Optional[Mapping[str, str]] = None, + ): + """ + Initialize the OpenAIEmbeddingFunction. + Args: + api_key (str, optional): Your API key for the OpenAI API. If not + provided, it will raise an error to provide an OpenAI API key. + organization_id(str, optional): The OpenAI organization ID if applicable + model_name (str, optional): The name of the model to use for text + embeddings. Defaults to "text-embedding-ada-002". + api_base (str, optional): The base path for the API. If not provided, + it will use the base path for the OpenAI API. This can be used to + point to a different deployment, such as an Azure deployment. + api_type (str, optional): The type of the API deployment. This can be + used to specify a different deployment, such as 'azure'. If not + provided, it will use the default OpenAI deployment. + api_version (str, optional): The api version for the API. If not provided, + it will use the api version for the OpenAI API. This can be used to + point to a different deployment, such as an Azure deployment. + deployment_id (str, optional): Deployment ID for Azure OpenAI. + default_headers (Mapping, optional): A mapping of default headers to be sent with each API request. + + """ + try: + import openai + except ImportError: + raise ValueError( + "The openai python package is not installed. Please install it with `pip install openai`" + ) + + if api_key is not None: + openai.api_key = api_key + # If the api key is still not set, raise an error + elif openai.api_key is None: + raise ValueError( + "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" + ) + + if api_base is not None: + openai.api_base = api_base + + if api_version is not None: + openai.api_version = api_version + + self._api_type = api_type + if api_type is not None: + openai.api_type = api_type + + if organization_id is not None: + openai.organization = organization_id + + self._v1 = openai.__version__.startswith("1.") + if self._v1: + if api_type == "azure": + self._client = openai.AzureOpenAI( + api_key=api_key, + api_version=api_version, + azure_endpoint=api_base, + default_headers=default_headers, + ).embeddings + else: + self._client = openai.OpenAI( + api_key=api_key, base_url=api_base, default_headers=default_headers + ).embeddings + else: + self._client = openai.Embedding + self._model_name = model_name + self._deployment_id = deployment_id + + def __call__(self, input: Documents) -> Embeddings: + # replace newlines, which can negatively affect performance. + input = [t.replace("\n", " ") for t in input] + + # Call the OpenAI Embedding API + if self._v1: + embeddings = self._client.create( + input=input, model=self._deployment_id or self._model_name + ).data + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e.index) + + # Return just the embeddings + return cast(Embeddings, [result.embedding for result in sorted_embeddings]) + else: + if self._api_type == "azure": + embeddings = self._client.create( + input=input, engine=self._deployment_id or self._model_name + )["data"] + else: + embeddings = self._client.create(input=input, model=self._model_name)[ + "data" + ] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + + # Return just the embeddings + return cast( + Embeddings, [result["embedding"] for result in sorted_embeddings] + ) + + +class CohereEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__(self, api_key: str, model_name: str = "large"): + try: + import cohere + except ImportError: + raise ValueError( + "The cohere python package is not installed. Please install it with `pip install cohere`" + ) + + self._client = cohere.Client(api_key) + self._model_name = model_name + + def __call__(self, input: Documents) -> Embeddings: + # Call Cohere Embedding API for each document. + return [ + embeddings + for embeddings in self._client.embed( + texts=input, model=self._model_name, input_type="search_document" + ) + ] + + +class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the HuggingFace API. + It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". + """ + + def __init__( + self, api_key: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" + ): + """ + Initialize the HuggingFaceEmbeddingFunction. + + Args: + api_key (str): Your API key for the HuggingFace API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". + """ + self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" + self._session = httpx.Client() + self._session.headers.update({"Authorization": f"Bearer {api_key}"}) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> hugging_face = HuggingFaceEmbeddingFunction(api_key="your_api_key") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = hugging_face(texts) + """ + # Call HuggingFace Embedding API for each document + return cast( + Embeddings, + self._session.post( + self._api_url, + json={"inputs": input, "options": {"wait_for_model": True}}, + ).json(), + ) + + +class JinaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the Jina AI API. + It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". + """ + + def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"): + """ + Initialize the JinaEmbeddingFunction. + + Args: + api_key (str): Your API key for the Jina AI API. + model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". + """ + self._model_name = model_name + self._api_url = "https://api.jina.ai/v1/embeddings" + self._session = httpx.Client() + self._session.headers.update( + {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} + ) + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") + >>> input = ["Hello, world!", "How are you?"] + >>> embeddings = jina_ai_fn(input) + """ + # Call Jina AI Embedding API + resp = self._session.post( + self._api_url, json={"input": input, "model": self._model_name} + ).json() + if "data" not in resp: + raise RuntimeError(resp["detail"]) + + embeddings = resp["data"] + + # Sort resulting embeddings by index + sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) + + # Return just the embeddings + return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) + + +class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): + # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" + # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list + def __init__( + self, + model_name: str = "hkunlp/instructor-base", + device: str = "cpu", + instruction: Optional[str] = None, + ): + try: + from InstructorEmbedding import INSTRUCTOR + except ImportError: + raise ValueError( + "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" + ) + self._model = INSTRUCTOR(model_name, device=device) + self._instruction = instruction + + def __call__(self, input: Documents) -> Embeddings: + if self._instruction is None: + return cast(Embeddings, self._model.encode(input).tolist()) + + texts_with_instructions = [[self._instruction, text] for text in input] + + return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) + + +# In order to remove dependencies on sentence-transformers, which in turn depends on +# pytorch and sentence-piece we have created a default ONNX embedding function that +# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. +# visit https://github.com/chroma-core/onnx-embedding for the source code to generate +# and verify the ONNX model. +class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): + MODEL_NAME = "all-MiniLM-L6-v2" + DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME + EXTRACTED_FOLDER_NAME = "onnx" + ARCHIVE_FILENAME = "onnx.tar.gz" + MODEL_DOWNLOAD_URL = ( + "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" + ) + _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" + + # https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if + # no args + def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: + # Import dependencies on demand to mirror other embedding functions. This + # breaks typechecking, thus the ignores. + # convert the list to set for unique values + if preferred_providers and not all( + [isinstance(i, str) for i in preferred_providers] + ): + raise ValueError("Preferred providers must be a list of strings") + # check for duplicate providers + if preferred_providers and len(preferred_providers) != len( + set(preferred_providers) + ): + raise ValueError("Preferred providers must be unique") + self._preferred_providers = preferred_providers + try: + # Equivalent to import onnxruntime + self.ort = importlib.import_module("onnxruntime") + except ImportError: + raise ValueError( + "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" + ) + try: + # Equivalent to from tokenizers import Tokenizer + self.Tokenizer = importlib.import_module("tokenizers").Tokenizer + except ImportError: + raise ValueError( + "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" + ) + try: + # Equivalent to from tqdm import tqdm + self.tqdm = importlib.import_module("tqdm").tqdm + except ImportError: + raise ValueError( + "The tqdm python package is not installed. Please install it with `pip install tqdm`" + ) + + # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 + # Download with tqdm to preserve the sentence-transformers experience + @retry( + reraise=True, + stop=stop_after_attempt(3), + wait=wait_random(min=1, max=3), + retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), + ) + def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: + with httpx.stream("GET", url) as resp: + total = int(resp.headers.get("content-length", 0)) + with open(fname, "wb") as file, self.tqdm( + desc=str(fname), + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_bytes(chunk_size=chunk_size): + size = file.write(data) + bar.update(size) + if not _verify_sha256(fname, self._MODEL_SHA256): + # if the integrity of the file is not verified, remove it + os.remove(fname) + raise ValueError( + f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." + ) + + # Use pytorches default epsilon for division by zero + # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html + def _normalize(self, v: npt.NDArray) -> npt.NDArray: + norm = np.linalg.norm(v, axis=1) + norm[norm == 0] = 1e-12 + return cast(npt.NDArray, v / norm[:, np.newaxis]) + + def _forward(self, documents: List[str], batch_size: int = 32) -> npt.NDArray: + # We need to cast to the correct type because the type checker doesn't know that init_model_and_tokenizer will set the values + self.tokenizer = cast(self.Tokenizer, self.tokenizer) + self.model = cast(self.ort.InferenceSession, self.model) + all_embeddings = [] + for i in range(0, len(documents), batch_size): + batch = documents[i : i + batch_size] + encoded = [self.tokenizer.encode(d) for d in batch] + input_ids = np.array([e.ids for e in encoded]) + attention_mask = np.array([e.attention_mask for e in encoded]) + onnx_input = { + "input_ids": np.array(input_ids, dtype=np.int64), + "attention_mask": np.array(attention_mask, dtype=np.int64), + "token_type_ids": np.array( + [np.zeros(len(e), dtype=np.int64) for e in input_ids], + dtype=np.int64, + ), + } + model_output = self.model.run(None, onnx_input) + last_hidden_state = model_output[0] + # Perform mean pooling with attention weighting + input_mask_expanded = np.broadcast_to( + np.expand_dims(attention_mask, -1), last_hidden_state.shape + ) + embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( + input_mask_expanded.sum(1), a_min=1e-9, a_max=None + ) + embeddings = self._normalize(embeddings).astype(np.float32) + all_embeddings.append(embeddings) + return np.concatenate(all_embeddings) + + @cached_property + def tokenizer(self) -> "Tokenizer": + tokenizer = self.Tokenizer.from_file( + os.path.join( + self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" + ) + ) + # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 + # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 + tokenizer.enable_truncation(max_length=256) + tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) + return tokenizer + + @cached_property + def model(self) -> "InferenceSession": + if self._preferred_providers is None or len(self._preferred_providers) == 0: + if len(self.ort.get_available_providers()) > 0: + logger.debug( + f"WARNING: No ONNX providers provided, defaulting to available providers: " + f"{self.ort.get_available_providers()}" + ) + self._preferred_providers = self.ort.get_available_providers() + elif not set(self._preferred_providers).issubset( + set(self.ort.get_available_providers()) + ): + raise ValueError( + f"Preferred providers must be subset of available providers: {self.ort.get_available_providers()}" + ) + + # Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model. + so = self.ort.SessionOptions() + so.log_severity_level = 3 + + return self.ort.InferenceSession( + os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), + # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html + # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs + providers=self._preferred_providers, + sess_options=so, + ) + + def __call__(self, input: Documents) -> Embeddings: + # Only download the model when it is actually used + self._download_model_if_not_exists() + return cast(Embeddings, self._forward(input).tolist()) + + def _download_model_if_not_exists(self) -> None: + onnx_files = [ + "config.json", + "model.onnx", + "special_tokens_map.json", + "tokenizer_config.json", + "tokenizer.json", + "vocab.txt", + ] + extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) + onnx_files_exist = True + for f in onnx_files: + if not os.path.exists(os.path.join(extracted_folder, f)): + onnx_files_exist = False + break + # Model is not downloaded yet + if not onnx_files_exist: + os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) + if not os.path.exists( + os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) + ) or not _verify_sha256( + os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + self._MODEL_SHA256, + ): + self._download( + url=self.MODEL_DOWNLOAD_URL, + fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + ) + with tarfile.open( + name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), + mode="r:gz", + ) as tar: + tar.extractall(path=self.DOWNLOAD_PATH) + + +def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: + if is_thin_client: + return None + else: + return ONNXMiniLM_L6_V2() + + +class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): + """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" + + def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): + if not api_key: + raise ValueError("Please provide a PaLM API key.") + + if not model_name: + raise ValueError("Please provide the model name.") + + try: + import google.generativeai as palm + except ImportError: + raise ValueError( + "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" + ) + + palm.configure(api_key=api_key) + self._palm = palm + self._model_name = model_name + + def __call__(self, input: Documents) -> Embeddings: + return [ + self._palm.generate_embeddings(model=self._model_name, text=text)[ + "embedding" + ] + for text in input + ] + + +class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): + """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" + + """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" + + def __init__( + self, + api_key: str, + model_name: str = "models/embedding-001", + task_type: str = "RETRIEVAL_DOCUMENT", + ): + if not api_key: + raise ValueError("Please provide a Google API key.") + + if not model_name: + raise ValueError("Please provide the model name.") + + try: + import google.generativeai as genai + except ImportError: + raise ValueError( + "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" + ) + + genai.configure(api_key=api_key) + self._genai = genai + self._model_name = model_name + self._task_type = task_type + self._task_title = None + if self._task_type == "RETRIEVAL_DOCUMENT": + self._task_title = "Embedding of single string" + + def __call__(self, input: Documents) -> Embeddings: + return [ + self._genai.embed_content( + model=self._model_name, + content=text, + task_type=self._task_type, + title=self._task_title, + )["embedding"] + for text in input + ] + + +class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): + # Follow API Quickstart for Google Vertex AI + # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart + # Information about the text embedding modules in Google Vertex AI + # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings + def __init__( + self, + api_key: str, + model_name: str = "textembedding-gecko", + project_id: str = "cloud-large-language-models", + region: str = "us-central1", + ): + self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" + self._session = httpx.Client() + self._session.headers.update({"Authorization": f"Bearer {api_key}"}) + + def __call__(self, input: Documents) -> Embeddings: + embeddings = [] + for text in input: + response = self._session.post( + self._api_url, json={"instances": [{"content": text}]} + ).json() + + if "predictions" in response: + embeddings.append(response["predictions"]["embeddings"]["values"]) + + return embeddings + + +class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): + def __init__( + self, + model_name: str = "ViT-B-32", + checkpoint: str = "laion2b_s34b_b79k", + device: Optional[str] = "cpu", + ) -> None: + try: + import open_clip + except ImportError: + raise ValueError( + "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" + ) + try: + self._torch = importlib.import_module("torch") + except ImportError: + raise ValueError( + "The torch python package is not installed. Please install it with `pip install torch`" + ) + + try: + self._PILImage = importlib.import_module("PIL.Image") + except ImportError: + raise ValueError( + "The PIL python package is not installed. Please install it with `pip install pillow`" + ) + + model, _, preprocess = open_clip.create_model_and_transforms( + model_name=model_name, pretrained=checkpoint + ) + self._model = model + self._model.to(device) + self._preprocess = preprocess + self._tokenizer = open_clip.get_tokenizer(model_name=model_name) + + def _encode_image(self, image: Image) -> Embedding: + pil_image = self._PILImage.fromarray(image) + with self._torch.no_grad(): + image_features = self._model.encode_image( + self._preprocess(pil_image).unsqueeze(0) + ) + image_features /= image_features.norm(dim=-1, keepdim=True) + return cast(Embedding, image_features.squeeze().tolist()) + + def _encode_text(self, text: Document) -> Embedding: + with self._torch.no_grad(): + text_features = self._model.encode_text(self._tokenizer(text)) + text_features /= text_features.norm(dim=-1, keepdim=True) + return cast(Embedding, text_features.squeeze().tolist()) + + def __call__(self, input: Union[Documents, Images]) -> Embeddings: + embeddings: Embeddings = [] + for item in input: + if is_image(item): + embeddings.append(self._encode_image(cast(Image, item))) + elif is_document(item): + embeddings.append(self._encode_text(cast(Document, item))) + return embeddings + + +class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): + def __init__(self, api_key: str = "", api_url="https://infer.roboflow.com") -> None: + """ + Create a RoboflowEmbeddingFunction. + + Args: + api_key (str): Your API key for the Roboflow API. + api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com". + """ + if not api_key: + api_key = os.environ.get("ROBOFLOW_API_KEY") + + self._api_url = api_url + self._api_key = api_key + + try: + self._PILImage = importlib.import_module("PIL.Image") + except ImportError: + raise ValueError( + "The PIL python package is not installed. Please install it with `pip install pillow`" + ) + + def __call__(self, input: Union[Documents, Images]) -> Embeddings: + embeddings = [] + + for item in input: + if is_image(item): + image = self._PILImage.fromarray(item) + + buffer = BytesIO() + image.save(buffer, format="JPEG") + base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") + + infer_clip_payload = { + "image": { + "type": "base64", + "value": base64_image, + }, + } + + res = httpx.post( + f"{self._api_url}/clip/embed_image?api_key={self._api_key}", + json=infer_clip_payload, + ) + + result = res.json()["embeddings"] + + embeddings.append(result[0]) + + elif is_document(item): + infer_clip_payload = { + "text": input, + } + + res = httpx.post( + f"{self._api_url}/clip/embed_text?api_key={self._api_key}", + json=infer_clip_payload, + ) + + result = res.json()["embeddings"] + + embeddings.append(result[0]) + + return embeddings + + +class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): + def __init__( + self, + session: "boto3.Session", # noqa: F821 # Quote for forward reference + model_name: str = "amazon.titan-embed-text-v1", + **kwargs: Any, + ): + """Initialize AmazonBedrockEmbeddingFunction. + + Args: + session (boto3.Session): The boto3 session to use. + model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" + **kwargs: Additional arguments to pass to the boto3 client. + + Example: + >>> import boto3 + >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") + >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = bedrock(texts) + """ + + self._model_name = model_name + + self._client = session.client( + service_name="bedrock-runtime", + **kwargs, + ) + + def __call__(self, input: Documents) -> Embeddings: + accept = "application/json" + content_type = "application/json" + embeddings = [] + for text in input: + input_body = {"inputText": text} + body = json.dumps(input_body) + response = self._client.invoke_model( + body=body, + modelId=self._model_name, + accept=accept, + contentType=content_type, + ) + embedding = json.load(response.get("body")).get("embedding") + embeddings.append(embedding) + return embeddings + + +class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): + """ + This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). + The embedding model is configured in the server. + """ + + def __init__(self, url: str): + """ + Initialize the HuggingFaceEmbeddingServer. + + Args: + url (str): The URL of the HuggingFace Embedding Server. + """ + try: + import httpx + except ImportError: + raise ValueError( + "The httpx python package is not installed. Please install it with `pip install httpx`" + ) + self._api_url = f"{url}" + self._session = httpx.Client() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + texts (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = hugging_face(texts) + """ + # Call HuggingFace Embedding Server API for each document + return cast( + Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() + ) + + +def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore + try: + from langchain_core.embeddings import Embeddings as LangchainEmbeddings + except ImportError: + raise ValueError( + "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" + ) + + class ChromaLangchainEmbeddingFunction( + LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore + ): + """ + This class is used as bridge between langchain embedding functions and custom chroma embedding functions. + """ + + def __init__(self, embedding_function: LangchainEmbeddings) -> None: + """ + Initialize the ChromaLangchainEmbeddingFunction + + Args: + embedding_function : The embedding function implementing Embeddings from langchain_core. + """ + self.embedding_function = embedding_function + + def embed_documents(self, documents: Documents) -> List[List[float]]: + return self.embedding_function.embed_documents(documents) # type: ignore + + def embed_query(self, query: str) -> List[float]: + return self.embedding_function.embed_query(query) # type: ignore + + def embed_image(self, uris: List[str]) -> List[List[float]]: + if hasattr(self.embedding_function, "embed_image"): + return self.embedding_function.embed_image(uris) # type: ignore + else: + raise ValueError( + "The provided embedding function does not support image embeddings." + ) + + def __call__(self, input: Documents) -> Embeddings: # type: ignore + """ + Get the embeddings for a list of texts or images. + + Args: + input (Documents | Images): A list of texts or images to get embeddings for. + Images should be provided as a list of URIs passed through the langchain data loader + + Returns: + Embeddings: The embeddings for the texts or images. + + Example: + >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = langchain_embedding(texts) + """ + # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images + if input[0] == "images": + return self.embed_image(list(input[1])) # type: ignore + + return self.embed_documents(list(input)) # type: ignore + + return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) + + +class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): + """ + This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). + """ + + def __init__(self, url: str, model_name: str) -> None: + """ + Initialize the Ollama Embedding Function. + + Args: + url (str): The URL of the Ollama Server. + model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). + """ + try: + import httpx + except ImportError: + raise ValueError( + "The httpx python package is not installed. Please install it with `pip install httpx`" + ) + self._api_url = f"{url}" + self._model_name = model_name + self._session = httpx.Client() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") + >>> texts = ["Hello, world!", "How are you?"] + >>> embeddings = ollama_ef(texts) + """ + # Call Ollama Server API for each document + texts = input if isinstance(input, list) else [input] + embeddings = [ + self._session.post( + self._api_url, json={"model": self._model_name, "prompt": text} + ).json() + for text in texts + ] + return cast( + Embeddings, + [ + embedding["embedding"] + for embedding in embeddings + if "embedding" in embedding + ], + ) + + +# List of all classes in this module +_classes = [ + name + for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) + if obj.__module__ == __name__ +] + + +def get_builtins() -> List[str]: + return _classes diff --git a/chromadb/utils/embedding_functions/__init__.py b/chromadb/utils/embedding_functions/__init__.py deleted file mode 100644 index 2f0bf0f5cf2..00000000000 --- a/chromadb/utils/embedding_functions/__init__.py +++ /dev/null @@ -1,61 +0,0 @@ -import os -import importlib -import pkgutil -from types import ModuleType -from typing import Optional, Set, cast - -from chromadb.api.types import Documents, EmbeddingFunction - -# Langchain embedding function is a special snowflake -from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import ( # noqa: F401 - create_langchain_embedding, -) - -_all_classes: Set[str] = set() -_all_classes.add("ChromaLangchainEmbeddingFunction") - -try: - from chromadb.is_thin_client import is_thin_client -except ImportError: - is_thin_client = False - - -def _import_all_efs() -> Set[str]: - imported_classes = set() - _module_dir = os.path.dirname(__file__) - for _, module_name, _ in pkgutil.iter_modules([_module_dir]): - # Skip the current module - if module_name == __name__: - continue - - module: ModuleType = importlib.import_module(f"{__name__}.{module_name}") - - for attr_name in dir(module): - attr = getattr(module, attr_name) - if ( - isinstance(attr, type) - and issubclass(attr, EmbeddingFunction) - and attr is not EmbeddingFunction # Don't re-export the type - ): - globals()[attr.__name__] = attr - imported_classes.add(attr.__name__) - return imported_classes - - -_all_classes.update(_import_all_efs()) - - -# Define and export the default embedding function -def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: - if is_thin_client: - return None - else: - return cast( - EmbeddingFunction[Documents], - # This is implicitly imported above - ONNXMiniLM_L6_V2(), # type: ignore[name-defined] # noqa: F821 - ) - - -def get_builtins() -> Set[str]: - return _all_classes diff --git a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py b/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py deleted file mode 100644 index 67103ab7ffd..00000000000 --- a/chromadb/utils/embedding_functions/amazon_bedrock_embedding_function.py +++ /dev/null @@ -1,55 +0,0 @@ -import json -import logging -from typing import Any - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__( - self, - session: Any, - model_name: str = "amazon.titan-embed-text-v1", - **kwargs: Any, - ): - """Initialize AmazonBedrockEmbeddingFunction. - - Args: - session (boto3.Session): The boto3 session to use. You need to have boto3 - installed, `pip install boto3`. - model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" - **kwargs: Additional arguments to pass to the boto3 client. - - Example: - >>> import boto3 - >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") - >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = bedrock(texts) - """ - - self._model_name = model_name - - self._client = session.client( - service_name="bedrock-runtime", - **kwargs, - ) - - def __call__(self, input: Documents) -> Embeddings: - accept = "application/json" - content_type = "application/json" - embeddings = [] - for text in input: - input_body = {"inputText": text} - body = json.dumps(input_body) - response = self._client.invoke_model( - body=body, - modelId=self._model_name, - accept=accept, - contentType=content_type, - ) - embedding = json.load(response.get("body")).get("embedding") - embeddings.append(embedding) - return embeddings diff --git a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py b/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py deleted file mode 100644 index 445cca5b128..00000000000 --- a/chromadb/utils/embedding_functions/chroma_langchain_embedding_function.py +++ /dev/null @@ -1,69 +0,0 @@ -import logging -from typing import Any, List, Union - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Images - -logger = logging.getLogger(__name__) - - -def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore - try: - from langchain_core.embeddings import Embeddings as LangchainEmbeddings - except ImportError: - raise ValueError( - "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" - ) - - class ChromaLangchainEmbeddingFunction( - LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore - ): - """ - This class is used as bridge between langchain embedding functions and custom chroma embedding functions. - """ - - def __init__(self, embedding_function: LangchainEmbeddings) -> None: - """ - Initialize the ChromaLangchainEmbeddingFunction - - Args: - embedding_function : The embedding function implementing Embeddings from langchain_core. - """ - self.embedding_function = embedding_function - - def embed_documents(self, documents: Documents) -> List[List[float]]: - return self.embedding_function.embed_documents(documents) # type: ignore - - def embed_query(self, query: str) -> List[float]: - return self.embedding_function.embed_query(query) # type: ignore - - def embed_image(self, uris: List[str]) -> List[List[float]]: - if hasattr(self.embedding_function, "embed_image"): - return self.embedding_function.embed_image(uris) # type: ignore - else: - raise ValueError( - "The provided embedding function does not support image embeddings." - ) - - def __call__(self, input: Documents) -> Embeddings: # type: ignore - """ - Get the embeddings for a list of texts or images. - - Args: - input (Documents | Images): A list of texts or images to get embeddings for. - Images should be provided as a list of URIs passed through the langchain data loader - - Returns: - Embeddings: The embeddings for the texts or images. - - Example: - >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = langchain_embedding(texts) - """ - # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images - if input[0] == "images": - return self.embed_image(list(input[1])) # type: ignore - - return self.embed_documents(list(input)) # type: ignore - - return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) diff --git a/chromadb/utils/embedding_functions/cohere_embedding_function.py b/chromadb/utils/embedding_functions/cohere_embedding_function.py deleted file mode 100644 index ef9c33e24b9..00000000000 --- a/chromadb/utils/embedding_functions/cohere_embedding_function.py +++ /dev/null @@ -1,27 +0,0 @@ -import logging - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class CohereEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, api_key: str, model_name: str = "large"): - try: - import cohere - except ImportError: - raise ValueError( - "The cohere python package is not installed. Please install it with `pip install cohere`" - ) - - self._client = cohere.Client(api_key) - self._model_name = model_name - - def __call__(self, input: Documents) -> Embeddings: - # Call Cohere Embedding API for each document. - return [ - embeddings - for embeddings in self._client.embed( - texts=input, model=self._model_name, input_type="search_document" - ) - ] diff --git a/chromadb/utils/embedding_functions/google_embedding_function.py b/chromadb/utils/embedding_functions/google_embedding_function.py deleted file mode 100644 index 0534d790674..00000000000 --- a/chromadb/utils/embedding_functions/google_embedding_function.py +++ /dev/null @@ -1,110 +0,0 @@ -import logging - -import httpx - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): - """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" - - def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): - if not api_key: - raise ValueError("Please provide a PaLM API key.") - - if not model_name: - raise ValueError("Please provide the model name.") - - try: - import google.generativeai as palm - except ImportError: - raise ValueError( - "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" - ) - - palm.configure(api_key=api_key) - self._palm = palm - self._model_name = model_name - - def __call__(self, input: Documents) -> Embeddings: - return [ - self._palm.generate_embeddings(model=self._model_name, text=text)[ - "embedding" - ] - for text in input - ] - - -class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): - """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" - - """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" - - def __init__( - self, - api_key: str, - model_name: str = "models/embedding-001", - task_type: str = "RETRIEVAL_DOCUMENT", - ): - if not api_key: - raise ValueError("Please provide a Google API key.") - - if not model_name: - raise ValueError("Please provide the model name.") - - try: - import google.generativeai as genai - except ImportError: - raise ValueError( - "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" - ) - - genai.configure(api_key=api_key) - self._genai = genai - self._model_name = model_name - self._task_type = task_type - self._task_title = None - if self._task_type == "RETRIEVAL_DOCUMENT": - self._task_title = "Embedding of single string" - - def __call__(self, input: Documents) -> Embeddings: - return [ - self._genai.embed_content( - model=self._model_name, - content=text, - task_type=self._task_type, - title=self._task_title, - )["embedding"] - for text in input - ] - - -class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): - # Follow API Quickstart for Google Vertex AI - # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart - # Information about the text embedding modules in Google Vertex AI - # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings - def __init__( - self, - api_key: str, - model_name: str = "textembedding-gecko", - project_id: str = "cloud-large-language-models", - region: str = "us-central1", - ): - self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" - self._session = httpx.Client() - self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - - def __call__(self, input: Documents) -> Embeddings: - embeddings = [] - for text in input: - response = self._session.post( - self._api_url, json={"instances": [{"content": text}]} - ).json() - - if "predictions" in response: - embeddings.append(response["predictions"]["embeddings"]["values"]) - - return embeddings diff --git a/chromadb/utils/embedding_functions/huggingface_embedding_function.py b/chromadb/utils/embedding_functions/huggingface_embedding_function.py deleted file mode 100644 index 376a98fa4ae..00000000000 --- a/chromadb/utils/embedding_functions/huggingface_embedding_function.py +++ /dev/null @@ -1,90 +0,0 @@ -import logging -from typing import cast - -import httpx - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class HuggingFaceEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the HuggingFace API. - It requires an API key and a model name. The default model name is "sentence-transformers/all-MiniLM-L6-v2". - """ - - def __init__( - self, api_key: str, model_name: str = "sentence-transformers/all-MiniLM-L6-v2" - ): - """ - Initialize the HuggingFaceEmbeddingFunction. - - Args: - api_key (str): Your API key for the HuggingFace API. - model_name (str, optional): The name of the model to use for text embeddings. Defaults to "sentence-transformers/all-MiniLM-L6-v2". - """ - self._api_url = f"https://api-inference.huggingface.co/pipeline/feature-extraction/{model_name}" - self._session = httpx.Client() - self._session.headers.update({"Authorization": f"Bearer {api_key}"}) - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> hugging_face = HuggingFaceEmbeddingFunction(api_key="your_api_key") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = hugging_face(texts) - """ - # Call HuggingFace Embedding API for each document - return cast( - Embeddings, - self._session.post( - self._api_url, - json={"inputs": input, "options": {"wait_for_model": True}}, - ).json(), - ) - - -class HuggingFaceEmbeddingServer(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the HuggingFace Embedding server (https://github.com/huggingface/text-embeddings-inference). - The embedding model is configured in the server. - """ - - def __init__(self, url: str): - """ - Initialize the HuggingFaceEmbeddingServer. - - Args: - url (str): The URL of the HuggingFace Embedding Server. - """ - self._api_url = f"{url}" - self._session = httpx.Client() - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> hugging_face = HuggingFaceEmbeddingServer(url="http://localhost:8080/embed") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = hugging_face(texts) - """ - # Call HuggingFace Embedding Server API for each document - return cast( - Embeddings, self._session.post(self._api_url, json={"inputs": input}).json() - ) diff --git a/chromadb/utils/embedding_functions/instructor_embedding_function.py b/chromadb/utils/embedding_functions/instructor_embedding_function.py deleted file mode 100644 index a9ea6b26038..00000000000 --- a/chromadb/utils/embedding_functions/instructor_embedding_function.py +++ /dev/null @@ -1,33 +0,0 @@ -import logging -from typing import Optional, cast - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class InstructorEmbeddingFunction(EmbeddingFunction[Documents]): - # If you have a GPU with at least 6GB try model_name = "hkunlp/instructor-xl" and device = "cuda" - # for a full list of options: https://github.com/HKUNLP/instructor-embedding#model-list - def __init__( - self, - model_name: str = "hkunlp/instructor-base", - device: str = "cpu", - instruction: Optional[str] = None, - ): - try: - from InstructorEmbedding import INSTRUCTOR - except ImportError: - raise ValueError( - "The InstructorEmbedding python package is not installed. Please install it with `pip install InstructorEmbedding`" - ) - self._model = INSTRUCTOR(model_name, device=device) - self._instruction = instruction - - def __call__(self, input: Documents) -> Embeddings: - if self._instruction is None: - return cast(Embeddings, self._model.encode(input).tolist()) - - texts_with_instructions = [[self._instruction, text] for text in input] - - return cast(Embeddings, self._model.encode(texts_with_instructions).tolist()) diff --git a/chromadb/utils/embedding_functions/jina_embedding_function.py b/chromadb/utils/embedding_functions/jina_embedding_function.py deleted file mode 100644 index f631bef4df8..00000000000 --- a/chromadb/utils/embedding_functions/jina_embedding_function.py +++ /dev/null @@ -1,60 +0,0 @@ -import logging -from typing import List, cast, Union - -import httpx - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class JinaEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to get embeddings for a list of texts using the Jina AI API. - It requires an API key and a model name. The default model name is "jina-embeddings-v2-base-en". - """ - - def __init__(self, api_key: str, model_name: str = "jina-embeddings-v2-base-en"): - """ - Initialize the JinaEmbeddingFunction. - - Args: - api_key (str): Your API key for the Jina AI API. - model_name (str, optional): The name of the model to use for text embeddings. Defaults to "jina-embeddings-v2-base-en". - """ - self._model_name = model_name - self._api_url = "https://api.jina.ai/v1/embeddings" - self._session = httpx.Client() - self._session.headers.update( - {"Authorization": f"Bearer {api_key}", "Accept-Encoding": "identity"} - ) - - def __call__(self, input: Documents) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - texts (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> jina_ai_fn = JinaEmbeddingFunction(api_key="your_api_key") - >>> input = ["Hello, world!", "How are you?"] - >>> embeddings = jina_ai_fn(input) - """ - # Call Jina AI Embedding API - resp = self._session.post( - self._api_url, json={"input": input, "model": self._model_name} - ).json() - if "data" not in resp: - raise RuntimeError(resp["detail"]) - - embeddings: List[dict[str, Union[str, List[float]]]] = resp["data"] - - # Sort resulting embeddings by index - sorted_embeddings = sorted(embeddings, key=lambda e: e["index"]) - - # Return just the embeddings - return cast(Embeddings, [result["embedding"] for result in sorted_embeddings]) diff --git a/chromadb/utils/embedding_functions/ollama_embedding_function.py b/chromadb/utils/embedding_functions/ollama_embedding_function.py deleted file mode 100644 index a6293e36075..00000000000 --- a/chromadb/utils/embedding_functions/ollama_embedding_function.py +++ /dev/null @@ -1,58 +0,0 @@ -import logging -from typing import Union, cast - -import httpx - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class OllamaEmbeddingFunction(EmbeddingFunction[Documents]): - """ - This class is used to generate embeddings for a list of texts using the Ollama Embedding API (https://github.com/ollama/ollama/blob/main/docs/api.md#generate-embeddings). - """ - - def __init__(self, url: str, model_name: str) -> None: - """ - Initialize the Ollama Embedding Function. - - Args: - url (str): The URL of the Ollama Server. - model_name (str): The name of the model to use for text embeddings. E.g. "nomic-embed-text" (see https://ollama.com/library for available models). - """ - self._api_url = f"{url}" - self._model_name = model_name - self._session = httpx.Client() - - def __call__(self, input: Union[Documents, str]) -> Embeddings: - """ - Get the embeddings for a list of texts. - - Args: - input (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the texts. - - Example: - >>> ollama_ef = OllamaEmbeddingFunction(url="http://localhost:11434/api/embeddings", model_name="nomic-embed-text") - >>> texts = ["Hello, world!", "How are you?"] - >>> embeddings = ollama_ef(texts) - """ - # Call Ollama Server API for each document - texts = input if isinstance(input, list) else [input] - embeddings = [ - self._session.post( - self._api_url, json={"model": self._model_name, "prompt": text} - ).json() - for text in texts - ] - return cast( - Embeddings, - [ - embedding["embedding"] - for embedding in embeddings - if "embedding" in embedding - ], - ) diff --git a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py b/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py deleted file mode 100644 index 3120f3ffad8..00000000000 --- a/chromadb/utils/embedding_functions/onnx_mini_lm_l6_v2.py +++ /dev/null @@ -1,236 +0,0 @@ -import hashlib -import importlib -import logging -import os -import tarfile -from functools import cached_property -from pathlib import Path -from typing import List, Optional, cast - -import numpy as np -import numpy.typing as npt -import httpx -from onnxruntime import InferenceSession, get_available_providers, SessionOptions -from tenacity import retry, retry_if_exception, stop_after_attempt, wait_random -from tokenizers import Tokenizer - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -def _verify_sha256(fname: str, expected_sha256: str) -> bool: - sha256_hash = hashlib.sha256() - with open(fname, "rb") as f: - # Read and update hash in chunks to avoid using too much memory - for byte_block in iter(lambda: f.read(4096), b""): - sha256_hash.update(byte_block) - - return sha256_hash.hexdigest() == expected_sha256 - - -# In order to remove dependencies on sentence-transformers, which in turn depends on -# pytorch and sentence-piece we have created a default ONNX embedding function that -# implements the same functionality as "all-MiniLM-L6-v2" from sentence-transformers. -# visit https://github.com/chroma-core/onnx-embedding for the source code to generate -# and verify the ONNX model. -class ONNXMiniLM_L6_V2(EmbeddingFunction[Documents]): - MODEL_NAME = "all-MiniLM-L6-v2" - DOWNLOAD_PATH = Path.home() / ".cache" / "chroma" / "onnx_models" / MODEL_NAME - EXTRACTED_FOLDER_NAME = "onnx" - ARCHIVE_FILENAME = "onnx.tar.gz" - MODEL_DOWNLOAD_URL = ( - "https://chroma-onnx-models.s3.amazonaws.com/all-MiniLM-L6-v2/onnx.tar.gz" - ) - _MODEL_SHA256 = "913d7300ceae3b2dbc2c50d1de4baacab4be7b9380491c27fab7418616a16ec3" - - # https://github.com/python/mypy/issues/7291 mypy makes you type the constructor if - # no args - def __init__(self, preferred_providers: Optional[List[str]] = None) -> None: - # Import dependencies on demand to mirror other embedding functions. This - # breaks typechecking, thus the ignores. - # convert the list to set for unique values - if preferred_providers and not all( - [isinstance(i, str) for i in preferred_providers] - ): - raise ValueError("Preferred providers must be a list of strings") - # check for duplicate providers - if preferred_providers and len(preferred_providers) != len( - set(preferred_providers) - ): - raise ValueError("Preferred providers must be unique") - self._preferred_providers = preferred_providers - try: - # Equivalent to import onnxruntime - self.ort = importlib.import_module("onnxruntime") - except ImportError: - raise ValueError( - "The onnxruntime python package is not installed. Please install it with `pip install onnxruntime`" - ) - try: - # Equivalent to from tokenizers import Tokenizer - self.Tokenizer = importlib.import_module("tokenizers").Tokenizer - except ImportError: - raise ValueError( - "The tokenizers python package is not installed. Please install it with `pip install tokenizers`" - ) - try: - # Equivalent to from tqdm import tqdm - self.tqdm = importlib.import_module("tqdm").tqdm - except ImportError: - raise ValueError( - "The tqdm python package is not installed. Please install it with `pip install tqdm`" - ) - - # Borrowed from https://gist.github.com/yanqd0/c13ed29e29432e3cf3e7c38467f42f51 - # Download with tqdm to preserve the sentence-transformers experience - @retry( # type: ignore - reraise=True, - stop=stop_after_attempt(3), - wait=wait_random(min=1, max=3), - retry=retry_if_exception(lambda e: "does not match expected SHA256" in str(e)), - ) - def _download(self, url: str, fname: str, chunk_size: int = 1024) -> None: - """ - Download the onnx model from the URL and save it to the file path. - - About ignored types: - tenacity.retry decorator is a bit convoluted when it comes to type annotations - which makes mypy unhappy. If some smart folk knows how to fix this in an - elegant way, please do so. - """ - with httpx.stream("GET", url) as resp: - total = int(resp.headers.get("content-length", 0)) - with open(fname, "wb") as file, self.tqdm( - desc=str(fname), - total=total, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as bar: - for data in resp.iter_bytes(chunk_size=chunk_size): - size = file.write(data) - bar.update(size) - if not _verify_sha256(fname, self._MODEL_SHA256): - # if the integrity of the file is not verified, remove it - os.remove(fname) - raise ValueError( - f"Downloaded file {fname} does not match expected SHA256 hash. Corrupted download or malicious file." - ) - - # Use pytorches default epsilon for division by zero - # https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html - def _normalize(self, v: npt.NDArray[np.float32]) -> npt.NDArray[np.float32]: - norm = np.linalg.norm(v, axis=1) - norm[norm == 0] = 1e-12 - return cast(npt.NDArray[np.float32], v / norm[:, np.newaxis]) - - def _forward( - self, documents: List[str], batch_size: int = 32 - ) -> npt.NDArray[np.float32]: - all_embeddings = [] - for i in range(0, len(documents), batch_size): - batch = documents[i : i + batch_size] - encoded = [self.tokenizer.encode(d) for d in batch] - input_ids = np.array([e.ids for e in encoded]) - attention_mask = np.array([e.attention_mask for e in encoded]) - onnx_input = { - "input_ids": np.array(input_ids, dtype=np.int64), - "attention_mask": np.array(attention_mask, dtype=np.int64), - "token_type_ids": np.array( - [np.zeros(len(e), dtype=np.int64) for e in input_ids], - dtype=np.int64, - ), - } - model_output = self.model.run(None, onnx_input) - last_hidden_state = model_output[0] - # Perform mean pooling with attention weighting - input_mask_expanded = np.broadcast_to( - np.expand_dims(attention_mask, -1), last_hidden_state.shape - ) - embeddings = np.sum(last_hidden_state * input_mask_expanded, 1) / np.clip( - input_mask_expanded.sum(1), a_min=1e-9, a_max=None - ) - embeddings = self._normalize(embeddings).astype(np.float32) - all_embeddings.append(embeddings) - return np.concatenate(all_embeddings) - - @cached_property - def tokenizer(self) -> Tokenizer: - tokenizer = Tokenizer.from_file( - os.path.join( - self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "tokenizer.json" - ) - ) - # max_seq_length = 256, for some reason sentence-transformers uses 256 even though the HF config has a max length of 128 - # https://github.com/UKPLab/sentence-transformers/blob/3e1929fddef16df94f8bc6e3b10598a98f46e62d/docs/_static/html/models_en_sentence_embeddings.html#LL480 - tokenizer.enable_truncation(max_length=256) - tokenizer.enable_padding(pad_id=0, pad_token="[PAD]", length=256) - return tokenizer - - @cached_property - def model(self) -> InferenceSession: - if self._preferred_providers is None or len(self._preferred_providers) == 0: - if len(get_available_providers()) > 0: - logger.debug( - f"WARNING: No ONNX providers provided, defaulting to available providers: " - f"{get_available_providers()}" - ) - self._preferred_providers = get_available_providers() - elif not set(self._preferred_providers).issubset( - set(get_available_providers()) - ): - raise ValueError( - f"Preferred providers must be subset of available providers: {get_available_providers()}" - ) - - # Suppress onnxruntime warnings. This produces logspew, mainly when onnx tries to use CoreML, which doesn't fit this model. - so = SessionOptions() - so.log_severity_level = 3 - - return InferenceSession( - os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME, "model.onnx"), - # Since 1.9 onnyx runtime requires providers to be specified when there are multiple available - https://onnxruntime.ai/docs/api/python/api_summary.html - # This is probably not ideal but will improve DX as no exceptions will be raised in multi-provider envs - providers=self._preferred_providers, - sess_options=so, - ) - - def __call__(self, input: Documents) -> Embeddings: - # Only download the model when it is actually used - self._download_model_if_not_exists() - return cast(Embeddings, self._forward(input).tolist()) - - def _download_model_if_not_exists(self) -> None: - onnx_files = [ - "config.json", - "model.onnx", - "special_tokens_map.json", - "tokenizer_config.json", - "tokenizer.json", - "vocab.txt", - ] - extracted_folder = os.path.join(self.DOWNLOAD_PATH, self.EXTRACTED_FOLDER_NAME) - onnx_files_exist = True - for f in onnx_files: - if not os.path.exists(os.path.join(extracted_folder, f)): - onnx_files_exist = False - break - # Model is not downloaded yet - if not onnx_files_exist: - os.makedirs(self.DOWNLOAD_PATH, exist_ok=True) - if not os.path.exists( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME) - ) or not _verify_sha256( - os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - self._MODEL_SHA256, - ): - self._download( - url=self.MODEL_DOWNLOAD_URL, - fname=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - ) - with tarfile.open( - name=os.path.join(self.DOWNLOAD_PATH, self.ARCHIVE_FILENAME), - mode="r:gz", - ) as tar: - tar.extractall(path=self.DOWNLOAD_PATH) diff --git a/chromadb/utils/embedding_functions/open_clip_embedding_function.py b/chromadb/utils/embedding_functions/open_clip_embedding_function.py deleted file mode 100644 index 712cd871905..00000000000 --- a/chromadb/utils/embedding_functions/open_clip_embedding_function.py +++ /dev/null @@ -1,77 +0,0 @@ -import importlib -import logging -from typing import Optional, Union, cast - -from chromadb.api.types import ( - Document, - Documents, - Embedding, - EmbeddingFunction, - Embeddings, - Image, - Images, - is_document, - is_image, -) - -logger = logging.getLogger(__name__) - - -class OpenCLIPEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, - model_name: str = "ViT-B-32", - checkpoint: str = "laion2b_s34b_b79k", - device: Optional[str] = "cpu", - ) -> None: - try: - import open_clip - except ImportError: - raise ValueError( - "The open_clip python package is not installed. Please install it with `pip install open-clip-torch`. https://github.com/mlfoundations/open_clip" - ) - try: - self._torch = importlib.import_module("torch") - except ImportError: - raise ValueError( - "The torch python package is not installed. Please install it with `pip install torch`" - ) - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: - raise ValueError( - "The PIL python package is not installed. Please install it with `pip install pillow`" - ) - - model, _, preprocess = open_clip.create_model_and_transforms( - model_name=model_name, pretrained=checkpoint - ) - self._model = model - self._model.to(device) - self._preprocess = preprocess - self._tokenizer = open_clip.get_tokenizer(model_name=model_name) - - def _encode_image(self, image: Image) -> Embedding: - pil_image = self._PILImage.fromarray(image) - with self._torch.no_grad(): - image_features = self._model.encode_image( - self._preprocess(pil_image).unsqueeze(0) - ) - image_features /= image_features.norm(dim=-1, keepdim=True) - return cast(Embedding, image_features.squeeze().tolist()) - - def _encode_text(self, text: Document) -> Embedding: - with self._torch.no_grad(): - text_features = self._model.encode_text(self._tokenizer(text)) - text_features /= text_features.norm(dim=-1, keepdim=True) - return cast(Embedding, text_features.squeeze().tolist()) - - def __call__(self, input: Union[Documents, Images]) -> Embeddings: - embeddings: Embeddings = [] - for item in input: - if is_image(item): - embeddings.append(self._encode_image(cast(Image, item))) - elif is_document(item): - embeddings.append(self._encode_text(cast(Document, item))) - return embeddings diff --git a/chromadb/utils/embedding_functions/openai_embedding_function.py b/chromadb/utils/embedding_functions/openai_embedding_function.py deleted file mode 100644 index 03eff5437b3..00000000000 --- a/chromadb/utils/embedding_functions/openai_embedding_function.py +++ /dev/null @@ -1,138 +0,0 @@ -import logging -from typing import Mapping, Optional, cast - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class OpenAIEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__( - self, - api_key: Optional[str] = None, - model_name: str = "text-embedding-ada-002", - organization_id: Optional[str] = None, - api_base: Optional[str] = None, - api_type: Optional[str] = None, - api_version: Optional[str] = None, - deployment_id: Optional[str] = None, - default_headers: Optional[Mapping[str, str]] = None, - ): - """ - Initialize the OpenAIEmbeddingFunction. - Args: - api_key (str, optional): Your API key for the OpenAI API. If not - provided, it will raise an error to provide an OpenAI API key. - organization_id(str, optional): The OpenAI organization ID if applicable - model_name (str, optional): The name of the model to use for text - embeddings. Defaults to "text-embedding-ada-002". - api_base (str, optional): The base path for the API. If not provided, - it will use the base path for the OpenAI API. This can be used to - point to a different deployment, such as an Azure deployment. - api_type (str, optional): The type of the API deployment. This can be - used to specify a different deployment, such as 'azure'. If not - provided, it will use the default OpenAI deployment. - api_version (str, optional): The api version for the API. If not provided, - it will use the api version for the OpenAI API. This can be used to - point to a different deployment, such as an Azure deployment. - deployment_id (str, optional): Deployment ID for Azure OpenAI. - default_headers (Mapping, optional): A mapping of default headers to be sent with each API request. - - """ - try: - import openai - except ImportError: - raise ValueError( - "The openai python package is not installed. Please install it with `pip install openai`" - ) - - if api_key is not None: - openai.api_key = api_key - # If the api key is still not set, raise an error - elif openai.api_key is None: - raise ValueError( - "Please provide an OpenAI API key. You can get one at https://platform.openai.com/account/api-keys" - ) - - if api_base is not None: - openai.api_base = api_base - - if api_version is not None: - openai.api_version = api_version - - self._api_type = api_type - if api_type is not None: - openai.api_type = api_type - - if organization_id is not None: - openai.organization = organization_id - - self._v1 = openai.__version__.startswith("1.") - if self._v1: - if api_type == "azure": - self._client = openai.AzureOpenAI( - api_key=api_key, - api_version=api_version, - azure_endpoint=api_base, - default_headers=default_headers, - ).embeddings - else: - self._client = openai.OpenAI( - api_key=api_key, base_url=api_base, default_headers=default_headers - ).embeddings - else: - self._client = openai.Embedding - self._model_name = model_name - self._deployment_id = deployment_id - - def __call__(self, input: Documents) -> Embeddings: - """ - Generate the embeddings for the given `input`. - - # About ignoring types - We are not enforcing the openai library, therefore, `mypy` has hard times trying - to figure out what the types are for `self._client.create()` which throws an - error when trying to sort the list. If, eventually we include the `openai` lib - we can remove the type ignore tag. - - Args: - input (Documents): A list of texts to get embeddings for. - - Returns: - Embeddings: The embeddings for the given input sorted by index - """ - # replace newlines, which can negatively affect performance. - input = [t.replace("\n", " ") for t in input] - - # Call the OpenAI Embedding API - if self._v1: - embeddings = self._client.create( - input=input, model=self._deployment_id or self._model_name - ).data - - # Sort resulting embeddings by index - sorted_embeddings = sorted( - embeddings, key=lambda e: e.index # type: ignore - ) - - # Return just the embeddings - return cast(Embeddings, [result.embedding for result in sorted_embeddings]) - else: - if self._api_type == "azure": - embeddings = self._client.create( - input=input, engine=self._deployment_id or self._model_name - )["data"] - else: - embeddings = self._client.create(input=input, model=self._model_name)[ - "data" - ] - - # Sort resulting embeddings by index - sorted_embeddings = sorted( - embeddings, key=lambda e: e["index"] # type: ignore - ) - - # Return just the embeddings - return cast( - Embeddings, [result["embedding"] for result in sorted_embeddings] - ) diff --git a/chromadb/utils/embedding_functions/roboflow_embedding_function.py b/chromadb/utils/embedding_functions/roboflow_embedding_function.py deleted file mode 100644 index b118aa01c64..00000000000 --- a/chromadb/utils/embedding_functions/roboflow_embedding_function.py +++ /dev/null @@ -1,87 +0,0 @@ -import base64 -import importlib -import logging -import os -from io import BytesIO -from typing import Union - -import httpx - -from chromadb.api.types import ( - Documents, - EmbeddingFunction, - Embeddings, - Images, - is_document, - is_image, -) - -logger = logging.getLogger(__name__) - - -class RoboflowEmbeddingFunction(EmbeddingFunction[Union[Documents, Images]]): - def __init__( - self, api_key: str = "", api_url: str = "https://infer.roboflow.com" - ) -> None: - """ - Create a RoboflowEmbeddingFunction. - - Args: - api_key (str): Your API key for the Roboflow API. - api_url (str, optional): The URL of the Roboflow API. Defaults to "https://infer.roboflow.com". - """ - if not api_key: - api_key = os.environ.get("ROBOFLOW_API_KEY", "") - - self._api_url = api_url - self._api_key = api_key - - try: - self._PILImage = importlib.import_module("PIL.Image") - except ImportError: - raise ValueError( - "The PIL python package is not installed. Please install it with `pip install pillow`" - ) - - def __call__(self, input: Union[Documents, Images]) -> Embeddings: - embeddings = [] - - for item in input: - if is_image(item): - image = self._PILImage.fromarray(item) - - buffer = BytesIO() - image.save(buffer, format="JPEG") - base64_image = base64.b64encode(buffer.getvalue()).decode("utf-8") - - infer_clip_payload_image = { - "image": { - "type": "base64", - "value": base64_image, - }, - } - - res = httpx.post( - f"{self._api_url}/clip/embed_image?api_key={self._api_key}", - json=infer_clip_payload_image, - ) - - result = res.json()["embeddings"] - - embeddings.append(result[0]) - - elif is_document(item): - infer_clip_payload_text = { - "text": input, - } - - res = httpx.post( - f"{self._api_url}/clip/embed_text?api_key={self._api_key}", - json=infer_clip_payload_text, - ) - - result = res.json()["embeddings"] - - embeddings.append(result[0]) - - return embeddings diff --git a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py b/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py deleted file mode 100644 index 2ca530b0a30..00000000000 --- a/chromadb/utils/embedding_functions/sentence_transformer_embedding_function.py +++ /dev/null @@ -1,51 +0,0 @@ -import logging -from typing import Any, Dict, cast - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): - # Since we do dynamic imports we have to type this as Any - models: Dict[str, Any] = {} - - # If you have a beefier machine, try "gtr-t5-large". - # for a full list of options: https://huggingface.co/sentence-transformers, https://www.sbert.net/docs/pretrained_models.html - def __init__( - self, - model_name: str = "all-MiniLM-L6-v2", - device: str = "cpu", - normalize_embeddings: bool = False, - **kwargs: Any, - ): - """Initialize SentenceTransformerEmbeddingFunction. - - Args: - model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" - device (str, optional): Device used for computation, defaults to "cpu" - normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False - **kwargs: Additional arguments to pass to the SentenceTransformer model. - """ - if model_name not in self.models: - try: - from sentence_transformers import SentenceTransformer - except ImportError: - raise ValueError( - "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" - ) - self.models[model_name] = SentenceTransformer( - model_name, device=device, **kwargs - ) - self._model = self.models[model_name] - self._normalize_embeddings = normalize_embeddings - - def __call__(self, input: Documents) -> Embeddings: - return cast( - Embeddings, - self._model.encode( - list(input), - convert_to_numpy=True, - normalize_embeddings=self._normalize_embeddings, - ).tolist(), - ) diff --git a/chromadb/utils/embedding_functions/text2vec_embedding_function.py b/chromadb/utils/embedding_functions/text2vec_embedding_function.py deleted file mode 100644 index 86a45deff24..00000000000 --- a/chromadb/utils/embedding_functions/text2vec_embedding_function.py +++ /dev/null @@ -1,22 +0,0 @@ -import logging -from typing import cast - -from chromadb.api.types import Documents, EmbeddingFunction, Embeddings - -logger = logging.getLogger(__name__) - - -class Text2VecEmbeddingFunction(EmbeddingFunction[Documents]): - def __init__(self, model_name: str = "shibing624/text2vec-base-chinese"): - try: - from text2vec import SentenceModel - except ImportError: - raise ValueError( - "The text2vec python package is not installed. Please install it with `pip install text2vec`" - ) - self._model = SentenceModel(model_name_or_path=model_name) - - def __call__(self, input: Documents) -> Embeddings: - return cast( - Embeddings, self._model.encode(list(input), convert_to_numpy=True).tolist() - ) # noqa E501