diff --git a/learning_resources_search/api.py b/learning_resources_search/api.py index 6d68e3050f..b729eaf9f9 100644 --- a/learning_resources_search/api.py +++ b/learning_resources_search/api.py @@ -364,7 +364,7 @@ def generate_filter_clause( path: str, value: str, *, case_sensitive: bool, _current_path_length=1 ): """ - Generate search clause for a single filter path abd value. + Generate search clause for a single filter path and value. Args: path (str): Search index on which to filter @@ -595,8 +595,13 @@ def percolate_matches_for_document(document_id): return percolated_queries -def add_text_query_to_search( - search, text, search_params, query_type_query, use_hybrid_search +def add_text_query_to_search( # noqa: PLR0913 + search, + text, + search_params, + query_type_query, + use_hybrid_search, + filter_clauses, ): if search_params.get("endpoint") == CONTENT_FILE_TYPE: text_query = generate_content_file_text_clause(text) @@ -672,11 +677,13 @@ def add_text_query_to_search( "query_text": text, "model_id": model_id, "k": HYBRID_SEARCH_KNN_K_VALUE, + "filter": {"bool": {"should": list(filter_clauses.values())}}, } } } pagination_depth = search_params.get("limit", 10) * 3 + search = search.extra( query={ "hybrid": { @@ -685,6 +692,7 @@ def add_text_query_to_search( } } ) + else: search = search.query(text_query) @@ -703,15 +711,15 @@ def construct_search(search_params): # noqa: C901 Returns: opensearch_dsl.Search: an opensearch search instance """ + use_hybrid_search = search_params.get("search_mode") == HYBRID_SEARCH_MODE if ( not search_params.get("resource_type") and search_params.get("endpoint") != CONTENT_FILE_TYPE + and not use_hybrid_search ): search_params["resource_type"] = list(LEARNING_RESOURCE_TYPES) - use_hybrid_search = search_params.get("search_mode") == HYBRID_SEARCH_MODE - indexes = relevant_indexes( search_params.get("resource_type"), search_params.get("aggregations"), @@ -741,18 +749,23 @@ def construct_search(search_params): # noqa: C901 else: query_type_query = {"exists": {"field": "resource_type"}} + filter_clauses = generate_filter_clauses(search_params) + if search_params.get("q"): text = re.sub("[\u201c\u201d]", '"', search_params.get("q")) search = add_text_query_to_search( - search, text, search_params, query_type_query, use_hybrid_search + search, + text, + search_params, + query_type_query, + use_hybrid_search, + filter_clauses, ) else: search = search.query(query_type_query) - filter_clauses = generate_filter_clauses(search_params) - search = search.post_filter("bool", must=list(filter_clauses.values())) if search_params.get("aggregations"): diff --git a/learning_resources_search/api_test.py b/learning_resources_search/api_test.py index bdb9c485a7..c04e2bebb4 100644 --- a/learning_resources_search/api_test.py +++ b/learning_resources_search/api_test.py @@ -2731,6 +2731,40 @@ def test_execute_learn_search_with_hybrid_search(mocker, settings, opensearch): "query_text": "math", "model_id": "vector_model_id", "k": 5, + "filter": { + "bool": { + "should": [ + { + "bool": { + "should": [ + { + "term": { + "resource_type": { + "value": "course", + "case_insensitive": True, + } + } + } + ] + } + }, + { + "bool": { + "should": [ + { + "term": { + "free": { + "value": True, + "case_insensitive": True, + } + } + } + ] + } + }, + ] + } + }, } } },