Skip to content

Commit

Permalink
fix optional type for 3.9
Browse files Browse the repository at this point in the history
  • Loading branch information
vemonet committed Sep 17, 2024
1 parent 41fce7b commit 6e372ca
Show file tree
Hide file tree
Showing 5 changed files with 16 additions and 13 deletions.
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,10 @@ disallow_untyped_calls = true
disallow_untyped_defs = true
disallow_any_generics = true

# TODO: Use mypy or pyright for static type checking?
# TODO: Use mypy or pyright for static type checking? Mypy is slow
# But pyright does not recognize the type of the curies Converter ("unknown import symbol"), while mypy does.
# So I guess pyright is garbage, I don't have time to fix all their bs, they should be better, classic microsoft
# https://microsoft.github.io/pyright/#/configuration?id=sample-pyprojecttoml-file
# [tool.pyright]


# https://github.com/astral-sh/ruff#configuration
Expand Down
5 changes: 3 additions & 2 deletions src/sparql_llm/sparql_void_shapes_loader.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from langchain_core.document_loaders.base import BaseLoader
from langchain_core.documents import Document

Expand All @@ -13,8 +14,8 @@ class SparqlVoidShapesLoader(BaseLoader):
def __init__(
self,
endpoint_url: str,
namespaces_to_ignore: list[str] | None = None,
prefix_map: dict[str, str] | None = None,
namespaces_to_ignore: Optional[list[str]] = None,
prefix_map: Optional[dict[str, str]] = None,
verbose: bool = False,
):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/sparql_llm/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import json
from typing import Any
from typing import Any, Optional

import requests
from curies_rs import Converter
Expand Down Expand Up @@ -87,7 +87,7 @@ def get_void_dict(endpoint_url: str) -> TripleDict:
return void_dict


def query_sparql(query: str, endpoint_url: str, post: bool = False, timeout: int | None = None) -> Any:
def query_sparql(query: str, endpoint_url: str, post: bool = False, timeout: Optional[int] = None) -> Any:
"""Execute a SPARQL query on a SPARQL endpoint using requests"""
if post:
resp = requests.post(
Expand Down
10 changes: 5 additions & 5 deletions src/sparql_llm/validate_sparql.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import re
from collections import defaultdict
from typing import Any
from typing import Any, Optional

from curies_rs import Converter
from rdflib import Namespace, Variable
Expand All @@ -13,7 +13,7 @@
endpoint_pattern = re.compile(r"^#.*(https?://[^\s]+)", re.MULTILINE)


def extract_sparql_queries(md_resp: str) -> list[dict[str, str | None]]:
def extract_sparql_queries(md_resp: str) -> list[dict[str, Optional[str]]]:
"""Extract SPARQL queries and endpoint URL from a markdown response."""
extracted_queries = []
queries = queries_pattern.findall(md_resp)
Expand Down Expand Up @@ -134,7 +134,7 @@ def extract_triples(node: Any, endpoint: str):
return query_dict


def validate_sparql_with_void(query: str, endpoint_url: str, prefix_converter: Converter | None = None) -> set[str]:
def validate_sparql_with_void(query: str, endpoint_url: str, prefix_converter: Optional[Converter] = None) -> set[str]:
"""Validate SPARQL query using the VoID description of endpoints. Raise exception if errors found."""
if prefix_converter is None:
prefix_converter = get_prefix_converter(get_prefixes_for_endpoints([endpoint_url]))
Expand All @@ -145,8 +145,8 @@ def validate_triple_pattern(
void_dict: TripleDict,
endpoint: str,
issues: set[str],
parent_type: str | None = None,
parent_pred: str | None = None,
parent_type: Optional[str] = None,
parent_pred: Optional[str] = None,
) -> set[str]:
pred_dict = subj_dict.get(subj, {})
# Direct type provided for this entity
Expand Down
5 changes: 3 additions & 2 deletions src/sparql_llm/void_to_shex.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
from sparql_llm.utils import get_prefix_converter, get_prefixes_for_endpoints, get_void_dict, query_sparql

DEFAULT_NAMESPACES_TO_IGNORE = [
Expand All @@ -16,7 +17,7 @@ def ignore_namespaces(ns_to_ignore: list[str], cls: str) -> bool:


def get_shex_dict_from_void(
endpoint_url: str, prefix_map: dict[str, str] | None = None, namespaces_to_ignore: list[str] | None = None
endpoint_url: str, prefix_map: Optional[dict[str, str]] = None, namespaces_to_ignore: Optional[list[str]] = None
) -> dict[str, dict[str, str]]:
"""Get a dict of shex shapes from the VoID description."""
prefix_map = prefix_map or get_prefixes_for_endpoints([endpoint_url])
Expand Down Expand Up @@ -92,7 +93,7 @@ def get_shex_dict_from_void(
return shex_dict


def get_shex_from_void(endpoint_url: str, namespaces_to_ignore: list[str] | None = None) -> str:
def get_shex_from_void(endpoint_url: str, namespaces_to_ignore: Optional[list[str]] = None) -> str:
"""Function to build complete ShEx from VoID description with prefixes and all shapes"""
prefix_map = get_prefixes_for_endpoints([endpoint_url])
shex_dict = get_shex_dict_from_void(endpoint_url, prefix_map, namespaces_to_ignore)
Expand Down

0 comments on commit 6e372ca

Please sign in to comment.