diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index f54ab88c42e..22d7810710b 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -50,6 +50,49 @@ def _verify_sha256(fname: str, expected_sha256: str) -> bool: return sha256_hash.hexdigest() == expected_sha256 +class TogetherAIEmbeddingFunction(EmbeddingFunction[Documents]): + # Together AI Embeddings Quick Start Reference + # https://docs.together.ai/docs/embeddings-rest + # Models List + # https://docs.together.ai/docs/embedding-models + # You can get your API Keys from here : + # https://api.together.xyz/settings/api-keys + + def __init__( + self, + api_key: str, + model_name: str = "togethercomputer/m2-bert-80M-8k-retrieval", + api_url = "https://api.together.xyz/v1/embeddings" + ): + self._api_url = api_url + self._model_name = model_name + self._session = requests.Session() + self._session.headers.update({ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + }) + + def __call__(self, input: Documents) -> Embeddings: + embeddings = [] + for text in input: + response = self._session.post( + self._api_url, + json={ + "input": text, + "model": self._model_name + } + ) + + if response.status_code == 200: + response_data = response.json() + embedding = response_data.get("data", [])[0].get("embedding", []) + embeddings.append(embedding) + else: + raise ValueError(f"The API request to Together AI Endpoint failed with the status code : {response.status_code}. Refer https://docs.together.ai/reference/embeddings more details") + + return embeddings + + class SentenceTransformerEmbeddingFunction(EmbeddingFunction[Documents]): # Since we do dynamic imports we have to type this as Any models: Dict[str, Any] = {}