Skip to content
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
ffc5e91
1965 - Init the module
nablabits Apr 11, 2024
1c28da4
1965 - Move over `AmazonBedrockEmbeddingFunction`
nablabits Apr 13, 2024
6e4f190
1965 - Move over `create_langchain_embedding`
nablabits Apr 13, 2024
40dee43
1965 - Move over `CohereEmbeddingFunction`
nablabits Apr 13, 2024
385dcc0
1965 - Move over `google_embedding_function`
nablabits Apr 13, 2024
ed206d6
1965 - Move over `huggingface_embedding_function`
nablabits Apr 13, 2024
50076e4
1965 - Move over `InstructorEmbeddingFunction`
nablabits Apr 13, 2024
0aeb92e
1965 - Move over `JinaEmbeddingFunction`
nablabits Apr 13, 2024
18926e9
1965 - Move over `OllamaEmbeddingFunction`
nablabits Apr 13, 2024
1ec3d2a
1965 - Move over `ONNXMiniLM_L6_V2`
nablabits Apr 13, 2024
3642058
1965 - Move over `OpenCLIPEmbeddingFunction`
nablabits Apr 13, 2024
0196264
1965 - Move over `OpenAIEmbeddingFunction`
nablabits Apr 13, 2024
929a8d4
1965 - Move over `RoboflowEmbeddingFunction`
nablabits Apr 13, 2024
c2e2cc8
1965 - Move over `SentenceTransformerEmbeddingFunction`
nablabits Apr 13, 2024
2632601
1965 - Move over `Text2VecEmbeddingFunction`
nablabits Apr 13, 2024
6770d21
1965 - Move remaining functions
nablabits Apr 13, 2024
6ad7598
1965 - Lint Files
nablabits Apr 20, 2024
8f08d60
1965 - Lint onnx embedding function
nablabits Apr 20, 2024
fc6b3c8
1965 - Ensure that `get_builtins()` holds after the migration.
nablabits Apr 24, 2024
97dc885
Merge branch 'main' into feature/1965-split-up-embedding-functions
nablabits May 14, 2024
5c56387
Automate imports of EF in module
atroyn Jun 20, 2024
a548218
Automate imports of EF in module
atroyn Jun 20, 2024
cbb0b03
Additional tests
atroyn Jun 20, 2024
84aa4cf
Merge branch 'main' into feature/1965-split-up-embedding-functions
atroyn Jun 20, 2024
41a3e91
httpx everywhere
atroyn Jun 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,031 changes: 0 additions & 1,031 deletions chromadb/utils/embedding_functions.py

This file was deleted.

95 changes: 95 additions & 0 deletions chromadb/utils/embedding_functions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import inspect
import sys
from typing import List, Optional

from chromadb.api.types import Documents, EmbeddingFunction
from chromadb.utils.embedding_functions.amazon_bedrock_embedding_function import (
AmazonBedrockEmbeddingFunction,
)
from chromadb.utils.embedding_functions.chroma_langchain_embedding_function import (
create_langchain_embedding,
)
from chromadb.utils.embedding_functions.cohere_embedding_function import (
CohereEmbeddingFunction,
)
from chromadb.utils.embedding_functions.google_embedding_function import (
GoogleGenerativeAiEmbeddingFunction,
GooglePalmEmbeddingFunction,
GoogleVertexEmbeddingFunction,
)
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingFunction,
HuggingFaceEmbeddingServer,
)
from chromadb.utils.embedding_functions.instructor_embedding_function import (
InstructorEmbeddingFunction,
)
from chromadb.utils.embedding_functions.jina_embedding_function import (
JinaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.ollama_embedding_function import (
OllamaEmbeddingFunction,
)
from chromadb.utils.embedding_functions.onnx_mini_lm_l6_v2 import (
ONNXMiniLM_L6_V2,
_verify_sha256,
)
from chromadb.utils.embedding_functions.open_clip_embedding_function import (
OpenCLIPEmbeddingFunction,
)
from chromadb.utils.embedding_functions.openai_embedding_function import (
OpenAIEmbeddingFunction,
)
from chromadb.utils.embedding_functions.roboflow_embedding_function import (
RoboflowEmbeddingFunction,
)
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import (
SentenceTransformerEmbeddingFunction,
)
from chromadb.utils.embedding_functions.text2vec_embedding_function import (
Text2VecEmbeddingFunction,
)

__all__ = [
"AmazonBedrockEmbeddingFunction",
"create_langchain_embedding",
"CohereEmbeddingFunction",
"GoogleGenerativeAiEmbeddingFunction",
"GooglePalmEmbeddingFunction",
"GoogleVertexEmbeddingFunction",
"HuggingFaceEmbeddingFunction",
"HuggingFaceEmbeddingServer",
"InstructorEmbeddingFunction",
"JinaEmbeddingFunction",
"OllamaEmbeddingFunction",
"OpenCLIPEmbeddingFunction",
"ONNXMiniLM_L6_V2",
"OpenAIEmbeddingFunction",
"RoboflowEmbeddingFunction",
"SentenceTransformerEmbeddingFunction",
"Text2VecEmbeddingFunction",
"_verify_sha256",
]

try:
from chromadb.is_thin_client import is_thin_client
except ImportError:
is_thin_client = False


def DefaultEmbeddingFunction() -> Optional[EmbeddingFunction[Documents]]:
if is_thin_client:
return None
else:
return ONNXMiniLM_L6_V2()


_classes = [
name
for name, obj in inspect.getmembers(sys.modules[__name__], inspect.isclass)
if obj.__module__ == __name__
]


def get_builtins() -> List[str]:
return _classes
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import json
import logging
from typing import Any

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)


class AmazonBedrockEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
session: Any,
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me why mypy was complaining with the literal boto3 so I decided to put an Any

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to have a typecheck here with dynamic boto3 import?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I loosely remind this to have been a pain, mypy didn't seem to like removing type hints and the literals either 🤔 Do you have something specific in mind?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Types are really nice but both python and ts have a way to drive you into using Any. Maybe we can just try with instanceof.

I also like how the ecosystem is moving away from trying to have it all in the core package. Take for example how Langchain🦜🔗 has a partner lib for Chroma that is a completely separate package with all necessary deps. But that is probably a problem for another PR :)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, I really like that approach, but agreed, maybe a problem for the future

model_name: str = "amazon.titan-embed-text-v1",
**kwargs: Any,
):
"""Initialize AmazonBedrockEmbeddingFunction.

Args:
session (boto3.Session): The boto3 session to use. You need to have boto3
installed, `pip install boto3`.
model_name (str, optional): Identifier of the model, defaults to "amazon.titan-embed-text-v1"
**kwargs: Additional arguments to pass to the boto3 client.

Example:
>>> import boto3
>>> session = boto3.Session(profile_name="profile", region_name="us-east-1")
>>> bedrock = AmazonBedrockEmbeddingFunction(session=session)
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = bedrock(texts)
"""

self._model_name = model_name

self._client = session.client(
service_name="bedrock-runtime",
**kwargs,
)

def __call__(self, input: Documents) -> Embeddings:
accept = "application/json"
content_type = "application/json"
embeddings = []
for text in input:
input_body = {"inputText": text}
body = json.dumps(input_body)
response = self._client.invoke_model(
body=body,
modelId=self._model_name,
accept=accept,
contentType=content_type,
)
embedding = json.load(response.get("body")).get("embedding")
embeddings.append(embedding)
return embeddings
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import logging
from typing import Any, List, Union

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Images

logger = logging.getLogger(__name__)


def create_langchain_embedding(langchain_embdding_fn: Any): # type: ignore
try:
from langchain_core.embeddings import Embeddings as LangchainEmbeddings
except ImportError:
raise ValueError(
"The langchain_core python package is not installed. Please install it with `pip install langchain-core`"
)

class ChromaLangchainEmbeddingFunction(
LangchainEmbeddings, EmbeddingFunction[Union[Documents, Images]] # type: ignore
):
"""
This class is used as bridge between langchain embedding functions and custom chroma embedding functions.
"""

def __init__(self, embedding_function: LangchainEmbeddings) -> None:
"""
Initialize the ChromaLangchainEmbeddingFunction

Args:
embedding_function : The embedding function implementing Embeddings from langchain_core.
"""
self.embedding_function = embedding_function

def embed_documents(self, documents: Documents) -> List[List[float]]:
return self.embedding_function.embed_documents(documents) # type: ignore

def embed_query(self, query: str) -> List[float]:
return self.embedding_function.embed_query(query) # type: ignore

def embed_image(self, uris: List[str]) -> List[List[float]]:
if hasattr(self.embedding_function, "embed_image"):
return self.embedding_function.embed_image(uris) # type: ignore
else:
raise ValueError(
"The provided embedding function does not support image embeddings."
)

def __call__(self, input: Documents) -> Embeddings: # type: ignore
"""
Get the embeddings for a list of texts or images.

Args:
input (Documents | Images): A list of texts or images to get embeddings for.
Images should be provided as a list of URIs passed through the langchain data loader

Returns:
Embeddings: The embeddings for the texts or images.

Example:
>>> langchain_embedding = ChromaLangchainEmbeddingFunction(embedding_function=OpenAIEmbeddings(model="text-embedding-3-large"))
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = langchain_embedding(texts)
"""
# Due to langchain quirks, the dataloader returns a tuple if the input is uris of images
if input[0] == "images":
return self.embed_image(list(input[1])) # type: ignore

return self.embed_documents(list(input)) # type: ignore

return ChromaLangchainEmbeddingFunction(embedding_function=langchain_embdding_fn)
27 changes: 27 additions & 0 deletions chromadb/utils/embedding_functions/cohere_embedding_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import logging

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)


class CohereEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(self, api_key: str, model_name: str = "large"):
try:
import cohere
except ImportError:
raise ValueError(
"The cohere python package is not installed. Please install it with `pip install cohere`"
)

self._client = cohere.Client(api_key)
self._model_name = model_name

def __call__(self, input: Documents) -> Embeddings:
# Call Cohere Embedding API for each document.
return [
embeddings
for embeddings in self._client.embed(
texts=input, model=self._model_name, input_type="search_document"
)
]
110 changes: 110 additions & 0 deletions chromadb/utils/embedding_functions/google_embedding_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import logging

import requests

from chromadb.api.types import Documents, EmbeddingFunction, Embeddings

logger = logging.getLogger(__name__)


class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""

def __init__(self, api_key: str, model_name: str = "models/embedding-gecko-001"):
if not api_key:
raise ValueError("Please provide a PaLM API key.")

if not model_name:
raise ValueError("Please provide the model name.")

try:
import google.generativeai as palm
except ImportError:
raise ValueError(
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
)

palm.configure(api_key=api_key)
self._palm = palm
self._model_name = model_name

def __call__(self, input: Documents) -> Embeddings:
return [
self._palm.generate_embeddings(model=self._model_name, text=text)[
"embedding"
]
for text in input
]


class GoogleGenerativeAiEmbeddingFunction(EmbeddingFunction[Documents]):
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a Google API key."""

"""Use RETRIEVAL_DOCUMENT for the task_type for embedding, and RETRIEVAL_QUERY for the task_type for retrieval."""

def __init__(
self,
api_key: str,
model_name: str = "models/embedding-001",
task_type: str = "RETRIEVAL_DOCUMENT",
):
if not api_key:
raise ValueError("Please provide a Google API key.")

if not model_name:
raise ValueError("Please provide the model name.")

try:
import google.generativeai as genai
except ImportError:
raise ValueError(
"The Google Generative AI python package is not installed. Please install it with `pip install google-generativeai`"
)

genai.configure(api_key=api_key)
self._genai = genai
self._model_name = model_name
self._task_type = task_type
self._task_title = None
if self._task_type == "RETRIEVAL_DOCUMENT":
self._task_title = "Embedding of single string"

def __call__(self, input: Documents) -> Embeddings:
return [
self._genai.embed_content(
model=self._model_name,
content=text,
task_type=self._task_type,
title=self._task_title,
)["embedding"]
for text in input
]


class GoogleVertexEmbeddingFunction(EmbeddingFunction[Documents]):
# Follow API Quickstart for Google Vertex AI
# https://cloud.google.com/vertex-ai/docs/generative-ai/start/quickstarts/api-quickstart
# Information about the text embedding modules in Google Vertex AI
# https://cloud.google.com/vertex-ai/docs/generative-ai/embeddings/get-text-embeddings
def __init__(
self,
api_key: str,
model_name: str = "textembedding-gecko",
project_id: str = "cloud-large-language-models",
region: str = "us-central1",
):
self._api_url = f"https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/goole/models/{model_name}:predict"
self._session = requests.Session()
self._session.headers.update({"Authorization": f"Bearer {api_key}"})

def __call__(self, input: Documents) -> Embeddings:
embeddings = []
for text in input:
response = self._session.post(
self._api_url, json={"instances": [{"content": text}]}
).json()

if "predictions" in response:
embeddings.append(response["predictions"]["embeddings"]["values"])

return embeddings
Loading