Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 80 additions & 15 deletions tests/test_module_parsing_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
import os

from vibeprolog import PrologInterpreter
from vibeprolog import interpreter as interpreter_module


def test_reuses_parsed_modules_across_imports(tmp_path):
"""Ensure a dependency is parsed only once even when imported multiple times."""

def _interpreter_with_parse_counter():
prolog = PrologInterpreter()
parse_calls = 0

Expand All @@ -16,6 +15,13 @@ def hook():
parse_calls += 1

prolog._parser_invocation_hook = hook
return prolog, lambda: parse_calls


def test_reuses_parsed_modules_across_imports(tmp_path):
"""Ensure a dependency is parsed only once even when imported multiple times."""

prolog, parse_calls = _interpreter_with_parse_counter()

dep_path = tmp_path / "dep.pl"
dep_path.write_text(":- module(dep, [dep/1]).\n:- dynamic dep/1.\ndep(x).\n")
Expand All @@ -42,35 +48,94 @@ def hook():
)

prolog.consult(main_path)
first_parse_calls = parse_calls
first_parse_calls = parse_calls()

prolog.consult(main_path)

assert parse_calls == first_parse_calls
assert parse_calls() == first_parse_calls


def test_parsed_module_cache_invalidated_on_mtime_change(tmp_path):
"""The parsed-module cache should refresh when the source file changes."""

prolog = PrologInterpreter()
parse_calls = 0

def hook():
nonlocal parse_calls
parse_calls += 1

prolog._parser_invocation_hook = hook
prolog, parse_calls = _interpreter_with_parse_counter()

module_path = tmp_path / "cached.pl"
module_path.write_text(":- module(cached, [p/0]).\n:- dynamic p/0.\np.\n")

prolog.consult(module_path)
first_parse_calls = parse_calls
first_parse_calls = parse_calls()

stat_before = module_path.stat()
module_path.write_text(":- module(cached, [p/0]).\n:- dynamic p/0.\np.\np.\n")
os.utime(module_path, (stat_before.st_atime, stat_before.st_mtime + 1))

prolog.consult(module_path)

assert parse_calls > first_parse_calls
assert parse_calls() > first_parse_calls


def test_serialized_library_cache_hit_and_parity(monkeypatch, tmp_path):
"""Disk-backed cache is reused for shipped libraries without reparsing."""

library_root = tmp_path / "library"
library_root.mkdir()

module_path = library_root / "cached.pl"
module_path.write_text(
":- module(cached, [p/1]).\n"
"p(value).\n"
)

monkeypatch.setattr(interpreter_module, "LIBRARY_SEARCH_PATHS", [library_root])

warm_prolog, warm_calls = _interpreter_with_parse_counter()
warm_prolog.consult(module_path)

assert warm_calls() > 0
assert warm_prolog.has_solution("cached:p(value)")

cached_prolog, cached_calls = _interpreter_with_parse_counter()
cached_prolog.consult(module_path)

assert cached_calls() == 0
assert cached_prolog.has_solution("cached:p(value)")


def test_serialized_library_cache_invalidated_on_edit(monkeypatch, tmp_path):
"""Editing a shipped library file regenerates its serialized cache."""

library_root = tmp_path / "library"
library_root.mkdir()

module_path = library_root / "changing.pl"
module_path.write_text(
":- module(changing, [p/1]).\n"
"p(old).\n"
)

monkeypatch.setattr(interpreter_module, "LIBRARY_SEARCH_PATHS", [library_root])

warm_prolog, warm_calls = _interpreter_with_parse_counter()
warm_prolog.consult(module_path)

assert warm_calls() > 0

cached_prolog, cached_calls = _interpreter_with_parse_counter()
cached_prolog.consult(module_path)

assert cached_calls() == 0
assert cached_prolog.has_solution("changing:p(old)")

stat_before = module_path.stat()
module_path.write_text(
":- module(changing, [p/1]).\n"
"p(new).\n"
)
os.utime(module_path, (stat_before.st_atime, stat_before.st_mtime + 1))

refreshed_prolog, refreshed_calls = _interpreter_with_parse_counter()
refreshed_prolog.consult(module_path)

assert refreshed_calls() > 0
assert refreshed_prolog.has_solution("changing:p(new)")
112 changes: 112 additions & 0 deletions vibeprolog/interpreter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Main Prolog interpreter interface."""

import copy
import pickle
import io
import re
import sys
Expand Down Expand Up @@ -64,6 +65,16 @@ class ParsedModuleCacheEntry(TypedDict):
file_mtime: float | None


class SerializedParsedModule(TypedDict):
version: int
parser_signature: tuple[int, tuple[tuple[str, str], ...], tuple[tuple[bool, bool, bool], ...]]
file_mtime: float | None
items: list[Clause | Directive]


SERIALIZED_PARSED_CACHE_VERSION = 1


class Module:
def __init__(self, name: str, exports: set[tuple[str, int]] | None):
self.name = name
Expand Down Expand Up @@ -913,6 +924,91 @@ def _safe_mtime(self, path: Path) -> float | None:
except OSError:
return None

def _library_root_for_path(self, path: Path) -> Path | None:
"""Return the library search root containing ``path`` if any."""

try:
resolved_path = path.resolve()
except OSError:
return None

for root in LIBRARY_SEARCH_PATHS:
try:
resolved_root = root.resolve()
except OSError:
continue
try:
resolved_path.relative_to(resolved_root)
except ValueError:
continue
return resolved_root
return None

def _library_cache_path(self, path: Path) -> Path | None:
"""Return the on-disk cache location for a shipped library file."""

library_root = self._library_root_for_path(path)
if library_root is None:
return None

try:
relative_path = path.resolve().relative_to(library_root)
except (OSError, ValueError):
return None

cache_dir = library_root / ".vibe_parsed_cache"
safe_name = "__".join(relative_path.parts) + ".pickle"
return cache_dir / safe_name

def _load_serialized_parsed_items(
self, cache_file: Path, file_mtime: float | None
) -> list[Clause | Directive] | None:
"""Load cached parsed items from disk if they are still valid."""

try:
with cache_file.open("rb") as handle:
payload: SerializedParsedModule = pickle.load(handle)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

security-high high

Loading data with pickle.load from a file can lead to arbitrary code execution if the file's content can be controlled by an attacker. The cache files are stored in .vibe_parsed_cache directories within the library search paths. If any of these library paths are in a world-writable location (e.g., /tmp), it could expose a security vulnerability. Please ensure library paths are in trusted, permission-controlled locations. For greater security, consider using a safer serialization format like JSON, although this may require more work to serialize/deserialize the complex objects.

except (OSError, pickle.PickleError):
return None
except Exception as e:
warnings.warn(f"Failed to load parsed module cache from {cache_file}: {e}", RuntimeWarning)
return None

signature = self._parser_config_signature()
if (
payload.get("version") != SERIALIZED_PARSED_CACHE_VERSION
or payload.get("file_mtime") != file_mtime
or tuple(payload.get("parser_signature", ())) != signature
):
return None

items = payload.get("items")
if not isinstance(items, list):
return None

return copy.deepcopy(items)

def _store_serialized_parsed_items(
self, cache_file: Path, file_mtime: float | None, items: list[Clause | Directive]
) -> None:
"""Persist parsed items to disk for shipped libraries."""

payload: SerializedParsedModule = {
"version": SERIALIZED_PARSED_CACHE_VERSION,
"parser_signature": self._parser_config_signature(),
"file_mtime": file_mtime,
"items": items,
}

try:
cache_file.parent.mkdir(parents=True, exist_ok=True)
with cache_file.open("wb") as handle:
pickle.dump(payload, handle, protocol=pickle.HIGHEST_PROTOCOL)
except OSError:
warnings.warn(
f"Failed to write parsed module cache to {cache_file}", RuntimeWarning
)

def _parser_config_signature(self) -> tuple[
int, tuple[tuple[str, str], ...], tuple[tuple[bool, bool, bool], ...]
]:
Expand Down Expand Up @@ -1892,6 +1988,18 @@ def _consult_code(
error_term = PrologError.syntax_error(str(exc), "consult/1")
raise PrologThrow(error_term)

disk_cache_path = None
if cache_path is not None:
disk_cache_path = self._library_cache_path(cache_path)
if cached_items is None and disk_cache_path is not None:
cached_items = self._load_serialized_parsed_items(
disk_cache_path, file_mtime
)
if cached_items is not None:
self._store_parsed_module_cache(
cache_path, file_mtime, cached_items
)

def _process_parsed_items(
parsed_items: list[Clause | Directive],
last_pred: tuple[str, str, int] | None,
Expand Down Expand Up @@ -1963,6 +2071,10 @@ def _process_parsed_items(

if cache_path is not None and cached_items is None:
self._store_parsed_module_cache(cache_path, file_mtime, parsed_items_for_cache)
if disk_cache_path is not None:
self._store_serialized_parsed_items(
disk_cache_path, file_mtime, parsed_items_for_cache
)

self.engine = PrologEngine(
self.clauses,
Expand Down