diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 90daecbe2..a6f0805e3 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -6,6 +6,7 @@ import contextvars import inspect import logging +import sys import threading import typing import uuid @@ -1565,6 +1566,7 @@ def __init__(self, logger: logging.Logger, extra: Mapping[str, Any] | None) -> N self.workflow_info_on_extra = True self.full_workflow_info_on_extra = False self.log_during_replay = False + self.disable_sandbox = False def process( self, msg: Any, kwargs: MutableMapping[str, Any] @@ -1598,7 +1600,27 @@ def process( kwargs["extra"] = {**extra, **(kwargs.get("extra") or {})} if msg_extra: msg = f"{msg} ({msg_extra})" - return (msg, kwargs) + return msg, kwargs + + def log( + self, + level: int, + msg: object, + *args: Any, + stacklevel: int = 1, + **kwargs: Any, + ): + """Override to potentially disable the sandbox.""" + if sys.version_info < (3, 11): + # An additional stacklevel is needed on 3.10 because it doesn't skip internal frames until after stacklevel + stacklevel += 1 # type: ignore[reportUnreachable] + stacklevel += 1 + if self.disable_sandbox: + with unsafe.sandbox_unrestricted(): + with unsafe.imports_passed_through(): + super().log(level, msg, *args, stacklevel=stacklevel, **kwargs) + else: + super().log(level, msg, *args, stacklevel=stacklevel, **kwargs) def isEnabledFor(self, level: int) -> bool: """Override to ignore replay logs.""" @@ -1613,6 +1635,12 @@ def base_logger(self) -> logging.Logger: """ return self.logger + def unsafe_disable_sandbox(self, value: bool = True): + """Disable the sandbox during log processing. + Can be turned back on with unsafe_disable_sandbox(False). + """ + self.disable_sandbox = value + logger = LoggerAdapter(logging.getLogger(__name__), None) """Logger that will have contextual workflow details embedded. diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index 8c7feae82..f21090e01 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -8432,3 +8432,52 @@ async def test_activity_failure_with_encoded_payload_is_decoded_in_workflow( run_timeout=timedelta(seconds=5), ) assert result == "Handled encrypted failure successfully" + + +@workflow.defn +class DisableLoggerSandbox: + @workflow.run + async def run(self): + workflow.logger.info("Running workflow") + + +class CustomLogHandler(logging.Handler): + def emit(self, record: logging.LogRecord) -> None: + import httpx # type: ignore[reportUnusedImport] + + +async def test_disable_logger_sandbox( + client: Client, +): + logger = workflow.logger.logger + logger.addHandler(CustomLogHandler()) + async with new_worker( + client, + DisableLoggerSandbox, + activities=[], + ) as worker: + with pytest.raises(WorkflowFailureError): + await client.execute_workflow( + DisableLoggerSandbox.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + run_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + workflow.logger.unsafe_disable_sandbox() + await client.execute_workflow( + DisableLoggerSandbox.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + run_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=1), + ) + workflow.logger.unsafe_disable_sandbox(False) + with pytest.raises(WorkflowFailureError): + await client.execute_workflow( + DisableLoggerSandbox.run, + id=f"workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + run_timeout=timedelta(seconds=1), + retry_policy=RetryPolicy(maximum_attempts=1), + )