diff --git a/README.md b/README.md index bb79a7dab..f105fd376 100644 --- a/README.md +++ b/README.md @@ -103,3 +103,15 @@ See the examples in the [clients](./engine/clients) directory. Once all the necessary classes are implemented, you can register the engine in the [ClientFactory](./engine/clients/client_factory.py). +### Doris Vector Search Support + +This repository now includes experimental support for benchmarking Apache Doris vector search via the [`doris_vector_search`](https://github.com/uchenily/doris_vector_search) SDK. + +To run a Doris benchmark (server docker unchanged for now), start your Doris cluster separately, then invoke: + +```bash +python run.py --engines "doris" --datasets "dbpedia-openai-100K-1536-angular" +``` + +You can adjust connection parameters and table/database names through the `experiments/configurations` engine configuration files (add a section for `doris`). Table schema is inferred automatically on first upload batch; vector distance mapping uses `l2_distance` or `inner_product` depending on dataset distance (cosine mapped to inner product assuming normalized vectors). + diff --git a/TODO.md b/TODO.md new file mode 100644 index 000000000..01ccf0192 --- /dev/null +++ b/TODO.md @@ -0,0 +1,3 @@ +- [ ] Fix Doris search init when distance argument is string +- [ ] Ensure Doris searcher opens existing table or handles missing gracefully +- [ ] Address missing column detection due to newly created table \ No newline at end of file diff --git a/engine/base_client/client.py b/engine/base_client/client.py index 670768a97..f4104adea 100644 --- a/engine/base_client/client.py +++ b/engine/base_client/client.py @@ -90,6 +90,20 @@ def run_experiment( distance=dataset.config.distance, vector_size=dataset.config.vector_size ) + # Ensure Doris components (and other engines that might need it) know the vector dim + vector_dim = dataset.config.vector_size + if vector_dim is not None: + self.configurator.collection_params.setdefault("vector_dim", vector_dim) + uploader_collection = self.uploader.upload_params.setdefault( + "collection_params", {} + ) + uploader_collection.setdefault("vector_dim", vector_dim) + for searcher in self.searchers: + search_collection = searcher.search_params.setdefault( + "collection_params", {} + ) + search_collection.setdefault("vector_dim", vector_dim) + reader = dataset.get_reader(execution_params.get("normalize", False)) if skip_if_exists: diff --git a/engine/clients/client_factory.py b/engine/clients/client_factory.py index a74df2ab4..f7df2da62 100644 --- a/engine/clients/client_factory.py +++ b/engine/clients/client_factory.py @@ -7,6 +7,7 @@ BaseSearcher, BaseUploader, ) +from engine.clients.doris import DorisConfigurator, DorisSearcher, DorisUploader from engine.clients.elasticsearch import ( ElasticConfigurator, ElasticSearcher, @@ -39,6 +40,7 @@ "opensearch": OpenSearchConfigurator, "redis": RedisConfigurator, "pgvector": PgVectorConfigurator, + "doris": DorisConfigurator, } ENGINE_UPLOADERS = { @@ -49,6 +51,7 @@ "opensearch": OpenSearchUploader, "redis": RedisUploader, "pgvector": PgVectorUploader, + "doris": DorisUploader, } ENGINE_SEARCHERS = { @@ -59,6 +62,7 @@ "opensearch": OpenSearchSearcher, "redis": RedisSearcher, "pgvector": PgVectorSearcher, + "doris": DorisSearcher, } @@ -79,10 +83,18 @@ def _create_configurator(self, experiment) -> BaseConfigurator: def _create_uploader(self, experiment) -> BaseUploader: engine_uploader_class = ENGINE_UPLOADERS[experiment["engine"]] + upload_params = {**experiment.get("upload_params", {})} + # Propagate collection_params for engines that need database/table info during upload (e.g., doris) + if experiment["engine"] == "doris": + merged_collection = { + **experiment.get("collection_params", {}), + **upload_params.get("collection_params", {}), + } + upload_params["collection_params"] = merged_collection engine_uploader = engine_uploader_class( self.host, connection_params={**experiment.get("connection_params", {})}, - upload_params={**experiment.get("upload_params", {})}, + upload_params=upload_params, ) return engine_uploader @@ -90,15 +102,22 @@ def _create_searchers(self, experiment) -> List[BaseSearcher]: engine_searcher_class: Type[BaseSearcher] = ENGINE_SEARCHERS[ experiment["engine"] ] - - engine_searchers = [ - engine_searcher_class( - self.host, - connection_params={**experiment.get("connection_params", {})}, - search_params=search_params, + engine_searchers = [] + for search_params in experiment.get("search_params", [{}]): + params = {**search_params} + if experiment["engine"] == "doris": + merged_collection = { + **experiment.get("collection_params", {}), + **params.get("collection_params", {}), + } + params["collection_params"] = merged_collection + engine_searchers.append( + engine_searcher_class( + self.host, + connection_params={**experiment.get("connection_params", {})}, + search_params=params, + ) ) - for search_params in experiment.get("search_params", [{}]) - ] return engine_searchers diff --git a/engine/clients/doris/__init__.py b/engine/clients/doris/__init__.py new file mode 100644 index 000000000..8d67e2937 --- /dev/null +++ b/engine/clients/doris/__init__.py @@ -0,0 +1,9 @@ +from .configure import DorisConfigurator +from .search import DorisSearcher +from .upload import DorisUploader + +__all__ = [ + "DorisConfigurator", + "DorisUploader", + "DorisSearcher", +] diff --git a/engine/clients/doris/config.py b/engine/clients/doris/config.py new file mode 100644 index 000000000..c87b3a33f --- /dev/null +++ b/engine/clients/doris/config.py @@ -0,0 +1,10 @@ +DEFAULT_DORIS_DATABASE = "benchmark" +DEFAULT_DORIS_TABLE = "vectors" + +# Mapping from internal distances to doris metric_type +DISTANCE_MAPPING = { + "l2": "l2_distance", + "dot": "inner_product", + # Cosine can be approximated by inner product if vectors normalized upstream + "cosine": "inner_product", +} diff --git a/engine/clients/doris/configure.py b/engine/clients/doris/configure.py new file mode 100644 index 000000000..f6e40c097 --- /dev/null +++ b/engine/clients/doris/configure.py @@ -0,0 +1,76 @@ +from typing import Optional + +import mysql.connector +from doris_vector_search import AuthOptions, DorisVectorClient +from mysql.connector import ProgrammingError + +from benchmark.dataset import Dataset +from engine.base_client.configure import BaseConfigurator +from engine.base_client.distances import Distance +from engine.clients.doris.config import ( + DEFAULT_DORIS_DATABASE, + DEFAULT_DORIS_TABLE, + DISTANCE_MAPPING, +) + + +class DorisConfigurator(BaseConfigurator): + SPARSE_VECTOR_SUPPORT = False + + DISTANCE_MAPPING = { + Distance.L2: DISTANCE_MAPPING["l2"], + Distance.DOT: DISTANCE_MAPPING["dot"], + Distance.COSINE: DISTANCE_MAPPING["cosine"], + } + + def __init__(self, host, collection_params: dict, connection_params: dict): + super().__init__(host, collection_params, connection_params) + + database = collection_params.get("database", DEFAULT_DORIS_DATABASE) + auth = AuthOptions( + host=connection_params.get("host", host), + query_port=connection_params.get("query_port", 9030), + http_port=connection_params.get("http_port", 8030), + user=connection_params.get("user", "root"), + password=connection_params.get("password", ""), + ) + # Ensure database exists before creating main client + try: + tmp_conn = mysql.connector.connect( + host=auth.host, + port=auth.query_port, + user=auth.user, + password=auth.password, + ) + cursor = tmp_conn.cursor() + cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{database}`") + cursor.close() + tmp_conn.close() + except ProgrammingError: + # If we cannot create database, proceed and let actual client raise clearer error + pass + self.client = DorisVectorClient(database=database, auth_options=auth) + + def clean(self): + table_name = self.collection_params.get("table_name", DEFAULT_DORIS_TABLE) + try: + self.client.drop_table(table_name) + except Exception: + # Table may not exist, ignore + pass + + def recreate(self, dataset: Dataset, collection_params) -> Optional[dict]: + # Doris table and index are created lazily on first upload batch to infer schema + # Return execution params which depend on distance/metric mapping + return {} + + def execution_params(self, distance, vector_size) -> dict: + metric = self.DISTANCE_MAPPING.get(distance) + # Provide search-related session variables tuning if needed + return {"metric_type": metric} + + def delete_client(self): + try: + self.client.close() + except Exception: + pass diff --git a/engine/clients/doris/search.py b/engine/clients/doris/search.py new file mode 100644 index 000000000..ac10e89e0 --- /dev/null +++ b/engine/clients/doris/search.py @@ -0,0 +1,243 @@ +import atexit +import math +from typing import Dict, List, Tuple + +import mysql.connector +from doris_vector_search import AuthOptions, DorisVectorClient +from mysql.connector import ProgrammingError + +from dataset_reader.base_reader import Query +from engine.base_client.distances import Distance +from engine.base_client.search import BaseSearcher +from engine.clients.doris.config import ( + DEFAULT_DORIS_DATABASE, + DEFAULT_DORIS_TABLE, + DISTANCE_MAPPING, +) + + +class DorisSearcher(BaseSearcher): + search_params = {} + client: DorisVectorClient = None + table = None + id_column = "id" + metric_type = "l2_distance" + table_name = DEFAULT_DORIS_TABLE + vector_dim = None + _cleanup_registered = False + + @classmethod + def get_mp_start_method(cls): + return "spawn" + + @classmethod + def init_client(cls, host, distance, connection_params: dict, search_params: dict): + database = ( + search_params.get("database") + or search_params.get("collection_params", {}).get("database") + or DEFAULT_DORIS_DATABASE + ) + auth = AuthOptions( + host=connection_params.get("host", host), + query_port=connection_params.get("query_port", 9030), + http_port=connection_params.get("http_port", 8030), + user=connection_params.get("user", "root"), + password=connection_params.get("password", ""), + ) + # Ensure database exists before connecting + try: + tmp_conn = mysql.connector.connect( + host=auth.host, + port=auth.query_port, + user=auth.user, + password=auth.password, + ) + cursor = tmp_conn.cursor() + cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{database}`") + cursor.close() + tmp_conn.close() + except ProgrammingError: + pass + cls.client = DorisVectorClient(database=database, auth_options=auth) + cls.search_params = search_params + cls.table_name = ( + search_params.get("table_name") + or search_params.get("collection_params", {}).get("table_name") + or DEFAULT_DORIS_TABLE + ) + if isinstance(distance, Distance): + distance_key = distance.value + else: + distance_key = str(distance).lower() + cls.metric_type = DISTANCE_MAPPING.get(distance_key, "l2_distance") + cls.vector_dim = search_params.get("vector_dim") or search_params.get( + "collection_params", {} + ).get("vector_dim") + if cls.vector_dim: + try: + cls.vector_dim = int(cls.vector_dim) + except (TypeError, ValueError): + cls.vector_dim = None + + def setup_search(self): + if self.__class__.table is None: + try: + self.__class__.table = self.__class__.client.open_table( + self.__class__.table_name + ) + if self.__class__.vector_dim: + self.__class__.table.index_options.dim = self.__class__.vector_dim + # Detect id column: first non-vector column from table schema + # Fallback: "id" + try: + cols = self.__class__.table.column_names + # Choose first column that is not vector-like + self.__class__.id_column = cols[0] if cols else "id" + except Exception: + pass + # Apply session overrides for search if provided + cfg = self.search_params.get("config", {}) + sessions = {} + # Accept either doris-native key or pgvector-like alias + if "hnsw_ef_search" in cfg: + sessions["hnsw_ef_search"] = str(cfg["hnsw_ef_search"]) + if "hnsw_ef" in cfg and "hnsw_ef_search" not in sessions: + sessions["hnsw_ef_search"] = str(cfg["hnsw_ef"]) + if sessions: + try: + self.__class__.client.with_sessions(sessions) + except Exception: + pass + except Exception as ex: + raise RuntimeError( + f"Failed to open Doris table '{self.__class__.table_name}': {ex}" + ) + + @classmethod + def search_one(cls, query: Query, top: int) -> List[Tuple[int, float]]: + if cls.table is None: + cls.table = cls.client.open_table(cls.table_name) + if cls.vector_dim: + cls.table.index_options.dim = cls.vector_dim + vector_query = cls.table.search(query.vector, metric_type=cls.metric_type) + vector_query.limit(top) + + res = cls._execute_vector_query(vector_query) + results = [] + for row in res: + # Distance field may vary; try common keys + distance = row.get("distance") or row.get("score") or 0.0 + identifier = row.get(cls.id_column) or row.get("id") + if identifier is None: + continue + try: + identifier = int(identifier) + except Exception: + # If cannot cast, skip + continue + results.append((identifier, float(distance))) + return results + + @classmethod + def _execute_vector_query(cls, vector_query) -> List[Dict[str, object]]: + select_columns = ( + vector_query.selected_columns or vector_query.table.column_names + ) + distance_range = None + if ( + vector_query.distance_range_lower is not None + or vector_query.distance_range_upper is not None + ): + distance_range = ( + vector_query.distance_range_lower, + vector_query.distance_range_upper, + ) + + where_conditions = vector_query.where_conditions or None + + sql = vector_query.compiler.compile_vector_search_query( + table_name=vector_query.table.table_name, + query_vector=vector_query.query_vector, + vector_column=vector_query.vector_column, + limit=vector_query.limit_value, + distance_range=distance_range, + where_conditions=where_conditions, + selected_columns=select_columns, + metric_type=vector_query.metric_type, + ) + + cursor = vector_query.table._get_cursor(prepared=False) + cursor.execute(sql) + rows = cursor.fetchall() or [] + + processed: List[Dict[str, object]] = [] + for raw_row in rows: + row_dict: Dict[str, object] = {} + for col_name, value in zip(select_columns, raw_row): + if isinstance(value, (bytes, bytearray)): + value = value.decode("utf-8") + row_dict[col_name] = value + + processed.append( + cls._postprocess_row( + row_dict, + vector_query.vector_column, + vector_query.query_vector, + ) + ) + + return processed + + @classmethod + def _postprocess_row( + cls, + row: Dict[str, object], + vector_column: str, + query_vector: List[float], + ) -> Dict[str, object]: + vector_data = row.get(vector_column) + + if ( + isinstance(vector_data, str) + and vector_data.startswith("[") + and vector_data.endswith("]") + ): + try: + vector_values = [ + float(item.strip()) + for item in vector_data[1:-1].split(",") + if item.strip() + ] + except ValueError: + vector_values = None + else: + row[vector_column] = vector_values + elif isinstance(vector_data, list): + vector_values = vector_data + else: + vector_values = None + + if vector_values and len(vector_values) == len(query_vector): + if cls.metric_type == "inner_product": + score = sum(a * b for a, b in zip(query_vector, vector_values)) + row.setdefault("score", score) + row.setdefault("distance", -score) + else: + dist_sq = sum((a - b) ** 2 for a, b in zip(query_vector, vector_values)) + row.setdefault("distance", math.sqrt(dist_sq)) + + return row + + @classmethod + def delete_client(cls): + try: + if cls.client: + cls.client.close() + finally: + cls.client = None + cls.table = None + + +if not DorisSearcher._cleanup_registered: + atexit.register(DorisSearcher.delete_client) + DorisSearcher._cleanup_registered = True diff --git a/engine/clients/doris/upload.py b/engine/clients/doris/upload.py new file mode 100644 index 000000000..3650f5594 --- /dev/null +++ b/engine/clients/doris/upload.py @@ -0,0 +1,165 @@ +import atexit +from contextlib import closing +from typing import List + +import mysql.connector +from doris_vector_search import AuthOptions, DorisVectorClient, IndexOptions +from mysql.connector import Error, ProgrammingError + +from dataset_reader.base_reader import Record +from engine.base_client.upload import BaseUploader +from engine.clients.doris.config import DEFAULT_DORIS_DATABASE, DEFAULT_DORIS_TABLE + + +class DorisUploader(BaseUploader): + client: DorisVectorClient = None + table = None + created = False + upload_params = {} + collection_params = {} + metric_type = "l2_distance" + vector_dim = None + _cleanup_registered = False + + @classmethod + def get_mp_start_method(cls): + return "spawn" + + @classmethod + def init_client(cls, host, distance, connection_params, upload_params): + # Prefer database passed within collection_params (from experiment file) + database = ( + upload_params.get("database") + or upload_params.get("collection_params", {}).get("database") + or DEFAULT_DORIS_DATABASE + ) + auth = AuthOptions( + host=connection_params.get("host", host), + query_port=connection_params.get("query_port", 9030), + http_port=connection_params.get("http_port", 8030), + user=connection_params.get("user", "root"), + password=connection_params.get("password", ""), + ) + # Ensure database exists + try: + with closing( + mysql.connector.connect( + host=auth.host, + port=auth.query_port, + user=auth.user, + password=auth.password, + ) + ) as tmp_conn: + with closing(tmp_conn.cursor()) as cursor: + cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{database}`") + except ProgrammingError: + pass + + cls.upload_params = upload_params + cls.collection_params = upload_params.get("collection_params", {}) + cls.vector_dim = upload_params.get("vector_dim") or cls.collection_params.get( + "vector_dim" + ) + if cls.vector_dim: + try: + cls.vector_dim = int(cls.vector_dim) + except (TypeError, ValueError): + cls.vector_dim = None + # Map distance to Doris metric type + from engine.clients.doris.config import DISTANCE_MAPPING as _MAP + + # distance can be Distance enum or string + if hasattr(distance, "value"): + cls.metric_type = _MAP.get(distance.value, "l2_distance") + else: + cls.metric_type = _MAP.get(str(distance), "l2_distance") + + if cls.client is not None: + return + + cls.client = DorisVectorClient(database=database, auth_options=auth) + + @classmethod + def upload_batch(cls, batch: List[Record]): + rows = [] + for rec in batch: + row = {"id": rec.id, "vector": rec.vector} + if rec.metadata: + for k, v in rec.metadata.items(): + # Avoid overwriting existing keys + if k not in row: + row[k] = v + rows.append(row) + + table_name = cls.collection_params.get("table_name", DEFAULT_DORIS_TABLE) + # Allow overriding table name from upload_params root + table_name = cls.upload_params.get("table_name", table_name) + + if not cls.created: + table_exists = cls._table_exists(table_name) + if table_exists: + cls.table = cls.client.open_table(table_name) + else: + # Table does not exist (yet), attempt to create using the first batch + hnsw_cfg = cls.upload_params.get("hnsw_config", {}) + index_options = IndexOptions( + metric_type=cls.metric_type, + max_degree=hnsw_cfg.get("m", 32), + ef_construction=hnsw_cfg.get("ef_construct", 40), + dim=cls.vector_dim if cls.vector_dim else -1, + ) + try: + cls.table = cls.client.create_table( + table_name, + rows, + create_index=True, + index_options=index_options, + overwrite=False, + ) + except Error as exc: + # If table already created by another process, just open it + if "already exists" in str(exc).lower(): + cls.table = cls.client.open_table(table_name) + else: + raise + cls.created = True + + if cls.table is None: + cls.table = cls.client.open_table(table_name) + if cls.vector_dim and cls.table: + cls.table.index_options.dim = cls.vector_dim + cls.table.add(rows) + + @classmethod + def post_upload(cls, _distance): + return {} + + @classmethod + def _table_exists(cls, table_name: str) -> bool: + if cls.client is None: + return False + try: + cursor = cls.client.connection.cursor() + try: + cursor.execute("SHOW TABLES LIKE %s", (table_name,)) + return cursor.fetchone() is not None + finally: + cursor.close() + except Exception: + return False + + @classmethod + def delete_client(cls): + try: + if cls.client: + cls.client.close() + finally: + cls.client = None + cls.table = None + cls.created = False + + +# Register cleanup once per interpreter to silence resource warnings when using pools +if not DorisUploader._cleanup_registered: + atexit.register(DorisUploader.delete_client) + DorisUploader._cleanup_registered = True diff --git a/experiments/configurations/doris-single-node.json b/experiments/configurations/doris-single-node.json new file mode 100644 index 000000000..5084d44de --- /dev/null +++ b/experiments/configurations/doris-single-node.json @@ -0,0 +1,69 @@ +[ + { + "name": "doris-default", + "engine": "doris", + "connection_params": { + "host": "localhost", + "query_port": 6937, + "http_port": 5937, + "user": "root", + "password": "" + }, + "collection_params": { + "database": "qdrant_benchmark", + "table_name": "doris_default" + }, + "upload_params": { + "parallel": 8, + "batch_size": 526336, + "collection_params": { "table_name": "vectors" } + }, + "search_params": [ + { "parallel": 8, "table_name": "vectors", "config": { "hnsw_ef_search": 128 } } + ] + }, + { + "name": "doris-m-96-ef-512", + "engine": "doris", + "connection_params": { + "host": "localhost", + "query_port": 6937, + "http_port": 5937, + "user": "root", + "password": "" + }, + "collection_params": { "database": "qdrant_benchmark", "table_name": "doris_m_96_ef_512" }, + "upload_params": { + "parallel": 16, + "batch_size": 526336, + "collection_params": { "table_name": "doris_m_96_ef_512" }, + "hnsw_config": { "m": 96, "ef_construct": 512 } + }, + "search_params": [ + { "parallel": 1, "table_name": "doris_m_96_ef_512", "config": { "hnsw_ef_search": 512 } }, + { "parallel": 100, "table_name": "doris_m_96_ef_512", "config": { "hnsw_ef_search": 512 } } + ] + }, + { + "name": "doris-m-64-ef-256", + "engine": "doris", + "connection_params": { + "host": "localhost", + "query_port": 6937, + "http_port": 5937, + "user": "root", + "password": "" + }, + "collection_params": { "database": "qdrant_benchmark", "table_name": "doris_m_64_ef_256" }, + "upload_params": { + "parallel": 16, + "batch_size": 526336, + "collection_params": { "table_name": "doris_m_64_ef_256" }, + "hnsw_config": { "m": 64, "ef_construct": 256 } + }, + "search_params": [ + { "parallel": 1, "table_name": "doris_m_64_ef_256", "config": { "hnsw_ef_search": 256 } }, + { "parallel": 100, "table_name": "doris_m_64_ef_256", "config": { "hnsw_ef_search": 256 } } + ] + } +] diff --git a/pyproject.toml b/pyproject.toml index 17b0656a1..a71976e26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,8 @@ opensearch-py = "^2.3.2" tqdm = "^4.66.1" psycopg = {extras = ["binary"], version = "^3.1.17"} pgvector = "^0.2.4" +doris_vector_search = "^0.0.6" +pyarrow = "^21.0.0" [tool.poetry.group.dev.dependencies] pre-commit = "^2.20.0" diff --git a/tests/engine/clients/doris/test_doris_basic.py b/tests/engine/clients/doris/test_doris_basic.py new file mode 100644 index 000000000..7de17cc9e --- /dev/null +++ b/tests/engine/clients/doris/test_doris_basic.py @@ -0,0 +1,24 @@ +import pytest + +from engine.clients.client_factory import ClientFactory + + +@pytest.mark.skip( + reason="Requires running Doris instance; integration test skipped by default" +) +def test_doris_factory_registration(): + factory = ClientFactory(host="localhost") + experiment = { + "name": "doris-basic", + "engine": "doris", + "connection_params": {"host": "localhost", "query_port": 9030}, + "collection_params": {"database": "benchmark", "table_name": "vectors"}, + "upload_params": {"batch_size": 2}, + "search_params": [{"top": 5}], + } + + client = factory.build_client(experiment) + assert client.engine == "doris" + assert client.configurator is not None + assert client.uploader is not None + assert client.searchers, "Searchers list should not be empty"