From 31eacbdc54f6c21aadaf3a85204fbc9e8fa99902 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 25 Oct 2024 13:00:59 -0400 Subject: [PATCH 1/4] WIP on threadpool impl of query_namespaces --- pinecone/data/index.py | 102 +++++++++++++++++- .../query_results_aggregator.py | 0 pinecone/grpc/index_grpc_asyncio.py | 2 +- .../test_query_results_aggregator.py | 2 +- 4 files changed, 102 insertions(+), 4 deletions(-) rename pinecone/{grpc => data}/query_results_aggregator.py (100%) rename tests/{unit_grpc => unit}/test_query_results_aggregator.py (99%) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index cc1bae6c..d5695abb 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -27,6 +27,8 @@ from pinecone.core.openapi.data.api.data_plane_api import DataPlaneApi from ..utils import setup_openapi_client, parse_non_empty_args from .vector_factory import VectorFactory +from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults +from multiprocessing.pool import ApplyResult __all__ = [ "Index", @@ -361,7 +363,7 @@ def query( Union[SparseValues, Dict[str, Union[List[float], List[int]]]] ] = None, **kwargs, - ) -> QueryResponse: + ) -> Union[QueryResponse, ApplyResult[QueryResponse]]: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -403,6 +405,39 @@ def query( and namespace name. """ + response = self._query( + *args, + top_k=top_k, + vector=vector, + id=id, + namespace=namespace, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + **kwargs, + ) + + if kwargs.get("async_req", False): + return response + else: + return parse_query_response(response) + + def _query( + self, + *args, + top_k: int, + vector: Optional[List[float]] = None, + id: Optional[str] = None, + namespace: Optional[str] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + **kwargs, + ) -> QueryResponse: if len(args) > 0: raise ValueError( "The argument order for `query()` has changed; please use keyword arguments instead of positional arguments. Example: index.query(vector=[0.1, 0.2, 0.3], top_k=10, namespace='my_namespace')" @@ -435,7 +470,70 @@ def query( ), **{k: v for k, v in kwargs.items() if k in _OPENAPI_ENDPOINT_PARAMS}, ) - return parse_query_response(response) + return response + + def query_namespaces( + self, + vector: List[float], + namespaces: List[str], + top_k: Optional[int] = None, + filter: Optional[Dict[str, Union[str, float, int, bool, List, dict]]] = None, + include_values: Optional[bool] = None, + include_metadata: Optional[bool] = None, + sparse_vector: Optional[ + Union[SparseValues, Dict[str, Union[List[float], List[int]]]] + ] = None, + show_progress: Optional[bool] = True, + **kwargs, + ) -> QueryNamespacesResults: + if len(namespaces) == 0: + raise ValueError("At least one namespace must be specified") + if len(vector) == 0: + raise ValueError("Query vector must not be empty") + + # The caller may only want the top_k=1 result across all queries, + # but we need to get at least 2 results from each query in order to + # aggregate them correctly. So we'll temporarily set topK to 2 for the + # subqueries, and then we'll take the topK=1 results from the aggregated + # results. + overall_topk = top_k if top_k is not None else 10 + aggregator = QueryResultsAggregator(top_k=overall_topk) + subquery_topk = overall_topk if overall_topk > 2 else 2 + + target_namespaces = set(namespaces) # dedup namespaces + async_results = [ + self.query( + vector=vector, + namespace=ns, + top_k=subquery_topk, + filter=filter, + include_values=include_values, + include_metadata=include_metadata, + sparse_vector=sparse_vector, + async_req=True, + **kwargs, + ) + for ns in target_namespaces + ] + + for result in async_results: + response = result.get() + aggregator.add_results(response) + + final_results = aggregator.get_results() + return final_results + + # with tqdm( + # total=len(query_tasks), disable=not show_progress, desc="Querying namespaces" + # ) as pbar: + # for query_task in asyncio.as_completed(query_tasks): + # response = await query_task + # pbar.update(1) + # async with aggregator_lock: + # aggregator.add_results(response) + + # final_results = aggregator.get_results() + # return final_results @validate_and_convert_errors def update( diff --git a/pinecone/grpc/query_results_aggregator.py b/pinecone/data/query_results_aggregator.py similarity index 100% rename from pinecone/grpc/query_results_aggregator.py rename to pinecone/data/query_results_aggregator.py diff --git a/pinecone/grpc/index_grpc_asyncio.py b/pinecone/grpc/index_grpc_asyncio.py index 2dbdda77..fac1aef8 100644 --- a/pinecone/grpc/index_grpc_asyncio.py +++ b/pinecone/grpc/index_grpc_asyncio.py @@ -42,7 +42,7 @@ parse_sparse_values_arg, ) from .vector_factory_grpc import VectorFactoryGRPC -from .query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults +from ..data.query_results_aggregator import QueryResultsAggregator, QueryNamespacesResults class GRPCIndexAsyncio(GRPCIndexBase): diff --git a/tests/unit_grpc/test_query_results_aggregator.py b/tests/unit/test_query_results_aggregator.py similarity index 99% rename from tests/unit_grpc/test_query_results_aggregator.py rename to tests/unit/test_query_results_aggregator.py index b4c78802..73c7f61d 100644 --- a/tests/unit_grpc/test_query_results_aggregator.py +++ b/tests/unit/test_query_results_aggregator.py @@ -1,4 +1,4 @@ -from pinecone.grpc.query_results_aggregator import ( +from pinecone.data.query_results_aggregator import ( QueryResultsAggregator, QueryResultsAggregatorInvalidTopKError, QueryResultsAggregregatorNotEnoughResultsError, From d7a1c30b6a18eead49902e3f9fd4cdebcc605d8e Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 25 Oct 2024 13:11:35 -0400 Subject: [PATCH 2/4] WIP on threadpool impl of query_namespaces --- pinecone/data/index.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index d5695abb..a3b31d3d 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -363,7 +363,7 @@ def query( Union[SparseValues, Dict[str, Union[List[float], List[int]]]] ] = None, **kwargs, - ) -> Union[QueryResponse, ApplyResult[QueryResponse]]: + ) -> Union[QueryResponse, ApplyResult]: """ The Query operation searches a namespace, using a query vector. It retrieves the ids of the most similar items in a namespace, along with their similarity scores. @@ -483,7 +483,6 @@ def query_namespaces( sparse_vector: Optional[ Union[SparseValues, Dict[str, Union[List[float], List[int]]]] ] = None, - show_progress: Optional[bool] = True, **kwargs, ) -> QueryNamespacesResults: if len(namespaces) == 0: @@ -523,18 +522,6 @@ def query_namespaces( final_results = aggregator.get_results() return final_results - # with tqdm( - # total=len(query_tasks), disable=not show_progress, desc="Querying namespaces" - # ) as pbar: - # for query_task in asyncio.as_completed(query_tasks): - # response = await query_task - # pbar.update(1) - # async with aggregator_lock: - # aggregator.add_results(response) - - # final_results = aggregator.get_results() - # return final_results - @validate_and_convert_errors def update( self, From 489801e3a8eeed2c14a22d2d9fa2e35598b408a2 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 25 Oct 2024 14:00:37 -0400 Subject: [PATCH 3/4] Retries on threadpool async_apply --- pinecone/core/openapi/shared/api_client.py | 63 +++++++++++++++------- 1 file changed, 44 insertions(+), 19 deletions(-) diff --git a/pinecone/core/openapi/shared/api_client.py b/pinecone/core/openapi/shared/api_client.py index dda97ec5..7ec644c5 100644 --- a/pinecone/core/openapi/shared/api_client.py +++ b/pinecone/core/openapi/shared/api_client.py @@ -8,6 +8,24 @@ import typing from urllib.parse import quote from urllib3.fields import RequestField +import time +import random + +def retry_api_call( + func, args=(), kwargs={}, retries=3, backoff=1, jitter=0.5 +): + attempts = 0 + while attempts < retries: + try: + return func(*args, **kwargs) # Attempt to call __call_api + except Exception as e: + attempts += 1 + if attempts >= retries: + print(f"API call failed after {attempts} attempts: {e}") + raise # Re-raise exception if retries are exhausted + sleep_time = backoff * (2 ** (attempts - 1)) + random.uniform(0, jitter) + # print(f"Retrying ({attempts}/{retries}) in {sleep_time:.2f} seconds after error: {e}") + time.sleep(sleep_time) from pinecone.core.openapi.shared import rest @@ -397,25 +415,32 @@ def call_api( ) return self.pool.apply_async( - self.__call_api, - ( - resource_path, - method, - path_params, - query_params, - header_params, - body, - post_params, - files, - response_type, - auth_settings, - _return_http_data_only, - collection_formats, - _preload_content, - _request_timeout, - _host, - _check_type, - ), + retry_api_call, + args=( + self.__call_api, # Pass the API call function as the first argument + ( + resource_path, + method, + path_params, + query_params, + header_params, + body, + post_params, + files, + response_type, + auth_settings, + _return_http_data_only, + collection_formats, + _preload_content, + _request_timeout, + _host, + _check_type, + ), + {}, # empty kwargs dictionary + 3, # retries + 1, # backoff time + 0.5 # jitter + ) ) def request( From 8c8a960a7d3adf37d46f9746505223a4f8e69499 Mon Sep 17 00:00:00 2001 From: Jen Hamon Date: Fri, 25 Oct 2024 14:03:01 -0400 Subject: [PATCH 4/4] Add errors decorator --- pinecone/data/index.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pinecone/data/index.py b/pinecone/data/index.py index a3b31d3d..ba5769b3 100644 --- a/pinecone/data/index.py +++ b/pinecone/data/index.py @@ -472,6 +472,7 @@ def _query( ) return response + @validate_and_convert_errors def query_namespaces( self, vector: List[float],