From 60688ee1f6c09005b10327a9a635a2d76763d8c5 Mon Sep 17 00:00:00 2001 From: Yunlu Wen Date: Fri, 21 Feb 2025 12:06:06 +0800 Subject: [PATCH 1/2] add BoundedKeepAliveProvider (#986) --- .../keepalive/BoundedKeepAliveProvider.java | 202 ++++++++++++++++++ .../java/net/schmizz/keepalive/KeepAlive.java | 7 + src/main/java/net/schmizz/sshj/SSHClient.java | 2 +- .../BoundedKeepAliveProviderTest.java | 95 ++++++++ .../sshj/test/SshServerExtension.java | 9 + 5 files changed, 314 insertions(+), 1 deletion(-) create mode 100644 src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java create mode 100644 src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java diff --git a/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java new file mode 100644 index 000000000..9a8806e6f --- /dev/null +++ b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java @@ -0,0 +1,202 @@ +package net.schmizz.keepalive; + +import net.schmizz.sshj.Config; +import net.schmizz.sshj.connection.ConnectionException; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; +import org.slf4j.Logger; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.Condition; +import java.util.concurrent.locks.ReentrantLock; + +/** + * This implementation manages all {@link KeepAlive}s using configured number of threads. It works like a + * thread pool, thus {@link BoundedKeepAliveProvider#shutdown()} must be called to clean up resources. + *
+ * This provider uses {@link KeepAliveRunner#doKeepAlive()} as delegate, so it supports maxKeepAliveCount + * parameter. All instances provided by this provider have identical configuration. + */ +public class BoundedKeepAliveProvider extends KeepAliveProvider { + + public int maxKeepAliveCount = 3; + public int keepAliveInterval = 5; + + protected final KeepAliveMonitor monitor; + + + public BoundedKeepAliveProvider(Config config, int numberOfThreads) { + this.monitor = new KeepAliveMonitor(config, numberOfThreads); + } + + public void setKeepAliveInterval(int interval) { + keepAliveInterval = interval; + } + + public void setMaxKeepAliveCount(int count) { + maxKeepAliveCount = count; + } + + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new Impl(connection, "bounded-keepalive-impl"); + } + + public void shutdown() throws InterruptedException { + monitor.shutdown(); + } + + class Impl extends KeepAlive { + + private final KeepAliveRunner delegate; + + protected Impl(ConnectionImpl conn, String name) { + super(conn, name); + this.delegate = new KeepAliveRunner(conn); + + // take care here, some parameters are set to both delegate and this + this.delegate.setMaxAliveCount(BoundedKeepAliveProvider.this.maxKeepAliveCount); + super.keepAliveInterval = BoundedKeepAliveProvider.this.keepAliveInterval; + } + + @Override + protected void doKeepAlive() throws TransportException, ConnectionException { + delegate.doKeepAlive(); + } + + @Override + public void startKeepAlive() { + monitor.register(this); + } + + } + + protected static class KeepAliveMonitor { + + private final int numberOfThreads; + private final PriorityBlockingQueue Q = + new PriorityBlockingQueue<>(32, Comparator.comparingLong(w -> w.nextTimeMillis)); + private long idleSleepMillis = 100; + private static final List workerThreads = new ArrayList<>(); + volatile boolean started = false; + private final Logger logger; + + private final ReentrantLock lock = new ReentrantLock(); + private final Condition shutDown = lock.newCondition(); + private final AtomicInteger shutDownCnt = new AtomicInteger(0); + + public KeepAliveMonitor(Config config, int numberOfThreads) { + this.numberOfThreads = numberOfThreads; + logger = config.getLoggerFactory().getLogger(KeepAliveMonitor.class); + } + + // made public for test + public void register(KeepAlive keepAlive) { + if (!started) { + start(); + } + Q.add(new Wrapper(keepAlive)); + } + + public void setIdleSleepMillis(long idleSleepMillis) { + this.idleSleepMillis = idleSleepMillis; + } + + void unregister(KeepAlive keepAlive) { + Q.removeIf(w -> keepAlive == w.keepAlive); + } + + private void sleep() { + sleep(idleSleepMillis); + } + + private void sleep(long millis) { + try { + Thread.sleep(millis); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + + private synchronized void start() { + if (started) { + return; + } + + for (int i = 0; i < numberOfThreads; i++) { + Thread t = new Thread(this::doStart); + workerThreads.add(t); + } + workerThreads.forEach(Thread::start); + started = true; + } + + + private void doStart() { + while (!Thread.currentThread().isInterrupted()) { + Wrapper wrapper; + + if (Q.isEmpty() || (wrapper = Q.poll()) == null) { + sleep(); + continue; + } + + long currentTimeMillis = System.currentTimeMillis(); + if (wrapper.nextTimeMillis > currentTimeMillis) { + long sleepMillis = wrapper.nextTimeMillis - currentTimeMillis; + logger.debug("{} millis until next check, sleep", sleepMillis); + sleep(sleepMillis); + } + + try { + wrapper.keepAlive.doKeepAlive(); + Q.add(wrapper.reschedule()); + } catch (Exception e) { + // If we weren't interrupted, kill the transport, then this exception was unexpected. + // Else we're in shutdown-mode already, so don't forcibly kill the transport. + if (!Thread.currentThread().isInterrupted()) { + wrapper.keepAlive.conn.getTransport().die(e); + } + } + } + lock.lock(); + try { + if (shutDownCnt.incrementAndGet() == numberOfThreads) { + shutDown.signal(); + } + } finally { + lock.unlock(); + } + } + + private synchronized void shutdown() throws InterruptedException { + if (workerThreads.isEmpty()) { + return; + } + for (Thread t : workerThreads) { + t.interrupt(); + } + lock.lock(); + logger.info("waiting for all {} threads to finish", numberOfThreads); + shutDown.await(); + } + + private static class Wrapper { + private final KeepAlive keepAlive; + private final long nextTimeMillis; + + private Wrapper(KeepAlive keepAlive) { + this.keepAlive = keepAlive; + this.nextTimeMillis = System.currentTimeMillis() + keepAlive.keepAliveInterval * 1000L; + } + + private Wrapper reschedule() { + return new Wrapper(keepAlive); + } + } + } +} diff --git a/src/main/java/net/schmizz/keepalive/KeepAlive.java b/src/main/java/net/schmizz/keepalive/KeepAlive.java index 05e771f73..badbb948d 100644 --- a/src/main/java/net/schmizz/keepalive/KeepAlive.java +++ b/src/main/java/net/schmizz/keepalive/KeepAlive.java @@ -89,4 +89,11 @@ public void run() { } protected abstract void doKeepAlive() throws TransportException, ConnectionException; + + /** + * Start keep-alive loop. Implementations MUST NOT block current thread. + */ + public void startKeepAlive() { + start(); + } } diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index 78b91c5f7..4ea44f346 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -808,7 +808,7 @@ protected void onConnect() final KeepAlive keepAliveThread = conn.getKeepAlive(); if (keepAliveThread.isEnabled()) { ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); - keepAliveThread.start(); + keepAliveThread.startKeepAlive(); } } diff --git a/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java new file mode 100644 index 000000000..ff49e557c --- /dev/null +++ b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java @@ -0,0 +1,95 @@ +package com.hierynomus.sshj.keepalive; + +import com.hierynomus.sshj.test.SshServerExtension; +import net.schmizz.keepalive.BoundedKeepAliveProvider; +import net.schmizz.keepalive.KeepAlive; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.connection.ConnectionException; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +class EventuallyFailKeepAlive extends KeepAlive { + // they can survive first 2 checks, and fail at 3rd + int failAfter = 2; + volatile int current = 0; + + protected EventuallyFailKeepAlive(ConnectionImpl conn, String name) { + super(conn, name); + setKeepAliveInterval(1); + } + + @Override + protected void doKeepAlive() throws TransportException, ConnectionException { + current++; + if (current > failAfter) { + throw new ConnectionException("failed"); + } + } +} + +public class BoundedKeepAliveProviderTest { + + static BoundedKeepAliveProvider kp; + static final DefaultConfig defaultConfig = new DefaultConfig(); + + + @BeforeAll + static void setUpBeforeClass() throws Exception { + + kp = new BoundedKeepAliveProvider(defaultConfig, 2) { + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new EventuallyFailKeepAlive(connection, "test") { + @Override + public void startKeepAlive() { + monitor.register(this); + } + }; + } + }; + } + + @RegisterExtension + public SshServerExtension fixture = new SshServerExtension(); + + void testWithConnections(int numOfConnections) throws IOException, InterruptedException { + List clients = setupClients(numOfConnections); + for (SSHClient client : clients) { + fixture.connectClient(client); + } + // first two checks are ok + Thread.sleep(2000); + Assertions.assertTrue(clients.stream().allMatch(SSHClient::isConnected)); + + // wait for 3rd check to take place, we wait additional 100ms for it to finish + Thread.sleep(1100); + Assertions.assertTrue(clients.stream().noneMatch(SSHClient::isConnected)); + Assertions.assertEquals(0, fixture.getServer().getActiveSessions().size()); + } + + @Test + void testBoundedKeepAlive() throws IOException, InterruptedException { + // 2 threads can handle 64 connections + testWithConnections(64); + } + + private List setupClients(int numOfConnections) { + List clients = new ArrayList<>(); + defaultConfig.setKeepAliveProvider(kp); + + for (int i = 0; i < numOfConnections; i++) { + final SSHClient sshClient = fixture.createClient(defaultConfig); + clients.add(sshClient); + } + return clients; + } +} diff --git a/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java b/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java index 1d07bbe36..9711530f7 100644 --- a/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java +++ b/src/test/java/com/hierynomus/sshj/test/SshServerExtension.java @@ -97,6 +97,15 @@ public SSHClient setupClient(Config config) { return client; } + /** + * create a new uncached client + */ + public SSHClient createClient(Config config) { + SSHClient client = new SSHClient(config); + client.addHostKeyVerifier(fingerprint); + return client; + } + public SSHClient getClient() { if (client != null) { return client; From 38af9b695464751d8164bff602032b3ee9a531ac Mon Sep 17 00:00:00 2001 From: Yunlu Wen Date: Fri, 21 Feb 2025 16:17:34 +0800 Subject: [PATCH 2/2] add int tests (#986) --- .../com/hierynomus/sshj/KeepAliveTest.java | 59 +++++++++++++++++++ .../com/hierynomus/sshj/SshdContainer.java | 19 +++++- .../keepalive/BoundedKeepAliveProvider.java | 29 +++++---- .../BoundedKeepAliveProviderTest.java | 15 ++--- 4 files changed, 99 insertions(+), 23 deletions(-) create mode 100644 src/itest/java/com/hierynomus/sshj/KeepAliveTest.java diff --git a/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java b/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java new file mode 100644 index 000000000..c95bfe6ce --- /dev/null +++ b/src/itest/java/com/hierynomus/sshj/KeepAliveTest.java @@ -0,0 +1,59 @@ +package com.hierynomus.sshj; + +import net.schmizz.keepalive.BoundedKeepAliveProvider; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; +import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.LoggerFactory; +import net.schmizz.sshj.transport.verification.PromiscuousVerifier; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.testcontainers.junit.jupiter.Container; + +import java.util.ArrayList; +import java.util.List; + +public class KeepAliveTest { + @Container + SshdContainer sshd = new SshdContainer(SshdContainer.Builder + .defaultBuilder() + .withAllKeys() + .withPackages("iptables") + .withPrivileged(true)); + + @Test + void testKeepAlive() throws Exception { + sshd.start(); + + Config config = new DefaultConfig(); + BoundedKeepAliveProvider p = new BoundedKeepAliveProvider(LoggerFactory.DEFAULT, 4); + p.setKeepAliveInterval(1); + p.setMaxKeepAliveCount(1); + config.setKeepAliveProvider(p); + List clients = new ArrayList<>(); + for (int i=0; i<10; i++) { + SSHClient c = new SSHClient(config); + c.addHostKeyVerifier(new PromiscuousVerifier()); + c.connect("127.0.0.1", sshd.getFirstMappedPort()); + c.authPassword("sshj", "ultrapassword"); + var sess = c.startSession(); + sess.allocateDefaultPTY(); + clients.add(c); + } + + for (SSHClient client : clients) { + Assertions.assertTrue(client.isConnected()); + } + + var res = sshd.execInContainer("iptables", "-A", "INPUT", "-p", "tcp", "--dport", "22", "-j", "DROP"); + Assertions.assertEquals(0, res.getExitCode()); + // wait for keepalive to take action + Thread.sleep(2000); + + for (SSHClient client : clients) { + Assertions.assertFalse(client.isConnected()); + } + + p.shutdown(); + } +} diff --git a/src/itest/java/com/hierynomus/sshj/SshdContainer.java b/src/itest/java/com/hierynomus/sshj/SshdContainer.java index 91b531e68..8d17560fb 100644 --- a/src/itest/java/com/hierynomus/sshj/SshdContainer.java +++ b/src/itest/java/com/hierynomus/sshj/SshdContainer.java @@ -106,13 +106,24 @@ public static class Builder implements Consumer { private List hostKeys = new ArrayList<>(); private List certificates = new ArrayList<>(); private @NotNull SshdConfigBuilder sshdConfig = SshdConfigBuilder.defaultBuilder(); + private boolean privileged = false; + private List packages = new ArrayList<>(); public static Builder defaultBuilder() { Builder b = new Builder(); - return b; } + public @NotNull Builder withPrivileged(boolean privileged) { + this.privileged = privileged; + return this; + } + + public @NotNull Builder withPackages(@NotNull String... packages) { + this.packages.addAll(List.of(packages)); + return this; + } + public @NotNull Builder withSshdConfig(@NotNull SshdConfigBuilder sshdConfig) { this.sshdConfig = sshdConfig; @@ -153,6 +164,9 @@ public void accept(@NotNull DockerfileBuilder builder) { builder.expose(22); builder.copy("entrypoint.sh", "/entrypoint.sh"); + if (!packages.isEmpty()) { + builder.run("apk add --no-cache " + String.join(" ", packages)); + } builder.add("authorized_keys", "/home/sshj/.ssh/authorized_keys"); builder.copy("test-container/trusted_ca_keys", "/etc/ssh/trusted_ca_keys"); @@ -201,6 +215,9 @@ public SshdContainer() { public SshdContainer(SshdContainer.Builder builder) { this(builder.buildInner()); + if (builder.privileged) { + withPrivilegedMode(true); + } } public SshdContainer(@NotNull Future future) { diff --git a/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java index 9a8806e6f..5702b2aaa 100644 --- a/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java +++ b/src/main/java/net/schmizz/keepalive/BoundedKeepAliveProvider.java @@ -1,6 +1,7 @@ package net.schmizz.keepalive; import net.schmizz.sshj.Config; +import net.schmizz.sshj.common.LoggerFactory; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.ConnectionImpl; import net.schmizz.sshj.transport.TransportException; @@ -29,8 +30,8 @@ public class BoundedKeepAliveProvider extends KeepAliveProvider { protected final KeepAliveMonitor monitor; - public BoundedKeepAliveProvider(Config config, int numberOfThreads) { - this.monitor = new KeepAliveMonitor(config, numberOfThreads); + public BoundedKeepAliveProvider(LoggerFactory loggerFactory, int numberOfThreads) { + this.monitor = new KeepAliveMonitor(loggerFactory, numberOfThreads); } public void setKeepAliveInterval(int interval) { @@ -76,22 +77,24 @@ public void startKeepAlive() { } protected static class KeepAliveMonitor { + private final Logger logger; - private final int numberOfThreads; - private final PriorityBlockingQueue Q = + private final PriorityBlockingQueue q = new PriorityBlockingQueue<>(32, Comparator.comparingLong(w -> w.nextTimeMillis)); - private long idleSleepMillis = 100; private static final List workerThreads = new ArrayList<>(); + + private volatile long idleSleepMillis = 100; + private final int numberOfThreads; + volatile boolean started = false; - private final Logger logger; private final ReentrantLock lock = new ReentrantLock(); private final Condition shutDown = lock.newCondition(); private final AtomicInteger shutDownCnt = new AtomicInteger(0); - public KeepAliveMonitor(Config config, int numberOfThreads) { + public KeepAliveMonitor(LoggerFactory loggerFactory, int numberOfThreads) { this.numberOfThreads = numberOfThreads; - logger = config.getLoggerFactory().getLogger(KeepAliveMonitor.class); + logger = loggerFactory.getLogger(KeepAliveMonitor.class); } // made public for test @@ -99,17 +102,13 @@ public void register(KeepAlive keepAlive) { if (!started) { start(); } - Q.add(new Wrapper(keepAlive)); + q.add(new Wrapper(keepAlive)); } public void setIdleSleepMillis(long idleSleepMillis) { this.idleSleepMillis = idleSleepMillis; } - void unregister(KeepAlive keepAlive) { - Q.removeIf(w -> keepAlive == w.keepAlive); - } - private void sleep() { sleep(idleSleepMillis); } @@ -140,7 +139,7 @@ private void doStart() { while (!Thread.currentThread().isInterrupted()) { Wrapper wrapper; - if (Q.isEmpty() || (wrapper = Q.poll()) == null) { + if (q.isEmpty() || (wrapper = q.poll()) == null) { sleep(); continue; } @@ -154,7 +153,7 @@ private void doStart() { try { wrapper.keepAlive.doKeepAlive(); - Q.add(wrapper.reschedule()); + q.add(wrapper.reschedule()); } catch (Exception e) { // If we weren't interrupted, kill the transport, then this exception was unexpected. // Else we're in shutdown-mode already, so don't forcibly kill the transport. diff --git a/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java index ff49e557c..0ee50f6c4 100644 --- a/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java +++ b/src/test/java/com/hierynomus/sshj/keepalive/BoundedKeepAliveProviderTest.java @@ -5,6 +5,7 @@ import net.schmizz.keepalive.KeepAlive; import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.LoggerFactory; import net.schmizz.sshj.connection.ConnectionException; import net.schmizz.sshj.connection.ConnectionImpl; import net.schmizz.sshj.transport.TransportException; @@ -19,7 +20,7 @@ class EventuallyFailKeepAlive extends KeepAlive { // they can survive first 2 checks, and fail at 3rd - int failAfter = 2; + int failAfter = 1; volatile int current = 0; protected EventuallyFailKeepAlive(ConnectionImpl conn, String name) { @@ -45,7 +46,7 @@ public class BoundedKeepAliveProviderTest { @BeforeAll static void setUpBeforeClass() throws Exception { - kp = new BoundedKeepAliveProvider(defaultConfig, 2) { + kp = new BoundedKeepAliveProvider(LoggerFactory.DEFAULT, 2) { @Override public KeepAlive provide(ConnectionImpl connection) { return new EventuallyFailKeepAlive(connection, "test") { @@ -67,19 +68,19 @@ void testWithConnections(int numOfConnections) throws IOException, InterruptedEx fixture.connectClient(client); } // first two checks are ok - Thread.sleep(2000); + Thread.sleep(1000); Assertions.assertTrue(clients.stream().allMatch(SSHClient::isConnected)); - // wait for 3rd check to take place, we wait additional 100ms for it to finish - Thread.sleep(1100); + // wait for 2nd check to take place, we wait additional 200ms for it to finish + Thread.sleep(1200); Assertions.assertTrue(clients.stream().noneMatch(SSHClient::isConnected)); Assertions.assertEquals(0, fixture.getServer().getActiveSessions().size()); } @Test void testBoundedKeepAlive() throws IOException, InterruptedException { - // 2 threads can handle 64 connections - testWithConnections(64); + // 2 threads can handle 32 connections + testWithConnections(32); } private List setupClients(int numOfConnections) {