10
10
from pathlib import Path
11
11
from typing import AsyncIterator , Awaitable , Callable , cast
12
12
13
- import aiosqlite
14
13
import anyio
15
14
from anyio import TASK_STATUS_IGNORED , Event , Lock , create_task_group
16
15
from anyio .abc import TaskGroup , TaskStatus
17
16
from pycrdt import Doc
17
+ from sqlite_anyio import Connection , connect
18
18
19
19
from .yutils import Decoder , get_new_path , write_var_uint
20
20
@@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
83
83
if self ._task_group is not None :
84
84
raise RuntimeError ("YStore already running" )
85
85
86
- self .started .set ()
87
- self ._starting = False
88
- task_status .started ()
86
+ async with create_task_group () as self ._task_group :
87
+ self .started .set ()
88
+ self ._starting = False
89
+ task_status .started ()
89
90
90
- def stop (self ) -> None :
91
+ async def stop (self ) -> None :
91
92
"""Stop the store."""
92
93
if self ._task_group is None :
93
94
raise RuntimeError ("YStore not running" )
@@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
300
301
path : str
301
302
lock : Lock
302
303
db_initialized : Event
304
+ _db : Connection | None
303
305
304
306
def __init__ (
305
307
self ,
@@ -319,6 +321,7 @@ def __init__(
319
321
self .log = log or getLogger (__name__ )
320
322
self .lock = Lock ()
321
323
self .db_initialized = Event ()
324
+ self ._db = None
322
325
323
326
async def start (self , * , task_status : TaskStatus [None ] = TASK_STATUS_IGNORED ):
324
327
"""Start the SQLiteYStore.
@@ -340,43 +343,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
340
343
self ._starting = False
341
344
task_status .started ()
342
345
346
+ async def stop (self ) -> None :
347
+ """Stop the store."""
348
+ if self ._db is not None :
349
+ await self ._db .close ()
350
+ await super ().stop ()
351
+
343
352
async def _init_db (self ):
344
353
create_db = False
345
354
move_db = False
346
355
if not await anyio .Path (self .db_path ).exists ():
347
356
create_db = True
348
357
else :
349
358
async with self .lock :
350
- async with aiosqlite .connect (self .db_path ) as db :
351
- cursor = await db .execute (
352
- "SELECT count(name) FROM sqlite_master "
353
- "WHERE type='table' and name='yupdates'"
354
- )
355
- table_exists = (await cursor .fetchone ())[0 ]
356
- if table_exists :
357
- cursor = await db .execute ("pragma user_version" )
358
- version = (await cursor .fetchone ())[0 ]
359
- if version != self .version :
360
- move_db = True
361
- create_db = True
362
- else :
359
+ db = await connect (self .db_path )
360
+ cursor = await db .cursor ()
361
+ await cursor .execute (
362
+ "SELECT count(name) FROM sqlite_master "
363
+ "WHERE type='table' and name='yupdates'"
364
+ )
365
+ table_exists = (await cursor .fetchone ())[0 ]
366
+ if table_exists :
367
+ await cursor .execute ("pragma user_version" )
368
+ version = (await cursor .fetchone ())[0 ]
369
+ if version != self .version :
370
+ move_db = True
363
371
create_db = True
372
+ else :
373
+ create_db = True
374
+ await db .close ()
364
375
if move_db :
365
376
new_path = await get_new_path (self .db_path )
366
377
self .log .warning ("YStore version mismatch, moving %s to %s" , self .db_path , new_path )
367
378
await anyio .Path (self .db_path ).rename (new_path )
368
379
if create_db :
369
380
async with self .lock :
370
- async with aiosqlite .connect (self .db_path ) as db :
371
- await db .execute (
372
- "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
373
- "metadata BLOB, timestamp REAL NOT NULL)"
374
- )
375
- await db .execute (
376
- "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
377
- )
378
- await db .execute (f"PRAGMA user_version = { self .version } " )
379
- await db .commit ()
381
+ db = await connect (self .db_path )
382
+ cursor = await db .cursor ()
383
+ await cursor .execute (
384
+ "CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
385
+ "metadata BLOB, timestamp REAL NOT NULL)"
386
+ )
387
+ await cursor .execute (
388
+ "CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
389
+ )
390
+ await cursor .execute (f"PRAGMA user_version = { self .version } " )
391
+ await db .commit ()
392
+ await db .close ()
393
+ self ._db = await connect (self .db_path )
380
394
self .db_initialized .set ()
381
395
382
396
async def read (self ) -> AsyncIterator [tuple [bytes , bytes , float ]]:
@@ -388,17 +402,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
388
402
await self .db_initialized .wait ()
389
403
try :
390
404
async with self .lock :
391
- async with aiosqlite . connect ( self .db_path ) as db :
392
- async with db .execute (
393
- "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?" ,
394
- (self .path ,),
395
- ) as cursor :
396
- found = False
397
- async for update , metadata , timestamp in cursor :
398
- found = True
399
- yield update , metadata , timestamp
400
- if not found :
401
- raise YDocNotFound
405
+ cursor = await self ._db . cursor ()
406
+ await cursor .execute (
407
+ "SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?" ,
408
+ (self .path ,),
409
+ )
410
+ found = False
411
+ for update , metadata , timestamp in await cursor . fetchall () :
412
+ found = True
413
+ yield update , metadata , timestamp
414
+ if not found :
415
+ raise YDocNotFound
402
416
except Exception :
403
417
raise YDocNotFound
404
418
@@ -410,38 +424,35 @@ async def write(self, data: bytes) -> None:
410
424
"""
411
425
await self .db_initialized .wait ()
412
426
async with self .lock :
413
- async with aiosqlite .connect (self .db_path ) as db :
414
- # first, determine time elapsed since last update
415
- cursor = await db .execute (
416
- "SELECT timestamp FROM yupdates WHERE path = ? "
417
- "ORDER BY timestamp DESC LIMIT 1" ,
418
- (self .path ,),
419
- )
420
- row = await cursor .fetchone ()
421
- diff = (time .time () - row [0 ]) if row else 0
422
-
423
- if self .document_ttl is not None and diff > self .document_ttl :
424
- # squash updates
425
- ydoc = Doc ()
426
- async with db .execute (
427
- "SELECT yupdate FROM yupdates WHERE path = ?" , (self .path ,)
428
- ) as cursor :
429
- async for (update ,) in cursor :
430
- ydoc .apply_update (update )
431
- # delete history
432
- await db .execute ("DELETE FROM yupdates WHERE path = ?" , (self .path ,))
433
- # insert squashed updates
434
- squashed_update = ydoc .get_update ()
435
- metadata = await self .get_metadata ()
436
- await db .execute (
437
- "INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
438
- (self .path , squashed_update , metadata , time .time ()),
439
- )
440
-
441
- # finally, write this update to the DB
427
+ # first, determine time elapsed since last update
428
+ cursor = await self ._db .cursor ()
429
+ await cursor .execute (
430
+ "SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1" ,
431
+ (self .path ,),
432
+ )
433
+ row = await cursor .fetchone ()
434
+ diff = (time .time () - row [0 ]) if row else 0
435
+
436
+ if self .document_ttl is not None and diff > self .document_ttl :
437
+ # squash updates
438
+ ydoc = Doc ()
439
+ await cursor .execute ("SELECT yupdate FROM yupdates WHERE path = ?" , (self .path ,))
440
+ for (update ,) in await cursor .fetchall ():
441
+ ydoc .apply_update (update )
442
+ # delete history
443
+ await cursor .execute ("DELETE FROM yupdates WHERE path = ?" , (self .path ,))
444
+ # insert squashed updates
445
+ squashed_update = ydoc .get_update ()
442
446
metadata = await self .get_metadata ()
443
- await db .execute (
447
+ await cursor .execute (
444
448
"INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
445
- (self .path , data , metadata , time .time ()),
449
+ (self .path , squashed_update , metadata , time .time ()),
446
450
)
447
- await db .commit ()
451
+
452
+ # finally, write this update to the DB
453
+ metadata = await self .get_metadata ()
454
+ await cursor .execute (
455
+ "INSERT INTO yupdates VALUES (?, ?, ?, ?)" ,
456
+ (self .path , data , metadata , time .time ()),
457
+ )
458
+ await self ._db .commit ()
0 commit comments