diff --git a/pom.xml b/pom.xml index 9e597c0178..ffcaa969f4 100644 --- a/pom.xml +++ b/pom.xml @@ -550,6 +550,8 @@ **/DriverInfo.java **/ClientSetInfo*.java **/ClientCommandsTest*.java + **/Delay*.java + **/SentineledConnectionProviderReconnectionTest.java diff --git a/src/main/java/redis/clients/jedis/builders/SentinelClientBuilder.java b/src/main/java/redis/clients/jedis/builders/SentinelClientBuilder.java index 24cae18eb3..214836f06f 100644 --- a/src/main/java/redis/clients/jedis/builders/SentinelClientBuilder.java +++ b/src/main/java/redis/clients/jedis/builders/SentinelClientBuilder.java @@ -1,9 +1,12 @@ package redis.clients.jedis.builders; +import java.time.Duration; import java.util.Set; import redis.clients.jedis.*; import redis.clients.jedis.providers.ConnectionProvider; import redis.clients.jedis.providers.SentineledConnectionProvider; +import redis.clients.jedis.util.Delay; +import redis.clients.jedis.util.JedisAsserts; /** * Builder for creating JedisSentineled instances (Redis Sentinel connections). @@ -16,11 +19,16 @@ public abstract class SentinelClientBuilder extends AbstractClientBuilder, C> { + private static final Delay DEFAULT_RESUBSCRIBE_DELAY = Delay.constant(Duration.ofMillis(5000)); + // Sentinel-specific configuration fields private String masterName = null; private Set sentinels = null; private JedisClientConfig sentinelClientConfig = null; + // delay between re-subscribing to sentinel nodes after a disconnection + private Delay sentinellReconnectDelay = DEFAULT_RESUBSCRIBE_DELAY; + /** * Sets the master name for the Redis Sentinel configuration. *

@@ -60,6 +68,21 @@ public SentinelClientBuilder sentinelClientConfig(JedisClientConfig sentinelC return this; } + /** + * Sets the delay between re-subscribing to sentinel node after a disconnection.* + *

+ * In case connection to sentinel nodes is lost, the client will try to reconnect to them. This + * method sets the delay between re-subscribing to sentinel nodes after a disconnection. + *

+ * @param reconnectDelay + * @return + */ + public SentinelClientBuilder sentinelReconnectDelay(Delay reconnectDelay) { + JedisAsserts.notNull(reconnectDelay, "reconnectDelay must not be null"); + this.sentinellReconnectDelay = reconnectDelay; + return this; + } + @Override protected SentinelClientBuilder self() { return this; @@ -68,7 +91,7 @@ protected SentinelClientBuilder self() { @Override protected ConnectionProvider createDefaultConnectionProvider() { return new SentineledConnectionProvider(this.masterName, this.clientConfig, this.cache, - this.poolConfig, this.sentinels, this.sentinelClientConfig); + this.poolConfig, this.sentinels, this.sentinelClientConfig, sentinellReconnectDelay); } @Override diff --git a/src/main/java/redis/clients/jedis/providers/SentineledConnectionProvider.java b/src/main/java/redis/clients/jedis/providers/SentineledConnectionProvider.java index c3b13c6016..df74b67003 100644 --- a/src/main/java/redis/clients/jedis/providers/SentineledConnectionProvider.java +++ b/src/main/java/redis/clients/jedis/providers/SentineledConnectionProvider.java @@ -1,5 +1,6 @@ package redis.clients.jedis.providers; +import java.time.Duration; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -25,6 +26,7 @@ import redis.clients.jedis.csc.Cache; import redis.clients.jedis.exceptions.JedisConnectionException; import redis.clients.jedis.exceptions.JedisException; +import redis.clients.jedis.util.Delay; import redis.clients.jedis.util.IOUtils; import redis.clients.jedis.util.Pool; @@ -34,6 +36,9 @@ public class SentineledConnectionProvider implements ConnectionProvider { protected static final long DEFAULT_SUBSCRIBE_RETRY_WAIT_TIME_MILLIS = 5000; + static final Delay DEFAULT_RESUBSCRIBE_DELAY = Delay + .constant(Duration.ofMillis(DEFAULT_SUBSCRIBE_RETRY_WAIT_TIME_MILLIS)); + private volatile HostAndPort currentMaster; private volatile ConnectionPool pool; @@ -50,10 +55,14 @@ public class SentineledConnectionProvider implements ConnectionProvider { private final JedisClientConfig sentinelClientConfig; - private final long subscribeRetryWaitTimeMillis; + private final Delay resubscribeDelay; private final Lock initPoolLock = new ReentrantLock(true); + private final SentinelConnectionFactory sentinelConnectionFactory; + + private final Sleeper sleeper; + public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, Set sentinels, final JedisClientConfig sentinelClientConfig) { this(masterName, masterClientConfig, null, null, sentinels, sentinelClientConfig); @@ -73,25 +82,49 @@ public SentineledConnectionProvider(String masterName, final JedisClientConfig m } @Experimental - public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, - Cache clientSideCache, final GenericObjectPoolConfig poolConfig, - Set sentinels, final JedisClientConfig sentinelClientConfig) { + public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, Cache clientSideCache, + final GenericObjectPoolConfig poolConfig, Set sentinels, + final JedisClientConfig sentinelClientConfig) { this(masterName, masterClientConfig, clientSideCache, poolConfig, sentinels, sentinelClientConfig, - DEFAULT_SUBSCRIBE_RETRY_WAIT_TIME_MILLIS); + DEFAULT_RESUBSCRIBE_DELAY); } public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, - final GenericObjectPoolConfig poolConfig, + final GenericObjectPoolConfig poolConfig, Set sentinels, + final JedisClientConfig sentinelClientConfig, final long subscribeRetryWaitTimeMillis) { + this(masterName, masterClientConfig, null, poolConfig, sentinels, sentinelClientConfig, + Delay.constant(Duration.ofMillis(subscribeRetryWaitTimeMillis))); + } + + /** + * @deprecated use + * {@link #SentineledConnectionProvider(String, JedisClientConfig, Cache, GenericObjectPoolConfig, Set, JedisClientConfig, Delay)} + */ + @Experimental + @Deprecated + public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, + Cache clientSideCache, final GenericObjectPoolConfig poolConfig, Set sentinels, final JedisClientConfig sentinelClientConfig, final long subscribeRetryWaitTimeMillis) { - this(masterName, masterClientConfig, null, poolConfig, sentinels, sentinelClientConfig, subscribeRetryWaitTimeMillis); + + this(masterName, masterClientConfig, clientSideCache, poolConfig, sentinels, + sentinelClientConfig, Delay.constant(Duration.ofMillis(subscribeRetryWaitTimeMillis))); } @Experimental public SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, Cache clientSideCache, final GenericObjectPoolConfig poolConfig, Set sentinels, final JedisClientConfig sentinelClientConfig, - final long subscribeRetryWaitTimeMillis) { + final Delay resubscribeDelay) { + this(masterName, masterClientConfig, clientSideCache, poolConfig, sentinels, + sentinelClientConfig, resubscribeDelay, null, null); + } + + SentineledConnectionProvider(String masterName, final JedisClientConfig masterClientConfig, + Cache clientSideCache, final GenericObjectPoolConfig poolConfig, + Set sentinels, final JedisClientConfig sentinelClientConfig, + final Delay resubscribeDelay, SentinelConnectionFactory sentinelConnectionFactory, + Sleeper sleeper) { this.masterName = masterName; this.masterClientConfig = masterClientConfig; @@ -99,7 +132,12 @@ public SentineledConnectionProvider(String masterName, final JedisClientConfig m this.masterPoolConfig = poolConfig; this.sentinelClientConfig = sentinelClientConfig; - this.subscribeRetryWaitTimeMillis = subscribeRetryWaitTimeMillis; + this.resubscribeDelay = resubscribeDelay; + + this.sentinelConnectionFactory = sentinelConnectionFactory != null ? sentinelConnectionFactory + : defaultSentinelConnectionFactory(); + + this.sleeper = sleeper != null ? sleeper : Thread::sleep; HostAndPort master = initSentinels(sentinels); initMaster(master); @@ -191,7 +229,8 @@ private HostAndPort initSentinels(Set sentinels) { LOG.debug("Connecting to Sentinel {}...", sentinel); - try (Jedis jedis = new Jedis(sentinel, sentinelClientConfig)) { + try (Jedis jedis = sentinelConnectionFactory.createConnection(sentinel, + sentinelClientConfig)) { List masterAddr = jedis.sentinelGetMasterAddrByName(masterName); @@ -254,6 +293,7 @@ protected class SentinelListener extends Thread { protected final HostAndPort node; protected volatile Jedis sentinelJedis; protected AtomicBoolean running = new AtomicBoolean(false); + protected long subscribeAttempt = 0; public SentinelListener(HostAndPort node) { super(String.format("%s-SentinelListener-[%s]", masterName, node.toString())); @@ -266,14 +306,13 @@ public void run() { running.set(true); while (running.get()) { - try { // double check that it is not being shutdown if (!running.get()) { break; } - sentinelJedis = new Jedis(node, sentinelClientConfig); + sentinelJedis = sentinelConnectionFactory.createConnection(node, sentinelClientConfig); // code for active refresh List masterAddr = sentinelJedis.sentinelGetMasterAddrByName(masterName); @@ -284,6 +323,14 @@ public void run() { } sentinelJedis.subscribe(new JedisPubSub() { + @Override + public void onSubscribe(String channel, int subscribedChannels) { + // Successfully subscribed - reset attempt counter + subscribeAttempt = 0; + LOG.debug("Successfully subscribed to {} on Sentinel {}. Reset attempt counter.", + channel, node); + } + @Override public void onMessage(String channel, String message) { LOG.debug("Sentinel {} published: {}.", node, message); @@ -295,14 +342,13 @@ public void onMessage(String channel, String message) { if (masterName.equals(switchMasterMsg[0])) { initMaster(toHostAndPort(switchMasterMsg[3], switchMasterMsg[4])); } else { - LOG.debug( - "Ignoring message on +switch-master for master {}. Our master is {}.", - switchMasterMsg[0], masterName); + LOG.debug("Ignoring message on +switch-master for master {}. Our master is {}.", + switchMasterMsg[0], masterName); } } else { LOG.error("Invalid message received on sentinel {} on channel +switch-master: {}.", - node, message); + node, message); } } }, "+switch-master"); @@ -310,10 +356,12 @@ public void onMessage(String channel, String message) { } catch (JedisException e) { if (running.get()) { - LOG.error("Lost connection to Sentinel {}. Sleeping {}ms and retrying.", node, - subscribeRetryWaitTimeMillis, e); + long subscribeRetryWaitTimeMillis = resubscribeDelay.delay(subscribeAttempt).toMillis(); + subscribeAttempt++; + LOG.warn("Lost connection to Sentinel {}. Sleeping {}ms and retrying.", node, + subscribeRetryWaitTimeMillis, e); try { - Thread.sleep(subscribeRetryWaitTimeMillis); + sleeper.sleep(subscribeRetryWaitTimeMillis); } catch (InterruptedException se) { LOG.error("Sleep interrupted.", se); } @@ -340,4 +388,22 @@ public void shutdown() { } } } + + protected SentinelConnectionFactory defaultSentinelConnectionFactory() { + return (node, config) -> new Jedis(node, config); + } + + @FunctionalInterface + protected interface Sleeper { + + void sleep(long millis) throws InterruptedException; + + } + + @FunctionalInterface + protected interface SentinelConnectionFactory { + + Jedis createConnection(HostAndPort node, JedisClientConfig config); + + } } diff --git a/src/main/java/redis/clients/jedis/util/Delay.java b/src/main/java/redis/clients/jedis/util/Delay.java new file mode 100644 index 0000000000..55b5d56d49 --- /dev/null +++ b/src/main/java/redis/clients/jedis/util/Delay.java @@ -0,0 +1,84 @@ +package redis.clients.jedis.util; + +import java.time.Duration; +import java.util.concurrent.ThreadLocalRandom; + +public abstract class Delay { + + protected Delay() { + } + + /** + * Calculate a specific delay based on the attempt. + * @param attempt the attempt to calculate the delay from. + * @return the calculated delay. + */ + public abstract Duration delay(long attempt); + + /** + * Creates a constant delay. + * @param delay the constant delay duration + * @return a Delay that always returns the same duration + */ + public static Delay constant(Duration delay) { + return new ConstantDelay(delay); + } + + /** + * Creates an exponential delay with equal jitter. Based on AWS exponential backoff strategy: + * https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/ Formula: temp = + * min(upper, base * 2^attempt) sleep = temp/2 + random_between(0, temp/2) result = max(lower, + * sleep) + * @param lower the minimum delay duration (lower bound) + * @param upper the maximum delay duration (upper bound) + * @param base the base delay duration + * @return a Delay with exponential backoff and equal jitter + */ + public static Delay exponentialWithJitter(Duration lower, Duration upper, Duration base) { + return new EqualJitterDelay(lower, upper, base); + } + + static class ConstantDelay extends Delay { + + private final Duration delay; + + ConstantDelay(Duration delay) { + this.delay = delay; + } + + @Override + public Duration delay(long attempt) { + return delay; + } + } + + static class EqualJitterDelay extends Delay { + + private final long lowerMillis; + private final long upperMillis; + private final long baseMillis; + + EqualJitterDelay(Duration lower, Duration upper, Duration base) { + this.lowerMillis = lower.toMillis(); + this.upperMillis = upper.toMillis(); + this.baseMillis = base.toMillis(); + } + + @Override + public Duration delay(long attempt) { + // temp = min(upper, base * 2^attempt) + long exponential = baseMillis * (1L << Math.min(attempt, 62)); + long temp = Math.min(upperMillis, exponential); + + // sleep = temp/2 + random_between(0, temp/2) + long half = temp / 2; + long jitter = ThreadLocalRandom.current().nextLong(half + 1); + long delayMillis = half + jitter; + + // Apply lower bound + delayMillis = Math.max(lowerMillis, delayMillis); + + return Duration.ofMillis(delayMillis); + } + } +} diff --git a/src/test/java/redis/clients/jedis/providers/SentineledConnectionProviderReconnectionTest.java b/src/test/java/redis/clients/jedis/providers/SentineledConnectionProviderReconnectionTest.java new file mode 100644 index 0000000000..9dcfae976a --- /dev/null +++ b/src/test/java/redis/clients/jedis/providers/SentineledConnectionProviderReconnectionTest.java @@ -0,0 +1,122 @@ +package redis.clients.jedis.providers; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.time.Duration; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Tag; +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import redis.clients.jedis.HostAndPort; +import redis.clients.jedis.Jedis; +import redis.clients.jedis.JedisClientConfig; +import redis.clients.jedis.JedisPubSub; +import redis.clients.jedis.exceptions.JedisConnectionException; +import redis.clients.jedis.providers.SentineledConnectionProvider.SentinelConnectionFactory; +import redis.clients.jedis.util.Delay; + +/** + * Unit tests for SentineledConnectionProvider reconnection logic. Tests connection provider's + * ability to reconnect to sentinel nodes with configured delay. + */ +@ExtendWith(MockitoExtension.class) +@Tag("unit") +public class SentineledConnectionProviderReconnectionTest { + + private static final String MASTER_NAME = "mymaster"; + + private static final HostAndPort SENTINEL_1 = new HostAndPort("localhost", 26379); + + private static final HostAndPort SENTINEL_2 = new HostAndPort("localhost", 26380); + + private static final HostAndPort MASTER = new HostAndPort("localhost", 6379); + + private Set sentinels; + + private JedisClientConfig masterConfig; + + private JedisClientConfig sentinelConfig; + + private SentineledConnectionProvider provider; + + @Mock + private Jedis mockJedis; + + @Mock + private SentinelConnectionFactory sentinelConnectionFactory; + + @Mock + private Delay reconnectDelay; + + @BeforeEach + void setUp() { + sentinels = new HashSet<>(); + sentinels.add(SENTINEL_1); + sentinels.add(SENTINEL_2); + + masterConfig = mock(JedisClientConfig.class); + sentinelConfig = mock(JedisClientConfig.class); + } + + @AfterEach + void tearDown() { + if (provider != null) { + provider.close(); + } + } + + @Test + void testReconnectToSentinelWithConfiguredDelay() throws InterruptedException { + // Capture delay values passed to sleeper + CopyOnWriteArrayList capturedDelays = new CopyOnWriteArrayList<>(); + // await for 3 reconnect attempts + CountDownLatch reconnectAttempts = new CountDownLatch(3); + + // Mock dependencies + SentineledConnectionProvider.Sleeper capturingSleeper = millis -> { + capturedDelays.add(millis); + reconnectAttempts.countDown(); + }; + + // Simulate sentinel connection failures (disconnect scenario) + long expectedDelay = 100L; + when(mockJedis.sentinelGetMasterAddrByName(MASTER_NAME)) + .thenReturn(Arrays.asList(MASTER.getHost(), String.valueOf(MASTER.getPort()))); + + Jedis failingJedis = mock(Jedis.class); + doThrow(new JedisConnectionException("Connection lost")).when(failingJedis) + .subscribe(any(JedisPubSub.class), anyString()); + when(sentinelConnectionFactory.createConnection(any(), any())) + .thenAnswer(invocation -> mockJedis).thenAnswer(invocation -> failingJedis); + when(reconnectDelay.delay(anyLong())).thenReturn(Duration.ofMillis(expectedDelay)); + + // Create provider + provider = new SentineledConnectionProvider(MASTER_NAME, masterConfig, null, null, sentinels, + sentinelConfig, reconnectDelay, sentinelConnectionFactory, capturingSleeper); + + // Verify reconnection attempts happen with configured delay + assertTrue(reconnectAttempts.await(100, TimeUnit.MILLISECONDS), + "Should attempt to reconnect at least 3 times after disconnect"); + + // Assert all delays match the configured value + assertTrue(capturedDelays.size() >= 3, "Should have captured at least 3 delay values"); + for (Long delay : capturedDelays) { + assertEquals(expectedDelay, delay, + "Sleeper should be called with configured delay of " + expectedDelay + "ms"); + } + } + +} diff --git a/src/test/java/redis/clients/jedis/util/DelayTest.java b/src/test/java/redis/clients/jedis/util/DelayTest.java new file mode 100644 index 0000000000..c2a19802e4 --- /dev/null +++ b/src/test/java/redis/clients/jedis/util/DelayTest.java @@ -0,0 +1,144 @@ +package redis.clients.jedis.util; + +import org.junit.jupiter.api.Test; + +import java.time.Duration; + +import static org.junit.jupiter.api.Assertions.*; + +public class DelayTest { + + @Test + public void testConstantDelay() { + Delay delay = Delay.constant(Duration.ofMillis(100)); + + // Constant delay should return the same value for all attempts + assertEquals(100, delay.delay(0).toMillis()); + assertEquals(100, delay.delay(1).toMillis()); + assertEquals(100, delay.delay(5).toMillis()); + assertEquals(100, delay.delay(100).toMillis()); + } + + @Test + public void testExponentialWithJitterBounds() { + Duration lower = Duration.ofMillis(50); + Duration upper = Duration.ofSeconds(10); + Duration base = Duration.ofMillis(100); + + Delay delay = Delay.exponentialWithJitter(lower, upper, base); + + // Test multiple attempts to verify bounds + for (int attempt = 0; attempt < 20; attempt++) { + Duration result = delay.delay(attempt); + long millis = result.toMillis(); + + // Verify lower bound + assertTrue(millis >= lower.toMillis(), String.format( + "Attempt %d: delay %d should be >= lower bound %d", attempt, millis, lower.toMillis())); + + // Verify upper bound + assertTrue(millis <= upper.toMillis(), String.format( + "Attempt %d: delay %d should be <= upper bound %d", attempt, millis, upper.toMillis())); + } + } + + @Test + public void testExponentialWithJitterGrowth() { + Duration lower = Duration.ofMillis(10); + Duration upper = Duration.ofSeconds(60); + Duration base = Duration.ofMillis(100); + + Delay delay = Delay.exponentialWithJitter(lower, upper, base); + + // Collect multiple samples for each attempt to verify growth trend + int samples = 100; + + // Attempt 0: base * 2^0 = 100ms, range [50-100]ms + long sum0 = 0; + for (int i = 0; i < samples; i++) { + sum0 += delay.delay(0).toMillis(); + } + long avg0 = sum0 / samples; + + // Attempt 1: base * 2^1 = 200ms, range [100-200]ms + long sum1 = 0; + for (int i = 0; i < samples; i++) { + sum1 += delay.delay(1).toMillis(); + } + long avg1 = sum1 / samples; + + // Attempt 2: base * 2^2 = 400ms, range [200-400]ms + long sum2 = 0; + for (int i = 0; i < samples; i++) { + sum2 += delay.delay(2).toMillis(); + } + long avg2 = sum2 / samples; + + // Verify exponential growth: avg1 should be roughly 2x avg0, avg2 should be roughly 2x avg1 + assertTrue(avg1 > avg0, "Average delay should increase with attempts"); + assertTrue(avg2 > avg1, "Average delay should continue to increase"); + } + + @Test + public void testExponentialWithJitterEqualJitterFormula() { + Duration lower = Duration.ofMillis(0); + Duration upper = Duration.ofSeconds(10); + Duration base = Duration.ofMillis(100); + + Delay delay = Delay.exponentialWithJitter(lower, upper, base); + + // For attempt 0: temp = min(10000, 100 * 2^0) = 100 + // Equal jitter: delay = temp/2 + random[0, temp/2] = 50 + random[0, 50] + // Range should be [50, 100] + for (int i = 0; i < 50; i++) { + long millis = delay.delay(0).toMillis(); + assertTrue(millis >= 50 && millis <= 100, + String.format("Attempt 0: delay %d should be in range [50, 100]", millis)); + } + + // For attempt 1: temp = min(10000, 100 * 2^1) = 200 + // Equal jitter: delay = 100 + random[0, 100] + // Range should be [100, 200] + for (int i = 0; i < 50; i++) { + long millis = delay.delay(1).toMillis(); + assertTrue(millis >= 100 && millis <= 200, + String.format("Attempt 1: delay %d should be in range [100, 200]", millis)); + } + } + + @Test + public void testExponentialWithJitterUpperBoundCapping() { + Duration lower = Duration.ofMillis(10); + Duration upper = Duration.ofMillis(500); + Duration base = Duration.ofMillis(100); + + Delay delay = Delay.exponentialWithJitter(lower, upper, base); + + // For high attempts, exponential should be capped at upper bound + // Attempt 10: base * 2^10 = 102400ms, but capped at 500ms + // Equal jitter: delay = 250 + random[0, 250] + // Range should be [250, 500] + for (int i = 0; i < 50; i++) { + long millis = delay.delay(10).toMillis(); + assertTrue(millis >= 250 && millis <= 500, + String.format("Attempt 10: delay %d should be in range [250, 500] (capped)", millis)); + } + } + + @Test + public void testExponentialWithJitterLowerBoundEnforcement() { + Duration lower = Duration.ofMillis(200); + Duration upper = Duration.ofSeconds(10); + Duration base = Duration.ofMillis(100); + + Delay delay = Delay.exponentialWithJitter(lower, upper, base); + + // For attempt 0: temp = 100, equal jitter would give [50, 100] + // But lower bound is 200, so all values should be >= 200 + for (int i = 0; i < 50; i++) { + long millis = delay.delay(0).toMillis(); + assertTrue(millis >= 200, + String.format("Attempt 0: delay %d should be >= lower bound 200", millis)); + } + } +}