Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,5 @@ dev/cleanup.py

.python-version
.databricks-login.json

.test_*.ipynb
179 changes: 179 additions & 0 deletions src/databricks/labs/blueprint/_logging_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
"""internall plumbing for passing logging context (dict) to logger instances"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""internall plumbing for passing logging context (dict) to logger instances"""
"""Internal plumbing for passing logging context (dict) to logger instances."""


import dataclasses
import inspect
import logging
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from contextvars import ContextVar
from functools import partial, wraps
from types import MappingProxyType
from typing import TYPE_CHECKING, Annotated, Any, TypeVar, get_origin

if TYPE_CHECKING:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general we don't use TYPE_CHECKING guards: our IDEs and mypy work fine without it, and it can interfere with some refactoring operations.

I think this code will work fine without the guard?

T = TypeVar("T")
# SkipLogging[list[str]] will be treated by type checkers as list[str], because that's what Annotated is on runtime
# if this workaround is not in place, caller will complain about having wrong typing
# https://github.com/python/typing/discussions/1229
SkipLogging = Annotated[T, ...]
else:

@dataclasses.dataclass(slots=True)
class SkipLogging:
"""`@logging_context_params` will ignore parameters annotated with this class."""

def __class_getitem__(cls, item: Any) -> Any:
return Annotated[item, SkipLogging()]


_CTX: ContextVar = ContextVar("ctx", default={})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name of the variable and the first parameter are supposed to match, for example:

Suggested change
_CTX: ContextVar = ContextVar("ctx", default={})
_CTX: ContextVar = ContextVar("_CTX", default={})

That said, there's a bigger safety issue we need to think about here, because the dictionary holding the context is mutable. This leads to some weird consequences:

  • The default of {} will be problematic: I expect this to allow the context to be shared between contexts: context $a$ obtains the default (via current_context(), for example), and updates it to include $k=v$; if context $b$ then obtains the default it will include the $k=v$ from context $a$ because it was all the same dictionary instance.
  • The .reset() semantics can be surprising because they assume immutable values, but the dictionary is not.

From a practical point of view, I can see most of the code is avoiding this issue by creating new dictionaries, but there's a path to problems via current_context().

Some things that would help here:

  • Type-hint _CTX as a ContextVar[Mapping[str,Any]]: this will (I think?) allow the type-checker to detect updates to the returned value, useful for verifying our own code doesn't accidentally update the instance.
  • Either make current_context() protected or return a defensive copy of the internal dictionary.
  • A comment explaining the situation so that the next maintainer doesn't accidentally break it. ;)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Later note: I've thought about this some more, and also think it's safest to have no default value, but rather use .get(None) when fetching, and initialise if None is returned.]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch with that, let me handle default internally and now worry about {}



def _params_str(params: dict[str, Any]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def _params_str(params: dict[str, Any]):
def _params_str(params: dict[str, Any]) -> str:

return ", ".join(f"{k}={v!r}" for k, v in params.items())


def _get_skip_logging_param_names(sig: inspect.Signature):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please annotate the return-type.

"""Generates list of parameters names having SkipLogging annotation"""
for name, param in sig.parameters.items():
ann = param.annotation

# only consider annotation
if not ann or get_origin(ann) is not Annotated:
continue

# there can be many annotations for each param
for meta in ann.__metadata__:
# type checker thinks SkipLogging is a generic, despite it being Annotated
if meta and isinstance(meta, SkipLogging): # type: ignore
yield name


def _skip_dict_key(params: dict, keys_to_skip: set):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we type-hint this properly, please?

return {k: v for k, v in params.items() if k not in keys_to_skip}


def current_context():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once _CTX is type-hinted, I think the return type here can also be added in.

"""Returns dictionary of current context set via `with loggin_context(...)` context manager or `@logging_context_params` decorator
Example:
current_context()
>>> {'foo': 'bar', 'a': 2}
"""
return _CTX.get()
Comment on lines +56 to +64
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed elsewhere, this is currently problematic because it's a public API and returns mutable (internal state).

The logging facility is often relied upon as a source of truth when debugging confusing and unusual situations, so this really needs to be handled in a defensive manner so that we know it's not (itself) a further source of problems.

(This is definitely fixable, but I think it's sufficiently important to emphasise why it should be addressed.)



def current_context_repr():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this used anywhere? I don't see any references to it, maybe left over from debugging?

"""Returns repr like "key1=val1, key2=val2" string representation of current_context(), or "" in case context is empty"""
return _params_str(current_context())


@contextmanager
def logging_context(**kwds):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type-hint please, for the return-type.

(Without this the body isn't checked.)

"""Context manager adding keywords to current loging context. Thread and async safe.
Example:
with logging_context(foo="bar", a=2):
logger.info("hello")
>>> 2025-06-06 07:15:09,329 - __main__ - INFO - hello (foo='bar', a=2)
"""
# Get the current context and update it with new keywords
current_ctx = _CTX.get()
new_ctx = {**current_ctx, **kwds}
token = _CTX.set(MappingProxyType(new_ctx))
try:
yield _CTX.get()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. This is another place where we expose the (mutable) internal state.

Stepping back a bit, does it need to yield anything? (I have an open mind here… I don't see why users would want to capture the current context, but also don't see any reason to prevent it, assuming we can do it safely.)

I suspect it should either be:

Suggested change
yield _CTX.get()
yield

(If there's no real use-case.)

or

Suggested change
yield _CTX.get()
yield current_context()

(If we want to expose it, it needs to be done safely.)

except Exception as e:
# python 3.11+: https://docs.python.org/3.11/tutorial/errors.html#enriching-exceptions-with-notes
# https://peps.python.org/pep-0678/
if hasattr(e, "add_note"):
# __notes__ list[str] is only defined if notes were added, otherwise is not there
# we only want to add note if there was no note before, otherwise chaining context cause sproblems
if not getattr(e, "__notes__", None):
e.add_note(f"Context: {_params_str(current_context())}")

raise
finally:
_CTX.reset(token)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is good to see, although bear in mind my earlier comments that the .reset() mechanism assumes values stored in the context are immutable: it can't restore mutable values within the context variable to their earlier state.



def logging_context_params(func=None, **extra_context):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type-hints, please. ;)

"""Decorator that automatically adds all the function parameters to current logging context.
Any passed keyward arguments in will be added to the context. Function parameters take precendnce over the extra keywords in case the names would overlap.
Parameters annotated with `SkipLogging` will be ignored from being added to the context.
Example:
@logging_context_params(foo="bar")
def do_math(a: int, b: SkipLogging[int]):
r = pow(a, b)
logger.info(f"result of {a}**{b} is {r}")
return r
>>> 2025-06-06 07:15:09,329 - __main__ - INFO - result of 2**8 is 256 (foo='bar', a=2)
Note:
- `a` parameter will be logged, type annotation is optional
- `b` parameter wont be logged because is it is annotated with `SkipLogging`
- `foo` parameter will be logged because it is passed as extra context to the decorator
"""

if func is None:
return partial(logging_context_params, **extra_context)

# will use function's singature to bind positional params to name of the param
sig = inspect.signature(func)
skip_params = set(_get_skip_logging_param_names(sig))

@wraps(func)
def wrapper(*args, **kwds):
# only bind if there are positional args
# extra context has lower priority than any of the args
# skip_params is used to filter out parameters that are annotated with SkipLogging

if args:
bound = sig.bind(*args, **kwds)
ctx_data = {**extra_context, **_skip_dict_key(bound.arguments, skip_params)}
else:
ctx_data = {**extra_context, **_skip_dict_key(kwds, skip_params)}

with logging_context(**ctx_data):
return func(*args, **kwds)

return wrapper


class LoggingContextInjectingFilter(logging.Filter):
"""Adds current_context() to the log record."""

def filter(self, record):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type-hints, please.

# https://docs.python.org/3/howto/logging-cookbook.html#using-filters-to-impart-contextual-information
# https://docs.python.org/3/howto/logging-cookbook.html#use-of-contextvars
ctx = current_context()
record.context = f"{_params_str(ctx)}" if ctx else ""
record.context_msg = f" ({record.context})" if record.context else ""
Comment on lines +157 to +158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious why we're attaching the context rendered as a string? I would think attaching it as a dictionary is preferable?

Maybe something like:

Suggested change
record.context = f"{_params_str(ctx)}" if ctx else ""
record.context_msg = f" ({record.context})" if record.context else ""
record.context = ctx
record.context_msg = f" ({_params_str(ctx)})" if ctx else ""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've put it as string to remove the quotes around name of key, it takes less space this way and feels easier to read:

d = { 'process': 'a', 'name': 'Alice', 'priority': 70}

# {'process': 'a', 'name': 'Alice', 'priority': 70}
print(d) 

# {'process': 'a', 'name': 'Alice', 'priority': 70}
print(repr(d)) 

# process = 'a', name = 'Alice', priority = 70
print(", ".join(f"{k} = {v!r}" for k, v in d.items()))

but all in all, I don't have strong opinion

return True


class LoggingThreadPoolExecutor(ThreadPoolExecutor):
"""ThreadPoolExecutor drop in replacement that will apply current loging context to all new started threads."""

def __init__(self, max_workers=None, thread_name_prefix="", initializer=None, initargs=()):
self.__current_context = current_context()
self.__wrapped_initializer = initializer

super().__init__(
max_workers=max_workers,
thread_name_prefix=thread_name_prefix,
initializer=self._logging_context_init,
initargs=initargs,
)

def _logging_context_init(self, *args):
_CTX.set(self.__current_context)
if self.__wrapped_initializer:
self.__wrapped_initializer(*args)
Comment on lines +162 to +179
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was mildly surprised that we need this, I thought that the threading module and friends was updated with contextvars support but it looks like only asuncion was retrofitted.

With this in mind, you're correct that context variables aren't properly transferred into execution contexts in other threads. I suspect this is partly because it's not always clear at which point the context should be preferred.

For example, using a basic Thread you could choose to capture the context during __init__() or during start(). Ensuring this happens means doing something like:

class ContextThread(threading.Thread):
    def __init__(self, *args, **kw) -> None:
        self._captured_ctx = contextvars.copy_context()
        return super().__init__(*args, **kw)
    def run(self):
        transferred_ctx = self._captured_ctx
        del self._transfer_ctx
        return transferred_ctx.run(super().run)

Returning to the situation at hand: ThreadPoolExecutor. For this what really matters is the context when a task is submitted to the pool. That is, during .submit() the context of the caller should be captured so that when the submitted function is executed the captured context of the submitter is in place. (This is a bit fiddly, but my example above shows the basics of it… it's even cleaner because there's no temporary state required on the pool instance.)

I'd like to see this implemented for contextvars in general, irrespective of the logging changes that we're making. Maybe it should be a separate PR, just to limit the scope and ensure we get that right independently of these changes. What do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to patch this in this PR, otherwise it's quite hard to prove usefulness of this feature when we don't have tests for multithreading, and only then I see benefits of using this feature. On non multithreading code, just reading top down logs is enough to understand context.

I can make it working in submit(), will update code soon.

29 changes: 27 additions & 2 deletions src/databricks/labs/blueprint/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,23 @@
import sys
from typing import TextIO

from ._logging_context import (
LoggingContextInjectingFilter,
SkipLogging,
current_context,
logging_context,
logging_context_params,
)

__all__ = [
"NiceFormatter",
"install_logger",
"current_context",
"SkipLogging",
"logging_context_params",
"logging_context",
]


class NiceFormatter(logging.Formatter):
"""A nice formatter for logging. It uses colors and bold text if the console supports it."""
Expand Down Expand Up @@ -36,7 +53,7 @@ def __init__(self, *, probe_tty: bool = False, stream: TextIO = sys.stdout) -> N
stream: the output stream to which the formatter will write, used to check if it is a console.
probe_tty: If true, the formatter will enable color support if the output stream appears to be a console.
"""
super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s", datefmt="%H:%M:%S")
super().__init__(fmt="%(asctime)s %(levelname)s [%(name)s] %(message)s%(context_msg)s", datefmt="%H:%M:%S")
# Used to colorize the level names.
self._levels = {
logging.DEBUG: self._bold(f"{self.CYAN} DEBUG"),
Expand Down Expand Up @@ -88,7 +105,12 @@ def format(self, record: logging.LogRecord) -> str:
color_marker = self._msg_colors[record.levelno]

thread_name = f"[{record.threadName}]" if record.threadName != "MainThread" else ""
return f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}"

# safe check, just in case injection filter is removed
context_repr = record.context if hasattr(record, "context") else ""
context_msg = f" {self.GRAY}({context_repr}){self.RESET}" if context_repr else ""

return f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}{context_msg}"
Comment on lines +108 to +113
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, how about?

Suggested change
# safe check, just in case injection filter is removed
context_repr = record.context if hasattr(record, "context") else ""
context_msg = f" {self.GRAY}({context_repr}){self.RESET}" if context_repr else ""
return f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}{context_msg}"
context_msg = getattr("record", "context_msg", "")
if context_msg:
formatted = f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.GRAY}{context_msg}{self.RESET}
else:
formatted = f"{self.GRAY}{timestamp}{self.RESET} {level} {color_marker}[{name}]{thread_name} {msg}{self.RESET}
return formatted



def install_logger(
Expand All @@ -102,6 +124,7 @@ def install_logger(
- All existing handlers will be removed.
- A new handler will be installed with our custom formatter. It will be configured to emit logs at the given level
(default: DEBUG) or higher, to the specified stream (default: sys.stderr).
- A new (injection) filter for adding logger_context will be added, that will add `context` with current context, to all logger messages.

Args:
level: The logging level to set for the console handler.
Expand All @@ -115,6 +138,8 @@ def install_logger(
root.removeHandler(handler)
console_handler = logging.StreamHandler(stream)
console_handler.setFormatter(NiceFormatter(stream=stream))
console_handler.addFilter(LoggingContextInjectingFilter())
console_handler.setLevel(level)

root.addHandler(console_handler)
return console_handler
9 changes: 5 additions & 4 deletions src/databricks/labs/blueprint/parallel.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the plot thickens… ideally we'd capture the context (as discussed earlier) when the task to run is being created, but by the time we reach here that's already happened. That said, we should be capturing the caller context during task submission and ensuring that when each task runs on the other thread, it starts within the captured context. (The task may modify it further, of-course.)

Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import re
import threading
from collections.abc import Callable, Collection, Sequence
from concurrent.futures import ThreadPoolExecutor
from typing import Generic, TypeVar

from ._logging_context import LoggingThreadPoolExecutor

MIN_THREADS = 8

Result = TypeVar("Result")
Expand Down Expand Up @@ -61,12 +62,12 @@ def gather(
return cls(name, tasks, num_threads=num_threads)._run()

@classmethod
def strict(cls, name: str, tasks: Sequence[Task[Result]]) -> Collection[Result]:
def strict(cls, name: str, tasks: Sequence[Task[Result]], num_threads: int | None = None) -> Collection[Result]:
"""Run tasks in parallel and raise ManyError if any task fails"""
# this dunder variable is hiding this method from tracebacks, making it cleaner
# for the user to see the actual error without too much noise.
__tracebackhide__ = True # pylint: disable=unused-variable
collected, errs = cls.gather(name, tasks)
collected, errs = cls.gather(name, tasks, num_threads)
if errs:
if len(errs) == 1:
raise errs[0]
Expand Down Expand Up @@ -114,7 +115,7 @@ def _on_finish(self, given_cnt: int, collected_cnt: int, failed_cnt: int):
def _execute(self):
"""Run tasks in parallel and return futures"""
thread_name_prefix = re.sub(r"\W+", "_", self._name)
with ThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool:
with LoggingThreadPoolExecutor(self._num_threads, thread_name_prefix) as pool:
futures = []
for task in self._tasks:
if task is None:
Expand Down
37 changes: 34 additions & 3 deletions tests/unit/test_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,12 @@

import pytest

from databricks.labs.blueprint.logger import NiceFormatter, install_logger
from databricks.labs.blueprint._logging_context import LoggingContextInjectingFilter
from databricks.labs.blueprint.logger import (
NiceFormatter,
install_logger,
logging_context,
)


class LogCaptureHandler(logging.Handler):
Expand All @@ -32,6 +37,7 @@ def emit(self, record: logging.LogRecord) -> None:
def record_capturing(cls, logger: logging.Logger) -> Generator[LogCaptureHandler, None, None]:
"""Temporarily capture all log records, in addition to existing handling."""
handler = LogCaptureHandler()
handler.addFilter(LoggingContextInjectingFilter())
logger.addHandler(handler)
try:
yield handler
Expand All @@ -49,7 +55,9 @@ class LoggingSystemFixture:
def __init__(self) -> None:
self.output_buffer = io.StringIO()
self.root = logging.RootLogger(logging.WARNING)
self.root.addHandler(logging.StreamHandler(self.output_buffer))
handler = logging.StreamHandler(self.output_buffer)
handler.addFilter(LoggingContextInjectingFilter())
self.root.addHandler(handler)
self.manager = logging.Manager(self.root)

def getLogger(self, name: str) -> logging.Logger:
Expand Down Expand Up @@ -84,6 +92,8 @@ def test_install_logger(logging_system) -> None:
# Verify that the root logger was configured as expected.
assert root.level == logging.FATAL # remains unchanged
assert root.handlers == [handler]
assert len(handler.filters) == 1
assert isinstance(handler.filters[0], LoggingContextInjectingFilter)
assert handler.level == logging.INFO
assert isinstance(handler.formatter, NiceFormatter)

Expand All @@ -98,7 +108,8 @@ def test_installed_logger_logging(logging_system) -> None:
logger = logging_system.getLogger(__file__)
logger.debug("This is a debug message")
logger.info("This is an info message")
logger.warning("This is a warning message")
with logging_context(foo="bar-warning"):
logger.warning("This is a warning message")
logger.error("This is an error message", exc_info=KeyError(123))
logger.critical("This is a critical message")

Expand All @@ -107,6 +118,7 @@ def test_installed_logger_logging(logging_system) -> None:
assert "This is a debug message" in output
assert "This is an info message" in output
assert "This is a warning message" in output
assert "(foo='bar-warning')" in output
assert "This is an error message\nKeyError: 123" in output
assert "This is a critical message" in output

Expand Down Expand Up @@ -348,3 +360,22 @@ def test_formatter_format_exception(use_colors: bool) -> None:
" raise RuntimeError(exception_message)",
]
assert exception == "RuntimeError: Test exception."


@pytest.mark.parametrize("use_colors", (True, False), ids=("with_colors", "without_colors"))
def test_formatter_with_logging_context(use_colors: bool) -> None:
"""Ensure the formatter correctly formats message when logging_context is used"""
formatter = NiceFormatter()
formatter.colors = use_colors

with logging_context(foo="bar", baz="zak"):
record = create_record(logging.DEBUG, " This is a test message for logging context")
assert hasattr(record, "context")
assert record.context == "foo='bar', baz='zak'"
formatted = formatter.format(record)
assert record.context in formatted, "record context not in formatted"
stripped = _strip_sgr_sequences(formatted) if use_colors else formatted
assert record.context in stripped, "record context not in stripped"

# H:M:S LEVEL [logger_name] message (logging_context)
assert stripped.endswith(" This is a test message for logging context (foo='bar', baz='zak')")
Loading
Loading