Skip to content

Commit c4a2d1a

Browse files
committed
Replace aiosqlite with sqlite-anyio
1 parent 6fe12dd commit c4a2d1a

File tree

3 files changed

+98
-81
lines changed

3 files changed

+98
-81
lines changed

Diff for: pycrdt_websocket/ystore.py

+82-71
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@
1010
from pathlib import Path
1111
from typing import AsyncIterator, Awaitable, Callable, cast
1212

13-
import aiosqlite
1413
import anyio
1514
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
1615
from anyio.abc import TaskGroup, TaskStatus
1716
from pycrdt import Doc
17+
from sqlite_anyio import Connection, connect
1818

1919
from .yutils import Decoder, get_new_path, write_var_uint
2020

@@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
8383
if self._task_group is not None:
8484
raise RuntimeError("YStore already running")
8585

86-
self.started.set()
87-
self._starting = False
88-
task_status.started()
86+
async with create_task_group() as self._task_group:
87+
self.started.set()
88+
self._starting = False
89+
task_status.started()
8990

90-
def stop(self) -> None:
91+
async def stop(self) -> None:
9192
"""Stop the store."""
9293
if self._task_group is None:
9394
raise RuntimeError("YStore not running")
@@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
300301
path: str
301302
lock: Lock
302303
db_initialized: Event
304+
_db: Connection | None
303305

304306
def __init__(
305307
self,
@@ -319,6 +321,7 @@ def __init__(
319321
self.log = log or getLogger(__name__)
320322
self.lock = Lock()
321323
self.db_initialized = Event()
324+
self._db = None
322325

323326
async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
324327
"""Start the SQLiteYStore.
@@ -340,43 +343,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
340343
self._starting = False
341344
task_status.started()
342345

346+
async def stop(self) -> None:
347+
"""Stop the store."""
348+
if self._db is not None:
349+
await self._db.close()
350+
await super().stop()
351+
343352
async def _init_db(self):
344353
create_db = False
345354
move_db = False
346355
if not await anyio.Path(self.db_path).exists():
347356
create_db = True
348357
else:
349358
async with self.lock:
350-
async with aiosqlite.connect(self.db_path) as db:
351-
cursor = await db.execute(
352-
"SELECT count(name) FROM sqlite_master "
353-
"WHERE type='table' and name='yupdates'"
354-
)
355-
table_exists = (await cursor.fetchone())[0]
356-
if table_exists:
357-
cursor = await db.execute("pragma user_version")
358-
version = (await cursor.fetchone())[0]
359-
if version != self.version:
360-
move_db = True
361-
create_db = True
362-
else:
359+
db = await connect(self.db_path)
360+
cursor = await db.cursor()
361+
await cursor.execute(
362+
"SELECT count(name) FROM sqlite_master "
363+
"WHERE type='table' and name='yupdates'"
364+
)
365+
table_exists = (await cursor.fetchone())[0]
366+
if table_exists:
367+
await cursor.execute("pragma user_version")
368+
version = (await cursor.fetchone())[0]
369+
if version != self.version:
370+
move_db = True
363371
create_db = True
372+
else:
373+
create_db = True
374+
await db.close()
364375
if move_db:
365376
new_path = await get_new_path(self.db_path)
366377
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
367378
await anyio.Path(self.db_path).rename(new_path)
368379
if create_db:
369380
async with self.lock:
370-
async with aiosqlite.connect(self.db_path) as db:
371-
await db.execute(
372-
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
373-
"metadata BLOB, timestamp REAL NOT NULL)"
374-
)
375-
await db.execute(
376-
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
377-
)
378-
await db.execute(f"PRAGMA user_version = {self.version}")
379-
await db.commit()
381+
db = await connect(self.db_path)
382+
cursor = await db.cursor()
383+
await cursor.execute(
384+
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
385+
"metadata BLOB, timestamp REAL NOT NULL)"
386+
)
387+
await cursor.execute(
388+
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
389+
)
390+
await cursor.execute(f"PRAGMA user_version = {self.version}")
391+
await db.commit()
392+
await db.close()
393+
self._db = await connect(self.db_path)
380394
self.db_initialized.set()
381395

382396
async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
@@ -388,17 +402,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
388402
await self.db_initialized.wait()
389403
try:
390404
async with self.lock:
391-
async with aiosqlite.connect(self.db_path) as db:
392-
async with db.execute(
393-
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
394-
(self.path,),
395-
) as cursor:
396-
found = False
397-
async for update, metadata, timestamp in cursor:
398-
found = True
399-
yield update, metadata, timestamp
400-
if not found:
401-
raise YDocNotFound
405+
cursor = await self._db.cursor()
406+
await cursor.execute(
407+
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
408+
(self.path,),
409+
)
410+
found = False
411+
for update, metadata, timestamp in await cursor.fetchall():
412+
found = True
413+
yield update, metadata, timestamp
414+
if not found:
415+
raise YDocNotFound
402416
except Exception:
403417
raise YDocNotFound
404418

@@ -410,38 +424,35 @@ async def write(self, data: bytes) -> None:
410424
"""
411425
await self.db_initialized.wait()
412426
async with self.lock:
413-
async with aiosqlite.connect(self.db_path) as db:
414-
# first, determine time elapsed since last update
415-
cursor = await db.execute(
416-
"SELECT timestamp FROM yupdates WHERE path = ? "
417-
"ORDER BY timestamp DESC LIMIT 1",
418-
(self.path,),
419-
)
420-
row = await cursor.fetchone()
421-
diff = (time.time() - row[0]) if row else 0
422-
423-
if self.document_ttl is not None and diff > self.document_ttl:
424-
# squash updates
425-
ydoc = Doc()
426-
async with db.execute(
427-
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
428-
) as cursor:
429-
async for (update,) in cursor:
430-
ydoc.apply_update(update)
431-
# delete history
432-
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
433-
# insert squashed updates
434-
squashed_update = ydoc.get_update()
435-
metadata = await self.get_metadata()
436-
await db.execute(
437-
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
438-
(self.path, squashed_update, metadata, time.time()),
439-
)
440-
441-
# finally, write this update to the DB
427+
# first, determine time elapsed since last update
428+
cursor = await self._db.cursor()
429+
await cursor.execute(
430+
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
431+
(self.path,),
432+
)
433+
row = await cursor.fetchone()
434+
diff = (time.time() - row[0]) if row else 0
435+
436+
if self.document_ttl is not None and diff > self.document_ttl:
437+
# squash updates
438+
ydoc = Doc()
439+
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
440+
for (update,) in await cursor.fetchall():
441+
ydoc.apply_update(update)
442+
# delete history
443+
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
444+
# insert squashed updates
445+
squashed_update = ydoc.get_update()
442446
metadata = await self.get_metadata()
443-
await db.execute(
447+
await cursor.execute(
444448
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
445-
(self.path, data, metadata, time.time()),
449+
(self.path, squashed_update, metadata, time.time()),
446450
)
447-
await db.commit()
451+
452+
# finally, write this update to the DB
453+
metadata = await self.get_metadata()
454+
await cursor.execute(
455+
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
456+
(self.path, data, metadata, time.time()),
457+
)
458+
await self._db.commit()

Diff for: pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ classifiers = [
2929
]
3030
dependencies = [
3131
"anyio >=3.6.2,<5",
32-
"aiosqlite >=0.18.0,<1",
32+
"sqlite-anyio >=0.2.0,<0.3.0",
3333
"pycrdt >=0.8.7,<0.9.0",
3434
]
3535

@@ -68,7 +68,7 @@ include = [
6868

6969
[tool.ruff]
7070
line-length = 99
71-
select = [
71+
lint.select = [
7272
"ASYNC", # flake8-async
7373
"E", "F", "W", # default Flake8
7474
"G", # flake8-logging-format

Diff for: tests/test_ystore.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from pathlib import Path
55
from unittest.mock import patch
66

7-
import aiosqlite
87
import pytest
8+
from sqlite_anyio import connect
99

1010
from pycrdt_websocket.ystore import SQLiteYStore, TempFileYStore
1111

@@ -59,31 +59,36 @@ async def test_ystore(YStore):
5959

6060
assert i == len(data)
6161

62+
await ystore.stop()
63+
6264

6365
@pytest.mark.anyio
6466
async def test_document_ttl_sqlite_ystore(test_ydoc):
6567
store_name = "my_store"
6668
ystore = MySQLiteYStore(store_name, delete_db=True)
6769
await ystore.start()
6870
now = time.time()
71+
db = await connect(ystore.db_path)
72+
cursor = await db.cursor()
6973

7074
for i in range(3):
7175
# assert that adding a record before document TTL doesn't delete document history
7276
with patch("time.time") as mock_time:
7377
mock_time.return_value = now
7478
await ystore.write(test_ydoc.update())
75-
async with aiosqlite.connect(ystore.db_path) as db:
76-
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[
77-
0
78-
] == i + 1
79+
assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[
80+
0
81+
] == i + 1
7982

8083
# assert that adding a record after document TTL deletes previous document history
8184
with patch("time.time") as mock_time:
8285
mock_time.return_value = now + ystore.document_ttl + 1
8386
await ystore.write(test_ydoc.update())
84-
async with aiosqlite.connect(ystore.db_path) as db:
85-
# two updates in DB: one squashed update and the new update
86-
assert (await (await db.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2
87+
# two updates in DB: one squashed update and the new update
88+
assert (await (await cursor.execute("SELECT count(*) FROM yupdates")).fetchone())[0] == 2
89+
90+
await db.close()
91+
await ystore.stop()
8792

8893

8994
@pytest.mark.anyio
@@ -97,3 +102,4 @@ async def test_version(YStore, caplog):
97102
await ystore.write(b"foo")
98103
YStore.version = prev_version
99104
assert "YStore version mismatch" in caplog.text
105+
await ystore.stop()

0 commit comments

Comments
 (0)