Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PoC] Remote debugger based on debugpy #9007

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 84 additions & 0 deletions distributed/debugpy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

import importlib.util
import logging
import sys
import threading

import dask.config

logger = logging.getLogger(__name__)

DEBUGPY_ENABLED: bool = dask.config.get("distributed.diagnostics.debugpy.enabled")
DEBUGPY_PORT: int = dask.config.get("distributed.diagnostics.debugpy.port")


def _check_debugpy_installed():
if importlib.util.find_spec("debugpy") is None:
raise ModuleNotFoundError(
"Dask debugger requires debugpy. Please make sure it is installed."
)


LOCK = threading.Lock()


def _ensure_debugpy_listens() -> tuple[str, int]:
import debugpy

from distributed.worker import get_worker

worker = get_worker()

with LOCK:
if endpoint := worker.extensions.get("debugpy", None):
return endpoint
endpoint = debugpy.listen(("0.0.0.0", DEBUGPY_PORT))
worker.extensions["debugpy"] = endpoint
return endpoint


def breakpointhook() -> None:
import debugpy

host, port = _ensure_debugpy_listens()
if not debugpy.is_client_connected():
logger.warning(
"Breakpoint encountered; waiting for client to attach to %s:%d...",
host,
port,
)
debugpy.wait_for_client()

debugpy.breakpoint()


def post_mortem() -> None:
# Based on https://github.com/microsoft/debugpy/issues/723
import debugpy

host, port = _ensure_debugpy_listens()
if not debugpy.is_client_connected():
logger.warning(
"Exception encountered; waiting for client to attach to %s:%d...",
host,
port,
)
debugpy.wait_for_client()

import pydevd

py_db = pydevd.get_global_debugger()
thread = threading.current_thread()
additional_info = py_db.set_additional_thread_info(thread)
additional_info.is_tracing += 1
try:
error = sys.exc_info()
py_db.stop_on_unhandled_exception(py_db, thread, additional_info, error)
finally:
additional_info.is_tracing -= 1


if DEBUGPY_ENABLED:
_check_debugpy_installed()
sys.breakpointhook = breakpointhook
11 changes: 10 additions & 1 deletion distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1061,7 +1061,16 @@ properties:
minimum: 0
description: |
The maximum number of erred tasks to remember.

debugpy:
type: object
description: Configuration settings for Dask's remote debugger
properties:
enabled:
type: boolean
description: Enable remote debugging.
port:
type: integer
description: Port used by the debug adapter to listen on.
p2p:
type: object
description: Configuration for P2P shuffles
Expand Down
3 changes: 3 additions & 0 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ distributed:
- get_output_via_markers\.py
erred-tasks:
max-history: 100
debugpy:
enabled: True
port: 5678

p2p:
comm:
Expand Down
4 changes: 3 additions & 1 deletion distributed/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
typename,
)

from distributed import preloading, profile, utils
from distributed import debugpy, preloading, profile, utils
from distributed.batched import BatchedSend
from distributed.collections import LRU
from distributed.comm import Comm, connect, get_address_host, parse_address
Expand Down Expand Up @@ -2995,6 +2995,7 @@ def _run_task_simple(
msg: RunTaskFailure = error_message(e) # type: ignore
msg["op"] = "task-erred"
msg["actual_exception"] = e
debugpy.post_mortem()
else:
msg: RunTaskSuccess = { # type: ignore
"op": "task-finished",
Expand Down Expand Up @@ -3041,6 +3042,7 @@ async def _run_task_async(
msg: RunTaskFailure = error_message(e) # type: ignore
msg["op"] = "task-erred"
msg["actual_exception"] = e
debugpy.post_mortem()
else:
msg: RunTaskSuccess = { # type: ignore
"op": "task-finished",
Expand Down
Loading