diff --git a/README.md b/README.md index 7bfe6193..c072c3e0 100644 --- a/README.md +++ b/README.md @@ -143,6 +143,27 @@ JSON output: } ``` +#### Hybrid Search + +Hybrid search combines traditional keyword search with semantic search for more relevant results. You need to have an embedder configured in your index settings to use this feature. + +```python +# Using hybrid search with the search method +index.search( + 'action movie', + { + "hybrid": {"semanticRatio": 0.5, "embedder": "default"} + } +) +``` + +The `semanticRatio` parameter (between 0 and 1) controls the balance between keyword search and semantic search: +- 0: Only keyword search +- 1: Only semantic search +- Values in between: A mix of both approaches + +The `embedder` parameter specifies which configured embedder to use for the semantic search component. + #### Custom Search With Filters If you want to enable filtering, you must add your attributes to the `filterableAttributes` index setting. diff --git a/meilisearch/index.py b/meilisearch/index.py index 9f6e176d..812c3a57 100644 --- a/meilisearch/index.py +++ b/meilisearch/index.py @@ -24,19 +24,22 @@ from meilisearch.config import Config from meilisearch.errors import version_error_hint_message from meilisearch.models.document import Document, DocumentsResults -from meilisearch.models.index import ( +from meilisearch.models.embedders import ( Embedders, - Faceting, + EmbedderType, HuggingFaceEmbedder, - IndexStats, - LocalizedAttributes, OllamaEmbedder, OpenAiEmbedder, + RestEmbedder, + UserProvidedEmbedder, +) +from meilisearch.models.index import ( + Faceting, + IndexStats, + LocalizedAttributes, Pagination, ProximityPrecision, - RestEmbedder, TypoTolerance, - UserProvidedEmbedder, ) from meilisearch.models.task import Task, TaskInfo, TaskResults from meilisearch.task import TaskHandler @@ -277,14 +280,21 @@ def get_stats(self) -> IndexStats: def search(self, query: str, opt_params: Optional[Mapping[str, Any]] = None) -> Dict[str, Any]: """Search in the index. + https://www.meilisearch.com/docs/reference/api/search + Parameters ---------- query: String containing the searched word(s) opt_params (optional): Dictionary containing optional query parameters. - Note: The vector parameter is only available in Meilisearch >= v1.13.0 - https://www.meilisearch.com/docs/reference/api/search#search-in-an-index + Common parameters include: + - hybrid: Dict with 'semanticRatio' and 'embedder' fields for hybrid search + - vector: Array of numbers for vector search + - retrieveVectors: Boolean to include vector data in search results + - filter: Filter queries by an attribute's value + - limit: Maximum number of documents returned + - offset: Number of documents to skip Returns ------- @@ -298,7 +308,9 @@ def search(self, query: str, opt_params: Optional[Mapping[str, Any]] = None) -> """ if opt_params is None: opt_params = {} + body = {"q": query, **opt_params} + return self.http.post( f"{self.config.paths.index}/{self.uid}/{self.config.paths.search}", body=body, @@ -955,14 +967,7 @@ def get_settings(self) -> Dict[str, Any]: ) if settings.get("embedders"): - embedders: dict[ - str, - OpenAiEmbedder - | HuggingFaceEmbedder - | OllamaEmbedder - | RestEmbedder - | UserProvidedEmbedder, - ] = {} + embedders: dict[str, EmbedderType] = {} for k, v in settings["embedders"].items(): if v.get("source") == "openAi": embedders[k] = OpenAiEmbedder(**v) @@ -988,6 +993,26 @@ def update_settings(self, body: MutableMapping[str, Any]) -> TaskInfo: ---------- body: Dictionary containing the settings of the index. + Supported settings include: + - 'rankingRules': List of ranking rules + - 'distinctAttribute': Attribute for deduplication + - 'searchableAttributes': Attributes that can be searched + - 'displayedAttributes': Attributes to display in search results + - 'stopWords': Words ignored in search queries + - 'synonyms': Dictionary of synonyms + - 'filterableAttributes': Attributes that can be used for filtering + - 'sortableAttributes': Attributes that can be used for sorting + - 'typoTolerance': Settings for typo tolerance + - 'pagination': Settings for pagination + - 'faceting': Settings for faceting + - 'dictionary': List of custom dictionary words + - 'separatorTokens': List of separator tokens + - 'nonSeparatorTokens': List of non-separator tokens + - 'embedders': Dictionary of embedder configurations for AI-powered search + - 'searchCutoffMs': Maximum search time in milliseconds + - 'proximityPrecision': Precision for proximity ranking + - 'localizedAttributes': Settings for localized attributes + More information: https://www.meilisearch.com/docs/reference/api/settings#update-settings @@ -1000,7 +1025,8 @@ def update_settings(self, body: MutableMapping[str, Any]) -> TaskInfo: Raises ------ MeilisearchApiError - An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors + An error containing details about why Meilisearch can't process your request. + Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors """ if body.get("embedders"): for _, v in body["embedders"].items(): @@ -1879,10 +1905,13 @@ def reset_non_separator_tokens(self) -> TaskInfo: def get_embedders(self) -> Embedders | None: """Get embedders of the index. + Retrieves the current embedder configuration from Meilisearch. + Returns ------- - settings: - The embedders settings of the index. + Embedders: + The embedders settings of the index, or None if no embedders are configured. + Contains a dictionary of embedder configurations, where keys are embedder names. Raises ------ @@ -1894,24 +1923,21 @@ def get_embedders(self) -> Embedders | None: if not response: return None - embedders: dict[ - str, - OpenAiEmbedder - | HuggingFaceEmbedder - | OllamaEmbedder - | RestEmbedder - | UserProvidedEmbedder, - ] = {} + embedders: dict[str, EmbedderType] = {} for k, v in response.items(): - if v.get("source") == "openAi": + source = v.get("source") + if source == "openAi": embedders[k] = OpenAiEmbedder(**v) - elif v.get("source") == "ollama": - embedders[k] = OllamaEmbedder(**v) - elif v.get("source") == "huggingFace": + elif source == "huggingFace": embedders[k] = HuggingFaceEmbedder(**v) - elif v.get("source") == "rest": + elif source == "ollama": + embedders[k] = OllamaEmbedder(**v) + elif source == "rest": embedders[k] = RestEmbedder(**v) + elif source == "userProvided": + embedders[k] = UserProvidedEmbedder(**v) else: + # Default to UserProvidedEmbedder for unknown sources embedders[k] = UserProvidedEmbedder(**v) return Embedders(embedders=embedders) @@ -1919,10 +1945,13 @@ def get_embedders(self) -> Embedders | None: def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskInfo: """Update embedders of the index. + Updates the embedder configuration for the index. The embedder configuration + determines how Meilisearch generates vector embeddings for documents. + Parameters ---------- body: dict - Dictionary containing the embedders. + Dictionary containing the embedders configuration. Returns ------- @@ -1933,13 +1962,28 @@ def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskI Raises ------ MeilisearchApiError - An error containing details about why Meilisearch can't process your request. Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors + An error containing details about why Meilisearch can't process your request. + Meilisearch error codes are described here: https://www.meilisearch.com/docs/reference/errors/error_codes#meilisearch-errors """ + if body is not None and body.get("embedders"): + embedders: dict[str, EmbedderType] = {} + for k, v in body["embedders"].items(): + source = v.get("source") + if source == "openAi": + embedders[k] = OpenAiEmbedder(**v) + elif source == "huggingFace": + embedders[k] = HuggingFaceEmbedder(**v) + elif source == "ollama": + embedders[k] = OllamaEmbedder(**v) + elif source == "rest": + embedders[k] = RestEmbedder(**v) + elif source == "userProvided": + embedders[k] = UserProvidedEmbedder(**v) + else: + # Default to UserProvidedEmbedder for unknown sources + embedders[k] = UserProvidedEmbedder(**v) - if body: - for _, v in body.items(): - if "documentTemplateMaxBytes" in v and v["documentTemplateMaxBytes"] is None: - del v["documentTemplateMaxBytes"] + body = {"embedders": {k: v.model_dump(by_alias=True) for k, v in embedders.items()}} task = self.http.patch(self.__settings_url_for(self.config.paths.embedders), body) @@ -1948,6 +1992,8 @@ def update_embedders(self, body: Union[MutableMapping[str, Any], None]) -> TaskI def reset_embedders(self) -> TaskInfo: """Reset embedders of the index to default values. + Removes all embedder configurations from the index. + Returns ------- task_info: diff --git a/meilisearch/models/embedders.py b/meilisearch/models/embedders.py new file mode 100644 index 00000000..01ba7b3c --- /dev/null +++ b/meilisearch/models/embedders.py @@ -0,0 +1,213 @@ +from __future__ import annotations + +from typing import Any, Dict, Optional, Union + +from camel_converter.pydantic_base import CamelBase + + +class Distribution(CamelBase): + """Distribution settings for embedders. + + Parameters + ---------- + mean: float + Mean value between 0 and 1 + sigma: float + Sigma value between 0 and 1 + """ + + mean: float + sigma: float + + +class OpenAiEmbedder(CamelBase): + """OpenAI embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "openAi" + url: Optional[str] + The URL Meilisearch contacts when querying the embedder + api_key: Optional[str] + Authentication token Meilisearch should send with each request to the embedder + model: Optional[str] + The model your embedder uses when generating vectors (defaults to text-embedding-3-small) + dimensions: Optional[int] + Number of dimensions in the chosen model + document_template: Optional[str] + Template defining the data Meilisearch sends to the embedder + document_template_max_bytes: Optional[int] + Maximum allowed size of rendered document template (defaults to 400) + distribution: Optional[Distribution] + Describes the natural distribution of search results + binary_quantized: Optional[bool] + Once set to true, irreversibly converts all vector dimensions to 1-bit values + """ + + source: str = "openAi" + url: Optional[str] = None + api_key: Optional[str] = None + model: Optional[str] = None # Defaults to text-embedding-3-small + dimensions: Optional[int] = None # Uses the model default + document_template: Optional[str] = None + document_template_max_bytes: Optional[int] = None # Default to 400 + distribution: Optional[Distribution] = None + binary_quantized: Optional[bool] = None + + +class HuggingFaceEmbedder(CamelBase): + """HuggingFace embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "huggingFace" + url: Optional[str] + The URL Meilisearch contacts when querying the embedder + model: Optional[str] + The model your embedder uses when generating vectors (defaults to BAAI/bge-base-en-v1.5) + dimensions: Optional[int] + Number of dimensions in the chosen model + revision: Optional[str] + Model revision hash + document_template: Optional[str] + Template defining the data Meilisearch sends to the embedder + document_template_max_bytes: Optional[int] + Maximum allowed size of rendered document template (defaults to 400) + distribution: Optional[Distribution] + Describes the natural distribution of search results + binary_quantized: Optional[bool] + Once set to true, irreversibly converts all vector dimensions to 1-bit values + """ + + source: str = "huggingFace" + url: Optional[str] = None + model: Optional[str] = None # Defaults to BAAI/bge-base-en-v1.5 + dimensions: Optional[int] = None + revision: Optional[str] = None + document_template: Optional[str] = None + document_template_max_bytes: Optional[int] = None # Default to 400 + distribution: Optional[Distribution] = None + binary_quantized: Optional[bool] = None + + +class OllamaEmbedder(CamelBase): + """Ollama embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "ollama" + url: Optional[str] + The URL Meilisearch contacts when querying the embedder (defaults to http://localhost:11434/api/embeddings) + api_key: Optional[str] + Authentication token Meilisearch should send with each request to the embedder + model: Optional[str] + The model your embedder uses when generating vectors + dimensions: Optional[int] + Number of dimensions in the chosen model + document_template: Optional[str] + Template defining the data Meilisearch sends to the embedder + document_template_max_bytes: Optional[int] + Maximum allowed size of rendered document template (defaults to 400) + distribution: Optional[Distribution] + Describes the natural distribution of search results + binary_quantized: Optional[bool] + Once set to true, irreversibly converts all vector dimensions to 1-bit values + """ + + source: str = "ollama" + url: Optional[str] = None + api_key: Optional[str] = None + model: Optional[str] = None + dimensions: Optional[int] = None + document_template: Optional[str] = None + document_template_max_bytes: Optional[int] = None + distribution: Optional[Distribution] = None + binary_quantized: Optional[bool] = None + + +class RestEmbedder(CamelBase): + """REST API embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "rest" + url: Optional[str] + The URL Meilisearch contacts when querying the embedder + api_key: Optional[str] + Authentication token Meilisearch should send with each request to the embedder + dimensions: Optional[int] + Number of dimensions in the embeddings + document_template: Optional[str] + Template defining the data Meilisearch sends to the embedder + document_template_max_bytes: Optional[int] + Maximum allowed size of rendered document template (defaults to 400) + request: Dict[str, Any] + A JSON value representing the request Meilisearch makes to the remote embedder + response: Dict[str, Any] + A JSON value representing the request Meilisearch expects from the remote embedder + headers: Optional[Dict[str, str]] + Custom headers to send with the request + distribution: Optional[Distribution] + Describes the natural distribution of search results + binary_quantized: Optional[bool] + Once set to true, irreversibly converts all vector dimensions to 1-bit values + """ + + source: str = "rest" + url: Optional[str] = None + api_key: Optional[str] = None + dimensions: Optional[int] = None + document_template: Optional[str] = None + document_template_max_bytes: Optional[int] = None + request: Dict[str, Any] + response: Dict[str, Any] + headers: Optional[Dict[str, str]] = None + distribution: Optional[Distribution] = None + binary_quantized: Optional[bool] = None + + +class UserProvidedEmbedder(CamelBase): + """User-provided embedder configuration. + + Parameters + ---------- + source: str + The embedder source, must be "userProvided" + dimensions: int + Number of dimensions in the embeddings + distribution: Optional[Distribution] + Describes the natural distribution of search results + binary_quantized: Optional[bool] + Once set to true, irreversibly converts all vector dimensions to 1-bit values + """ + + source: str = "userProvided" + dimensions: int + distribution: Optional[Distribution] = None + binary_quantized: Optional[bool] = None + + +# Type alias for the embedder union type +EmbedderType = Union[ + OpenAiEmbedder, + HuggingFaceEmbedder, + OllamaEmbedder, + RestEmbedder, + UserProvidedEmbedder, +] + + +class Embedders(CamelBase): + """Container for embedder configurations. + + Parameters + ---------- + embedders: Dict[str, Union[OpenAiEmbedder, HuggingFaceEmbedder, OllamaEmbedder, RestEmbedder, UserProvidedEmbedder]] + Dictionary of embedder configurations, where keys are embedder names + """ + + embedders: Dict[str, EmbedderType] diff --git a/meilisearch/models/index.py b/meilisearch/models/index.py index 5457be56..f23cb458 100644 --- a/meilisearch/models/index.py +++ b/meilisearch/models/index.py @@ -1,7 +1,7 @@ from __future__ import annotations from enum import Enum -from typing import Any, Dict, Iterator, List, Optional, Union +from typing import Any, Dict, Iterator, List, Optional from camel_converter import to_snake from camel_converter.pydantic_base import CamelBase @@ -62,65 +62,3 @@ class EmbedderDistribution(CamelBase): class LocalizedAttributes(CamelBase): attribute_patterns: List[str] locales: List[str] - - -class OpenAiEmbedder(CamelBase): - source: str = "openAi" - url: Optional[str] = None - model: Optional[str] = None # Defaults to text-embedding-3-small - dimensions: Optional[int] = None # Uses the model default - api_key: Optional[str] = None # Can be provided through a CLI option or environment variable - document_template: Optional[str] = None - document_template_max_bytes: Optional[int] = None # Default to 400 - distribution: Optional[EmbedderDistribution] = None - binary_quantized: Optional[bool] = None - - -class HuggingFaceEmbedder(CamelBase): - source: str = "huggingFace" - model: Optional[str] = None # Defaults to BAAI/bge-base-en-v1.5 - revision: Optional[str] = None - document_template: Optional[str] = None - document_template_max_bytes: Optional[int] = None # Default to 400 - distribution: Optional[EmbedderDistribution] = None - binary_quantized: Optional[bool] = None - - -class OllamaEmbedder(CamelBase): - source: str = "ollama" - url: Optional[str] = None - api_key: Optional[str] = None - model: str - document_template: Optional[str] = None - document_template_max_bytes: Optional[int] = None # Default to 400 - distribution: Optional[EmbedderDistribution] = None - binary_quantized: Optional[bool] = None - - -class RestEmbedder(CamelBase): - source: str = "rest" - url: str - api_key: Optional[str] # required for protected APIs - document_template: Optional[str] = None - document_template_max_bytes: Optional[int] = None # Default to 400 - request: Dict[str, Any] - response: Dict[str, Any] - distribution: Optional[EmbedderDistribution] = None - headers: Optional[Dict[str, Any]] = None - binary_quantized: Optional[bool] = None - - -class UserProvidedEmbedder(CamelBase): - source: str = "userProvided" - dimensions: int - distribution: Optional[EmbedderDistribution] = None - binary_quantized: Optional[bool] = None - - -class Embedders(CamelBase): - embedders: Dict[ - str, - Union[ - OpenAiEmbedder, HuggingFaceEmbedder, OllamaEmbedder, RestEmbedder, UserProvidedEmbedder - ], - ] diff --git a/tests/conftest.py b/tests/conftest.py index 7a814ff7..54e53e4e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -7,7 +7,7 @@ import meilisearch from meilisearch.errors import MeilisearchApiError -from meilisearch.models.index import OpenAiEmbedder, UserProvidedEmbedder +from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder from tests import common diff --git a/tests/index/test_index_search_meilisearch.py b/tests/index/test_index_search_meilisearch.py index 8939e9c5..8b087d1c 100644 --- a/tests/index/test_index_search_meilisearch.py +++ b/tests/index/test_index_search_meilisearch.py @@ -503,11 +503,28 @@ def test_show_ranking_score(index_with_documents): def test_vector_search(index_with_documents_and_vectors): + """Tests vector search with hybrid parameters.""" response = index_with_documents_and_vectors().search( "", opt_params={"vector": [0.1, 0.2], "hybrid": {"semanticRatio": 1.0, "embedder": "default"}}, ) assert len(response["hits"]) > 0 + # Check that semanticHitCount field is present in the response + assert "semanticHitCount" in response + # With semanticRatio = 1.0, all hits should be semantic + assert response["semanticHitCount"] == len(response["hits"]) + + +def test_hybrid_search(index_with_documents_and_vectors): + """Tests hybrid search with semantic ratio and embedder.""" + response = index_with_documents_and_vectors().search( + "movie", opt_params={"hybrid": {"semanticRatio": 0.5, "embedder": "default"}} + ) + assert len(response["hits"]) > 0 + # Check that semanticHitCount field is present in the response + assert "semanticHitCount" in response + # semanticHitCount should be an integer + assert isinstance(response["semanticHitCount"], int) def test_search_distinct(index_with_documents): @@ -534,3 +551,70 @@ def test_search_ranking_threshold(query, ranking_score_threshold, expected, inde query, {"rankingScoreThreshold": ranking_score_threshold} ) assert len(response["hits"]) == expected + + +def test_vector_search_with_retrieve_vectors(index_with_documents_and_vectors): + """Tests vector search with retrieveVectors parameter.""" + response = index_with_documents_and_vectors().search( + "", + opt_params={ + "vector": [0.1, 0.2], + "retrieveVectors": True, + "hybrid": {"semanticRatio": 1.0, "embedder": "default"}, + }, + ) + assert len(response["hits"]) > 0 + # Check that the first hit has a _vectors field + assert "_vectors" in response["hits"][0] + # Check that the _vectors field contains the default embedder + assert "default" in response["hits"][0]["_vectors"] + + +def test_get_similar_documents_with_identical_vectors(empty_index): + """Tests get_similar_documents method with documents having identical vectors.""" + # Create documents with identical vector embeddings + identical_vector = [0.5, 0.5] + documents = [ + {"id": "doc1", "title": "Document 1", "_vectors": {"default": identical_vector}}, + {"id": "doc2", "title": "Document 2", "_vectors": {"default": identical_vector}}, + {"id": "doc3", "title": "Document 3", "_vectors": {"default": identical_vector}}, + # Add a document with a different vector to verify it's not returned first + {"id": "doc4", "title": "Document 4", "_vectors": {"default": [0.1, 0.1]}}, + ] + + # Set up the index with the documents + index = empty_index() + + # Configure the embedder + settings_update_task = index.update_embedders( + { + "default": { + "source": "userProvided", + "dimensions": 2, + } + } + ) + index.wait_for_task(settings_update_task.task_uid) + + # Add the documents + document_addition_task = index.add_documents(documents) + index.wait_for_task(document_addition_task.task_uid) + + # Test get_similar_documents with doc1 + response = index.get_similar_documents({"id": "doc1", "embedder": "default"}) + + # Verify response structure + assert isinstance(response, dict) + assert "hits" in response + assert len(response["hits"]) >= 2 # Should find at least doc2 and doc3 + assert "id" in response + assert response["id"] == "doc1" + + # Verify that doc2 and doc3 are in the results (they have identical vectors to doc1) + result_ids = [hit["id"] for hit in response["hits"]] + assert "doc2" in result_ids + assert "doc3" in result_ids + + # Verify that doc4 is not the first result (it has a different vector) + if "doc4" in result_ids: + assert result_ids[0] != "doc4" diff --git a/tests/settings/test_settings.py b/tests/settings/test_settings.py index 8bbdafa2..147001de 100644 --- a/tests/settings/test_settings.py +++ b/tests/settings/test_settings.py @@ -1,7 +1,7 @@ # pylint: disable=redefined-outer-name import pytest -from meilisearch.models.index import OpenAiEmbedder, UserProvidedEmbedder +from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder @pytest.fixture diff --git a/tests/settings/test_settings_embedders.py b/tests/settings/test_settings_embedders.py index f932bbae..5baf2e09 100644 --- a/tests/settings/test_settings_embedders.py +++ b/tests/settings/test_settings_embedders.py @@ -1,6 +1,6 @@ # pylint: disable=redefined-outer-name -from meilisearch.models.index import OpenAiEmbedder, UserProvidedEmbedder +from meilisearch.models.embedders import OpenAiEmbedder, UserProvidedEmbedder def test_get_default_embedders(empty_index): @@ -42,3 +42,144 @@ def test_reset_embedders(new_embedders, empty_index): assert isinstance(response_get.embedders["open_ai"], OpenAiEmbedder) response_last = index.get_embedders() assert response_last is None + + +def test_openai_embedder_format(empty_index): + """Tests that OpenAi embedder has the required fields and proper format.""" + index = empty_index() + + openai_embedder = { + "openai": { + "source": "openAi", + "apiKey": "test-key", + "model": "text-embedding-3-small", + "dimensions": 1536, + "documentTemplateMaxBytes": 400, + "distribution": {"mean": 0.5, "sigma": 0.1}, + "binaryQuantized": False, + } + } + response = index.update_embedders(openai_embedder) + index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert embedders.embedders["openai"].source == "openAi" + assert embedders.embedders["openai"].model == "text-embedding-3-small" + assert embedders.embedders["openai"].dimensions == 1536 + assert hasattr(embedders.embedders["openai"], "document_template") + assert embedders.embedders["openai"].document_template_max_bytes == 400 + assert embedders.embedders["openai"].distribution.mean == 0.5 + assert embedders.embedders["openai"].distribution.sigma == 0.1 + assert embedders.embedders["openai"].binary_quantized is False + + +def test_huggingface_embedder_format(empty_index): + """Tests that HuggingFace embedder has the required fields and proper format.""" + index = empty_index() + + huggingface_embedder = { + "huggingface": { + "source": "huggingFace", + "model": "BAAI/bge-base-en-v1.5", + "revision": "main", + "documentTemplateMaxBytes": 400, + "distribution": {"mean": 0.5, "sigma": 0.1}, + "binaryQuantized": False, + } + } + response = index.update_embedders(huggingface_embedder) + index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert embedders.embedders["huggingface"].source == "huggingFace" + assert embedders.embedders["huggingface"].model == "BAAI/bge-base-en-v1.5" + assert embedders.embedders["huggingface"].revision == "main" + assert hasattr(embedders.embedders["huggingface"], "document_template") + assert embedders.embedders["huggingface"].document_template_max_bytes == 400 + assert embedders.embedders["huggingface"].distribution.mean == 0.5 + assert embedders.embedders["huggingface"].distribution.sigma == 0.1 + assert embedders.embedders["huggingface"].binary_quantized is False + + +def test_ollama_embedder_format(empty_index): + """Tests that Ollama embedder has the required fields and proper format.""" + index = empty_index() + + ollama_embedder = { + "ollama": { + "source": "ollama", + "url": "http://localhost:11434/api/embeddings", + "apiKey": "test-key", + "model": "llama2", + "dimensions": 4096, + "documentTemplateMaxBytes": 400, + "distribution": {"mean": 0.5, "sigma": 0.1}, + "binaryQuantized": False, + } + } + response = index.update_embedders(ollama_embedder) + index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert embedders.embedders["ollama"].source == "ollama" + assert embedders.embedders["ollama"].url == "http://localhost:11434/api/embeddings" + assert embedders.embedders["ollama"].model == "llama2" + assert embedders.embedders["ollama"].dimensions == 4096 + assert hasattr(embedders.embedders["ollama"], "document_template") + assert embedders.embedders["ollama"].document_template_max_bytes == 400 + assert embedders.embedders["ollama"].distribution.mean == 0.5 + assert embedders.embedders["ollama"].distribution.sigma == 0.1 + assert embedders.embedders["ollama"].binary_quantized is False + + +def test_rest_embedder_format(empty_index): + """Tests that Rest embedder has the required fields and proper format.""" + index = empty_index() + + rest_embedder = { + "rest": { + "source": "rest", + "url": "http://localhost:8000/embed", + "apiKey": "test-key", + "dimensions": 512, + "documentTemplateMaxBytes": 400, + "request": {"model": "MODEL_NAME", "input": "{{text}}"}, + "response": {"result": {"data": ["{{embedding}}"]}}, + "headers": {"Authorization": "Bearer test-key"}, + "distribution": {"mean": 0.5, "sigma": 0.1}, + "binaryQuantized": False, + } + } + response = index.update_embedders(rest_embedder) + index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert embedders.embedders["rest"].source == "rest" + assert embedders.embedders["rest"].url == "http://localhost:8000/embed" + assert embedders.embedders["rest"].dimensions == 512 + assert hasattr(embedders.embedders["rest"], "document_template") + assert embedders.embedders["rest"].document_template_max_bytes == 400 + assert embedders.embedders["rest"].request == {"model": "MODEL_NAME", "input": "{{text}}"} + assert embedders.embedders["rest"].response == {"result": {"data": ["{{embedding}}"]}} + assert embedders.embedders["rest"].headers == {"Authorization": "Bearer test-key"} + assert embedders.embedders["rest"].distribution.mean == 0.5 + assert embedders.embedders["rest"].distribution.sigma == 0.1 + assert embedders.embedders["rest"].binary_quantized is False + + +def test_user_provided_embedder_format(empty_index): + """Tests that UserProvided embedder has the required fields and proper format.""" + index = empty_index() + + user_provided_embedder = { + "user_provided": { + "source": "userProvided", + "dimensions": 512, + "distribution": {"mean": 0.5, "sigma": 0.1}, + "binaryQuantized": False, + } + } + response = index.update_embedders(user_provided_embedder) + index.wait_for_task(response.task_uid) + embedders = index.get_embedders() + assert embedders.embedders["user_provided"].source == "userProvided" + assert embedders.embedders["user_provided"].dimensions == 512 + assert embedders.embedders["user_provided"].distribution.mean == 0.5 + assert embedders.embedders["user_provided"].distribution.sigma == 0.1 + assert embedders.embedders["user_provided"].binary_quantized is False