@@ -334,6 +334,15 @@ impl RateLimit for RedisRateLimit {
334
334
fn try_acquire ( self : Arc < Self > , addr : IpAddr ) -> Result < Ticket , RateLimitError > {
335
335
self . clone ( ) . start_background_tasks ( ) ;
336
336
337
+ let permit = match self . semaphore . clone ( ) . try_acquire_owned ( ) {
338
+ Ok ( permit) => permit,
339
+ Err ( _) => {
340
+ return Err ( RateLimitError :: Limit {
341
+ reason : "Maximum connection limit reached for this server instance" . to_string ( ) ,
342
+ } ) ;
343
+ }
344
+ } ;
345
+
337
346
let mut conn = match self . redis_client . get_connection ( ) {
338
347
Ok ( conn) => conn,
339
348
Err ( e) => {
@@ -404,15 +413,6 @@ impl RateLimit for RedisRateLimit {
404
413
} ) ;
405
414
}
406
415
407
- let permit = match self . semaphore . clone ( ) . try_acquire_owned ( ) {
408
- Ok ( permit) => permit,
409
- Err ( _) => {
410
- return Err ( RateLimitError :: Limit {
411
- reason : "Maximum connection limit reached for this server instance" . to_string ( ) ,
412
- } ) ;
413
- }
414
- } ;
415
-
416
416
let ip_instance_connections: usize = match conn. incr ( self . ip_instance_key ( & addr) , 1 ) {
417
417
Ok ( count) => count,
418
418
Err ( e) => {
@@ -599,6 +599,110 @@ mod tests {
599
599
assert ! ( c4. is_ok( ) ) ;
600
600
}
601
601
602
+ #[ tokio:: test]
603
+ async fn test_global_limits_with_multiple_ips ( ) {
604
+ let user_1 = IpAddr :: from_str ( "127.0.0.1" ) . unwrap ( ) ;
605
+ let user_2 = IpAddr :: from_str ( "127.0.0.2" ) . unwrap ( ) ;
606
+ let user_3 = IpAddr :: from_str ( "127.0.0.3" ) . unwrap ( ) ;
607
+
608
+ let rate_limiter = Arc :: new ( InMemoryRateLimit :: new ( 4 , 3 ) ) ;
609
+
610
+ let ticket_1_1 = rate_limiter. clone ( ) . try_acquire ( user_1) . unwrap ( ) ;
611
+ let ticket_1_2 = rate_limiter. clone ( ) . try_acquire ( user_1) . unwrap ( ) ;
612
+
613
+ let ticket_2_1 = rate_limiter. clone ( ) . try_acquire ( user_2) . unwrap ( ) ;
614
+ let ticket_2_2 = rate_limiter. clone ( ) . try_acquire ( user_2) . unwrap ( ) ;
615
+
616
+ assert_eq ! (
617
+ rate_limiter
618
+ . inner
619
+ . lock( )
620
+ . unwrap( )
621
+ . semaphore
622
+ . available_permits( ) ,
623
+ 0
624
+ ) ;
625
+
626
+ // Try user_3 - should fail due to global limit
627
+ let result = rate_limiter. clone ( ) . try_acquire ( user_3) ;
628
+ assert ! ( result. is_err( ) ) ;
629
+ assert_eq ! (
630
+ result. err( ) . unwrap( ) . to_string( ) ,
631
+ "Rate Limit Reached: Global limit"
632
+ ) ;
633
+
634
+ drop ( ticket_1_1) ;
635
+
636
+ let ticket_3_1 = rate_limiter. clone ( ) . try_acquire ( user_3) . unwrap ( ) ;
637
+
638
+ drop ( ticket_1_2) ;
639
+ drop ( ticket_2_1) ;
640
+ drop ( ticket_2_2) ;
641
+ drop ( ticket_3_1) ;
642
+
643
+ assert_eq ! (
644
+ rate_limiter
645
+ . inner
646
+ . lock( )
647
+ . unwrap( )
648
+ . semaphore
649
+ . available_permits( ) ,
650
+ 4
651
+ ) ;
652
+ assert_eq ! (
653
+ rate_limiter. inner. lock( ) . unwrap( ) . active_connections. len( ) ,
654
+ 0
655
+ ) ;
656
+ }
657
+
658
+ #[ tokio:: test]
659
+ async fn test_per_ip_limits_remain_enforced ( ) {
660
+ let user_1 = IpAddr :: from_str ( "127.0.0.1" ) . unwrap ( ) ;
661
+ let user_2 = IpAddr :: from_str ( "127.0.0.2" ) . unwrap ( ) ;
662
+
663
+ let rate_limiter = Arc :: new ( InMemoryRateLimit :: new ( 5 , 2 ) ) ;
664
+
665
+ let ticket_1_1 = rate_limiter. clone ( ) . try_acquire ( user_1) . unwrap ( ) ;
666
+ let ticket_1_2 = rate_limiter. clone ( ) . try_acquire ( user_1) . unwrap ( ) ;
667
+
668
+ let result = rate_limiter. clone ( ) . try_acquire ( user_1) ;
669
+ assert ! ( result. is_err( ) ) ;
670
+ assert_eq ! (
671
+ result. err( ) . unwrap( ) . to_string( ) ,
672
+ "Rate Limit Reached: IP limit exceeded"
673
+ ) ;
674
+
675
+ let ticket_2_1 = rate_limiter. clone ( ) . try_acquire ( user_2) . unwrap ( ) ;
676
+ drop ( ticket_1_1) ;
677
+
678
+ let ticket_1_3 = rate_limiter. clone ( ) . try_acquire ( user_1) . unwrap ( ) ;
679
+
680
+ let result = rate_limiter. clone ( ) . try_acquire ( user_1) ;
681
+ assert ! ( result. is_err( ) ) ;
682
+ assert_eq ! (
683
+ result. err( ) . unwrap( ) . to_string( ) ,
684
+ "Rate Limit Reached: IP limit exceeded"
685
+ ) ;
686
+
687
+ drop ( ticket_1_2) ;
688
+ drop ( ticket_1_3) ;
689
+ drop ( ticket_2_1) ;
690
+
691
+ assert_eq ! (
692
+ rate_limiter
693
+ . inner
694
+ . lock( )
695
+ . unwrap( )
696
+ . semaphore
697
+ . available_permits( ) ,
698
+ 5
699
+ ) ;
700
+ assert_eq ! (
701
+ rate_limiter. inner. lock( ) . unwrap( ) . active_connections. len( ) ,
702
+ 0
703
+ ) ;
704
+ }
705
+
602
706
#[ tokio:: test]
603
707
#[ cfg( all( feature = "integration" , test) ) ]
604
708
async fn test_instance_tracking_and_cleanup ( ) {
0 commit comments