Skip to content

Commit 3bf5569

Browse files
committed
Adds typing info
1 parent 788f9ef commit 3bf5569

File tree

10 files changed

+132
-102
lines changed

10 files changed

+132
-102
lines changed

.vscode/settings.json

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
{
2+
"python.testing.pytestArgs": [],
3+
"python.testing.unittestEnabled": false,
4+
"python.testing.pytestEnabled": true,
5+
"python.analysis.diagnosticMode": "workspace",
6+
"python.analysis.typeCheckingMode": "strict",
7+
"files.exclude": {
8+
".pytest_cache": true,
9+
".venv": true,
10+
"**/__pycache__": true
11+
}
12+
}

README.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -101,16 +101,16 @@ The nodes can be searched by their ids:
101101
Searches can also use combinations of other attributes, both as strict equality, or using `LIKE` in combination with a trailing `%` for "starts with" or `%` at both ends for "contains":
102102

103103
```
104-
>>> db.atomic(apple, db.find_nodes([db._generate_clause('name', predicate='LIKE')], ('Steve%',)))
104+
>>> db.atomic(apple, db.find_nodes([db.generate_clause('name', predicate='LIKE')], ('Steve%',)))
105105
[{'name': 'Steve Wozniak', 'type': ['person', 'engineer', 'founder'], 'id': 2, 'nickname': 'Woz'}, {'name': 'Steve Jobs', 'type': ['person', 'designer', 'founder'], 'id': 3}]
106-
>>> db.atomic(apple, db.find_nodes([db._generate_clause('name', predicate='LIKE'), db._generate_clause('name', predicate='LIKE', joiner='OR')], ('%Woz%', '%Markkula',)))
106+
>>> db.atomic(apple, db.find_nodes([db.generate_clause('name', predicate='LIKE'), db.generate_clause('name', predicate='LIKE', joiner='OR')], ('%Woz%', '%Markkula',)))
107107
[{'name': 'Steve Wozniak', 'type': ['person', 'engineer', 'founder'], 'id': 2, 'nickname': 'Woz'}, {'name': 'Mike Markkula', 'type': ['person', 'investor'], 'id': 5}]
108108
```
109109

110110
More complex queries to introspect the json body, using the [sqlite json_tree() function](https://www.sqlite.org/json1.html), are also possible, such as this query for every node whose `type` array contains the value `founder`:
111111

112112
```
113-
>>> db.atomic(apple, db.find_nodes([db._generate_clause('type', tree=True)], ('founder',), tree_query=True, key='type'))
113+
>>> db.atomic(apple, db.find_nodes([db.generate_clause('type', tree=True)], ('founder',), tree_query=True, key='type'))
114114
[{'name': 'Steve Wozniak', 'type': ['person', 'engineer', 'founder'], 'id': 2, 'nickname': 'Woz'}, {'name': 'Steve Jobs', 'type': ['person', 'designer', 'founder'], 'id': 3}, {'name': 'Ronald Wayne', 'type': ['person', 'administrator', 'founder'], 'id': 4}]
115115
```
116116

pyproject.toml

+3
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ requires = [
44
"wheel"
55
]
66
build-backend = "setuptools.build_meta"
7+
testpaths = [
8+
"tests"
9+
]

setup.cfg

+6
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,9 @@
22
# This includes the license file(s) in the wheel.
33
# https://wheel.readthedocs.io/en/stable/user_guide.html#including-license-files-in-the-generated-wheel-file
44
license_files = LICENSE
5+
[tool:pytest]
6+
minversion = 6.0
7+
addopts = -ra -q
8+
testpaths =
9+
tests
10+
pythonpath = src tests

setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from setuptools import setup, find_packages
1+
from setuptools import setup
22
import pathlib
33

44
here = pathlib.Path(__file__).parent.resolve()

src/simple_graph_sqlite/database.py

+51-47
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,26 @@
88
using an atomic transaction wrapper function.
99
1010
"""
11-
1211
import sqlite3
1312
import json
1413
import pathlib
1514
from functools import lru_cache
1615
from jinja2 import Environment, BaseLoader, select_autoescape
16+
from typing import Any, Callable, Optional, TypeVar
17+
from pathlib import Path
1718

19+
Json = dict[str, Any]
20+
Identifier = int|str|None
1821

1922
@lru_cache(maxsize=None)
20-
def read_sql(sql_file):
23+
def read_sql(sql_file: str) -> str:
2124
with open(pathlib.Path(__file__).parent.resolve() / "sql" / sql_file) as f:
2225
return f.read()
2326

2427

2528
class SqlTemplateLoader(BaseLoader):
26-
def get_source(self, environment, template):
27-
return read_sql(template), template, True
29+
def get_source(self, environment: "Environment", template: str) -> tuple[str, Optional[str], Optional[Callable[[], bool]]]:
30+
return read_sql(template), template, lambda: True
2831

2932

3033
env = Environment(
@@ -37,7 +40,8 @@ def get_source(self, environment, template):
3740
traverse_template = env.get_template('traverse.template')
3841

3942

40-
def atomic(db_file, cursor_exec_fn):
43+
T = TypeVar("T")
44+
def atomic(db_file: Path, cursor_exec_fn: Callable[[sqlite3.Cursor], T]) -> T:
4145
connection = None
4246
try:
4347
connection = sqlite3.connect(db_file)
@@ -51,37 +55,37 @@ def atomic(db_file, cursor_exec_fn):
5155
return results
5256

5357

54-
def initialize(db_file, schema_file='schema.sql'):
55-
def _init(cursor):
58+
def initialize(db_file: Path, schema_file:str='schema.sql') -> None:
59+
def _init(cursor: sqlite3.Cursor) -> None:
5660
cursor.executescript(read_sql(schema_file))
5761
return atomic(db_file, _init)
5862

5963

60-
def _set_id(identifier, data):
64+
def _set_id(identifier: Identifier, data: Json) -> Json:
6165
if identifier is not None:
6266
data["id"] = identifier
6367
return data
6468

6569

66-
def _insert_node(cursor, identifier, data):
70+
def _insert_node(cursor: sqlite3.Cursor, identifier: Identifier, data: Json) -> None:
6771
cursor.execute(read_sql('insert-node.sql'),
6872
(json.dumps(_set_id(identifier, data)),))
6973

7074

71-
def add_node(data, identifier=None):
72-
def _add_node(cursor):
75+
def add_node(data: Json, identifier:Identifier=None) -> Callable[[sqlite3.Cursor], None]:
76+
def _add_node(cursor: sqlite3.Cursor) -> None:
7377
_insert_node(cursor, identifier, data)
7478
return _add_node
7579

7680

77-
def add_nodes(nodes, ids):
78-
def _add_nodes(cursor):
81+
def add_nodes(nodes: list[Json], ids: list[int|str]):
82+
def _add_nodes(cursor: sqlite3.Cursor) -> None:
7983
cursor.executemany(read_sql('insert-node.sql'), [(x,) for x in map(
8084
lambda node: json.dumps(_set_id(node[0], node[1])), zip(ids, nodes))])
8185
return _add_nodes
8286

8387

84-
def _upsert_node(cursor, identifier, data):
88+
def _upsert_node(cursor: sqlite3.Cursor, identifier:str|int, data: Json) -> None:
8589
current_data = find_node(identifier)(cursor)
8690
if not current_data:
8791
# no prior record exists, so regular insert
@@ -93,50 +97,50 @@ def _upsert_node(cursor, identifier, data):
9397
'update-node.sql'), (json.dumps(_set_id(identifier, updated_data)), identifier,))
9498

9599

96-
def upsert_node(identifier, data):
97-
def _upsert(cursor):
100+
def upsert_node(identifier: str, data: Json) -> Callable[[sqlite3.Cursor], None]:
101+
def _upsert(cursor: sqlite3.Cursor) -> None:
98102
_upsert_node(cursor, identifier, data)
99103
return _upsert
100104

101105

102-
def upsert_nodes(nodes, ids):
103-
def _upsert(cursor):
106+
def upsert_nodes(nodes: list[Json], ids: list[str|int]) -> Callable[[sqlite3.Cursor], None]:
107+
def _upsert(cursor: sqlite3.Cursor) -> None:
104108
for (id, node) in zip(ids, nodes):
105109
_upsert_node(cursor, id, node)
106110
return _upsert
107111

108112

109-
def connect_nodes(source_id, target_id, properties={}):
110-
def _connect_nodes(cursor):
113+
def connect_nodes(source_id: Identifier, target_id: Identifier, properties: Json={}) -> Callable[[sqlite3.Cursor], None]:
114+
def _connect_nodes(cursor: sqlite3.Cursor):
111115
cursor.execute(read_sql('insert-edge.sql'),
112116
(source_id, target_id, json.dumps(properties),))
113117
return _connect_nodes
114118

115119

116-
def connect_many_nodes(sources, targets, properties):
117-
def _connect_nodes(cursor):
120+
def connect_many_nodes(sources: list[str|int], targets: list[str|int], properties: list[Json]) -> Callable[[sqlite3.Cursor], None]:
121+
def _connect_nodes(cursor: sqlite3.Cursor):
118122
cursor.executemany(read_sql(
119123
'insert-edge.sql'), [(x[0], x[1], json.dumps(x[2]),) for x in zip(sources, targets, properties)])
120124
return _connect_nodes
121125

122126

123-
def remove_node(identifier):
124-
def _remove_node(cursor):
127+
def remove_node(identifier: str) -> Callable[[sqlite3.Cursor], None]:
128+
def _remove_node(cursor: sqlite3.Cursor) -> None:
125129
cursor.execute(read_sql('delete-edge.sql'), (identifier, identifier,))
126130
cursor.execute(read_sql('delete-node.sql'), (identifier,))
127131
return _remove_node
128132

129133

130-
def remove_nodes(identifiers):
131-
def _remove_node(cursor):
134+
def remove_nodes(identifiers: list[str|int]) -> Callable[[sqlite3.Cursor], None]:
135+
def _remove_node(cursor: sqlite3.Cursor):
132136
cursor.executemany(read_sql(
133137
'delete-edge.sql'), [(identifier, identifier,) for identifier in identifiers])
134138
cursor.executemany(read_sql('delete-node.sql'),
135139
[(identifier,) for identifier in identifiers])
136140
return _remove_node
137141

138142

139-
def _generate_clause(key, predicate=None, joiner=None, tree=False, tree_with_key=False):
143+
def generate_clause(key:str, predicate:None|str=None, joiner:None|str=None, tree:bool=False, tree_with_key:bool=False) -> str:
140144
'''Given at minimum a key in the body json, generate a query clause
141145
which can be bound to a corresponding value at point of execution'''
142146

@@ -154,7 +158,7 @@ def _generate_clause(key, predicate=None, joiner=None, tree=False, tree_with_key
154158
return clause_template.render(and_or=joiner, key=key, predicate=predicate, key_value=True)
155159

156160

157-
def _generate_query(where_clauses, result_column=None, key=None, tree=False):
161+
def _generate_query(where_clauses:list[str], result_column:str|None=None, key:None|str=None, tree:bool=False) -> str:
158162
'''Generate the search query, selecting either the id or the body,
159163
adding the json_tree function and optionally the key, as needed'''
160164

@@ -170,43 +174,43 @@ def _generate_query(where_clauses, result_column=None, key=None, tree=False):
170174
return search_template.render(result_column=result_column, search_clauses=where_clauses)
171175

172176

173-
def find_node(identifier):
174-
def _find_node(cursor):
177+
def find_node(identifier: str|int) -> Callable[[sqlite3.Cursor], Json]:
178+
def _find_node(cursor: sqlite3.Cursor) -> Json:
175179
query = _generate_query([clause_template.render(id_lookup=True)])
176180
result = cursor.execute(query, (identifier,)).fetchone()
177181
return {} if not result else json.loads(result[0])
178182
return _find_node
179183

180184

181-
def _parse_search_results(results, idx=0):
182-
print(results)
185+
def _parse_search_results(results:list[tuple[str, ...]], idx:int=0) -> list[Json]:
183186
return [json.loads(item[idx]) for item in results]
184187

185188

186-
def find_nodes(where_clauses, bindings, tree_query=False, key=None):
187-
def _find_nodes(cursor):
189+
def find_nodes(where_clauses:list[str], bindings:tuple[str, ...], tree_query:bool=False, key:str|None=None):
190+
def _find_nodes(cursor:sqlite3.Cursor):
188191
query = _generate_query(where_clauses, key=key, tree=tree_query)
189192
return _parse_search_results(cursor.execute(query, bindings).fetchall())
190193
return _find_nodes
191194

192195

193-
def find_neighbors(with_bodies=False):
196+
def find_neighbors(with_bodies:bool=False) -> str:
194197
return traverse_template.render(with_bodies=with_bodies, inbound=True, outbound=True)
195198

196199

197-
def find_outbound_neighbors(with_bodies=False):
200+
def find_outbound_neighbors(with_bodies:bool=False) -> str:
198201
return traverse_template.render(with_bodies=with_bodies, outbound=True)
199202

200203

201-
def find_inbound_neighbors(with_bodies=False):
204+
def find_inbound_neighbors(with_bodies:bool=False) -> str:
202205
return traverse_template.render(with_bodies=with_bodies, inbound=True)
203206

204207

205-
def traverse(db_file, src, tgt=None, neighbors_fn=find_neighbors, with_bodies=False):
206-
def _traverse(cursor):
207-
path = []
208+
def traverse(db_file:Path, src:Identifier, tgt:Identifier=None, neighbors_fn:Callable[[bool], str]=find_neighbors, with_bodies:bool=False) -> list[str|tuple[str,str,str]]:
209+
def _traverse(cursor: sqlite3.Cursor) -> list[str|tuple[str, str, str]]:
210+
path: list[str|tuple[str, str, str]] = []
208211
target = json.dumps(tgt)
209-
for row in cursor.execute(neighbors_fn(with_bodies=with_bodies), (src,)):
212+
neighbors_sql:str = neighbors_fn(with_bodies)
213+
for row in cursor.execute(neighbors_sql, (src,)):
210214
if row:
211215
if with_bodies:
212216
identifier, obj, _ = row
@@ -223,21 +227,21 @@ def _traverse(cursor):
223227
return atomic(db_file, _traverse)
224228

225229

226-
def connections_in():
230+
def connections_in() -> str:
227231
return read_sql('search-edges-inbound.sql')
228232

229233

230-
def connections_out():
234+
def connections_out() -> str:
231235
return read_sql('search-edges-outbound.sql')
232236

233237

234-
def get_connections_one_way(identifier, direction=connections_in):
235-
def _get_connections(cursor):
238+
def get_connections_one_way(identifier: Identifier, direction: Callable[[], str] = connections_in):
239+
def _get_connections(cursor: sqlite3.Cursor) -> list[tuple[str, str, str]]:
236240
return cursor.execute(direction(), (identifier,)).fetchall()
237241
return _get_connections
238242

239243

240-
def get_connections(identifier):
241-
def _get_connections(cursor):
244+
def get_connections(identifier: str|int) -> Callable[[sqlite3.Cursor], list[tuple[str, str, str]]]:
245+
def _get_connections(cursor: sqlite3.Cursor) -> list[tuple[str, str, str]]:
242246
return cursor.execute(read_sql('search-edges.sql'), (identifier, identifier,)).fetchall()
243247
return _get_connections

0 commit comments

Comments
 (0)