From 9439a42bea68fd886e9647df5fc4e3e63f786086 Mon Sep 17 00:00:00 2001 From: david Date: Tue, 2 Jul 2024 18:19:44 +1000 Subject: [PATCH] black repo --- function_app.py | 7 +- prez/app.py | 9 +- prez/config.py | 5 +- prez/exceptions/model_exceptions.py | 2 +- prez/routers/api_extras_examples.py | 5 +- prez/routers/identifier.py | 2 +- prez/routers/ogc_router.py | 102 +++++++++----------- prez/routers/sparql.py | 51 ++++++---- prez/services/curie_functions.py | 2 +- prez/services/exception_catchers.py | 9 +- prez/services/link_generation.py | 27 ++++-- prez/services/query_generation/classes.py | 4 +- prez/services/query_generation/count.py | 79 ++++++++++------ prez/services/query_generation/cql.py | 5 +- prez/services/query_generation/umbrella.py | 4 +- tests/_test_cql_fuseki.py | 8 +- tests/test_query_construction.py | 105 ++++++++++++--------- tests/test_sparql.py | 2 +- 18 files changed, 233 insertions(+), 195 deletions(-) diff --git a/function_app.py b/function_app.py index 3601c4c4..d6498618 100644 --- a/function_app.py +++ b/function_app.py @@ -9,7 +9,9 @@ if assemble_app is None: - raise RuntimeError("Cannot import prez in the Azure function app. Check requirements.py and pyproject.toml.") + raise RuntimeError( + "Cannot import prez in the Azure function app. Check requirements.py and pyproject.toml." + ) # This is the base URL path that Prez routes will stem from @@ -32,7 +34,8 @@ if __name__ == "__main__": from azure.functions import HttpRequest, Context import asyncio - req = HttpRequest("GET", "/v", headers={}, body=b'') + + req = HttpRequest("GET", "/v", headers={}, body=b"") context = dict() loop = asyncio.get_event_loop() diff --git a/prez/app.py b/prez/app.py index 13348fc8..4eccf03b 100755 --- a/prez/app.py +++ b/prez/app.py @@ -16,12 +16,15 @@ load_local_data_to_oxigraph, get_oxrdflib_store, get_system_store, - load_system_data_to_oxigraph, load_annotations_data_to_oxigraph, get_annotations_store, + load_system_data_to_oxigraph, + load_annotations_data_to_oxigraph, + get_annotations_store, ) from prez.exceptions.model_exceptions import ( ClassNotFoundException, URINotFoundException, - NoProfilesException, InvalidSPARQLQueryException, + NoProfilesException, + InvalidSPARQLQueryException, ) from prez.repositories import RemoteSparqlRepo, PyoxigraphRepo, OxrdflibRepo from prez.routers.identifier import router as identifier_router @@ -161,7 +164,7 @@ def assemble_app( ClassNotFoundException: catch_class_not_found_exception, URINotFoundException: catch_uri_not_found_exception, NoProfilesException: catch_no_profiles_exception, - InvalidSPARQLQueryException: catch_invalid_sparql_query + InvalidSPARQLQueryException: catch_invalid_sparql_query, }, **kwargs ) diff --git a/prez/config.py b/prez/config.py index 000495d1..b1df20ac 100755 --- a/prez/config.py +++ b/prez/config.py @@ -52,10 +52,7 @@ class Settings(BaseSettings): SDO.description, ] provenance_predicates: Optional[List[URIRef]] = [DCTERMS.provenance] - other_predicates: Optional[List[URIRef]] = [ - SDO.color, - REG.status - ] + other_predicates: Optional[List[URIRef]] = [SDO.color, REG.status] sparql_repo_type: str = "remote" sparql_timeout: int = 30 log_level: str = "INFO" diff --git a/prez/exceptions/model_exceptions.py b/prez/exceptions/model_exceptions.py index 1f01e890..8e729247 100755 --- a/prez/exceptions/model_exceptions.py +++ b/prez/exceptions/model_exceptions.py @@ -44,4 +44,4 @@ class InvalidSPARQLQueryException(Exception): def __init__(self, error: str): self.message = f"Invalid SPARQL query: {error}" - super().__init__(self.message) \ No newline at end of file + super().__init__(self.message) diff --git a/prez/routers/api_extras_examples.py b/prez/routers/api_extras_examples.py index 90a6c67c..8fe404c2 100644 --- a/prez/routers/api_extras_examples.py +++ b/prez/routers/api_extras_examples.py @@ -5,10 +5,7 @@ responses = json.loads(responses_json.read_text()) cql_json_examples_dir = Path(__file__).parent.parent / "examples/cql" cql_examples = { - file.stem: { - "summary": file.stem, - "value": json.loads(file.read_text()) - } + file.stem: {"summary": file.stem, "value": json.loads(file.read_text())} for file in cql_json_examples_dir.glob("*.json") } diff --git a/prez/routers/identifier.py b/prez/routers/identifier.py index d909eb84..67639c01 100755 --- a/prez/routers/identifier.py +++ b/prez/routers/identifier.py @@ -4,8 +4,8 @@ from rdflib.term import _is_valid_uri from prez.dependencies import get_data_repo -from prez.services.query_generation.identifier import get_foaf_homepage_query from prez.services.curie_functions import get_uri_for_curie_id, get_curie_id_for_uri +from prez.services.query_generation.identifier import get_foaf_homepage_query router = APIRouter(tags=["Identifier Resolution"]) diff --git a/prez/routers/ogc_router.py b/prez/routers/ogc_router.py index d61c7bf4..8a3efb6a 100755 --- a/prez/routers/ogc_router.py +++ b/prez/routers/ogc_router.py @@ -10,7 +10,8 @@ get_negotiated_pmts, get_profile_nodeshape, get_endpoint_structure, - generate_concept_hierarchy_query, cql_post_parser_dependency, + generate_concept_hierarchy_query, + cql_post_parser_dependency, ) from prez.models.query_params import QueryParams from prez.reference_data.prez_ns import EP, ONT, OGCE @@ -26,71 +27,63 @@ router = APIRouter(tags=["ogcprez"]) -@router.get( - path="/search", - summary="Search", - name=OGCE["search"], - responses=responses -) +@router.get(path="/search", summary="Search", name=OGCE["search"], responses=responses) @router.get( "/profiles", summary="List Profiles", name=EP["system/profile-listing"], - responses=responses + responses=responses, ) @router.get( - path="/cql", - summary="CQL GET endpoint", - name=OGCE["cql-get"], - responses=responses + path="/cql", summary="CQL GET endpoint", name=OGCE["cql-get"], responses=responses ) @router.get( "/catalogs", summary="Catalog Listing", name=OGCE["catalog-listing"], - responses=responses + responses=responses, ) @router.get( "/catalogs/{catalogId}/collections", summary="Collection Listing", name=OGCE["collection-listing"], openapi_extra=openapi_extras.get("collection-listing"), - responses=responses + responses=responses, ) @router.get( "/catalogs/{catalogId}/collections/{collectionId}/items", summary="Item Listing", name=OGCE["item-listing"], openapi_extra=openapi_extras.get("item-listing"), - responses=responses + responses=responses, ) @router.get( "/concept-hierarchy/{parent_curie}/top-concepts", summary="Top Concepts", name=OGCE["top-concepts"], openapi_extra=openapi_extras.get("top-concepts"), - responses=responses + responses=responses, ) @router.get( "/concept-hierarchy/{parent_curie}/narrowers", summary="Narrowers", name=OGCE["narrowers"], openapi_extra=openapi_extras.get("narrowers"), - responses=responses + responses=responses, ) async def listings( - query_params: QueryParams = Depends(), - endpoint_nodeshape: NodeShape = Depends(get_endpoint_nodeshapes), - pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), - endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), - profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), - cql_parser: CQLParser = Depends(cql_get_parser_dependency), - search_query: ConstructQuery = Depends(generate_search_query), - concept_hierarchy_query: ConceptHierarchyQuery = Depends( - generate_concept_hierarchy_query - ), - data_repo: Repo = Depends(get_data_repo), - system_repo: Repo = Depends(get_system_repo), + query_params: QueryParams = Depends(), + endpoint_nodeshape: NodeShape = Depends(get_endpoint_nodeshapes), + pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), + endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), + profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), + cql_parser: CQLParser = Depends(cql_get_parser_dependency), + search_query: ConstructQuery = Depends(generate_search_query), + concept_hierarchy_query: ConceptHierarchyQuery = Depends( + generate_concept_hierarchy_query + ), + data_repo: Repo = Depends(get_data_repo), + system_repo: Repo = Depends(get_system_repo), ): return await listing_function( data_repo=data_repo, @@ -112,26 +105,20 @@ async def listings( summary="CQL POST endpoint", name=OGCE["cql-post"], openapi_extra={ - "requestBody": { - "content": { - "application/json": { - "examples": cql_examples - } - } - } + "requestBody": {"content": {"application/json": {"examples": cql_examples}}} }, - responses=responses + responses=responses, ) async def cql_post_listings( - query_params: QueryParams = Depends(), - endpoint_nodeshape: NodeShape = Depends(get_endpoint_nodeshapes), - pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), - endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), - profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), - cql_parser: CQLParser = Depends(cql_post_parser_dependency), - search_query: ConstructQuery = Depends(generate_search_query), - data_repo: Repo = Depends(get_data_repo), - system_repo: Repo = Depends(get_system_repo), + query_params: QueryParams = Depends(), + endpoint_nodeshape: NodeShape = Depends(get_endpoint_nodeshapes), + pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), + endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), + profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), + cql_parser: CQLParser = Depends(cql_post_parser_dependency), + search_query: ConstructQuery = Depends(generate_search_query), + data_repo: Repo = Depends(get_data_repo), + system_repo: Repo = Depends(get_system_repo), ): return await listing_function( data_repo=data_repo, @@ -160,45 +147,42 @@ async def cql_post_listings( @router.get( - path="/object", - summary="Object", - name=EP["system/object"], - responses=responses + path="/object", summary="Object", name=EP["system/object"], responses=responses ) @router.get( path="/profiles/{profile_curie}", summary="Profile", name=EP["system/profile-object"], openapi_extra=openapi_extras.get("profile-object"), - responses=responses + responses=responses, ) @router.get( path="/catalogs/{catalogId}", summary="Catalog Object", name=OGCE["catalog-object"], openapi_extra=openapi_extras.get("catalog-object"), - responses=responses + responses=responses, ) @router.get( path="/catalogs/{catalogId}/collections/{collectionId}", summary="Collection Object", name=OGCE["collection-object"], openapi_extra=openapi_extras.get("collection-object"), - responses=responses + responses=responses, ) @router.get( path="/catalogs/{catalogId}/collections/{collectionId}/items/{itemId}", summary="Item Object", name=OGCE["item-object"], openapi_extra=openapi_extras.get("item-object"), - responses=responses + responses=responses, ) async def objects( - pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), - endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), - profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), - data_repo: Repo = Depends(get_data_repo), - system_repo: Repo = Depends(get_system_repo), + pmts: NegotiatedPMTs = Depends(get_negotiated_pmts), + endpoint_structure: tuple[str, ...] = Depends(get_endpoint_structure), + profile_nodeshape: NodeShape = Depends(get_profile_nodeshape), + data_repo: Repo = Depends(get_data_repo), + system_repo: Repo = Depends(get_system_repo), ): return await object_function( data_repo=data_repo, diff --git a/prez/routers/sparql.py b/prez/routers/sparql.py index 3eb2c2a0..a9a4a7b0 100755 --- a/prez/routers/sparql.py +++ b/prez/routers/sparql.py @@ -9,7 +9,7 @@ from starlette.requests import Request from starlette.responses import StreamingResponse -from prez.dependencies import get_data_repo, get_system_repo, get_negotiated_pmts +from prez.dependencies import get_data_repo, get_system_repo from prez.renderers.renderer import return_annotated_rdf from prez.repositories import Repo from prez.services.connegp_service import NegotiatedPMTs @@ -21,29 +21,35 @@ @router.post("/sparql") async def sparql_post_passthrough( - # To maintain compatibility with the other SPARQL endpoints, - # /sparql POST endpoint is not a JSON API, it uses - # values encoded with x-www-form-urlencoded - query: Annotated[str, Form()], - # Pydantic validation prevents update queries (the Form would need to be "update") - request: Request, - repo: Repo = Depends(get_data_repo), - system_repo: Repo = Depends(get_system_repo), + # To maintain compatibility with the other SPARQL endpoints, + # /sparql POST endpoint is not a JSON API, it uses + # values encoded with x-www-form-urlencoded + query: Annotated[str, Form()], + # Pydantic validation prevents update queries (the Form would need to be "update") + request: Request, + repo: Repo = Depends(get_data_repo), + system_repo: Repo = Depends(get_system_repo), ): - return await sparql_endpoint_handler(query, request, repo, system_repo, method="POST") + return await sparql_endpoint_handler( + query, request, repo, system_repo, method="POST" + ) @router.get("/sparql") async def sparql_get_passthrough( - query: str, - request: Request, - repo: Repo = Depends(get_data_repo), - system_repo: Repo = Depends(get_system_repo), + query: str, + request: Request, + repo: Repo = Depends(get_data_repo), + system_repo: Repo = Depends(get_system_repo), ): - return await sparql_endpoint_handler(query, request, repo, system_repo, method="GET") + return await sparql_endpoint_handler( + query, request, repo, system_repo, method="GET" + ) -async def sparql_endpoint_handler(query: str, request: Request, repo: Repo, system_repo, method="GET"): +async def sparql_endpoint_handler( + query: str, request: Request, repo: Repo, system_repo, method="GET" +): pmts = NegotiatedPMTs( **{ "headers": request.headers, @@ -72,13 +78,14 @@ async def sparql_endpoint_handler(query: str, request: Request, repo: Repo, syst media_type=non_anot_mediatype, headers=pmts.generate_response_headers(), ) - query_result: 'httpx.Response' = await repo.sparql(query, request.headers.raw, method=method) + query_result: "httpx.Response" = await repo.sparql( + query, request.headers.raw, method=method + ) if isinstance(query_result, dict): return JSONResponse(content=query_result) elif isinstance(query_result, Graph): return Response( - content=query_result.serialize(format="text/turtle"), - status_code=200 + content=query_result.serialize(format="text/turtle"), status_code=200 ) dispositions = query_result.headers.get_list("Content-Disposition") @@ -92,7 +99,11 @@ async def sparql_endpoint_handler(query: str, request: Request, repo: Repo, syst # remove transfer-encoding chunked, disposition=attachment, and content-length headers = dict() for k, v in query_result.headers.items(): - if k.lower() not in ("transfer-encoding", "content-disposition", "content-length"): + if k.lower() not in ( + "transfer-encoding", + "content-disposition", + "content-length", + ): headers[k] = v content = await query_result.aread() await query_result.aclose() diff --git a/prez/services/curie_functions.py b/prez/services/curie_functions.py index 425bd0c3..44b21ca1 100755 --- a/prez/services/curie_functions.py +++ b/prez/services/curie_functions.py @@ -42,7 +42,7 @@ def generate_new_prefix(uri): else: ns = f'{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.rsplit("/", 1)[0]}/' - split_prefix_path = ns[:-1].rsplit('/', 1) + split_prefix_path = ns[:-1].rsplit("/", 1) if len(split_prefix_path) > 1: to_generate_prefix_from = split_prefix_path[-1].lower() # attempt to just use the last part of the path prior to the fragment or "identifier" diff --git a/prez/services/exception_catchers.py b/prez/services/exception_catchers.py index 76642860..2a4e05d5 100755 --- a/prez/services/exception_catchers.py +++ b/prez/services/exception_catchers.py @@ -5,7 +5,8 @@ from prez.exceptions.model_exceptions import ( ClassNotFoundException, URINotFoundException, - NoProfilesException, InvalidSPARQLQueryException, + NoProfilesException, + InvalidSPARQLQueryException, ) @@ -53,11 +54,13 @@ async def catch_no_profiles_exception(request: Request, exc: NoProfilesException ) -async def catch_invalid_sparql_query(request: Request, exc: InvalidSPARQLQueryException): +async def catch_invalid_sparql_query( + request: Request, exc: InvalidSPARQLQueryException +): return JSONResponse( status_code=400, content={ "error": "Bad Request", "detail": exc.message, }, - ) \ No newline at end of file + ) diff --git a/prez/services/link_generation.py b/prez/services/link_generation.py index 1c7db6b3..79043fc0 100755 --- a/prez/services/link_generation.py +++ b/prez/services/link_generation.py @@ -46,11 +46,11 @@ async def add_prez_links(graph: Graph, repo: Repo, endpoint_structure): async def _link_generation( - uri: URIRef, - repo: Repo, - klasses, - graph: Graph, - endpoint_structure: str = settings.endpoint_structure, + uri: URIRef, + repo: Repo, + klasses, + graph: Graph, + endpoint_structure: str = settings.endpoint_structure, ): """ Generates links for the given URI if it is not already cached. @@ -69,8 +69,13 @@ async def _link_generation( available_nodeshapes = await get_nodeshapes_constraining_class(klasses, uri) # ignore CQL and Search nodeshapes as we do not want to generate links for these. available_nodeshapes = [ - ns for ns in available_nodeshapes - if ns.uri not in [URIRef('http://example.org/ns#CQL'), URIRef('http://example.org/ns#Search')] + ns + for ns in available_nodeshapes + if ns.uri + not in [ + URIRef("http://example.org/ns#CQL"), + URIRef("http://example.org/ns#Search"), + ] ] # run queries for available nodeshapes to get link components for ns in available_nodeshapes: @@ -128,7 +133,7 @@ async def get_nodeshapes_constraining_class(klasses, uri): async def add_links_to_graph_and_cache( - curie_for_uri, graph, members_link, object_link, uri + curie_for_uri, graph, members_link, object_link, uri ): """ Adds links and identifiers to the given graph and cache. @@ -139,7 +144,7 @@ async def add_links_to_graph_and_cache( (uri, DCTERMS.identifier, Literal(curie_for_uri, datatype=PREZ.identifier), uri) ) if ( - members_link + members_link ): # TODO need to confirm the link value doesn't match the existing link value, as multiple endpoints can deliver the same class/have different links for the same URI existing_members_link = list( links_ids_graph_cache.quads((uri, PREZ["members"], None, uri)) @@ -197,7 +202,9 @@ async def get_link_components(ns, repo): where_clause=WhereClause( group_graph_pattern=GroupGraphPattern( content=GroupGraphPatternSub( - triples_block=TriplesBlock.from_tssp_list(ns.tssp_list[::-1]), # reversed for performance + triples_block=TriplesBlock.from_tssp_list( + ns.tssp_list[::-1] + ), # reversed for performance graph_patterns_or_triples_blocks=ns.gpnt_list, ) ) diff --git a/prez/services/query_generation/classes.py b/prez/services/query_generation/classes.py index aa73580b..804ba0d0 100755 --- a/prez/services/query_generation/classes.py +++ b/prez/services/query_generation/classes.py @@ -30,8 +30,8 @@ class ClassesSelectQuery(SubSelect): """ def __init__( - self, - iris: list[IRI], + self, + iris: list[IRI], ): class_var = Var(value="class") uris_var = Var(value="uri") diff --git a/prez/services/query_generation/count.py b/prez/services/query_generation/count.py index 289e3750..f21b72a6 100755 --- a/prez/services/query_generation/count.py +++ b/prez/services/query_generation/count.py @@ -1,9 +1,38 @@ -from sparql_grammar_pydantic import IRI, Var, TriplesSameSubject, SubSelect, SelectClause, \ - WhereClause, GroupGraphPattern, GroupGraphPatternSub, SolutionModifier, LimitOffsetClauses, \ - LimitClause, Expression, PrimaryExpression, BuiltInCall, Aggregate, ConditionalOrExpression, \ - ConditionalAndExpression, ValueLogical, RelationalExpression, NumericExpression, AdditiveExpression, \ - MultiplicativeExpression, UnaryExpression, NumericLiteral, RDFLiteral, Bind, GraphPatternNotTriples, \ - GroupOrUnionGraphPattern, ConstructTemplate, BlankNode, Anon, ConstructTriples, ConstructQuery +from sparql_grammar_pydantic import ( + IRI, + Var, + TriplesSameSubject, + SubSelect, + SelectClause, + WhereClause, + GroupGraphPattern, + GroupGraphPatternSub, + SolutionModifier, + LimitOffsetClauses, + LimitClause, + Expression, + PrimaryExpression, + BuiltInCall, + Aggregate, + ConditionalOrExpression, + ConditionalAndExpression, + ValueLogical, + RelationalExpression, + NumericExpression, + AdditiveExpression, + MultiplicativeExpression, + UnaryExpression, + NumericLiteral, + RDFLiteral, + Bind, + GraphPatternNotTriples, + GroupOrUnionGraphPattern, + ConstructTemplate, + BlankNode, + Anon, + ConstructTriples, + ConstructQuery, +) from prez.config import settings @@ -35,16 +64,14 @@ def __init__(self, original_subselect: SubSelect): limit = settings.listing_count_limit limit_plus_one = limit + 1 inner_ss = SubSelect( - select_clause=SelectClause( - variables_or_all=[Var(value="focus_node")] - ), + select_clause=SelectClause(variables_or_all=[Var(value="focus_node")]), where_clause=original_subselect.where_clause, solution_modifier=SolutionModifier( limit_offset=LimitOffsetClauses( limit_clause=LimitClause(limit=limit_plus_one) ), ), - values_clause=original_subselect.values_clause + values_clause=original_subselect.values_clause, ) count_expression = Expression.from_primary_expression( PrimaryExpression( @@ -64,10 +91,8 @@ def __init__(self, original_subselect: SubSelect): variables_or_all=[(count_expression, Var(value="count"))], ), where_clause=WhereClause( - group_graph_pattern=GroupGraphPattern( - content=inner_ss - ) - ) + group_graph_pattern=GroupGraphPattern(content=inner_ss) + ), ) outer_ss_ggp = GroupGraphPattern(content=outer_ss) count_equals_1001_expr = Expression( @@ -101,7 +126,7 @@ def __init__(self, original_subselect: SubSelect): ) ) ) - ) + ), ) ) ] @@ -109,14 +134,14 @@ def __init__(self, original_subselect: SubSelect): ] ) ) - gt_1000_exp = Expression.from_primary_expression(PrimaryExpression(content=RDFLiteral(value=f">{limit}"))) + gt_1000_exp = Expression.from_primary_expression( + PrimaryExpression(content=RDFLiteral(value=f">{limit}")) + ) str_count_exp = Expression.from_primary_expression( PrimaryExpression( content=BuiltInCall.create_with_one_expr( function_name="STR", - expression=PrimaryExpression( - content=Var(value="count") - ) + expression=PrimaryExpression(content=Var(value="count")), ) ) ) @@ -125,15 +150,11 @@ def __init__(self, original_subselect: SubSelect): PrimaryExpression( content=BuiltInCall( function_name="IF", - arguments=[ - count_equals_1001_expr, - gt_1000_exp, - str_count_exp - ] + arguments=[count_equals_1001_expr, gt_1000_exp, str_count_exp], ) ) ), - var=Var(value="count_str") + var=Var(value="count_str"), ) wc = WhereClause( group_graph_pattern=GroupGraphPattern( @@ -141,14 +162,10 @@ def __init__(self, original_subselect: SubSelect): graph_patterns_or_triples_blocks=[ GraphPatternNotTriples( content=GroupOrUnionGraphPattern( - group_graph_patterns=[ - outer_ss_ggp - ] + group_graph_patterns=[outer_ss_ggp] ) ), - GraphPatternNotTriples( - content=bind - ) + GraphPatternNotTriples(content=bind), ] ) ) diff --git a/prez/services/query_generation/cql.py b/prez/services/query_generation/cql.py index 06ff67cd..9ec26348 100755 --- a/prez/services/query_generation/cql.py +++ b/prez/services/query_generation/cql.py @@ -149,7 +149,6 @@ def _add_triple(self, ggps, subject, predicate, object): else: ggps.triples_block = TriplesBlock(triples=tssp) - def _handle_comparison(self, operator, args, existing_ggps=None): self.var_counter += 1 ggps = existing_ggps if existing_ggps is not None else GroupGraphPatternSub() @@ -282,7 +281,9 @@ def _handle_spatial(self, operator, args, existing_ggps=None): def get_wkt_from_coords(self, coordinates, geom_type): shapely_spatial_class = cql_to_shapely_mapping.get(geom_type) if not shapely_spatial_class: - raise NotImplementedError(f"Geometry Class for \"{geom_type}\" not found in Shapely.") + raise NotImplementedError( + f'Geometry Class for "{geom_type}" not found in Shapely.' + ) wkt = shapely_spatial_class(coordinates).wkt return wkt diff --git a/prez/services/query_generation/umbrella.py b/prez/services/query_generation/umbrella.py index 9cf8d40e..b75d5dc0 100755 --- a/prez/services/query_generation/umbrella.py +++ b/prez/services/query_generation/umbrella.py @@ -97,7 +97,9 @@ def __init__( # for listing queries only, add an inner select to the where clause ss_gpotb = [] if inner_select_tssp_list: - inner_select_tssp_list = sorted(inner_select_tssp_list, key=lambda x: str(x), reverse=True) # grouping for performance + inner_select_tssp_list = sorted( + inner_select_tssp_list, key=lambda x: str(x), reverse=True + ) # grouping for performance ss_gpotb.append(TriplesBlock.from_tssp_list(inner_select_tssp_list)) if inner_select_gpnt: ss_gpotb.extend(inner_select_gpnt) diff --git a/tests/_test_cql_fuseki.py b/tests/_test_cql_fuseki.py index 4b7239fd..bd81e987 100755 --- a/tests/_test_cql_fuseki.py +++ b/tests/_test_cql_fuseki.py @@ -65,9 +65,7 @@ def test_spatial_contains(client_fuseki): cql = json.load(f) cql_str = json.dumps(cql) cql_encoded = quote_plus(cql_str) - response = client_fuseki.get( - f"/cql?filter={cql_encoded}&_mediatype=text/turtle" - ) + response = client_fuseki.get(f"/cql?filter={cql_encoded}&_mediatype=text/turtle") response_graph = Graph().parse(data=response.text) print(response_graph.serialize(format="turtle")) print("x") @@ -93,9 +91,7 @@ def test_spatial_contains_like(client_fuseki): cql = json.load(f) cql_str = json.dumps(cql) cql_encoded = quote_plus(cql_str) - response = client_fuseki.get( - f"/cql?filter={cql_encoded}" - ) + response = client_fuseki.get(f"/cql?filter={cql_encoded}") response_graph = Graph().parse(data=response.text) print(response_graph.serialize(format="turtle")) diff --git a/tests/test_query_construction.py b/tests/test_query_construction.py index 8f486175..afbd57da 100755 --- a/tests/test_query_construction.py +++ b/tests/test_query_construction.py @@ -3,12 +3,43 @@ import pytest from rdflib import RDF, RDFS, SKOS from rdflib.namespace import GEO -from sparql_grammar_pydantic import IRI, Var, TriplesSameSubject, TriplesSameSubjectPath, SubSelect, SelectClause, \ - WhereClause, GroupGraphPattern, GroupGraphPatternSub, TriplesBlock, SolutionModifier, LimitOffsetClauses, \ - LimitClause, Expression, PrimaryExpression, BuiltInCall, Aggregate, ConditionalOrExpression, \ - ConditionalAndExpression, ValueLogical, RelationalExpression, NumericExpression, AdditiveExpression, \ - MultiplicativeExpression, UnaryExpression, NumericLiteral, RDFLiteral, Bind, GraphPatternNotTriples, \ - GroupOrUnionGraphPattern, ConstructTemplate, BlankNode, Anon, ConstructTriples, ConstructQuery +from sparql_grammar_pydantic import ( + IRI, + Var, + TriplesSameSubject, + TriplesSameSubjectPath, + SubSelect, + SelectClause, + WhereClause, + GroupGraphPattern, + GroupGraphPatternSub, + TriplesBlock, + SolutionModifier, + LimitOffsetClauses, + LimitClause, + Expression, + PrimaryExpression, + BuiltInCall, + Aggregate, + ConditionalOrExpression, + ConditionalAndExpression, + ValueLogical, + RelationalExpression, + NumericExpression, + AdditiveExpression, + MultiplicativeExpression, + UnaryExpression, + NumericLiteral, + RDFLiteral, + Bind, + GraphPatternNotTriples, + GroupOrUnionGraphPattern, + ConstructTemplate, + BlankNode, + Anon, + ConstructTriples, + ConstructQuery, +) from prez.services.query_generation.classes import ClassesSelectQuery from prez.services.query_generation.concept_hierarchy import ConceptHierarchyQuery @@ -86,11 +117,11 @@ def test_search_query_regex(): ), ], construct_tss_list=sq.construct_triples.to_tss_list() - + [ - TriplesSameSubject.from_spo( - IRI(value="https://s"), IRI(value="https://p"), IRI(value="https://o") - ) - ], + + [ + TriplesSameSubject.from_spo( + IRI(value="https://s"), IRI(value="https://p"), IRI(value="https://o") + ) + ], inner_select_vars=sq.inner_select_vars, inner_select_gpnt=[sq.inner_select_gpnt], limit=sq.limit, @@ -160,9 +191,7 @@ def test_concept_hierarchy_narrowers(): def test_count_query(): inner_ss = SubSelect( - select_clause=SelectClause( - variables_or_all=[Var(value="focus_node")] - ), + select_clause=SelectClause(variables_or_all=[Var(value="focus_node")]), where_clause=WhereClause( group_graph_pattern=GroupGraphPattern( content=GroupGraphPatternSub( @@ -172,7 +201,9 @@ def test_count_query(): TriplesSameSubjectPath.from_spo( subject=Var(value="focus_node"), predicate=IRI(value=RDF.type), - object=IRI(value="http://www.w3.org/ns/sosa/Sampling") + object=IRI( + value="http://www.w3.org/ns/sosa/Sampling" + ), ) ] ) @@ -181,10 +212,8 @@ def test_count_query(): ) ), solution_modifier=SolutionModifier( - limit_offset=LimitOffsetClauses( - limit_clause=LimitClause(limit=1001) - ), - ) + limit_offset=LimitOffsetClauses(limit_clause=LimitClause(limit=1001)), + ), ) count_expression = Expression.from_primary_expression( PrimaryExpression( @@ -204,10 +233,8 @@ def test_count_query(): variables_or_all=[(count_expression, Var(value="count"))], ), where_clause=WhereClause( - group_graph_pattern=GroupGraphPattern( - content=inner_ss - ) - ) + group_graph_pattern=GroupGraphPattern(content=inner_ss) + ), ) outer_ss_ggp = GroupGraphPattern(content=outer_ss) count_equals_1001_expr = Expression( @@ -234,14 +261,12 @@ def test_count_query(): base_expression=MultiplicativeExpression( base_expression=UnaryExpression( primary_expression=PrimaryExpression( - content=NumericLiteral( - value=1001 - ) + content=NumericLiteral(value=1001) ) ) ) ) - ) + ), ) ) ] @@ -249,14 +274,14 @@ def test_count_query(): ] ) ) - gt_1000_exp = Expression.from_primary_expression(PrimaryExpression(content=RDFLiteral(value=">1000"))) + gt_1000_exp = Expression.from_primary_expression( + PrimaryExpression(content=RDFLiteral(value=">1000")) + ) str_count_exp = Expression.from_primary_expression( PrimaryExpression( content=BuiltInCall.create_with_one_expr( function_name="STR", - expression=PrimaryExpression( - content=Var(value="count") - ) + expression=PrimaryExpression(content=Var(value="count")), ) ) ) @@ -265,15 +290,11 @@ def test_count_query(): PrimaryExpression( content=BuiltInCall( function_name="IF", - arguments=[ - count_equals_1001_expr, - gt_1000_exp, - str_count_exp - ] + arguments=[count_equals_1001_expr, gt_1000_exp, str_count_exp], ) ) ), - var=Var(value="count_str") + var=Var(value="count_str"), ) wc = WhereClause( group_graph_pattern=GroupGraphPattern( @@ -281,14 +302,10 @@ def test_count_query(): graph_patterns_or_triples_blocks=[ GraphPatternNotTriples( content=GroupOrUnionGraphPattern( - group_graph_patterns=[ - outer_ss_ggp - ] + group_graph_patterns=[outer_ss_ggp] ) ), - GraphPatternNotTriples( - content=bind - ) + GraphPatternNotTriples(content=bind), ] ) ) @@ -307,6 +324,6 @@ def test_count_query(): query = ConstructQuery( construct_template=construct_template, where_clause=wc, - solution_modifier=SolutionModifier() + solution_modifier=SolutionModifier(), ) print(query) diff --git a/tests/test_sparql.py b/tests/test_sparql.py index f790c8df..81e678fb 100755 --- a/tests/test_sparql.py +++ b/tests/test_sparql.py @@ -57,4 +57,4 @@ def test_insert_as_query(client): "format": "application/x-www-form-urlencoded", }, ) - assert r.status_code == 400 \ No newline at end of file + assert r.status_code == 400