Skip to content
Closed
21 changes: 17 additions & 4 deletions chromadb/utils/embedding_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,18 +1027,27 @@ class FastEmbedEmbeddingFunction(EmbeddingFunction[Documents]):
def __init__(
self,
model_name: str = "BAAI/bge-small-en-v1.5",
batch_size: int = 256,
cache_dir: Optional[str] = None,
threads: Optional[int] = None,
parallel: Optional[int] = None,
**kwargs,
) -> None:
"""
Initialize fastembed.TextEmbedding

Args:
model_name (str): The name of the model to use.
cache_dir (str, optional): The path to the model cache directory.
model_name (str): The name of the model to use. Defaults to `"BAAI/bge-small-en-v1.5"`.
batch_size (int): Batch size for encoding. Higher values will use more memory, but be faster.\
Defaults to 256.
cache_dir (str, optional): The path to the model cache directory.\
Can also be set using the `FASTEMBED_CACHE_PATH` env variable.
threads (int, optional): The number of threads single onnxruntime session can use..
threads (int, optional): The number of threads single onnxruntime session can use.
parallel (int, optional): If `>1`, data-parallel encoding will be used, recommended for offline encoding of large datasets.\
If `0`, use all available cores.\
If `None`, don't use data-parallel processing, use default onnxruntime threading instead.\
Defaults to None.
**kwargs: Additional options to pass to fastembed.TextEmbedding

Raises:
ValueError: If the model_name is not in the format <org>/<model> e.g. BAAI/bge-base-en.
Expand All @@ -1049,6 +1058,8 @@ def __init__(
raise ValueError(
"The 'fastembed' package is not installed. Please install it with `pip install fastembed`"
)
self._batch_size = batch_size
self._parallel = parallel
self._model = TextEmbedding(
model_name=model_name, cache_dir=cache_dir, threads=threads, **kwargs
)
Expand All @@ -1068,7 +1079,9 @@ def __call__(self, input: Documents) -> Embeddings:
>>> texts = ["Hello, world!", "How are you?"]
>>> embeddings = fastembed_ef(texts)
"""
embeddings = self._model.embed(input)
embeddings = self._model.embed(
input, batch_size=self._batch_size, parallel=self._parallel
)
return cast(
Embeddings,
[embedding.tolist() for embedding in embeddings],
Expand Down