-
Notifications
You must be signed in to change notification settings - Fork 2.2k
[ENH] 1965 Split up embedding functions #2034
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
ffc5e91
1c28da4
6e4f190
40dee43
385dcc0
ed206d6
50076e4
0aeb92e
18926e9
1ec3d2a
3642058
0196264
929a8d4
c2e2cc8
2632601
6770d21
6ad7598
8f08d60
fc6b3c8
97dc885
5c56387
a548218
cbb0b03
84aa4cf
41a3e91
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,95 @@ | ||
| import inspect | ||
| import sys | ||
| from typing import List, Optional | ||
|
|
||
| from chromadb.api.types import Documents, EmbeddingFunction | ||
| from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import ( | ||
| AmazonBedrockEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import ( | ||
| create_langchain_embedding, | ||
| ) | ||
| from chromadb.utils.embedding_functions.cohere_embedding_function import ( | ||
| CohereEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.google_embedding_function import ( | ||
| GoogleGenerativeAiEmbeddingFunction, | ||
| GooglePalmEmbeddingFunction, | ||
| GoogleVertexEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.huggingface_embedding_function import ( | ||
| HuggingFaceEmbeddingFunction, | ||
| HuggingFaceEmbeddingServer, | ||
| ) | ||
| from chromadb.utils.embedding_functions.instructor_embedding_function import ( | ||
| InstructorEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.jina_embedding_function import ( | ||
| JinaEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.ollama_embedding_function import ( | ||
| OllamaEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import ( | ||
| ONNXMiniLM_L6_V2, | ||
| _verify_sha256, | ||
| ) | ||
| from chromadb.utils.embedding_functions.open_clip_embedding_function import ( | ||
| OpenCLIPEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.openai_embedding_function import ( | ||
| OpenAIEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.roboflow_embedding_function import ( | ||
| RoboflowEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import ( | ||
| SentenceTransformerEmbeddingFunction, | ||
| ) | ||
| from chromadb.utils.embedding_functions.text2vec_embedding_function import ( | ||
| Text2VecEmbeddingFunction, | ||
| ) | ||
|
|
||
| __all__ = [ | ||
| "AmazonBedrockEmbeddingFunction", | ||
| "create_langchain_embedding", | ||
| "CohereEmbeddingFunction", | ||
| "GoogleGenerativeAiEmbeddingFunction", | ||
| "GooglePalmEmbeddingFunction", | ||
| "GoogleVertexEmbeddingFunction", | ||
| "HuggingFaceEmbeddingFunction", | ||
| "HuggingFaceEmbeddingServer", | ||
| "InstructorEmbeddingFunction", | ||
| "JinaEmbeddingFunction", | ||
| "OllamaEmbeddingFunction", | ||
| "OpenCLIPEmbeddingFunction", | ||
| "ONNXMiniLM_L6_V2", | ||
| "OpenAIEmbeddingFunction", | ||
| "RoboflowEmbeddingFunction", | ||
| "SentenceTransformerEmbeddingFunction", | ||
| "Text2VecEmbeddingFunction", | ||
| "_verify_sha256", | ||
| ] | ||
|
|
||
| try: | ||
| from chromadb.is_thin_client import is_thin_client | ||
| except ImportError: | ||
| is_thin_client = False | ||
|
|
||
|
|
||
| def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]: | ||
| if is_thin_client: | ||
| return None | ||
| else: | ||
| return ONNXMiniLM_L6_V2() | ||
|
|
||
|
|
||
| _classes = [ | ||
| name | ||
| for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass) | ||
| if obj.__module__ == __name__ | ||
| ] | ||
|
|
||
|
|
||
| def get_builtins() -> List[str]: | ||
nablabits marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return _classes | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,55 @@ | ||
| import json | ||
| import logging | ||
| from typing import Any | ||
|
|
||
| from chromadb.api.types import Documents, EmbeddingFunction, Embeddings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| def __init__( | ||
| self, | ||
| session: Any, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not clear to me why
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does it make sense to have a typecheck here with dynamic boto3 import?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I loosely remind this to have been a pain,
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Types are really nice but both python and ts have a way to drive you into using I also like how the ecosystem is moving away from trying to have it all in the core package. Take for example how Langchain🦜🔗 has a partner lib for Chroma that is a completely separate package with all necessary deps. But that is probably a problem for another PR :)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wow, I really like that approach, but agreed, maybe a problem for the future |
||
| model_name: str = "amazon.titan-embed-text-v1", | ||
| **kwargs: Any, | ||
| ): | ||
| """Initialize AmazonBedrockEmbeddingFunction. | ||
|
|
||
| Args: | ||
| session (boto3.Session): The boto3 session to use. You need to have boto3 | ||
| installed, `pip install boto3`. | ||
| model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1" | ||
| **kwargs: Additional arguments to pass to the boto3 client. | ||
|
|
||
| Example: | ||
| >>> import boto3 | ||
| >>> session = boto3.Session(profile_name="profile", region_name="us-east-1") | ||
| >>> bedrock = AmazonBedrockEmbeddingFunction(session=session) | ||
| >>> texts = ["Hello, world!", "How are you?"] | ||
| >>> embeddings = bedrock(texts) | ||
| """ | ||
|
|
||
| self._model_name = model_name | ||
|
|
||
| self._client = session.client( | ||
| service_name="bedrock-runtime", | ||
| **kwargs, | ||
| ) | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| accept = "application/json" | ||
| content_type = "application/json" | ||
| embeddings = [] | ||
| for text in input: | ||
| input_body = {"inputText": text} | ||
| body = json.dumps(input_body) | ||
| response = self._client.invoke_model( | ||
| body=body, | ||
| modelId=self._model_name, | ||
| accept=accept, | ||
| contentType=content_type, | ||
| ) | ||
| embedding = json.load(response.get("body")).get("embedding") | ||
| embeddings.append(embedding) | ||
| return embeddings | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,69 @@ | ||
| import logging | ||
| from typing import Any, List, Union | ||
|
|
||
| from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Images | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore | ||
| try: | ||
| from langchain_core.embeddings import Embeddings as LangchainEmbeddings | ||
| except ImportError: | ||
| raise ValueError( | ||
| "The langchain_core python package is not installed. Please install it with `pip install langchain-core`" | ||
| ) | ||
|
|
||
| class ChromaLangchainEmbeddingFunction( | ||
| LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore | ||
| ): | ||
| """ | ||
| This class is used as bridge between langchain embedding functions and custom chroma embedding functions. | ||
| """ | ||
|
|
||
| def __init__(self, embedding_function: LangchainEmbeddings) -> None: | ||
| """ | ||
| Initialize the ChromaLangchainEmbeddingFunction | ||
|
|
||
| Args: | ||
| embedding_function : The embedding function implementing Embeddings from langchain_core. | ||
| """ | ||
| self.embedding_function = embedding_function | ||
|
|
||
| def embed_documents(self, documents: Documents) -> List[List[float]]: | ||
| return self.embedding_function.embed_documents(documents) # type: ignore | ||
|
|
||
| def embed_query(self, query: str) -> List[float]: | ||
| return self.embedding_function.embed_query(query) # type: ignore | ||
|
|
||
| def embed_image(self, uris: List[str]) -> List[List[float]]: | ||
| if hasattr(self.embedding_function, "embed_image"): | ||
| return self.embedding_function.embed_image(uris) # type: ignore | ||
| else: | ||
| raise ValueError( | ||
| "The provided embedding function does not support image embeddings." | ||
| ) | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: # type: ignore | ||
| """ | ||
| Get the embeddings for a list of texts or images. | ||
|
|
||
| Args: | ||
| input (Documents | Images): A list of texts or images to get embeddings for. | ||
| Images should be provided as a list of URIs passed through the langchain data loader | ||
|
|
||
| Returns: | ||
| Embeddings: The embeddings for the texts or images. | ||
|
|
||
| Example: | ||
| >>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large")) | ||
| >>> texts = ["Hello, world!", "How are you?"] | ||
| >>> embeddings = langchain_embedding(texts) | ||
| """ | ||
| # Due to langchain quirks, the dataloader returns a tuple if the input is uris of images | ||
| if input[0] == "images": | ||
| return self.embed_image(list(input[1])) # type: ignore | ||
|
|
||
| return self.embed_documents(list(input)) # type: ignore | ||
|
|
||
| return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| import logging | ||
|
|
||
| from chromadb.api.types import Documents, EmbeddingFunction, Embeddings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class CohereEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| def __init__(self, api_key: str, model_name: str = "large"): | ||
| try: | ||
| import cohere | ||
| except ImportError: | ||
| raise ValueError( | ||
| "The cohere python package is not installed. Please install it with `pip install cohere`" | ||
| ) | ||
|
|
||
| self._client = cohere.Client(api_key) | ||
| self._model_name = model_name | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| # Call Cohere Embedding API for each document. | ||
| return [ | ||
| embeddings | ||
| for embeddings in self._client.embed( | ||
| texts=input, model=self._model_name, input_type="search_document" | ||
| ) | ||
| ] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,110 @@ | ||
| import logging | ||
|
|
||
| import requests | ||
|
|
||
| from chromadb.api.types import Documents, EmbeddingFunction, Embeddings | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key.""" | ||
|
|
||
| def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"): | ||
| if not api_key: | ||
| raise ValueError("Please provide a PaLM API key.") | ||
|
|
||
| if not model_name: | ||
| raise ValueError("Please provide the model name.") | ||
|
|
||
| try: | ||
| import google.generativeai as palm | ||
| except ImportError: | ||
| raise ValueError( | ||
| "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" | ||
| ) | ||
|
|
||
| palm.configure(api_key=api_key) | ||
| self._palm = palm | ||
| self._model_name = model_name | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| return [ | ||
| self._palm.generate_embeddings(model=self._model_name, text=text)[ | ||
| "embedding" | ||
| ] | ||
| for text in input | ||
| ] | ||
|
|
||
|
|
||
| class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| """To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key.""" | ||
|
|
||
| """Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval.""" | ||
|
|
||
| def __init__( | ||
| self, | ||
| api_key: str, | ||
| model_name: str = "models/embedding-001", | ||
| task_type: str = "RETRIEVAL_DOCUMENT", | ||
| ): | ||
| if not api_key: | ||
| raise ValueError("Please provide a Google API key.") | ||
|
|
||
| if not model_name: | ||
| raise ValueError("Please provide the model name.") | ||
|
|
||
| try: | ||
| import google.generativeai as genai | ||
| except ImportError: | ||
| raise ValueError( | ||
| "The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`" | ||
| ) | ||
|
|
||
| genai.configure(api_key=api_key) | ||
| self._genai = genai | ||
| self._model_name = model_name | ||
| self._task_type = task_type | ||
| self._task_title = None | ||
| if self._task_type == "RETRIEVAL_DOCUMENT": | ||
| self._task_title = "Embedding of single string" | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| return [ | ||
| self._genai.embed_content( | ||
| model=self._model_name, | ||
| content=text, | ||
| task_type=self._task_type, | ||
| title=self._task_title, | ||
| )["embedding"] | ||
| for text in input | ||
| ] | ||
|
|
||
|
|
||
| class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]): | ||
| # Follow API Quickstart for Google Vertex AI | ||
| # https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart | ||
| # Information about the text embedding modules in Google Vertex AI | ||
| # https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings | ||
| def __init__( | ||
| self, | ||
| api_key: str, | ||
| model_name: str = "textembedding-gecko", | ||
| project_id: str = "cloud-large-language-models", | ||
| region: str = "us-central1", | ||
| ): | ||
| self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict" | ||
| self._session = requests.Session() | ||
| self._session.headers.update({"Authorization": f"Bearer {api_key}"}) | ||
|
|
||
| def __call__(self, input: Documents) -> Embeddings: | ||
| embeddings = [] | ||
| for text in input: | ||
| response = self._session.post( | ||
| self._api_url, json={"instances": [{"content": text}]} | ||
| ).json() | ||
|
|
||
| if "predictions" in response: | ||
| embeddings.append(response["predictions"]["embeddings"]["values"]) | ||
|
|
||
| return embeddings |
Uh oh!
There was an error while loading. Please reload this page.