From 640e984969f1e03ff17d5cc4d06f84cf914a1fab Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 23 Jan 2024 16:47:47 -0500 Subject: [PATCH] feat(python/adbc_driver_manager): handle KeyboardInterrupt Fixes #1484. --- .github/workflows/native-unix.yml | 13 + ci/scripts/python_test.sh | 4 +- docker-compose.yml | 6 + .../tests/test_errors.py | 20 ++ python/adbc_driver_manager/MANIFEST.in | 2 + .../adbc_driver_manager/_blocking_impl.cc | 269 ++++++++++++++++++ .../adbc_driver_manager/_blocking_impl.h | 38 +++ .../adbc_driver_manager/_lib.pyi | 11 + .../adbc_driver_manager/_lib.pyx | 98 +++++++ .../adbc_driver_manager/dbapi.py | 17 +- python/adbc_driver_manager/pyproject.toml | 1 + python/adbc_driver_manager/setup.py | 3 + .../tests/test_blocking.py | 140 +++++++++ 13 files changed, 615 insertions(+), 7 deletions(-) create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc create mode 100644 python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h create mode 100644 python/adbc_driver_manager/tests/test_blocking.py diff --git a/.github/workflows/native-unix.yml b/.github/workflows/native-unix.yml index 142d5f4be7..c4ab9f44f8 100644 --- a/.github/workflows/native-unix.yml +++ b/.github/workflows/native-unix.yml @@ -477,7 +477,20 @@ jobs: - name: Test Python Driver Flight SQL shell: bash -l {0} run: | + # Can't use Docker on macOS + pushd $(pwd)/go/adbc + go build -o testserver ./driver/flightsql/cmd/testserver + popd + $(pwd)/go/adbc/testserver -host 0.0.0.0 -port 41414 & + while ! curl --http2-prior-knowledge -H "content-type: application/grpc" -v localhost:41414 -XPOST; + do + echo "Waiting for test server..." + jobs + sleep 5 + done + export ADBC_TEST_FLIGHTSQL_URI=grpc://localhost:41414 env BUILD_ALL=0 BUILD_DRIVER_FLIGHTSQL=1 ./ci/scripts/python_test.sh "$(pwd)" "$(pwd)/build" "$HOME/local" + kill %1 - name: Build Python Driver PostgreSQL shell: bash -l {0} run: | diff --git a/ci/scripts/python_test.sh b/ci/scripts/python_test.sh index f8d7091791..6f95b5898e 100755 --- a/ci/scripts/python_test.sh +++ b/ci/scripts/python_test.sh @@ -58,8 +58,8 @@ test_subproject() { fi echo "=== Testing ${subproject} ===" - echo env ${options[@]} python -m pytest -vv "${source_dir}/python/${subproject}/tests" - env ${options[@]} python -m pytest -vv "${source_dir}/python/${subproject}/tests" + echo env ${options[@]} python -m pytest -vvs --full-trace "${source_dir}/python/${subproject}/tests" + env ${options[@]} python -m pytest -vvs --full-trace "${source_dir}/python/${subproject}/tests" echo } diff --git a/docker-compose.yml b/docker-compose.yml index 89394e598f..789d5d450d 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -150,6 +150,12 @@ services: dockerfile: ci/docker/flightsql-test.dockerfile args: GO: ${GO} + healthcheck: + test: ["CMD", "curl", "--http2-prior-knowledge", "-XPOST", "-H", "content-type: application/grpc"] + interval: 5s + timeout: 30s + retries: 3 + start_period: 5m ports: - "41414:41414" volumes: diff --git a/python/adbc_driver_flightsql/tests/test_errors.py b/python/adbc_driver_flightsql/tests/test_errors.py index ed44b6a3fa..ee2b62d3ee 100644 --- a/python/adbc_driver_flightsql/tests/test_errors.py +++ b/python/adbc_driver_flightsql/tests/test_errors.py @@ -16,6 +16,8 @@ # under the License. import re +import threading +import time import google.protobuf.any_pb2 as any_pb2 import google.protobuf.wrappers_pb2 as wrappers_pb2 @@ -45,6 +47,24 @@ def test_query_cancel(test_dbapi): cur.fetchone() +def test_query_cancel_async(test_dbapi): + with test_dbapi.cursor() as cur: + cur.execute("forever") + + def _cancel(): + time.sleep(2) + cur.adbc_cancel() + + t = threading.Thread(target=_cancel, daemon=True) + t.start() + + with pytest.raises( + test_dbapi.OperationalError, + match=re.escape("CANCELLED: [FlightSQL] context canceled"), + ): + cur.fetchone() + + def test_query_error_fetch(test_dbapi): with test_dbapi.cursor() as cur: cur.execute("error_do_get") diff --git a/python/adbc_driver_manager/MANIFEST.in b/python/adbc_driver_manager/MANIFEST.in index 306c31144f..298ff3a9ca 100644 --- a/python/adbc_driver_manager/MANIFEST.in +++ b/python/adbc_driver_manager/MANIFEST.in @@ -22,6 +22,8 @@ include NOTICE.txt include adbc_driver_manager/adbc.h include adbc_driver_manager/adbc_driver_manager.cc include adbc_driver_manager/adbc_driver_manager.h +include adbc_driver_manager/_blocking_impl.cc +include adbc_driver_manager/_blocking_impl.h include adbc_driver_manager/_lib.pxd include adbc_driver_manager/_lib.pyi include adbc_driver_manager/_reader.pyi diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc new file mode 100644 index 0000000000..766b3964cd --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.cc @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "_blocking_impl.h" + +#if defined(_WIN32) +#define NOMINMAX +#define WIN32_LEAN_AND_MEAN +#include +#include +#include +#include +#else +#include +#include +#include +#endif + +#include +#include +#include +#include +#include + +namespace pyadbc_driver_manager { + +// This is somewhat derived from io_util.cc in arrow, but that implementation +// isn't easily used outside of Arrow's monolith. +namespace { +static std::once_flag kInitOnce; +// We may encounter errors below that we can't do anything about. Use this to +// print out an error, once. +static std::once_flag kWarnOnce; +// This thread reads from a pipe forever. Whenever it reads something, it +// calls the callback below. +static std::thread kCancelThread; + +static std::mutex cancel_mutex; +// This callback is registered by the Python side; basically it will call +// cancel() on an ADBC object. +static void (*cancel_callback)(void*) = nullptr; +// Callback state (a pointer to the ADBC PyObject). +static void* cancel_callback_data = nullptr; +// A nonblocking self-pipe. +static int pipe[2]; +#if defined(_WIN32) +void (*old_sigint)(int); +#else +// The old signal handler (most likely Python's). +struct sigaction old_sigint; +// Our signal handler (below). +struct sigaction our_sigint; +#endif + +std::string MakePipe() { + int rc = 0; +#if defined(__linux__) && defined(__GLIBC__) + rc = pipe2(pipe, O_CLOEXEC); +#elif defined(_WIN32) + rc = _pipe(pipe, 4096, _O_BINARY); +#else + rc = ::pipe(pipe); +#endif + + if (rc != 0) { + return std::strerror(errno); + } + +#if (!defined(__linux__) || !defined(__GLIBC__)) && !defined(_WIN32) + { + int flags = fcntl(pipe[0], F_GETFD, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[0], F_SETFD, flags | FD_CLOEXEC); + if (rc < 0) { + return std::strerror(errno); + } + + flags = fcntl(pipe[1], F_GETFD, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[1], F_SETFD, flags | FD_CLOEXEC); + if (rc < 0) { + return std::strerror(errno); + } + } +#endif + + // Make the write side nonblocking (the read side should stay blocking!) +#if defined(_WIN32) + const auto handle = reinterpret_cast(_get_osfhandle(pipe[1])); + DWORD mode = PIPE_NOWAIT; + if (!SetNamedPipeHandleState(handle, &mode, nullptr, nullptr)) { + DWORD last_error = GetLastError(); + LPVOID message; + + FormatMessage(FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS, + /*lpSource=*/nullptr, last_error, + MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT), + reinterpret_cast(&message), /*nSize=*/0, /*Arguments=*/nullptr); + + std::string buffer = "("; + buffer += std::to_string(last_error); + buffer += ") "; + buffer += reinterpret_cast(message); + LocalFree(message); + return buffer; + } +#else + { + int flags = fcntl(pipe[1], F_GETFL, 0); + if (flags < 0) { + return std::strerror(errno); + } + rc = fcntl(pipe[1], F_SETFL, flags | O_NONBLOCK); + if (rc < 0) { + return std::strerror(errno); + } + } +#endif + + return ""; +} + +void InterruptThread() { +#if defined(__APPLE__) + pthread_setname_np("AdbcInterrupt"); +#endif + + while (true) { + char buf = 0; + // Anytime something is written to the pipe, attempt to call the callback + auto bytes_read = read(pipe[0], &buf, 1); + if (bytes_read < 0) { + if (errno == EINTR) continue; + + // XXX: we failed reading from the pipe + std::string message = std::strerror(errno); + std::call_once(kWarnOnce, [&]() { + std::cerr << "adbc_driver_manager (native code): error handling interrupt: " + << message << std::endl; + }); + } else if (bytes_read > 0) { + // Save the callback locally instead of calling it under the lock, since + // otherwise we may deadlock with the Python side trying to call us + void (*local_callback)(void*) = nullptr; + void* local_callback_data = nullptr; + + { + std::lock_guard lock(cancel_mutex); + if (cancel_callback != nullptr) { + local_callback = cancel_callback; + local_callback_data = cancel_callback_data; + } + cancel_callback = nullptr; + cancel_callback_data = nullptr; + } + + if (local_callback != nullptr) { + local_callback(local_callback_data); + } + } + } +} + +// We can't do much about failures here, so ignore the result. If the pipe is +// full, that's fine; it just means the thread has fallen behind in processing +// earlier interrupts. +void SigintHandler(int) { +#if defined(_WIN32) + (void)_write(pipe[1], "X", 1); +#else + (void)write(pipe[1], "X", 1); +#endif +} + +} // namespace + +std::string InitBlockingCallback() { + std::string error; + std::call_once(kInitOnce, [&]() { + error = MakePipe(); + if (!error.empty()) { + return; + } + +#if !defined(_WIN32) + our_sigint.sa_handler = &SigintHandler; + our_sigint.sa_flags = 0; + sigemptyset(&our_sigint.sa_mask); +#endif + + kCancelThread = std::thread(InterruptThread); +#if defined(__linux__) + pthread_setname_np(kCancelThread.native_handle(), "AdbcInterrupt"); +#endif + kCancelThread.detach(); + }); + return error; +} + +std::string SetBlockingCallback(void (*callback)(void*), void* data) { + std::lock_guard lock(cancel_mutex); + cancel_callback = callback; + cancel_callback_data = data; + +#if defined(_WIN32) + if (old_sigint == nullptr) { + old_sigint = signal(SIGINT, &SigintHandler); + if (old_sigint == SIG_ERR) { + old_sigint = nullptr; + return std::strerror(errno); + } + } +#else + // Don't set the handler again if we're somehow called twice + if (old_sigint.sa_handler == nullptr && old_sigint.sa_sigaction == nullptr) { + int rc = sigaction(SIGINT, &our_sigint, &old_sigint); + if (rc != 0) { + return std::strerror(errno); + } + } +#endif + return ""; +} + +std::string ClearBlockingCallback() { + std::lock_guard lock(cancel_mutex); + cancel_callback = nullptr; + cancel_callback_data = nullptr; + +#if defined(_WIN32) + if (old_sigint != nullptr) { + auto rc = signal(SIGINT, old_sigint); + old_sigint = nullptr; + if (rc == SIG_ERR) { + return std::strerror(errno); + } + } +#else + if (old_sigint.sa_handler != nullptr || old_sigint.sa_sigaction != nullptr) { + int rc = sigaction(SIGINT, &old_sigint, nullptr); + std::memset(&old_sigint, 0, sizeof(old_sigint)); + if (rc != 0) { + return std::strerror(errno); + } + } +#endif + return ""; +} + +} // namespace pyadbc_driver_manager diff --git a/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h new file mode 100644 index 0000000000..ac76252f3e --- /dev/null +++ b/python/adbc_driver_manager/adbc_driver_manager/_blocking_impl.h @@ -0,0 +1,38 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Allow KeyboardInterrupt to function with ADBC in Python. +/// +/// Call SetBlockingCallback to register a callback. This will temporarily +/// suppress the Python SIGINT handler. When SIGINT is received, this module +/// will handle it by calling the callback. + +#include + +namespace pyadbc_driver_manager { + +/// \brief Set up internal state to handle. +/// \return An error message (or empty string). +std::string InitBlockingCallback(); +/// \brief Set the callback for when SIGINT is received. +/// \return An error message (or empty string). +std::string SetBlockingCallback(void (*callback)(void*), void* data); +/// \brief Clear the callback for when SIGINT is received. +/// \return An error message (or empty string). +std::string ClearBlockingCallback(); + +} // namespace pyadbc_driver_manager diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 7afada9ecc..2a818839a1 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -17,6 +17,7 @@ # NOTE: generated with mypy's stubgen, then hand-edited to fix things +import typing_extensions from typing import Any, ClassVar, Dict, List, Literal, Optional, Tuple, Union from typing import overload @@ -201,3 +202,13 @@ def _test_error( vendor_code: Optional[int], sqlstate: Optional[str], ) -> Error: ... + +_P = typing_extensions.ParamSpec("_P") +_T = typing.TypeVar("_T") + +def _blocking_call( + func: typing.Callable[_P, _T], + args: tuple, + kwargs: dict, + cancel: typing.Callable[[], None], +) -> _T: ... diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 91139100bb..79222d6082 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -20,8 +20,12 @@ """Low-level ADBC API.""" import enum +import functools import threading +import os import typing +import sys +import warnings from typing import List, Tuple cimport cpython @@ -33,6 +37,7 @@ from cpython.pycapsule cimport ( from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t from libc.stdlib cimport malloc, free from libc.string cimport memcpy, memset +from libcpp.string cimport string as c_string from libcpp.vector cimport vector as c_vector if typing.TYPE_CHECKING: @@ -1481,3 +1486,96 @@ cdef class AdbcStatement(_AdbcHandle): cdef const CAdbcError* PyAdbcErrorFromArrayStream( CArrowArrayStream* stream, CAdbcStatusCode* status): return AdbcErrorFromArrayStream(stream, status) + + +cdef extern from "_blocking_impl.h" nogil: + ctypedef void (*BlockingCallback)(void*) noexcept nogil + c_string CInitBlockingCallback"pyadbc_driver_manager::InitBlockingCallback"() + c_string CSetBlockingCallback"pyadbc_driver_manager::SetBlockingCallback"(BlockingCallback, void* data) + c_string CClearBlockingCallback"pyadbc_driver_manager::ClearBlockingCallback"() + + +@functools.cache +def _init_blocking_call(): + error = bytes(CInitBlockingCallback()).decode("utf-8") + if error: + warnings.warn( + f"Failed to initialize KeyboardInterrupt support: {error}", + RuntimeWarning, + ) + + +_blocking_lock = threading.Lock() +_blocking_exc = None + + +def _blocking_call_impl(func, args, kwargs, cancel): + """ + Run functions that are expected to block with a native SIGINT handler. + + Parameters + ---------- + """ + global _blocking_exc + + if threading.current_thread() is not threading.main_thread(): + return func(*args, **kwargs) + + _init_blocking_call() + + with _blocking_lock: + if _blocking_exc: + _blocking_exc = None + + # Set the callback for the background thread and save the signal handler + # TODO: ideally this would be no-op if already set + error = bytes( + CSetBlockingCallback(&_handle_blocking_call, cancel) + ).decode("utf-8") + if error: + warnings.warn( + f"Failed to set SIGINT handler: {error}", + RuntimeWarning, + ) + + try: + return func(*args, **kwargs) + except BaseException as e: + with _blocking_lock: + if _blocking_exc: + exc = _blocking_exc + _blocking_exc = None + raise e from exc[1].with_traceback(exc[2]) + raise e + finally: + # Restore the signal handler + error = bytes(CClearBlockingCallback()).decode("utf-8") + if error: + warnings.warn( + f"Failed to restore SIGINT handler: {error}", + RuntimeWarning, + ) + with _blocking_lock: + if _blocking_exc: + exc = _blocking_exc + _blocking_exc = None + raise exc[1].with_traceback(exc[2]) from KeyboardInterrupt + + +if os.name != "nt": + _blocking_call = _blocking_call_impl +else: + def _blocking_call(func, args, kwargs, cancel): + return func(*args, **kwargs) + + + +cdef void _handle_blocking_call(void* c_cancel) noexcept nogil: + with gil: + try: + cancel = c_cancel + cancel() + except: + with _blocking_lock: + global _blocking_exc + _blocking_exc = sys.exc_info() diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 1e86144c12..ee6f318d6b 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -55,6 +55,7 @@ import adbc_driver_manager from . import _lib, _reader +from ._lib import _blocking_call if typing.TYPE_CHECKING: import pandas @@ -677,9 +678,12 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: parameters, which will each be bound in turn). """ self._prepare_execute(operation, parameters) - handle, self._rowcount = self._stmt.execute_query() + + handle, self._rowcount = _blocking_call( + self._stmt.execute_query, (), {}, self._stmt.cancel + ) self._results = _RowIterator( - _reader.AdbcRecordBatchReader._import_from_c(handle.address) + self._stmt, _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: @@ -991,7 +995,7 @@ def adbc_read_partition(self, partition: bytes) -> None: handle = self._conn._conn.read_partition(partition) self._rowcount = -1 self._results = _RowIterator( - pyarrow.RecordBatchReader._import_from_c(handle.address) + self._stmt, pyarrow.RecordBatchReader._import_from_c(handle.address) ) @property @@ -1095,7 +1099,8 @@ def fetch_record_batch(self) -> pyarrow.RecordBatchReader: class _RowIterator(_Closeable): """Track state needed to iterate over the result set.""" - def __init__(self, reader: pyarrow.RecordBatchReader) -> None: + def __init__(self, stmt, reader: pyarrow.RecordBatchReader) -> None: + self._stmt = stmt self._reader = reader self._current_batch = None self._next_row = 0 @@ -1118,7 +1123,9 @@ def fetchone(self) -> Optional[tuple]: if self._current_batch is None or self._next_row >= len(self._current_batch): try: while True: - self._current_batch = self._reader.read_next_batch() + self._current_batch = _blocking_call( + self._reader.read_next_batch, (), {}, self._stmt.cancel + ) if self._current_batch.num_rows > 0: break self._next_row = 0 diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml index 0a03fa3ff9..d2db1f102f 100644 --- a/python/adbc_driver_manager/pyproject.toml +++ b/python/adbc_driver_manager/pyproject.toml @@ -23,6 +23,7 @@ license = {text = "Apache-2.0"} readme = "README.md" requires-python = ">=3.9" dynamic = ["version"] +dependencies = ["typing-extensions"] [project.optional-dependencies] dbapi = ["pandas", "pyarrow>=8.0.0"] diff --git a/python/adbc_driver_manager/setup.py b/python/adbc_driver_manager/setup.py index bbec1a01c6..1b6f1026ea 100644 --- a/python/adbc_driver_manager/setup.py +++ b/python/adbc_driver_manager/setup.py @@ -76,6 +76,8 @@ def get_version(pkg_path): if sys.platform == "win32": extra_compile_args = ["/std:c++17", "/DADBC_EXPORTING"] + if build_type == "debug": + extra_compile_args.extend(["/DEBUG:FULL"]) else: extra_compile_args = ["-std=c++17"] if build_type == "debug": @@ -93,6 +95,7 @@ def get_version(pkg_path): include_dirs=[str(source_root.joinpath("adbc_driver_manager").resolve())], language="c++", sources=[ + "adbc_driver_manager/_blocking_impl.cc", "adbc_driver_manager/_lib.pyx", "adbc_driver_manager/adbc_driver_manager.cc", ], diff --git a/python/adbc_driver_manager/tests/test_blocking.py b/python/adbc_driver_manager/tests/test_blocking.py new file mode 100644 index 0000000000..2af5e3d086 --- /dev/null +++ b/python/adbc_driver_manager/tests/test_blocking.py @@ -0,0 +1,140 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Direct tests of the SIGINT handler. + +Higher-level testing of SIGINT during queries appears to be flaky in CI due to +having to send the signal, so this tests the handler itself instead. +""" + +import os +import signal +import threading +import time + +import pytest + +from adbc_driver_manager import _lib + +# It works fine on the normal Windows builds, but not under the Conda builds +# where there is an unexplained/unreplicable crash, and so for now this is +# disabled on Windows +pytestmark = pytest.mark.skipif(os.name == "nt", reason="Disabled on Windows") + + +def _send_sigint(): + # Windows behavior is different + # https://stackoverflow.com/questions/35772001 + if os.name == "nt": + os.kill(os.getpid(), signal.CTRL_C_EVENT) + else: + os.kill(os.getpid(), signal.SIGINT) + + +def _blocking(event): + _send_sigint() + event.wait() + + +def test_sigint_fires(): + # Run the thing that fires SIGINT itself as the "blocking" call + event = threading.Event() + + def _cancel(): + event.set() + + _lib._blocking_call(_blocking, (event,), {}, _cancel) + + +def test_handler_restored(): + event = threading.Event() + _lib._blocking_call(_blocking, (event,), {}, event.set) + + # After it returns, this should raise KeyboardInterrupt like usual + with pytest.raises(KeyboardInterrupt): + _blocking(event) + # Needed on Windows so the handler runs before we exit the block (we + # won't sleep for the full time) + time.sleep(60) + + +def test_args_return(): + def _blocking(a, *, b): + return a, b + + assert _lib._blocking_call( + _blocking, + (1,), + {"b": 2}, + lambda: None, + ) == (1, 2) + + +def test_blocking_raise(): + def _blocking(): + raise ValueError("expected error") + + with pytest.raises(ValueError, match="expected error"): + _lib._blocking_call(_blocking, (), {}, lambda: None) + + +def test_cancel_raise(): + event = threading.Event() + + def _cancel(): + event.set() + raise ValueError("expected error") + + with pytest.raises(ValueError, match="expected error"): + _lib._blocking_call(_blocking, (event,), {}, _cancel) + + +def test_both_raise(): + event = threading.Event() + + def _blocking(event): + _send_sigint() + event.wait() + raise ValueError("expected error 1") + + def _cancel(): + event.set() + raise ValueError("expected error 2") + + with pytest.raises(ValueError, match="expected error 1") as excinfo: + _lib._blocking_call(_blocking, (event,), {}, _cancel) + assert excinfo.value.__cause__ is not None + with pytest.raises(ValueError, match="expected error 2"): + raise excinfo.value.__cause__ + + +def test_nested(): + # To be clear, don't ever do this. + event = threading.Event() + + def _wrap_blocking(): + _lib._blocking_call(_blocking, (event,), {}, event.set) + + _lib._blocking_call(_wrap_blocking, (), {}, lambda: None) + + # The original handler should be restored + with pytest.raises(KeyboardInterrupt): + _send_sigint() + # Needed on Windows so the handler runs before we exit the block (we + # won't sleep for the full time) + time.sleep(60)