@@ -284,6 +284,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
284
284
285
285
async def connect (self ):
286
286
"""Connects to the Redis server if not already connected"""
287
+ await self .connect_check_health (check_health = True )
288
+
289
+ async def connect_check_health (self , check_health : bool = True ):
287
290
if self .is_connected :
288
291
return
289
292
try :
@@ -302,7 +305,7 @@ async def connect(self):
302
305
try :
303
306
if not self .redis_connect_func :
304
307
# Use the default on_connect function
305
- await self .on_connect ( )
308
+ await self .on_connect_check_health ( check_health = check_health )
306
309
else :
307
310
# Use the passed function redis_connect_func
308
311
(
@@ -341,6 +344,9 @@ def get_protocol(self):
341
344
342
345
async def on_connect (self ) -> None :
343
346
"""Initialize the connection, authenticate and select a database"""
347
+ await self .on_connect_check_health (check_health = True )
348
+
349
+ async def on_connect_check_health (self , check_health : bool = True ) -> None :
344
350
self ._parser .on_connect (self )
345
351
parser = self ._parser
346
352
@@ -398,7 +404,7 @@ async def on_connect(self) -> None:
398
404
# update cluster exception classes
399
405
self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
400
406
self ._parser .on_connect (self )
401
- await self .send_command ("HELLO" , self .protocol )
407
+ await self .send_command ("HELLO" , self .protocol , check_health = check_health )
402
408
response = await self .read_response ()
403
409
# if response.get(b"proto") != self.protocol and response.get(
404
410
# "proto"
@@ -407,18 +413,35 @@ async def on_connect(self) -> None:
407
413
408
414
# if a client_name is given, set it
409
415
if self .client_name :
410
- await self .send_command ("CLIENT" , "SETNAME" , self .client_name )
416
+ await self .send_command (
417
+ "CLIENT" ,
418
+ "SETNAME" ,
419
+ self .client_name ,
420
+ check_health = check_health ,
421
+ )
411
422
if str_if_bytes (await self .read_response ()) != "OK" :
412
423
raise ConnectionError ("Error setting client name" )
413
424
414
425
# set the library name and version, pipeline for lower startup latency
415
426
if self .lib_name :
416
- await self .send_command ("CLIENT" , "SETINFO" , "LIB-NAME" , self .lib_name )
427
+ await self .send_command (
428
+ "CLIENT" ,
429
+ "SETINFO" ,
430
+ "LIB-NAME" ,
431
+ self .lib_name ,
432
+ check_health = check_health ,
433
+ )
417
434
if self .lib_version :
418
- await self .send_command ("CLIENT" , "SETINFO" , "LIB-VER" , self .lib_version )
435
+ await self .send_command (
436
+ "CLIENT" ,
437
+ "SETINFO" ,
438
+ "LIB-VER" ,
439
+ self .lib_version ,
440
+ check_health = check_health ,
441
+ )
419
442
# if a database is specified, switch to it. Also pipeline this
420
443
if self .db :
421
- await self .send_command ("SELECT" , self .db )
444
+ await self .send_command ("SELECT" , self .db , check_health = check_health )
422
445
423
446
# read responses from pipeline
424
447
for _ in (sent for sent in (self .lib_name , self .lib_version ) if sent ):
@@ -480,8 +503,8 @@ async def send_packed_command(
480
503
self , command : Union [bytes , str , Iterable [bytes ]], check_health : bool = True
481
504
) -> None :
482
505
if not self .is_connected :
483
- await self .connect ( )
484
- elif check_health :
506
+ await self .connect_check_health ( check_health = False )
507
+ if check_health :
485
508
await self .check_health ()
486
509
487
510
try :
0 commit comments