Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions fastembed/text/onnx_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,28 @@
size_in_GB=0.64,
sources=ModelSource(hf="mixedbread-ai/mxbai-embed-large-v1"),
model_file="onnx/model.onnx",
# Prefixes from https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1#usage
tasks={
"query_prefix": "Represent this sentence for searching relevant passages: ",
"passage_prefix": "",
},
),
DenseModelDescription(
model="mixedbread-ai/deepset-mxbai-embed-de-large-v1",
dim=1024,
description=(
"Text embeddings, Unimodal (text), German/English, 512 input tokens truncation, "
"Prefixes for queries/documents: necessary, 2024 year."
),
license="apache-2.0",
size_in_GB=1.94,
sources=ModelSource(hf="mixedbread-ai/deepset-mxbai-embed-de-large-v1"),
model_file="onnx/model.onnx",
# Prefixes from https://huggingface.co/mixedbread-ai/deepset-mxbai-embed-de-large-v1#usage
tasks={
"query_prefix": "query: ",
"passage_prefix": "passage: ",
},
),
DenseModelDescription(
model="snowflake/snowflake-arctic-embed-xs",
Expand Down Expand Up @@ -294,6 +316,50 @@ def embed(
**kwargs,
)

def query_embed(self, query: Union[str, Iterable[str]], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds queries with optional query prefix.

Args:
query (Union[str, Iterable[str]]): The query to embed, or an iterable e.g. list of queries.

Returns:
Iterable[NumpyArray]: The embeddings.
"""
# Check if model has query prefix
query_prefix = self.model_description.tasks.get("query_prefix", "") if self.model_description.tasks else ""

# Apply prefix if specified
if query_prefix:
if isinstance(query, str):
query = [query_prefix + query]
else:
query = [query_prefix + q for q in query]
elif isinstance(query, str):
query = [query]

yield from self.embed(query, **kwargs)

def passage_embed(self, texts: Iterable[str], **kwargs: Any) -> Iterable[NumpyArray]:
"""
Embeds passages with optional passage prefix.

Args:
texts (Iterable[str]): The list of texts to embed.
**kwargs: Additional keyword arguments to pass to the embed method.

Yields:
Iterable[NumpyArray]: The embeddings.
"""
# Check if model has passage prefix
passage_prefix = self.model_description.tasks.get("passage_prefix", "") if self.model_description.tasks else ""

# Apply prefix if specified
if passage_prefix:
texts = [passage_prefix + text for text in texts]

yield from self.embed(texts, **kwargs)

@classmethod
def _get_worker_class(cls) -> Type["TextEmbeddingWorker[NumpyArray]"]:
return OnnxTextEmbeddingWorker
Expand Down
3 changes: 3 additions & 0 deletions tests/test_text_onnx_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
"mixedbread-ai/mxbai-embed-large-v1": np.array(
[0.02295546, 0.03196154, 0.016512, -0.04031524, -0.0219634]
),
"mixedbread-ai/deepset-mxbai-embed-de-large-v1": np.array(
[0.00574683, 0.00185086, 0.00910093, -0.03800965, 0.00805963]
),
"snowflake/snowflake-arctic-embed-xs": np.array([0.0092, 0.0619, 0.0196, 0.009, -0.0114]),
"snowflake/snowflake-arctic-embed-s": np.array([-0.0416, -0.0867, 0.0209, 0.0554, -0.0272]),
"snowflake/snowflake-arctic-embed-m": np.array([-0.0329, 0.0364, 0.0481, 0.0016, 0.0328]),
Expand Down