Skip to content

refactor: RecoverableState abstraction #1172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Apr 30, 2025
Merged
153 changes: 153 additions & 0 deletions src/crawlee/_utils/recoverable_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Generic, TypeVar

from pydantic import BaseModel

from crawlee import service_locator
from crawlee.events._types import Event, EventPersistStateData
from crawlee.storages._key_value_store import KeyValueStore

if TYPE_CHECKING:
import logging

TStateModel = TypeVar('TStateModel', bound=BaseModel)


class RecoverableState(Generic[TStateModel]):
"""A class for managing persistent recoverable state using a Pydantic model.

This class facilitates state persistence to a `KeyValueStore`, allowing data to be saved and retrieved
across migrations or restarts. It manages the loading, saving, and resetting of state data,
with optional persistence capabilities.

The state is represented by a Pydantic model that can be serialized to and deserialized from JSON.
The class automatically hooks into the event system to persist state when needed.

Type Parameters:
TStateModel: A Pydantic BaseModel type that defines the structure of the state data.
Typically, it should be inferred from the `default_state` constructor parameter.
"""

def __init__(
self,
*,
default_state: TStateModel,
persist_state_key: str,
persistence_enabled: bool = False,
persist_state_kvs_name: str | None = None,
persist_state_kvs_id: str | None = None,
logger: logging.Logger,
) -> None:
"""Initialize a new recoverable state object.

Args:
default_state: The default state model instance to use when no persisted state is found.
A deep copy is made each time the state is used.
persist_state_key: The key under which the state is stored in the KeyValueStore
persistence_enabled: Flag to enable or disable state persistence
persist_state_kvs_name: The name of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
persist_state_kvs_id: The identifier of the KeyValueStore to use for persistence.
If neither a name nor and id are supplied, the default store will be used.
logger: A logger instance for logging operations related to state persistence
"""
self._default_state = default_state
self._state_type: type[TStateModel] = self._default_state.__class__
self._state: TStateModel | None = None
self._persistence_enabled = persistence_enabled
self._persist_state_key = persist_state_key
self._persist_state_kvs_name = persist_state_kvs_name
self._persist_state_kvs_id = persist_state_kvs_id
self._key_value_store: KeyValueStore | None = None
self._log = logger

async def initialize(self) -> TStateModel:
"""Initialize the recoverable state.

This method must be called before using the recoverable state. It loads the saved state
if persistence is enabled and registers the object to listen for PERSIST_STATE events.

Returns:
The loaded state model
"""
if not self._persistence_enabled:
self._state = self._default_state.model_copy(deep=True)
return self.current_value

self._key_value_store = await KeyValueStore.open(
name=self._persist_state_kvs_name, id=self._persist_state_kvs_id
)

await self._load_saved_state()

event_manager = service_locator.get_event_manager()
event_manager.on(event=Event.PERSIST_STATE, listener=self.persist_state)

return self.current_value

async def teardown(self) -> None:
"""Clean up resources used by the recoverable state.

If persistence is enabled, this method deregisters the object from PERSIST_STATE events
and persists the current state one last time.
"""
if not self._persistence_enabled:
return

event_manager = service_locator.get_event_manager()
event_manager.off(event=Event.PERSIST_STATE, listener=self.persist_state)
await self.persist_state()

@property
def current_value(self) -> TStateModel:
"""Get the current state."""
if self._state is None:
raise RuntimeError('Recoverable state has not yet been loaded')

return self._state

async def reset(self) -> None:
"""Reset the state to the default values and clear any persisted state.

Resets the current state to the default state and, if persistence is enabled,
clears the persisted state from the KeyValueStore.
"""
self._state = self._default_state.model_copy(deep=True)

if self._persistence_enabled:
if self._key_value_store is None:
raise RuntimeError('Recoverable state has not yet been initialized')

await self._key_value_store.set_value(self._persist_state_key, None)

async def persist_state(self, event_data: EventPersistStateData | None = None) -> None:
"""Persist the current state to the KeyValueStore.

This method is typically called in response to a PERSIST_STATE event, but can also be called
directly when needed.

Args:
event_data: Optional data associated with a PERSIST_STATE event
"""
self._log.debug(f'Persisting state of the Statistics (event_data={event_data}).')

if self._key_value_store is None or self._state is None:
raise RuntimeError('Recoverable state has not yet been initialized')

if self._persistence_enabled:
await self._key_value_store.set_value(
self._persist_state_key,
self._state.model_dump(mode='json', by_alias=True),
'application/json',
)

async def _load_saved_state(self) -> None:
if self._key_value_store is None:
raise RuntimeError('Recoverable state has not yet been initialized')

stored_state = await self._key_value_store.get_value(self._persist_state_key)
if stored_state is None:
self._state = self._default_state.model_copy(deep=True)
else:
self._state = self._state_type.model_validate(stored_state)
Original file line number Diff line number Diff line change
@@ -67,10 +67,11 @@ class _NonPersistentStatistics(Statistics):

def __init__(self) -> None:
super().__init__(state_model=StatisticsState)
self._active = True

async def __aenter__(self) -> Self:
self._active = True
await self._state.initialize()
self._after_initialize()
return self

async def __aexit__(
@@ -201,7 +202,12 @@ async def adaptive_pre_navigation_hook_pw(context: PlaywrightPreNavCrawlingConte
static_crawler.pre_navigation_hook(adaptive_pre_navigation_hook_static)
playwright_crawler.pre_navigation_hook(adaptive_pre_navigation_hook_pw)

self._additional_context_managers = [*self._additional_context_managers, playwright_crawler._browser_pool] # noqa: SLF001 # Intentional access to private member.
self._additional_context_managers = [
*self._additional_context_managers,
static_crawler.statistics,
playwright_crawler.statistics,
playwright_crawler._browser_pool, # noqa: SLF001 # Intentional access to private member.
]

# Sub crawler pipeline related
self._pw_context_pipeline = playwright_crawler._context_pipeline # noqa:SLF001 # Intentional access to private member.
@@ -376,6 +382,8 @@ async def _run_request_handler(self, context: BasicCrawlingContext) -> None:

context.log.debug(f'Running browser request handler for {context.request.url}')

old_state_copy = None

if should_detect_rendering_type:
# Save copy of global state from `use_state` before it can be mutated by browser crawl.
# This copy will be used in the static crawl to make sure they both run with same conditions and to
14 changes: 5 additions & 9 deletions src/crawlee/fingerprint_suite/_types.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

from typing import Annotated, Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from crawlee._utils.docs import docs_group

@@ -14,6 +14,8 @@

@docs_group('Data structures')
class ScreenOptions(BaseModel):
model_config = ConfigDict(extra='forbid', populate_by_name=True)

"""Defines the screen constrains for the fingerprint generator."""

min_width: Annotated[float | None, Field(alias='minWidth')] = None
@@ -28,15 +30,13 @@ class ScreenOptions(BaseModel):
max_height: Annotated[float | None, Field(alias='maxHeight')] = None
"""Maximal screen height constraint for the fingerprint generator."""

class Config:
extra = 'forbid'
populate_by_name = True


@docs_group('Data structures')
class HeaderGeneratorOptions(BaseModel):
"""Collection of header related attributes that can be used by the fingerprint generator."""

model_config = ConfigDict(extra='forbid', populate_by_name=True)

browsers: list[SupportedBrowserType] | None = None
"""List of BrowserSpecifications to generate the headers for."""

@@ -56,7 +56,3 @@ class HeaderGeneratorOptions(BaseModel):

strict: bool | None = None
"""If true, the generator will throw an error if it cannot generate headers based on the input."""

class Config:
extra = 'forbid'
populate_by_name = True
59 changes: 50 additions & 9 deletions src/crawlee/sessions/_models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from __future__ import annotations

from datetime import datetime, timedelta
from typing import Annotated
from typing import Annotated, Any

from pydantic import BaseModel, ConfigDict, Field
from pydantic import (
BaseModel,
BeforeValidator,
ConfigDict,
Field,
GetPydanticSchema,
PlainSerializer,
computed_field,
)

from ._cookies import CookieParam
from ._session import Session


class SessionModel(BaseModel):
@@ -31,11 +40,43 @@ class SessionPoolModel(BaseModel):

model_config = ConfigDict(populate_by_name=True)

persistence_enabled: Annotated[bool, Field(alias='persistenceEnabled')]
persist_state_kvs_name: Annotated[str | None, Field(alias='persistStateKvsName')]
persist_state_key: Annotated[str, Field(alias='persistStateKey')]
max_pool_size: Annotated[int, Field(alias='maxPoolSize')]
session_count: Annotated[int, Field(alias='sessionCount')]
usable_session_count: Annotated[int, Field(alias='usableSessionCount')]
retired_session_count: Annotated[int, Field(alias='retiredSessionCount')]
sessions: Annotated[list[SessionModel], Field(alias='sessions')]

sessions: Annotated[
dict[
str,
Annotated[
Session, GetPydanticSchema(lambda _, handler: handler(Any))
], # handler(Any) is fine - we validate manually in the BeforeValidator
],
Field(alias='sessions'),
PlainSerializer(
lambda value: [session.get_state().model_dump(by_alias=True) for session in value.values()],
return_type=list,
),
BeforeValidator(
lambda value: {
session.id: session
for item in value
if (session := Session.from_model(SessionModel.model_validate(item, by_alias=True)))
}
),
]

@computed_field(alias='sessionCount') # type: ignore[prop-decorator]
@property
def session_count(self) -> int:
"""Get the total number of sessions currently maintained in the pool."""
return len(self.sessions)

@computed_field(alias='usableSessionCount') # type: ignore[prop-decorator]
@property
def usable_session_count(self) -> int:
"""Get the number of sessions that are currently usable."""
return len([session for _, session in self.sessions.items() if session.is_usable])

@computed_field(alias='retiredSessionCount') # type: ignore[prop-decorator]
@property
def retired_session_count(self) -> int:
"""Get the number of sessions that are no longer usable."""
return self.session_count - self.usable_session_count
7 changes: 5 additions & 2 deletions src/crawlee/sessions/_session.py
Original file line number Diff line number Diff line change
@@ -9,11 +9,12 @@
from crawlee._utils.crypto import crypto_random_object_id
from crawlee._utils.docs import docs_group
from crawlee.sessions._cookies import CookieParam, SessionCookies
from crawlee.sessions._models import SessionModel

if TYPE_CHECKING:
from http.cookiejar import CookieJar

from crawlee.sessions._models import SessionModel

logger = getLogger(__name__)


@@ -146,6 +147,8 @@ def get_state(self, *, as_dict: Literal[False]) -> SessionModel: ...

def get_state(self, *, as_dict: bool = False) -> SessionModel | dict:
"""Retrieve the current state of the session either as a model or as a dictionary."""
from ._models import SessionModel

model = SessionModel(
id=self._id,
max_age=self._max_age,
@@ -157,7 +160,7 @@ def get_state(self, *, as_dict: bool = False) -> SessionModel | dict:
max_usage_count=self._max_usage_count,
error_score=self._error_score,
cookies=self._cookies.get_cookies_as_dicts(),
blocked_status_codes=self._blocked_status_codes,
blocked_status_codes=list(self._blocked_status_codes),
)
if as_dict:
return model.model_dump()
Loading