Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion application_sdk/activities/metadata_extraction/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,7 +850,7 @@ async def transform_data(
dataframe=dataframe, **workflow_args
)
await transformed_output.write_daft_dataframe(transform_metadata)
return await transformed_output.get_statistics()
return await transformed_output.get_statistics(typename=typename)

@activity.defn
@auto_heartbeater
Expand Down
2 changes: 2 additions & 0 deletions application_sdk/decorators/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@


89 changes: 89 additions & 0 deletions application_sdk/decorators/method_lock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import asyncio
from functools import wraps
from typing import Any, Awaitable, Callable, Optional

from temporalio import activity

from application_sdk.clients.redis import RedisClientAsync
from application_sdk.constants import (
APPLICATION_NAME,
IS_LOCKING_DISABLED,
)
from application_sdk.observability.logger_adaptor import get_logger

logger = get_logger(__name__)


def lock_per_run(
lock_name: Optional[str] = None, ttl_seconds: int = 10
) -> Callable[[Callable[..., Awaitable[Any]]], Callable[..., Awaitable[Any]]]:
"""Serialize an async method within an activity per workflow run.
Uses Redis SET NX EX for acquisition and an owner-verified release.
The lock key is namespaced and scoped to the current workflow run:
``{APPLICATION_NAME}:meth:{method_name}:run:{workflow_run_id}``.
Args:
lock_name: Optional explicit lock name. Defaults to the wrapped method's name.
ttl_seconds: Lock TTL in seconds. Should cover worst-case wait + execution time.
Returns:
A decorator for async callables to guard them with a per-run distributed lock.
"""

def _decorate(
fn: Callable[..., Awaitable[Any]]
) -> Callable[..., Awaitable[Any]]:
@wraps(fn)
async def _wrapped(*args: Any, **kwargs: Any) -> Any:
if IS_LOCKING_DISABLED:
return await fn(*args, **kwargs)

run_id = activity.info().workflow_run_id
name = lock_name or fn.__name__

resource_id = f"{APPLICATION_NAME}:meth:{name}:run:{run_id}"
owner_id = f"{APPLICATION_NAME}:{run_id}"

async with RedisClientAsync() as rc:
# Acquire with retry
retry_count = 0
while True:
logger.debug(f"Attempting to acquire lock: {resource_id}, owner: {owner_id}")
acquired = await rc._acquire_lock(
resource_id, owner_id, ttl_seconds
)
if acquired:
logger.info(f"Lock acquired: {resource_id}, owner: {owner_id}")
break
retry_count += 1
logger.debug(
f"Lock not available, retrying (attempt {retry_count}): {resource_id}"
)
await asyncio.sleep(5)

try:
return await fn(*args, **kwargs)
finally:
# Best-effort release; TTL guarantees cleanup if this fails
try:
logger.debug(f"Releasing lock: {resource_id}, owner: {owner_id}")
released, result = await rc._release_lock(resource_id, owner_id)
if released:
logger.info(f"Lock released successfully: {resource_id}")
else:
logger.warning(
f"Lock release failed (may already be released): {resource_id}, result: {result}"
)
except Exception as e:
logger.warning(
f"Exception during lock release for {resource_id}: {e}. TTL will handle cleanup."
)

return _wrapped

return _decorate


100 changes: 97 additions & 3 deletions application_sdk/outputs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
from temporalio import activity

from application_sdk.activities.common.models import ActivityStatistics
from application_sdk.activities.common.utils import get_object_store_prefix
from application_sdk.activities.common.utils import get_object_store_prefix, build_output_path
from application_sdk.common.dataframe_utils import is_empty_dataframe
from application_sdk.observability.logger_adaptor import get_logger
from application_sdk.observability.metrics_adaptor import MetricType
from application_sdk.services.objectstore import ObjectStore
from application_sdk.constants import TEMPORARY_PATH
from application_sdk.decorators.method_lock import lock_per_run

logger = get_logger(__name__)
activity.logger = logger


if TYPE_CHECKING:
import daft # type: ignore
import pandas as pd
Expand Down Expand Up @@ -71,6 +74,19 @@ class Output(ABC):
current_buffer_size_bytes: int
partitions: List[int]

def _infer_phase_from_path(self) -> Optional[str]:
"""Infer phase from output path by checking for raw/transformed directories.

Returns:
Optional[str]: "Extract" for raw, "Transform" for transformed, else None.
"""
path_parts = str(self.output_path).split("/")
if "raw" in path_parts:
return "Extract"
if "transformed" in path_parts:
return "Transform"
return None

def estimate_dataframe_record_size(self, dataframe: "pd.DataFrame") -> int:
"""Estimate File size of a DataFrame by sampling a few records."""
if len(dataframe) == 0:
Expand Down Expand Up @@ -330,7 +346,7 @@ async def get_statistics(
Exception: If there's an error writing the statistics
"""
try:
statistics = await self.write_statistics()
statistics = await self.write_statistics(typename)
if not statistics:
raise ValueError("No statistics data available")
statistics = ActivityStatistics.model_validate(statistics)
Expand Down Expand Up @@ -390,7 +406,7 @@ async def _flush_buffer(self, chunk: "pd.DataFrame", chunk_part: int):
logger.error(f"Error flushing buffer to files: {str(e)}")
raise e

async def write_statistics(self) -> Optional[Dict[str, Any]]:
async def write_statistics(self, typename: Optional[str] = None) -> Optional[Dict[str, Any]]:
"""Write statistics about the output to a JSON file.

This method writes statistics including total record count and chunk count
Expand Down Expand Up @@ -418,6 +434,84 @@ async def write_statistics(self) -> Optional[Dict[str, Any]]:
source=output_file_name,
destination=destination_file_path,
)

if typename:
statistics["typename"] = typename
# Update aggregated statistics at run root in object store
try:
await self._update_run_aggregate(destination_file_path, statistics)
except Exception as e:
logger.warning(f"Failed to update aggregated statistics: {str(e)}")
return statistics
except Exception as e:
logger.error(f"Error writing statistics: {str(e)}")

#TODO Do we need locking here ?
@lock_per_run()
async def _update_run_aggregate(
self, per_path_destination: str, statistics: Dict[str, Any]
) -> None:
"""Aggregate stats into a single file at the workflow run root.

Args:
per_path_destination: Object store destination path for this stats file
(used as key in the aggregate map)
statistics: The statistics dictionary to store
"""
inferred_phase = self._infer_phase_from_path()
if inferred_phase is None:
logger.info("Phase could not be inferred from path. Skipping aggregation.")
return

logger.info(f"Starting _update_run_aggregate for phase: {inferred_phase}")
workflow_run_root_relative = build_output_path()
output_file_name = f"{TEMPORARY_PATH}{workflow_run_root_relative}/statistics.json.ignore"
destination_file_path = get_object_store_prefix(output_file_name)

# Load existing aggregate from object store if present
# Structure: {"Extract": {"typename": {"record_count": N}}, "Transform": {...}, "Publish": {...}}
aggregate_by_phase: Dict[str, Dict[str, Dict[str, Any]]] = {
"Extract": {},
"Transform": {},
"Publish": {}
}

try:
# Download existing aggregate file if present
await ObjectStore.download_file(
source=destination_file_path,
destination=output_file_name,
)
# Load existing JSON structure
with open(output_file_name, "r") as f:
existing_aggregate = orjson.loads(f.read())
# Phase-based structure
aggregate_by_phase.update(existing_aggregate)
logger.info(f"Successfully loaded existing aggregates")
except Exception:
logger.info(
"No existing aggregate found or failed to read. Initializing a new aggregate structure."
)

# Accumulate statistics by typename within the phase
typename = statistics.get("typename", "unknown")

if typename not in aggregate_by_phase[inferred_phase]:
aggregate_by_phase[inferred_phase][typename] = {
"record_count": 0
}

logger.info(f"Accumulating statistics for phase '{inferred_phase}', typename '{typename}': +{statistics['total_record_count']} records")

# Accumulate the record count
aggregate_by_phase[inferred_phase][typename]["record_count"] += statistics["total_record_count"]

with open(output_file_name, "w") as f:
f.write(orjson.dumps(aggregate_by_phase).decode("utf-8"))
logger.info(f"Successfully updated aggregate with accumulated stats for phase '{inferred_phase}'")

# Upload aggregate to object store
await ObjectStore.upload_file(
source=output_file_name,
destination=destination_file_path,
)
Loading
Loading