Skip to content

Commit a4b4639

Browse files
committed
Fixed infinitely recursive health checks.
1 parent fa3a0ca commit a4b4639

File tree

2 files changed

+61
-15
lines changed

2 files changed

+61
-15
lines changed

Diff for: redis/asyncio/connection.py

+31-8
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
284284

285285
async def connect(self):
286286
"""Connects to the Redis server if not already connected"""
287+
await self.connect_check_health(check_health=True)
288+
289+
async def connect_check_health(self, check_health: bool = True):
287290
if self.is_connected:
288291
return
289292
try:
@@ -302,7 +305,7 @@ async def connect(self):
302305
try:
303306
if not self.redis_connect_func:
304307
# Use the default on_connect function
305-
await self.on_connect()
308+
await self.on_connect_check_health(check_health=check_health)
306309
else:
307310
# Use the passed function redis_connect_func
308311
(
@@ -341,6 +344,9 @@ def get_protocol(self):
341344

342345
async def on_connect(self) -> None:
343346
"""Initialize the connection, authenticate and select a database"""
347+
await self.on_connect_check_health(check_health=True)
348+
349+
async def on_connect_check_health(self, check_health: bool = True) -> None:
344350
self._parser.on_connect(self)
345351
parser = self._parser
346352

@@ -398,7 +404,7 @@ async def on_connect(self) -> None:
398404
# update cluster exception classes
399405
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
400406
self._parser.on_connect(self)
401-
await self.send_command("HELLO", self.protocol)
407+
await self.send_command("HELLO", self.protocol, check_health=check_health)
402408
response = await self.read_response()
403409
# if response.get(b"proto") != self.protocol and response.get(
404410
# "proto"
@@ -407,18 +413,35 @@ async def on_connect(self) -> None:
407413

408414
# if a client_name is given, set it
409415
if self.client_name:
410-
await self.send_command("CLIENT", "SETNAME", self.client_name)
416+
await self.send_command(
417+
"CLIENT",
418+
"SETNAME",
419+
self.client_name,
420+
check_health=check_health,
421+
)
411422
if str_if_bytes(await self.read_response()) != "OK":
412423
raise ConnectionError("Error setting client name")
413424

414425
# set the library name and version, pipeline for lower startup latency
415426
if self.lib_name:
416-
await self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
427+
await self.send_command(
428+
"CLIENT",
429+
"SETINFO",
430+
"LIB-NAME",
431+
self.lib_name,
432+
check_health=check_health,
433+
)
417434
if self.lib_version:
418-
await self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
435+
await self.send_command(
436+
"CLIENT",
437+
"SETINFO",
438+
"LIB-VER",
439+
self.lib_version,
440+
check_health=check_health,
441+
)
419442
# if a database is specified, switch to it. Also pipeline this
420443
if self.db:
421-
await self.send_command("SELECT", self.db)
444+
await self.send_command("SELECT", self.db, check_health=check_health)
422445

423446
# read responses from pipeline
424447
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -480,8 +503,8 @@ async def send_packed_command(
480503
self, command: Union[bytes, str, Iterable[bytes]], check_health: bool = True
481504
) -> None:
482505
if not self.is_connected:
483-
await self.connect()
484-
elif check_health:
506+
await self.connect_check_health(check_health=False)
507+
if check_health:
485508
await self.check_health()
486509

487510
try:

Diff for: redis/connection.py

+30-7
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,9 @@ def set_parser(self, parser_class):
372372

373373
def connect(self):
374374
"Connects to the Redis server if not already connected"
375+
self.connect_check_health(check_health=True)
376+
377+
def connect_check_health(self, check_health: bool = True):
375378
if self._sock:
376379
return
377380
try:
@@ -387,7 +390,7 @@ def connect(self):
387390
try:
388391
if self.redis_connect_func is None:
389392
# Use the default on_connect function
390-
self.on_connect()
393+
self.on_connect_check_health(check_health=check_health)
391394
else:
392395
# Use the passed function redis_connect_func
393396
self.redis_connect_func(self)
@@ -417,6 +420,9 @@ def _error_message(self, exception):
417420
return format_error_message(self._host_error(), exception)
418421

419422
def on_connect(self):
423+
self.on_connect_check_health(check_health=True)
424+
425+
def on_connect_check_health(self, check_health: bool = True):
420426
"Initialize the connection, authenticate and select a database"
421427
self._parser.on_connect(self)
422428
parser = self._parser
@@ -475,7 +481,7 @@ def on_connect(self):
475481
# update cluster exception classes
476482
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
477483
self._parser.on_connect(self)
478-
self.send_command("HELLO", self.protocol)
484+
self.send_command("HELLO", self.protocol, check_health=check_health)
479485
self.handshake_metadata = self.read_response()
480486
if (
481487
self.handshake_metadata.get(b"proto") != self.protocol
@@ -485,28 +491,45 @@ def on_connect(self):
485491

486492
# if a client_name is given, set it
487493
if self.client_name:
488-
self.send_command("CLIENT", "SETNAME", self.client_name)
494+
self.send_command(
495+
"CLIENT",
496+
"SETNAME",
497+
self.client_name,
498+
check_health=check_health,
499+
)
489500
if str_if_bytes(self.read_response()) != "OK":
490501
raise ConnectionError("Error setting client name")
491502

492503
try:
493504
# set the library name and version
494505
if self.lib_name:
495-
self.send_command("CLIENT", "SETINFO", "LIB-NAME", self.lib_name)
506+
self.send_command(
507+
"CLIENT",
508+
"SETINFO",
509+
"LIB-NAME",
510+
self.lib_name,
511+
check_health=check_health,
512+
)
496513
self.read_response()
497514
except ResponseError:
498515
pass
499516

500517
try:
501518
if self.lib_version:
502-
self.send_command("CLIENT", "SETINFO", "LIB-VER", self.lib_version)
519+
self.send_command(
520+
"CLIENT",
521+
"SETINFO",
522+
"LIB-VER",
523+
self.lib_version,
524+
check_health=check_health,
525+
)
503526
self.read_response()
504527
except ResponseError:
505528
pass
506529

507530
# if a database is specified, switch to it
508531
if self.db:
509-
self.send_command("SELECT", self.db)
532+
self.send_command("SELECT", self.db, check_health=check_health)
510533
if str_if_bytes(self.read_response()) != "OK":
511534
raise ConnectionError("Invalid Database")
512535

@@ -548,7 +571,7 @@ def check_health(self):
548571
def send_packed_command(self, command, check_health=True):
549572
"""Send an already packed command to the Redis server"""
550573
if not self._sock:
551-
self.connect()
574+
self.connect_check_health(check_health=False)
552575
# guard against health check recursion
553576
if check_health:
554577
self.check_health()

0 commit comments

Comments
 (0)