Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions src/crawlee/_utils/recoverable_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
persist_state_kvs_name: str | None = None,
persist_state_kvs_id: str | None = None,
logger: logging.Logger,
key_value_store: None | KeyValueStore = None,
) -> None:
"""Initialize a new recoverable state object.

Expand All @@ -52,16 +53,23 @@ def __init__(
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
logger: A logger instance for logging operations related to state persistence
key_value_store: KeyValueStore to use for persistence. If not provided, a system-wide KeyValueStore will be
used, based on service locator configuration.
"""
if key_value_store and (persist_state_kvs_name or persist_state_kvs_id):
raise ValueError(
'Cannot provide explicit key_value_store and persist_state_kvs_name or persist_state_kvs_id.'
)

self._default_state = default_state
self._state_type: type[TStateModel] = self._default_state.__class__
self._state: TStateModel | None = None
self._persistence_enabled = persistence_enabled
self._persist_state_key = persist_state_key
self._persist_state_kvs_name = persist_state_kvs_name
self._persist_state_kvs_id = persist_state_kvs_id
self._key_value_store: 'KeyValueStore | None' = None # noqa: UP037
self._log = logger
self._key_value_store = key_value_store

async def initialize(self) -> TStateModel:
"""Initialize the recoverable state.
Expand All @@ -79,9 +87,10 @@ async def initialize(self) -> TStateModel:
# Import here to avoid circular imports.
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415

self._key_value_store = await KeyValueStore.open(
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
)
if not self._key_value_store:
self._key_value_store = await KeyValueStore.open(
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
)

await self._load_saved_state()

Expand Down
69 changes: 39 additions & 30 deletions src/crawlee/storage_clients/_file_system/_request_queue_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def __init__(
metadata: RequestQueueMetadata,
storage_dir: Path,
lock: asyncio.Lock,
recoverable_state: RecoverableState[RequestQueueState],
) -> None:
"""Initialize a new instance.

Expand All @@ -112,13 +113,7 @@ def __init__(
self._is_empty_cache: bool | None = None
"""Cache for is_empty result: None means unknown, True/False is cached state."""

self._state = RecoverableState[RequestQueueState](
default_state=RequestQueueState(),
persist_state_key='request_queue_state',
persistence_enabled=True,
persist_state_kvs_name=f'__RQ_STATE_{self._metadata.id}',
logger=logger,
)
self._state = recoverable_state
"""Recoverable state to maintain request ordering, in-progress status, and handled status."""

@override
Expand Down Expand Up @@ -187,14 +182,9 @@ async def open(
metadata = RequestQueueMetadata(**file_content)

if metadata.id == id:
client = cls(
metadata=metadata,
storage_dir=storage_dir,
lock=asyncio.Lock(),
client = await cls._create_client(
metadata=metadata, storage_dir=storage_dir, update_accessed_at=True
)
await client._state.initialize()
await client._discover_existing_requests()
await client._update_metadata(update_accessed_at=True)
found = True
break
finally:
Expand Down Expand Up @@ -224,15 +214,7 @@ async def open(

metadata.name = name

client = cls(
metadata=metadata,
storage_dir=storage_dir,
lock=asyncio.Lock(),
)

await client._state.initialize()
await client._discover_existing_requests()
await client._update_metadata(update_accessed_at=True)
client = await cls._create_client(metadata=metadata, storage_dir=storage_dir, update_accessed_at=True)

# Otherwise, create a new dataset client.
else:
Expand All @@ -248,13 +230,40 @@ async def open(
pending_request_count=0,
total_request_count=0,
)
client = cls(
metadata=metadata,
storage_dir=storage_dir,
lock=asyncio.Lock(),
)
await client._state.initialize()
await client._update_metadata()
client = await cls._create_client(metadata=metadata, storage_dir=storage_dir)

return client

@classmethod
async def _create_client(
cls, metadata: RequestQueueMetadata, storage_dir: Path, *, update_accessed_at: bool = False
) -> FileSystemRequestQueueClient:
"""Create client from metadata and storage directory."""
from crawlee.storage_clients import FileSystemStorageClient # noqa: PLC0415 avoid circular imports
from crawlee.storages._key_value_store import KeyValueStore # noqa: PLC0415 avoid circular imports

# Prepare kvs for recoverable state
kvs_client = await FileSystemStorageClient().create_kvs_client(name=f'__RQ_STATE_{metadata.id}')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait a minute. This is a subtle bug waiting to happen - the FileSystemStorageClient that creates this RQ client might have a different configuration than the one you create here.

Can you pass the actual storage client down to this method instead?

kvs_client_metadata = await kvs_client.get_metadata()
kvs = KeyValueStore(client=kvs_client, id=kvs_client_metadata.id, name=kvs_client_metadata.name)

# Create state
recoverable_state = RecoverableState[RequestQueueState](
default_state=RequestQueueState(),
persist_state_key='request_queue_state',
persistence_enabled=True,
logger=logger,
key_value_store=kvs,
)

# Create client
client = cls(
metadata=metadata, storage_dir=storage_dir, lock=asyncio.Lock(), recoverable_state=recoverable_state
)

await client._state.initialize()
await client._discover_existing_requests()
await client._update_metadata(update_accessed_at=update_accessed_at)

return client

Expand Down
6 changes: 2 additions & 4 deletions src/crawlee/storages/_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from crawlee._types import JsonSerializable # noqa: TC001
from crawlee._utils.docs import docs_group
from crawlee._utils.recoverable_state import RecoverableState
from crawlee.storage_clients.models import KeyValueStoreMetadata

from ._base import Storage

Expand All @@ -23,8 +22,7 @@
from crawlee.storage_clients import StorageClient
from crawlee.storage_clients._base import KeyValueStoreClient
from crawlee.storage_clients.models import KeyValueStoreMetadata, KeyValueStoreRecordMetadata
else:
from crawlee._utils.recoverable_state import RecoverableState


T = TypeVar('T')

Expand Down Expand Up @@ -274,9 +272,9 @@ async def get_auto_saved_value(
cache[key] = recoverable_state = RecoverableState(
default_state=AutosavedValue(default_value),
persistence_enabled=True,
persist_state_kvs_id=self.id,
persist_state_key=key,
logger=logger,
key_value_store=self, # Use self for RecoverableState.
)

await recoverable_state.initialize()
Expand Down
81 changes: 47 additions & 34 deletions src/crawlee/storages/_storage_instance_manager.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Awaitable, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, TypeVar, cast

from crawlee.storage_clients._base import DatasetClient, KeyValueStoreClient, RequestQueueClient

from ._base import Storage

if TYPE_CHECKING:
from crawlee.configuration import Configuration

from ._base import Storage

T = TypeVar('T', bound='Storage')

StorageClientType = DatasetClient | KeyValueStoreClient | RequestQueueClient
Expand All @@ -19,6 +21,22 @@
"""Type alias for the client opener function."""


@dataclass
class _StorageClientCache:
"""Cache for specific storage client."""

by_id: defaultdict[type[Storage], defaultdict[str, Storage]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict())
)
"""Cache for storage instances by ID, separated by storage type."""
by_name: defaultdict[type[Storage], defaultdict[str, Storage]] = field(
default_factory=lambda: defaultdict(lambda: defaultdict())
)
"""Cache for storage instances by name, separated by storage type."""
default_instances: defaultdict[type[Storage], Storage] = field(default_factory=lambda: defaultdict())
"""Cache for default instances of each storage type."""


class StorageInstanceManager:
"""Manager for caching and managing storage instances.

Expand All @@ -27,14 +45,7 @@ class StorageInstanceManager:
"""

def __init__(self) -> None:
self._cache_by_id = dict[type[Storage], dict[str, Storage]]()
"""Cache for storage instances by ID, separated by storage type."""

self._cache_by_name = dict[type[Storage], dict[str, Storage]]()
"""Cache for storage instances by name, separated by storage type."""

self._default_instances = dict[type[Storage], Storage]()
"""Cache for default instances of each storage type."""
self._cache_by_storage_client: dict[str, _StorageClientCache] = defaultdict(_StorageClientCache)

async def open_storage_instance(
self,
Expand Down Expand Up @@ -64,19 +75,23 @@ async def open_storage_instance(
raise ValueError('Only one of "id" or "name" can be specified, not both.')

# Check for default instance
if id is None and name is None and cls in self._default_instances:
return cast('T', self._default_instances[cls])
if (
id is None
and name is None
and cls in self._cache_by_storage_client[client_opener.__qualname__].default_instances
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no guarantee that client_opener will be stable enough to use as a cache key. Something similar to id(storage_client) would be more appropriate IMO, but that would require a more sizable change to the code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, with the current state, we do not have access to the storage_client before caching, so I was left with client_opener.

I was thinking about accessing storage_client through client_opener.__self__, but that would work only for bound methods of the storage clients, and client_opener can be any callable. That is why I opted for client_opener.__qualname__, which is generic enough and I think should be stable enough for the cache, as the most common case will be just one of our prepared factory methods, like MemoryStorageClient.create_kvs_client

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that that's another consideration related to the instance cache mechanism. We should resolve that before v1.

):
return cast('T', self._cache_by_storage_client[client_opener.__qualname__].default_instances[cls])

# Check cache
if id is not None:
type_cache_by_id = self._cache_by_id.get(cls, {})
type_cache_by_id = self._cache_by_storage_client[client_opener.__qualname__].by_id[cls]
if id in type_cache_by_id:
cached_instance = type_cache_by_id[id]
if isinstance(cached_instance, cls):
return cached_instance

if name is not None:
type_cache_by_name = self._cache_by_name.get(cls, {})
type_cache_by_name = self._cache_by_storage_client[client_opener.__qualname__].by_name[cls]
if name in type_cache_by_name:
cached_instance = type_cache_by_name[name]
if isinstance(cached_instance, cls):
Expand All @@ -90,16 +105,13 @@ async def open_storage_instance(
instance_name = getattr(instance, 'name', None)

# Cache the instance
type_cache_by_id = self._cache_by_id.setdefault(cls, {})
type_cache_by_name = self._cache_by_name.setdefault(cls, {})

type_cache_by_id[instance.id] = instance
self._cache_by_storage_client[client_opener.__qualname__].by_id[cls][instance.id] = instance
if instance_name is not None:
type_cache_by_name[instance_name] = instance
self._cache_by_storage_client[client_opener.__qualname__].by_name[cls][instance_name] = instance

# Set as default if no id/name specified
if id is None and name is None:
self._default_instances[cls] = instance
self._cache_by_storage_client[client_opener.__qualname__].default_instances[cls] = instance

return instance

Expand All @@ -112,22 +124,23 @@ def remove_from_cache(self, storage_instance: Storage) -> None:
storage_type = type(storage_instance)

# Remove from ID cache
type_cache_by_id = self._cache_by_id.get(storage_type, {})
if storage_instance.id in type_cache_by_id:
del type_cache_by_id[storage_instance.id]

# Remove from name cache
if storage_instance.name is not None:
type_cache_by_name = self._cache_by_name.get(storage_type, {})
if storage_instance.name in type_cache_by_name:
for client_cache in self._cache_by_storage_client.values():
type_cache_by_id = client_cache.by_id[storage_type]
if storage_instance.id in type_cache_by_id:
del type_cache_by_id[storage_instance.id]

# Remove from name cache
type_cache_by_name = client_cache.by_name[storage_type]
if storage_instance.name in type_cache_by_name and storage_instance.name:
del type_cache_by_name[storage_instance.name]

# Remove from default instances
if storage_type in self._default_instances and self._default_instances[storage_type] is storage_instance:
del self._default_instances[storage_type]
# Remove from default instances
if (
storage_type in client_cache.default_instances
and client_cache.default_instances[storage_type] is storage_instance
):
del client_cache.default_instances[storage_type]

def clear_cache(self) -> None:
"""Clear all cached storage instances."""
self._cache_by_id.clear()
self._cache_by_name.clear()
self._default_instances.clear()
self._cache_by_storage_client = defaultdict(_StorageClientCache)
11 changes: 9 additions & 2 deletions tests/unit/storage_clients/_file_system/test_fs_rq_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import pytest

from crawlee import Request
from crawlee._service_locator import service_locator
from crawlee.configuration import Configuration
from crawlee.storage_clients import FileSystemStorageClient
from crawlee.storage_clients import FileSystemStorageClient, MemoryStorageClient

if TYPE_CHECKING:
from collections.abc import AsyncGenerator
Expand Down Expand Up @@ -55,14 +56,20 @@ async def test_file_and_directory_creation(configuration: Configuration) -> None
await client.drop()


async def test_request_file_persistence(rq_client: FileSystemRequestQueueClient) -> None:
@pytest.mark.parametrize('set_different_storage_client_in_service_locator', [True, False])
async def test_request_file_persistence(
rq_client: FileSystemRequestQueueClient, *, set_different_storage_client_in_service_locator: bool
) -> None:
"""Test that requests are properly persisted to files."""
requests = [
Request.from_url('https://example.com/1'),
Request.from_url('https://example.com/2'),
Request.from_url('https://example.com/3'),
]

if set_different_storage_client_in_service_locator:
service_locator.set_storage_client(MemoryStorageClient())

await rq_client.add_batch_of_requests(requests)

# Verify request files are created
Expand Down
27 changes: 27 additions & 0 deletions tests/unit/storages/test_key_value_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,3 +598,30 @@ async def test_record_exists_after_purge(kvs: KeyValueStore) -> None:
# Should no longer exist
assert await kvs.record_exists('key1') is False
assert await kvs.record_exists('key2') is False


async def test_get_auto_saved_value_with_multiple_storage_clients(tmp_path: Path) -> None:
"""Test that setting storage client through service locator does not break autosaved values in other clients."""
config = Configuration(
crawlee_storage_dir=str(tmp_path), # type: ignore[call-arg]
purge_on_start=True,
)

kvs1 = await KeyValueStore.open(storage_client=MemoryStorageClient(), configuration=config)

kvs2 = await KeyValueStore.open(
storage_client=FileSystemStorageClient(),
configuration=config,
)
assert kvs1 is not kvs2

expected_values = {'key': 'value'}
test_key = 'test_key'

autosaved_value = await kvs2.get_auto_saved_value(test_key)
assert autosaved_value == {}
autosaved_value.update(expected_values)

await kvs2.persist_autosaved_values()

assert await kvs2.get_value(test_key) == expected_values
Loading