Skip to content

Commit ef7ff8a

Browse files
committed
Implement schema version tracking and improve database initialization handling
1 parent c0e38b1 commit ef7ff8a

1 file changed

Lines changed: 56 additions & 14 deletions

File tree

src/db/database.py

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,44 @@
55
import aiosqlite
66
import asyncio
77
import logging
8+
import threading
89
from pathlib import Path
910
from typing import AsyncIterator
1011
from contextlib import asynccontextmanager
12+
from datetime import datetime, timezone, timedelta
1113

1214
from src.config import DB_PATH
1315

1416
logger = logging.getLogger(__name__)
1517

1618
# Module-level connection pool (single shared connection with WAL mode)
1719
_db: aiosqlite.Connection | None = None
18-
_lock = asyncio.Lock()
20+
_initializing = False
1921

2022
# Schema version for consistency tracking
2123
SCHEMA_VERSION = 1
2224

2325

26+
def _now() -> str:
27+
return datetime.now(timezone.utc).isoformat()
28+
29+
2430
async def get_db() -> aiosqlite.Connection:
2531
"""Return the shared async database connection, initializing it if needed."""
26-
global _db
27-
if _db is None:
28-
async with _lock:
29-
if _db is None:
30-
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
31-
_db = await aiosqlite.connect(DB_PATH)
32-
_db.row_factory = aiosqlite.Row
33-
# WAL mode: allows concurrent reads while writing
34-
await _db.execute("PRAGMA journal_mode=WAL")
35-
await _db.execute("PRAGMA foreign_keys=ON")
36-
await init_schema(_db)
37-
logger.info(f"Database initialized at {DB_PATH}")
32+
global _db, _initializing
33+
if _db is None and not _initializing:
34+
_initializing = True
35+
try:
36+
Path(DB_PATH).parent.mkdir(parents=True, exist_ok=True)
37+
_db = await aiosqlite.connect(DB_PATH)
38+
_db.row_factory = aiosqlite.Row
39+
# WAL mode: allows concurrent reads while writing
40+
await _db.execute("PRAGMA journal_mode=WAL")
41+
await _db.execute("PRAGMA foreign_keys=ON")
42+
await init_schema(_db)
43+
logger.info(f"Database initialized at {DB_PATH}")
44+
finally:
45+
_initializing = False
3846
return _db
3947

4048

@@ -247,4 +255,38 @@ async def init_schema(db: aiosqlite.Connection) -> None:
247255
except Exception:
248256
pass
249257

250-
logger.info("Schema initialized.")
258+
# Record current schema version
259+
await db.execute(
260+
"INSERT OR REPLACE INTO schema_version (version, applied_at) VALUES (?, ?)",
261+
(SCHEMA_VERSION, _now())
262+
)
263+
await db.commit()
264+
265+
logger.info(f"Schema initialized (version {SCHEMA_VERSION}).")
266+
267+
268+
async def get_schema_version(db: aiosqlite.Connection) -> int | None:
269+
"""Get the current schema version from the database."""
270+
try:
271+
async with db.execute("SELECT version FROM schema_version ORDER BY version DESC LIMIT 1") as cur:
272+
row = await cur.fetchone()
273+
return row["version"] if row else None
274+
except Exception:
275+
return None
276+
277+
278+
async def verify_schema_consistency(db: aiosqlite.Connection) -> tuple[bool, str]:
279+
"""Verify that the database schema matches the expected version.
280+
281+
Returns:
282+
(is_consistent, message)
283+
"""
284+
try:
285+
current_version = await get_schema_version(db)
286+
if current_version is None:
287+
return False, "Schema version table not found"
288+
if current_version != SCHEMA_VERSION:
289+
return False, f"Schema version mismatch: expected {SCHEMA_VERSION}, got {current_version}"
290+
return True, f"Schema version {SCHEMA_VERSION} is consistent"
291+
except Exception as e:
292+
return False, f"Error checking schema: {e}"

0 commit comments

Comments
 (0)