@@ -47,6 +47,7 @@ def __init__(self, *,
47
47
self .redis_pool = redis_pool
48
48
self .loop = redis_pool ._loop
49
49
self .max_concurrent_tasks = max_concurrent_tasks
50
+ self .task_semaphore = asyncio .Semaphore (value = max_concurrent_tasks , loop = self .loop )
50
51
self .shutdown_delay = max (shutdown_delay , 0.1 )
51
52
self .timeout_seconds = timeout_seconds
52
53
self .burst_mode = burst_mode
@@ -71,6 +72,10 @@ async def __aexit__(self, exc_type, exc_val, exc_tb):
71
72
e = self .task_exception
72
73
raise TaskError (f'A processed task failed: { e .__class__ .__name__ } , { e } ' ) from e
73
74
75
+ @property
76
+ def jobs_in_progress (self ):
77
+ return self .max_concurrent_tasks - self .task_semaphore ._value
78
+
74
79
async def iter (self , * raw_queues : bytes , pop_timeout = 1 ):
75
80
"""
76
81
blpop jobs from redis queues and yield them. Waits for the number of tasks to drop below max_concurrent_tasks
@@ -88,17 +93,21 @@ async def iter(self, *raw_queues: bytes, pop_timeout=1):
88
93
work_logger .debug ('populating quit queue to prompt exit: %s' , quit_queue .decode ())
89
94
await self .redis .rpush (quit_queue , b'1' )
90
95
raw_queues = tuple (raw_queues ) + (quit_queue ,)
91
- while self .running :
96
+ while True :
97
+ await self .task_semaphore .acquire ()
98
+ if not self .running :
99
+ break
92
100
msg = await self .redis .blpop (* raw_queues , timeout = pop_timeout )
93
101
if msg is None :
94
102
yield None , None
103
+ self .task_semaphore .release ()
95
104
continue
96
105
raw_queue , raw_data = msg
97
106
if self .burst_mode and raw_queue == quit_queue :
98
107
work_logger .debug ('got job from the quit queue, stopping' )
99
108
break
109
+ work_logger .debug ('jobs in progress %d' , self .jobs_in_progress )
100
110
yield raw_queue , raw_data
101
- await self .wait ()
102
111
103
112
def add (self , coro , job , re_enqueue = False ):
104
113
"""
@@ -115,18 +124,6 @@ def add(self, coro, job, re_enqueue=False):
115
124
self .loop .call_later (self .timeout_seconds , self ._cancel_job , task , job )
116
125
self .pending_tasks .add (task )
117
126
118
- async def wait (self ):
119
- """
120
- Wait for a the number of pending tasks to drop bellow ``max_concurrent_tasks``
121
- """
122
- while True :
123
- pt_cnt = len (self .pending_tasks )
124
- if pt_cnt < self .max_concurrent_tasks :
125
- return
126
- work_logger .info ('%d pending tasks, waiting for one to finish' , pt_cnt )
127
- _ , self .pending_tasks = await asyncio .wait (self .pending_tasks , loop = self .loop ,
128
- return_when = asyncio .FIRST_COMPLETED )
129
-
130
127
async def finish (self , timeout = None ):
131
128
"""
132
129
Cancel all pending tasks and optionally re-enqueue jobs which haven't finished after the timeout.
@@ -150,6 +147,7 @@ async def finish(self, timeout=None):
150
147
self .pending_tasks = set ()
151
148
152
149
def _job_callback (self , task ):
150
+ self .task_semaphore .release ()
153
151
self .jobs_complete += 1
154
152
task_exception = task .exception ()
155
153
if task_exception :
0 commit comments