Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace aiosqlite with sqlite-anyio #22

Merged
merged 3 commits into from
Mar 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 81 additions & 71 deletions pycrdt_websocket/ystore.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
from pathlib import Path
from typing import AsyncIterator, Awaitable, Callable, cast

import aiosqlite
import anyio
from anyio import TASK_STATUS_IGNORED, Event, Lock, create_task_group
from anyio.abc import TaskGroup, TaskStatus
from pycrdt import Doc
from sqlite_anyio import Connection, connect

from .yutils import Decoder, get_new_path, write_var_uint

Expand Down Expand Up @@ -83,11 +83,12 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
if self._task_group is not None:
raise RuntimeError("YStore already running")

self.started.set()
self._starting = False
task_status.started()
async with create_task_group() as self._task_group:
self.started.set()
self._starting = False
task_status.started()

def stop(self) -> None:
async def stop(self) -> None:
"""Stop the store."""
if self._task_group is None:
raise RuntimeError("YStore not running")
Expand Down Expand Up @@ -300,6 +301,7 @@ class MySQLiteYStore(SQLiteYStore):
path: str
lock: Lock
db_initialized: Event
_db: Connection

def __init__(
self,
Expand Down Expand Up @@ -340,43 +342,54 @@ async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED):
self._starting = False
task_status.started()

async def stop(self) -> None:
"""Stop the store."""
if self.db_initialized.is_set():
await self._db.close()
await super().stop()

async def _init_db(self):
create_db = False
move_db = False
if not await anyio.Path(self.db_path).exists():
create_db = True
else:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
cursor = await db.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
cursor = await db.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"SELECT count(name) FROM sqlite_master "
"WHERE type='table' and name='yupdates'"
)
table_exists = (await cursor.fetchone())[0]
if table_exists:
await cursor.execute("pragma user_version")
version = (await cursor.fetchone())[0]
if version != self.version:
move_db = True
create_db = True
else:
create_db = True
await db.close()
if move_db:
new_path = await get_new_path(self.db_path)
self.log.warning("YStore version mismatch, moving %s to %s", self.db_path, new_path)
await anyio.Path(self.db_path).rename(new_path)
if create_db:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
await db.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await db.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await db.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
db = await connect(self.db_path)
cursor = await db.cursor()
await cursor.execute(
"CREATE TABLE yupdates (path TEXT NOT NULL, yupdate BLOB, "
"metadata BLOB, timestamp REAL NOT NULL)"
)
await cursor.execute(
"CREATE INDEX idx_yupdates_path_timestamp ON yupdates (path, timestamp)"
)
await cursor.execute(f"PRAGMA user_version = {self.version}")
await db.commit()
await db.close()
self._db = await connect(self.db_path)
self.db_initialized.set()

async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
Expand All @@ -388,17 +401,17 @@ async def read(self) -> AsyncIterator[tuple[bytes, bytes, float]]:
await self.db_initialized.wait()
try:
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
async with db.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
) as cursor:
found = False
async for update, metadata, timestamp in cursor:
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
cursor = await self._db.cursor()
await cursor.execute(
"SELECT yupdate, metadata, timestamp FROM yupdates WHERE path = ?",
(self.path,),
)
found = False
for update, metadata, timestamp in await cursor.fetchall():
found = True
yield update, metadata, timestamp
if not found:
raise YDocNotFound
except Exception:
raise YDocNotFound

Expand All @@ -410,38 +423,35 @@ async def write(self, data: bytes) -> None:
"""
await self.db_initialized.wait()
async with self.lock:
async with aiosqlite.connect(self.db_path) as db:
# first, determine time elapsed since last update
cursor = await db.execute(
"SELECT timestamp FROM yupdates WHERE path = ? "
"ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
async with db.execute(
"SELECT yupdate FROM yupdates WHERE path = ?", (self.path,)
) as cursor:
async for (update,) in cursor:
ydoc.apply_update(update)
# delete history
await db.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, squashed_update, metadata, time.time()),
)

# finally, write this update to the DB
# first, determine time elapsed since last update
cursor = await self._db.cursor()
await cursor.execute(
"SELECT timestamp FROM yupdates WHERE path = ? ORDER BY timestamp DESC LIMIT 1",
(self.path,),
)
row = await cursor.fetchone()
diff = (time.time() - row[0]) if row else 0

if self.document_ttl is not None and diff > self.document_ttl:
# squash updates
ydoc = Doc()
await cursor.execute("SELECT yupdate FROM yupdates WHERE path = ?", (self.path,))
for (update,) in await cursor.fetchall():
ydoc.apply_update(update)
# delete history
await cursor.execute("DELETE FROM yupdates WHERE path = ?", (self.path,))
# insert squashed updates
squashed_update = ydoc.get_update()
metadata = await self.get_metadata()
await db.execute(
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
(self.path, squashed_update, metadata, time.time()),
)
await db.commit()

# finally, write this update to the DB
metadata = await self.get_metadata()
await cursor.execute(
"INSERT INTO yupdates VALUES (?, ?, ?, ?)",
(self.path, data, metadata, time.time()),
)
await self._db.commit()
5 changes: 2 additions & 3 deletions pycrdt_websocket/yutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,9 @@ async def get_new_path(path: str) -> str:
ext = p.suffix
p_noext = p.with_suffix("")
i = 1
dir_list = [p async for p in anyio.Path().iterdir()]
while True:
new_path = f"{p_noext}({i}){ext}"
if new_path not in dir_list:
if not await anyio.Path(new_path).exists():
break
i += 1
return str(new_path)
return new_path
11 changes: 6 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ classifiers = [
]
dependencies = [
"anyio >=3.6.2,<5",
"aiosqlite >=0.18.0,<1",
"sqlite-anyio >=0.2.0,<0.3.0",
"pycrdt >=0.8.7,<0.9.0",
]

Expand All @@ -38,9 +38,10 @@ test = [
"mypy",
"pre-commit",
"pytest",
"pytest-asyncio",
"websockets >=10.0",
"uvicorn",
"httpx-ws >=0.5.2",
"hypercorn >=0.16.0",
"trio >=0.25.0",
"sniffio",
]
docs = [
"mkdocs",
Expand Down Expand Up @@ -68,7 +69,7 @@ include = [

[tool.ruff]
line-length = 99
select = [
lint.select = [
"ASYNC", # flake8-async
"E", "F", "W", # default Flake8
"G", # flake8-logging-format
Expand Down
45 changes: 28 additions & 17 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import subprocess
from functools import partial
from socket import socket

import pytest
from anyio import Event, create_task_group
from hypercorn import Config
from pycrdt import Array, Doc
from websockets import serve
from sniffio import current_async_library
from utils import ensure_server_running

from pycrdt_websocket import WebsocketServer
from pycrdt_websocket import ASGIServer, WebsocketServer


class TestYDoc:
Expand All @@ -23,32 +27,39 @@ def update(self):


@pytest.fixture
async def yws_server(request):
async def yws_server(request, unused_tcp_port):
try:
kwargs = request.param
except Exception:
except AttributeError:
kwargs = {}
websocket_server = WebsocketServer(**kwargs)
app = ASGIServer(websocket_server)
config = Config()
config.bind = [f"localhost:{unused_tcp_port}"]
shutdown_event = Event()
if current_async_library() == "trio":
from hypercorn.trio import serve
else:
from hypercorn.asyncio import serve
try:
async with websocket_server, serve(websocket_server.serve, "127.0.0.1", 1234):
yield websocket_server
async with create_task_group() as tg, websocket_server:
tg.start_soon(
partial(serve, app, config, shutdown_trigger=shutdown_event.wait, mode="asgi")
)
await ensure_server_running("localhost", unused_tcp_port)
yield unused_tcp_port
shutdown_event.set()
except Exception:
pass


@pytest.fixture
def yjs_client(request):
client_id = request.param
p = subprocess.Popen(["node", f"tests/yjs_client_{client_id}.js"])
yield p
p.kill()


@pytest.fixture
def test_ydoc():
return TestYDoc()


@pytest.fixture
def anyio_backend():
return "asyncio"
def unused_tcp_port() -> int:
with socket() as sock:
sock.bind(("localhost", 0))
return sock.getsockname()[1]
Loading
Loading