1
1
from __future__ import annotations
2
2
3
- import atexit
4
3
import logging
5
4
import multiprocessing
6
5
import queue
6
+ import signal
7
7
import threading
8
8
import time
9
9
from concurrent .futures import ThreadPoolExecutor
@@ -81,9 +81,6 @@ def __init__(
81
81
82
82
self ._processing_pool_name : str = processing_pool_name or "unknown"
83
83
84
- def __del__ (self ) -> None :
85
- self .shutdown ()
86
-
87
84
def do_imports (self ) -> None :
88
85
for module in settings .TASKWORKER_IMPORTS :
89
86
__import__ (module )
@@ -99,10 +96,20 @@ def start(self) -> int:
99
96
self .start_result_thread ()
100
97
self .start_spawn_children_thread ()
101
98
102
- atexit .register (self .shutdown )
99
+ # Convert signals into KeyboardInterrupt.
100
+ # Running shutdown() within the signal handler can lead to deadlocks
101
+ def signal_handler (* args : Any ) -> None :
102
+ raise KeyboardInterrupt ()
103
103
104
- while True :
105
- self .run_once ()
104
+ signal .signal (signal .SIGINT , signal_handler )
105
+ signal .signal (signal .SIGTERM , signal_handler )
106
+
107
+ try :
108
+ while True :
109
+ self .run_once ()
110
+ except KeyboardInterrupt :
111
+ self .shutdown ()
112
+ raise
106
113
107
114
def run_once (self ) -> None :
108
115
"""Access point for tests to run a single worker loop"""
@@ -113,30 +120,33 @@ def shutdown(self) -> None:
113
120
Shutdown cleanly
114
121
Activate the shutdown event and drain results before terminating children.
115
122
"""
116
- if self ._shutdown_event .is_set ():
117
- return
118
-
119
- logger .info ("taskworker.worker.shutdown" )
123
+ logger .info ("taskworker.worker.shutdown.start" )
120
124
self ._shutdown_event .set ()
121
125
126
+ logger .info ("taskworker.worker.shutdown.spawn_children" )
127
+ if self ._spawn_children_thread :
128
+ self ._spawn_children_thread .join ()
129
+
130
+ logger .info ("taskworker.worker.shutdown.children" )
122
131
for child in self ._children :
123
132
child .terminate ()
133
+ for child in self ._children :
124
134
child .join ()
125
135
136
+ logger .info ("taskworker.worker.shutdown.result" )
126
137
if self ._result_thread :
127
- self ._result_thread .join ()
138
+ # Use a timeout as sometimes this thread can deadlock on the Event.
139
+ self ._result_thread .join (timeout = 5 )
128
140
129
- # Drain remaining results synchronously, as the thread will have terminated
130
- # when shutdown_event was set.
141
+ # Drain any remaining results synchronously
131
142
while True :
132
143
try :
133
144
result = self ._processed_tasks .get_nowait ()
134
145
self ._send_result (result , fetch = False )
135
146
except queue .Empty :
136
147
break
137
148
138
- if self ._spawn_children_thread :
139
- self ._spawn_children_thread .join ()
149
+ logger .info ("taskworker.worker.shutdown.complete" )
140
150
141
151
def _add_task (self ) -> bool :
142
152
"""
@@ -179,7 +189,7 @@ def start_result_thread(self) -> None:
179
189
"""
180
190
181
191
def result_thread () -> None :
182
- logger .debug ("taskworker.worker.result_thread_started " )
192
+ logger .debug ("taskworker.worker.result_thread.started " )
183
193
iopool = ThreadPoolExecutor (max_workers = self ._concurrency )
184
194
with iopool as executor :
185
195
while not self ._shutdown_event .is_set ():
@@ -193,7 +203,9 @@ def result_thread() -> None:
193
203
)
194
204
continue
195
205
196
- self ._result_thread = threading .Thread (target = result_thread )
206
+ self ._result_thread = threading .Thread (
207
+ name = "send-result" , target = result_thread , daemon = True
208
+ )
197
209
self ._result_thread .start ()
198
210
199
211
def _send_result (self , result : ProcessingResult , fetch : bool = True ) -> bool :
@@ -253,6 +265,7 @@ def _send_update_task(
253
265
)
254
266
# Use the shutdown_event as a sleep mechanism
255
267
self ._shutdown_event .wait (self ._setstatus_backoff_seconds )
268
+
256
269
try :
257
270
next_task = self .client .update_task (result , fetch_next )
258
271
self ._setstatus_backoff_seconds = 0
@@ -276,14 +289,15 @@ def _send_update_task(
276
289
277
290
def start_spawn_children_thread (self ) -> None :
278
291
def spawn_children_thread () -> None :
279
- logger .debug ("taskworker.worker.spawn_children_thread_started " )
292
+ logger .debug ("taskworker.worker.spawn_children_thread.started " )
280
293
while not self ._shutdown_event .is_set ():
281
294
self ._children = [child for child in self ._children if child .is_alive ()]
282
295
if len (self ._children ) >= self ._concurrency :
283
296
time .sleep (0.1 )
284
297
continue
285
298
for i in range (self ._concurrency - len (self ._children )):
286
299
process = self .mp_context .Process (
300
+ name = f"taskworker-child-{ i } " ,
287
301
target = child_process ,
288
302
args = (
289
303
self ._child_tasks ,
@@ -301,7 +315,9 @@ def spawn_children_thread() -> None:
301
315
extra = {"pid" : process .pid , "processing_pool" : self ._processing_pool_name },
302
316
)
303
317
304
- self ._spawn_children_thread = threading .Thread (target = spawn_children_thread )
318
+ self ._spawn_children_thread = threading .Thread (
319
+ name = "spawn-children" , target = spawn_children_thread , daemon = True
320
+ )
305
321
self ._spawn_children_thread .start ()
306
322
307
323
def fetch_task (self ) -> InflightTaskActivation | None :
0 commit comments