8
8
using an atomic transaction wrapper function.
9
9
10
10
"""
11
-
12
11
import sqlite3
13
12
import json
14
13
import pathlib
15
14
from functools import lru_cache
16
15
from jinja2 import Environment , BaseLoader , select_autoescape
16
+ from typing import Any , Callable , Optional , TypeVar
17
+ from pathlib import Path
17
18
19
+ Json = dict [str , Any ]
20
+ Identifier = int | str | None
18
21
19
22
@lru_cache (maxsize = None )
20
- def read_sql (sql_file ) :
23
+ def read_sql (sql_file : str ) -> str :
21
24
with open (pathlib .Path (__file__ ).parent .resolve () / "sql" / sql_file ) as f :
22
25
return f .read ()
23
26
24
27
25
28
class SqlTemplateLoader (BaseLoader ):
26
- def get_source (self , environment , template ) :
27
- return read_sql (template ), template , True
29
+ def get_source (self , environment : "Environment" , template : str ) -> tuple [ str , Optional [ str ], Optional [ Callable [[], bool ]]] :
30
+ return read_sql (template ), template , lambda : True
28
31
29
32
30
33
env = Environment (
@@ -37,7 +40,8 @@ def get_source(self, environment, template):
37
40
traverse_template = env .get_template ('traverse.template' )
38
41
39
42
40
- def atomic (db_file , cursor_exec_fn ):
43
+ T = TypeVar ("T" )
44
+ def atomic (db_file : Path , cursor_exec_fn : Callable [[sqlite3 .Cursor ], T ]) -> T :
41
45
connection = None
42
46
try :
43
47
connection = sqlite3 .connect (db_file )
@@ -51,37 +55,37 @@ def atomic(db_file, cursor_exec_fn):
51
55
return results
52
56
53
57
54
- def initialize (db_file , schema_file = 'schema.sql' ):
55
- def _init (cursor ) :
58
+ def initialize (db_file : Path , schema_file : str = 'schema.sql' ) -> None :
59
+ def _init (cursor : sqlite3 . Cursor ) -> None :
56
60
cursor .executescript (read_sql (schema_file ))
57
61
return atomic (db_file , _init )
58
62
59
63
60
- def _set_id (identifier , data ) :
64
+ def _set_id (identifier : Identifier , data : Json ) -> Json :
61
65
if identifier is not None :
62
66
data ["id" ] = identifier
63
67
return data
64
68
65
69
66
- def _insert_node (cursor , identifier , data ) :
70
+ def _insert_node (cursor : sqlite3 . Cursor , identifier : Identifier , data : Json ) -> None :
67
71
cursor .execute (read_sql ('insert-node.sql' ),
68
72
(json .dumps (_set_id (identifier , data )),))
69
73
70
74
71
- def add_node (data , identifier = None ):
72
- def _add_node (cursor ) :
75
+ def add_node (data : Json , identifier : Identifier = None ) -> Callable [[ sqlite3 . Cursor ], None ] :
76
+ def _add_node (cursor : sqlite3 . Cursor ) -> None :
73
77
_insert_node (cursor , identifier , data )
74
78
return _add_node
75
79
76
80
77
- def add_nodes (nodes , ids ):
78
- def _add_nodes (cursor ) :
81
+ def add_nodes (nodes : list [ Json ] , ids : list [ int | str ] ):
82
+ def _add_nodes (cursor : sqlite3 . Cursor ) -> None :
79
83
cursor .executemany (read_sql ('insert-node.sql' ), [(x ,) for x in map (
80
84
lambda node : json .dumps (_set_id (node [0 ], node [1 ])), zip (ids , nodes ))])
81
85
return _add_nodes
82
86
83
87
84
- def _upsert_node (cursor , identifier , data ) :
88
+ def _upsert_node (cursor : sqlite3 . Cursor , identifier : str | int , data : Json ) -> None :
85
89
current_data = find_node (identifier )(cursor )
86
90
if not current_data :
87
91
# no prior record exists, so regular insert
@@ -93,50 +97,50 @@ def _upsert_node(cursor, identifier, data):
93
97
'update-node.sql' ), (json .dumps (_set_id (identifier , updated_data )), identifier ,))
94
98
95
99
96
- def upsert_node (identifier , data ) :
97
- def _upsert (cursor ) :
100
+ def upsert_node (identifier : str , data : Json ) -> Callable [[ sqlite3 . Cursor ], None ] :
101
+ def _upsert (cursor : sqlite3 . Cursor ) -> None :
98
102
_upsert_node (cursor , identifier , data )
99
103
return _upsert
100
104
101
105
102
- def upsert_nodes (nodes , ids ) :
103
- def _upsert (cursor ) :
106
+ def upsert_nodes (nodes : list [ Json ] , ids : list [ str | int ]) -> Callable [[ sqlite3 . Cursor ], None ] :
107
+ def _upsert (cursor : sqlite3 . Cursor ) -> None :
104
108
for (id , node ) in zip (ids , nodes ):
105
109
_upsert_node (cursor , id , node )
106
110
return _upsert
107
111
108
112
109
- def connect_nodes (source_id , target_id , properties = {}):
110
- def _connect_nodes (cursor ):
113
+ def connect_nodes (source_id : Identifier , target_id : Identifier , properties : Json = {}) -> Callable [[ sqlite3 . Cursor ], None ] :
114
+ def _connect_nodes (cursor : sqlite3 . Cursor ):
111
115
cursor .execute (read_sql ('insert-edge.sql' ),
112
116
(source_id , target_id , json .dumps (properties ),))
113
117
return _connect_nodes
114
118
115
119
116
- def connect_many_nodes (sources , targets , properties ) :
117
- def _connect_nodes (cursor ):
120
+ def connect_many_nodes (sources : list [ str | int ] , targets : list [ str | int ] , properties : list [ Json ]) -> Callable [[ sqlite3 . Cursor ], None ] :
121
+ def _connect_nodes (cursor : sqlite3 . Cursor ):
118
122
cursor .executemany (read_sql (
119
123
'insert-edge.sql' ), [(x [0 ], x [1 ], json .dumps (x [2 ]),) for x in zip (sources , targets , properties )])
120
124
return _connect_nodes
121
125
122
126
123
- def remove_node (identifier ) :
124
- def _remove_node (cursor ) :
127
+ def remove_node (identifier : str ) -> Callable [[ sqlite3 . Cursor ], None ] :
128
+ def _remove_node (cursor : sqlite3 . Cursor ) -> None :
125
129
cursor .execute (read_sql ('delete-edge.sql' ), (identifier , identifier ,))
126
130
cursor .execute (read_sql ('delete-node.sql' ), (identifier ,))
127
131
return _remove_node
128
132
129
133
130
- def remove_nodes (identifiers ) :
131
- def _remove_node (cursor ):
134
+ def remove_nodes (identifiers : list [ str | int ]) -> Callable [[ sqlite3 . Cursor ], None ] :
135
+ def _remove_node (cursor : sqlite3 . Cursor ):
132
136
cursor .executemany (read_sql (
133
137
'delete-edge.sql' ), [(identifier , identifier ,) for identifier in identifiers ])
134
138
cursor .executemany (read_sql ('delete-node.sql' ),
135
139
[(identifier ,) for identifier in identifiers ])
136
140
return _remove_node
137
141
138
142
139
- def _generate_clause (key , predicate = None , joiner = None , tree = False , tree_with_key = False ):
143
+ def generate_clause (key : str , predicate : None | str = None , joiner : None | str = None , tree : bool = False , tree_with_key : bool = False ) -> str :
140
144
'''Given at minimum a key in the body json, generate a query clause
141
145
which can be bound to a corresponding value at point of execution'''
142
146
@@ -154,7 +158,7 @@ def _generate_clause(key, predicate=None, joiner=None, tree=False, tree_with_key
154
158
return clause_template .render (and_or = joiner , key = key , predicate = predicate , key_value = True )
155
159
156
160
157
- def _generate_query (where_clauses , result_column = None , key = None , tree = False ):
161
+ def _generate_query (where_clauses : list [ str ] , result_column : str | None = None , key : None | str = None , tree : bool = False ) -> str :
158
162
'''Generate the search query, selecting either the id or the body,
159
163
adding the json_tree function and optionally the key, as needed'''
160
164
@@ -170,43 +174,43 @@ def _generate_query(where_clauses, result_column=None, key=None, tree=False):
170
174
return search_template .render (result_column = result_column , search_clauses = where_clauses )
171
175
172
176
173
- def find_node (identifier ) :
174
- def _find_node (cursor ) :
177
+ def find_node (identifier : str | int ) -> Callable [[ sqlite3 . Cursor ], Json ] :
178
+ def _find_node (cursor : sqlite3 . Cursor ) -> Json :
175
179
query = _generate_query ([clause_template .render (id_lookup = True )])
176
180
result = cursor .execute (query , (identifier ,)).fetchone ()
177
181
return {} if not result else json .loads (result [0 ])
178
182
return _find_node
179
183
180
184
181
- def _parse_search_results (results , idx = 0 ):
182
- print (results )
185
+ def _parse_search_results (results :list [tuple [str , ...]], idx :int = 0 ) -> list [Json ]:
183
186
return [json .loads (item [idx ]) for item in results ]
184
187
185
188
186
- def find_nodes (where_clauses , bindings , tree_query = False , key = None ):
187
- def _find_nodes (cursor ):
189
+ def find_nodes (where_clauses : list [ str ] , bindings : tuple [ str , ...], tree_query : bool = False , key : str | None = None ):
190
+ def _find_nodes (cursor : sqlite3 . Cursor ):
188
191
query = _generate_query (where_clauses , key = key , tree = tree_query )
189
192
return _parse_search_results (cursor .execute (query , bindings ).fetchall ())
190
193
return _find_nodes
191
194
192
195
193
- def find_neighbors (with_bodies = False ):
196
+ def find_neighbors (with_bodies : bool = False ) -> str :
194
197
return traverse_template .render (with_bodies = with_bodies , inbound = True , outbound = True )
195
198
196
199
197
- def find_outbound_neighbors (with_bodies = False ):
200
+ def find_outbound_neighbors (with_bodies : bool = False ) -> str :
198
201
return traverse_template .render (with_bodies = with_bodies , outbound = True )
199
202
200
203
201
- def find_inbound_neighbors (with_bodies = False ):
204
+ def find_inbound_neighbors (with_bodies : bool = False ) -> str :
202
205
return traverse_template .render (with_bodies = with_bodies , inbound = True )
203
206
204
207
205
- def traverse (db_file , src , tgt = None , neighbors_fn = find_neighbors , with_bodies = False ):
206
- def _traverse (cursor ) :
207
- path = []
208
+ def traverse (db_file : Path , src : Identifier , tgt : Identifier = None , neighbors_fn : Callable [[ bool ], str ] = find_neighbors , with_bodies : bool = False ) -> list [ str | tuple [ str , str , str ]] :
209
+ def _traverse (cursor : sqlite3 . Cursor ) -> list [ str | tuple [ str , str , str ]] :
210
+ path : list [ str | tuple [ str , str , str ]] = []
208
211
target = json .dumps (tgt )
209
- for row in cursor .execute (neighbors_fn (with_bodies = with_bodies ), (src ,)):
212
+ neighbors_sql :str = neighbors_fn (with_bodies )
213
+ for row in cursor .execute (neighbors_sql , (src ,)):
210
214
if row :
211
215
if with_bodies :
212
216
identifier , obj , _ = row
@@ -223,21 +227,21 @@ def _traverse(cursor):
223
227
return atomic (db_file , _traverse )
224
228
225
229
226
- def connections_in ():
230
+ def connections_in () -> str :
227
231
return read_sql ('search-edges-inbound.sql' )
228
232
229
233
230
- def connections_out ():
234
+ def connections_out () -> str :
231
235
return read_sql ('search-edges-outbound.sql' )
232
236
233
237
234
- def get_connections_one_way (identifier , direction = connections_in ):
235
- def _get_connections (cursor ) :
238
+ def get_connections_one_way (identifier : Identifier , direction : Callable [[], str ] = connections_in ):
239
+ def _get_connections (cursor : sqlite3 . Cursor ) -> list [ tuple [ str , str , str ]] :
236
240
return cursor .execute (direction (), (identifier ,)).fetchall ()
237
241
return _get_connections
238
242
239
243
240
- def get_connections (identifier ) :
241
- def _get_connections (cursor ) :
244
+ def get_connections (identifier : str | int ) -> Callable [[ sqlite3 . Cursor ], list [ tuple [ str , str , str ]]] :
245
+ def _get_connections (cursor : sqlite3 . Cursor ) -> list [ tuple [ str , str , str ]] :
242
246
return cursor .execute (read_sql ('search-edges.sql' ), (identifier , identifier ,)).fetchall ()
243
247
return _get_connections
0 commit comments