Skip to content

Commit 9fbb627

Browse files
committed
pr comment updates, add tests, move sempahore
1 parent 915de1a commit 9fbb627

File tree

1 file changed

+113
-9
lines changed

1 file changed

+113
-9
lines changed

src/rate_limit.rs

+113-9
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,15 @@ impl RateLimit for RedisRateLimit {
334334
fn try_acquire(self: Arc<Self>, addr: IpAddr) -> Result<Ticket, RateLimitError> {
335335
self.clone().start_background_tasks();
336336

337+
let permit = match self.semaphore.clone().try_acquire_owned() {
338+
Ok(permit) => permit,
339+
Err(_) => {
340+
return Err(RateLimitError::Limit {
341+
reason: "Maximum connection limit reached for this server instance".to_string(),
342+
});
343+
}
344+
};
345+
337346
let mut conn = match self.redis_client.get_connection() {
338347
Ok(conn) => conn,
339348
Err(e) => {
@@ -404,15 +413,6 @@ impl RateLimit for RedisRateLimit {
404413
});
405414
}
406415

407-
let permit = match self.semaphore.clone().try_acquire_owned() {
408-
Ok(permit) => permit,
409-
Err(_) => {
410-
return Err(RateLimitError::Limit {
411-
reason: "Maximum connection limit reached for this server instance".to_string(),
412-
});
413-
}
414-
};
415-
416416
let ip_instance_connections: usize = match conn.incr(self.ip_instance_key(&addr), 1) {
417417
Ok(count) => count,
418418
Err(e) => {
@@ -599,6 +599,110 @@ mod tests {
599599
assert!(c4.is_ok());
600600
}
601601

602+
#[tokio::test]
603+
async fn test_global_limits_with_multiple_ips() {
604+
let user_1 = IpAddr::from_str("127.0.0.1").unwrap();
605+
let user_2 = IpAddr::from_str("127.0.0.2").unwrap();
606+
let user_3 = IpAddr::from_str("127.0.0.3").unwrap();
607+
608+
let rate_limiter = Arc::new(InMemoryRateLimit::new(4, 3));
609+
610+
let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap();
611+
let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap();
612+
613+
let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap();
614+
let ticket_2_2 = rate_limiter.clone().try_acquire(user_2).unwrap();
615+
616+
assert_eq!(
617+
rate_limiter
618+
.inner
619+
.lock()
620+
.unwrap()
621+
.semaphore
622+
.available_permits(),
623+
0
624+
);
625+
626+
// Try user_3 - should fail due to global limit
627+
let result = rate_limiter.clone().try_acquire(user_3);
628+
assert!(result.is_err());
629+
assert_eq!(
630+
result.err().unwrap().to_string(),
631+
"Rate Limit Reached: Global limit"
632+
);
633+
634+
drop(ticket_1_1);
635+
636+
let ticket_3_1 = rate_limiter.clone().try_acquire(user_3).unwrap();
637+
638+
drop(ticket_1_2);
639+
drop(ticket_2_1);
640+
drop(ticket_2_2);
641+
drop(ticket_3_1);
642+
643+
assert_eq!(
644+
rate_limiter
645+
.inner
646+
.lock()
647+
.unwrap()
648+
.semaphore
649+
.available_permits(),
650+
4
651+
);
652+
assert_eq!(
653+
rate_limiter.inner.lock().unwrap().active_connections.len(),
654+
0
655+
);
656+
}
657+
658+
#[tokio::test]
659+
async fn test_per_ip_limits_remain_enforced() {
660+
let user_1 = IpAddr::from_str("127.0.0.1").unwrap();
661+
let user_2 = IpAddr::from_str("127.0.0.2").unwrap();
662+
663+
let rate_limiter = Arc::new(InMemoryRateLimit::new(5, 2));
664+
665+
let ticket_1_1 = rate_limiter.clone().try_acquire(user_1).unwrap();
666+
let ticket_1_2 = rate_limiter.clone().try_acquire(user_1).unwrap();
667+
668+
let result = rate_limiter.clone().try_acquire(user_1);
669+
assert!(result.is_err());
670+
assert_eq!(
671+
result.err().unwrap().to_string(),
672+
"Rate Limit Reached: IP limit exceeded"
673+
);
674+
675+
let ticket_2_1 = rate_limiter.clone().try_acquire(user_2).unwrap();
676+
drop(ticket_1_1);
677+
678+
let ticket_1_3 = rate_limiter.clone().try_acquire(user_1).unwrap();
679+
680+
let result = rate_limiter.clone().try_acquire(user_1);
681+
assert!(result.is_err());
682+
assert_eq!(
683+
result.err().unwrap().to_string(),
684+
"Rate Limit Reached: IP limit exceeded"
685+
);
686+
687+
drop(ticket_1_2);
688+
drop(ticket_1_3);
689+
drop(ticket_2_1);
690+
691+
assert_eq!(
692+
rate_limiter
693+
.inner
694+
.lock()
695+
.unwrap()
696+
.semaphore
697+
.available_permits(),
698+
5
699+
);
700+
assert_eq!(
701+
rate_limiter.inner.lock().unwrap().active_connections.len(),
702+
0
703+
);
704+
}
705+
602706
#[tokio::test]
603707
#[cfg(all(feature = "integration", test))]
604708
async fn test_instance_tracking_and_cleanup() {

0 commit comments

Comments
 (0)