diff --git a/clients/js/src/embeddings/VoyageAIEmbeddingFunction.ts b/clients/js/src/embeddings/VoyageAIEmbeddingFunction.ts new file mode 100644 index 00000000000..9ee06750b36 --- /dev/null +++ b/clients/js/src/embeddings/VoyageAIEmbeddingFunction.ts @@ -0,0 +1,89 @@ +import { IEmbeddingFunction } from "./IEmbeddingFunction"; + +export enum InputType { + DOCUMENT = "document", + QUERY = "query", +} + +export class VoyageAIEmbeddingFunction implements IEmbeddingFunction { + private modelName: string; + private apiUrl: string; + private batchSize: number; + private truncation?: boolean; + private inputType?: InputType; + private headers: { [key: string]: string }; + + constructor({ + voyageaiApiKey, + modelName, + batchSize, + truncation, + inputType, + }: { + voyageaiApiKey: string; + modelName: string; + batchSize?: number; + truncation?: boolean; + inputType?: InputType; + }) { + this.apiUrl = "https://api.voyageai.com/v1/embeddings"; + this.headers = { + Authorization: `Bearer ${voyageaiApiKey}`, + "Content-Type": "application/json", + }; + + this.modelName = modelName; + this.truncation = truncation; + this.inputType = inputType; + if (batchSize) { + this.batchSize = batchSize; + } else { + if (modelName in ["voyage-2", "voyage-02"]) { + this.batchSize = 72; + } else { + this.batchSize = 7; + } + } + } + + public async generate(texts: string[]) { + try { + if (texts.length > this.batchSize) { + throw new Error( + `The number of texts to embed exceeds the maximum batch size of ${this.batchSize}` + ); + } + + const response = await fetch(this.apiUrl, { + method: "POST", + headers: this.headers, + body: JSON.stringify({ + input: texts, + model: this.modelName, + truncation: this.truncation, + input_type: this.inputType, + }), + }); + + const data = (await response.json()) as { data: any[]; detail: string }; + if (!data || !data.data) { + throw new Error(data.detail); + } + + const embeddings: any[] = data.data; + const sortedEmbeddings = embeddings.sort((a, b) => a.index - b.index); + + const embeddingsChunks = sortedEmbeddings.map( + (result) => result.embedding + ); + + return embeddingsChunks; + } catch (error) { + if (error instanceof Error) { + throw new Error(`Error calling VoyageAI API: ${error.message}`); + } else { + throw new Error(`Error calling VoyageAI API: ${error}`); + } + } + } +} diff --git a/clients/js/src/index.ts b/clients/js/src/index.ts index cef9c9356f0..12ae91f3c26 100644 --- a/clients/js/src/index.ts +++ b/clients/js/src/index.ts @@ -2,6 +2,7 @@ export { ChromaClient } from "./ChromaClient"; export { AdminClient } from "./AdminClient"; export { CloudClient } from "./CloudClient"; export { Collection } from "./Collection"; + export { IEmbeddingFunction } from "./embeddings/IEmbeddingFunction"; export { OpenAIEmbeddingFunction } from "./embeddings/OpenAIEmbeddingFunction"; export { CohereEmbeddingFunction } from "./embeddings/CohereEmbeddingFunction"; @@ -10,6 +11,10 @@ export { DefaultEmbeddingFunction } from "./embeddings/DefaultEmbeddingFunction" export { HuggingFaceEmbeddingServerFunction } from "./embeddings/HuggingFaceEmbeddingServerFunction"; export { JinaEmbeddingFunction } from "./embeddings/JinaEmbeddingFunction"; export { GoogleGenerativeAiEmbeddingFunction } from "./embeddings/GoogleGeminiEmbeddingFunction"; +export { + VoyageAIEmbeddingFunction, + InputType, +} from "./embeddings/VoyageAIEmbeddingFunction"; export { OllamaEmbeddingFunction } from "./embeddings/OllamaEmbeddingFunction"; export { diff --git a/clients/js/test/add.collections.test.ts b/clients/js/test/add.collections.test.ts index 41b3de3fef5..ebd652a3570 100644 --- a/clients/js/test/add.collections.test.ts +++ b/clients/js/test/add.collections.test.ts @@ -3,9 +3,11 @@ import chroma from "./initClient"; import { DOCUMENTS, EMBEDDINGS, IDS } from "./data"; import { METADATAS } from "./data"; import { IncludeEnum } from "../src/types"; + import { OpenAIEmbeddingFunction } from "../src/embeddings/OpenAIEmbeddingFunction"; import { CohereEmbeddingFunction } from "../src/embeddings/CohereEmbeddingFunction"; import { OllamaEmbeddingFunction } from "../src/embeddings/OllamaEmbeddingFunction"; + test("it should add single embeddings to a collection", async () => { await chroma.reset(); const collection = await chroma.createCollection({ name: "test" }); diff --git a/clients/js/test/embeddings/voyage.test.ts b/clients/js/test/embeddings/voyage.test.ts new file mode 100644 index 00000000000..4c220a8758d --- /dev/null +++ b/clients/js/test/embeddings/voyage.test.ts @@ -0,0 +1,60 @@ +import chroma from "../initClient"; +import { DOCUMENTS, IDS } from "../data"; +import { IncludeEnum } from "../../src/types"; +import { + VoyageAIEmbeddingFunction, + InputType, +} from "../../src/embeddings/VoyageAIEmbeddingFunction"; + +if (!process.env.VOYAGE_API_KEY) { + test.skip("it should add VoyageAI embeddings", async () => {}); +} else { + test("it should add VoyageAI embeddings", async () => { + await chroma.reset(); + const embedder = new VoyageAIEmbeddingFunction({ + voyageaiApiKey: process.env.VOYAGE_API_KEY || "", + modelName: "voyage-2", + batchSize: 5, + inputType: InputType.DOCUMENT, + }); + const collection = await chroma.createCollection({ + name: "test", + embeddingFunction: embedder, + }); + const embeddings = await embedder.generate(DOCUMENTS); + await collection.add({ ids: IDS, embeddings: embeddings }); + const count = await collection.count(); + expect(count).toBe(3); + expect(embeddings.length).toBe(3); + expect(embeddings[0].length).toBe(1024); + expect(embeddings[1].length).toBe(1024); + expect(embeddings[2].length).toBe(1024); + var res = await collection.get({ + ids: IDS, + include: [IncludeEnum.Embeddings], + }); + expect(res.embeddings).toEqual(embeddings); // reverse because of the order of the ids + }); + + test("it should throw an exception when the batch size is smaller than the number of texts", async () => { + await chroma.reset(); + const embedder = new VoyageAIEmbeddingFunction({ + voyageaiApiKey: process.env.VOYAGE_API_KEY || "", + modelName: "voyage-2", + batchSize: 2, + inputType: InputType.DOCUMENT, + }); + const collection = await chroma.createCollection({ + name: "test", + embeddingFunction: embedder, + }); + try { + const embeddings = await embedder.generate(DOCUMENTS); + fail("Should throw an exception"); + } catch (e: any) { + expect(e.message).toBe( + "Error calling VoyageAI API: The number of texts to embed exceeds the maximum batch size of 2" + ); + } + }); +}