@@ -293,6 +293,9 @@ def set_parser(self, parser_class: Type[BaseParser]) -> None:
293
293
294
294
async def connect (self ):
295
295
"""Connects to the Redis server if not already connected"""
296
+ await self .connect_check_health (check_health = True )
297
+
298
+ async def connect_check_health (self , check_health : bool = True ):
296
299
if self .is_connected :
297
300
return
298
301
try :
@@ -311,7 +314,7 @@ async def connect(self):
311
314
try :
312
315
if not self .redis_connect_func :
313
316
# Use the default on_connect function
314
- await self .on_connect ( )
317
+ await self .on_connect_check_health ( check_health = check_health )
315
318
else :
316
319
# Use the passed function redis_connect_func
317
320
(
@@ -350,6 +353,9 @@ def get_protocol(self):
350
353
351
354
async def on_connect (self ) -> None :
352
355
"""Initialize the connection, authenticate and select a database"""
356
+ await self .on_connect_check_health (check_health = True )
357
+
358
+ async def on_connect_check_health (self , check_health : bool = True ) -> None :
353
359
self ._parser .on_connect (self )
354
360
parser = self ._parser
355
361
@@ -407,7 +413,7 @@ async def on_connect(self) -> None:
407
413
# update cluster exception classes
408
414
self ._parser .EXCEPTION_CLASSES = parser .EXCEPTION_CLASSES
409
415
self ._parser .on_connect (self )
410
- await self .send_command ("HELLO" , self .protocol )
416
+ await self .send_command ("HELLO" , self .protocol , check_health = check_health )
411
417
response = await self .read_response ()
412
418
# if response.get(b"proto") != self.protocol and response.get(
413
419
# "proto"
@@ -416,18 +422,35 @@ async def on_connect(self) -> None:
416
422
417
423
# if a client_name is given, set it
418
424
if self .client_name :
419
- await self .send_command ("CLIENT" , "SETNAME" , self .client_name )
425
+ await self .send_command (
426
+ "CLIENT" ,
427
+ "SETNAME" ,
428
+ self .client_name ,
429
+ check_health = check_health ,
430
+ )
420
431
if str_if_bytes (await self .read_response ()) != "OK" :
421
432
raise ConnectionError ("Error setting client name" )
422
433
423
434
# set the library name and version, pipeline for lower startup latency
424
435
if self .lib_name :
425
- await self .send_command ("CLIENT" , "SETINFO" , "LIB-NAME" , self .lib_name )
436
+ await self .send_command (
437
+ "CLIENT" ,
438
+ "SETINFO" ,
439
+ "LIB-NAME" ,
440
+ self .lib_name ,
441
+ check_health = check_health ,
442
+ )
426
443
if self .lib_version :
427
- await self .send_command ("CLIENT" , "SETINFO" , "LIB-VER" , self .lib_version )
444
+ await self .send_command (
445
+ "CLIENT" ,
446
+ "SETINFO" ,
447
+ "LIB-VER" ,
448
+ self .lib_version ,
449
+ check_health = check_health ,
450
+ )
428
451
# if a database is specified, switch to it. Also pipeline this
429
452
if self .db :
430
- await self .send_command ("SELECT" , self .db )
453
+ await self .send_command ("SELECT" , self .db , check_health = check_health )
431
454
432
455
# read responses from pipeline
433
456
for _ in (sent for sent in (self .lib_name , self .lib_version ) if sent ):
@@ -489,8 +512,8 @@ async def send_packed_command(
489
512
self , command : Union [bytes , str , Iterable [bytes ]], check_health : bool = True
490
513
) -> None :
491
514
if not self .is_connected :
492
- await self .connect ( )
493
- elif check_health :
515
+ await self .connect_check_health ( check_health = False )
516
+ if check_health :
494
517
await self .check_health ()
495
518
496
519
try :
0 commit comments