diff --git a/src/main/java/com/hierynomus/protocol/commons/HostResolver.java b/src/main/java/com/hierynomus/protocol/commons/HostResolver.java new file mode 100644 index 00000000..d44bc31a --- /dev/null +++ b/src/main/java/com/hierynomus/protocol/commons/HostResolver.java @@ -0,0 +1,33 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.protocol.commons; + +import java.net.InetAddress; +import java.net.UnknownHostException; + +public interface HostResolver { + public static final HostResolver DEFAULT = new DefaultHostResolver(); + + InetAddress[] resolveHost(String host) throws UnknownHostException; + + public static class DefaultHostResolver implements HostResolver { + @Override + public InetAddress[] resolveHost(String host) throws UnknownHostException { + return InetAddress.getAllByName(host); + } + } +} + diff --git a/src/main/java/com/hierynomus/smbj/SmbConfig.java b/src/main/java/com/hierynomus/smbj/SmbConfig.java index 2347ae29..f1e92603 100644 --- a/src/main/java/com/hierynomus/smbj/SmbConfig.java +++ b/src/main/java/com/hierynomus/smbj/SmbConfig.java @@ -37,6 +37,7 @@ import com.hierynomus.mssmb2.SMB2GlobalCapability; import com.hierynomus.ntlm.NtlmConfig; import com.hierynomus.protocol.commons.Factory; +import com.hierynomus.protocol.commons.HostResolver; import com.hierynomus.protocol.commons.socket.ProxySocketFactory; import com.hierynomus.security.SecurityProvider; import com.hierynomus.security.bc.BCSecurityProvider; @@ -74,6 +75,7 @@ public final class SmbConfig { private Set dialects; private List> authenticators; private SocketFactory socketFactory; + private HostResolver hostResolver; private Random random; private UUID clientGuid; private boolean signingRequired; @@ -102,6 +104,7 @@ public static Builder builder() { .withClientGuid(UUID.randomUUID()) .withSecurityProvider(getDefaultSecurityProvider()) .withSocketFactory(new ProxySocketFactory()) + .withHostResolver(HostResolver.DEFAULT) .withSigningRequired(false) .withDfsEnabled(false) .withMultiProtocolNegotiate(false) @@ -153,6 +156,7 @@ private SmbConfig(SmbConfig other) { dialects.addAll(other.dialects); authenticators.addAll(other.authenticators); socketFactory = other.socketFactory; + hostResolver = other.hostResolver; random = other.random; clientGuid = other.clientGuid; signingRequired = other.signingRequired; @@ -244,6 +248,10 @@ public SocketFactory getSocketFactory() { return socketFactory; } + public HostResolver getHostResolver() { + return hostResolver; + } + public GSSContextConfig getClientGSSContextConfig() { return clientGSSContextConfig; } @@ -317,6 +325,14 @@ public Builder withSocketFactory(SocketFactory socketFactory) { return this; } + public Builder withHostResolver(HostResolver hostResolver) { + if (hostResolver == null) { + throw new IllegalArgumentException("Host resolver may not be null"); + } + config.hostResolver = hostResolver; + return this; + } + public Builder withDialects(SMB2Dialect... dialects) { return withDialects(Arrays.asList(dialects)); } diff --git a/src/main/java/com/hierynomus/smbj/connection/Connection.java b/src/main/java/com/hierynomus/smbj/connection/Connection.java index 60235f9a..dc0af53d 100644 --- a/src/main/java/com/hierynomus/smbj/connection/Connection.java +++ b/src/main/java/com/hierynomus/smbj/connection/Connection.java @@ -20,6 +20,7 @@ import java.io.Closeable; import java.io.IOException; +import java.net.InetAddress; import java.net.InetSocketAddress; import java.util.UUID; import java.util.concurrent.Future; @@ -29,6 +30,7 @@ import com.hierynomus.mssmb.SMB1PacketFactory; import com.hierynomus.mssmb2.*; import com.hierynomus.mssmb2.messages.SMB2Cancel; +import com.hierynomus.protocol.commons.HostResolver; import com.hierynomus.protocol.commons.buffer.Buffer; import com.hierynomus.protocol.commons.concurrent.CancellableFuture; import com.hierynomus.protocol.commons.concurrent.Futures; @@ -43,6 +45,7 @@ import com.hierynomus.smbj.SmbConfig; import com.hierynomus.smbj.auth.AuthenticationContext; import com.hierynomus.smbj.common.Pooled; +import com.hierynomus.smbj.common.SMBRuntimeException; import com.hierynomus.smbj.connection.packet.DeadLetterPacketHandler; import com.hierynomus.smbj.connection.packet.IncomingPacketHandler; import com.hierynomus.smbj.connection.packet.SMB1PacketHandler; @@ -95,13 +98,14 @@ public SMBClient getClient() { private SmbConfig config; TransportLayer> transport; + private final HostResolver hostResolver; private final SMBEventBus bus; private final ReentrantLock lock = new ReentrantLock(); public Connection(SmbConfig config, SMBClient client, SMBEventBus bus, ServerList serverList) { this.config = config; this.client = client; - this.transport = config.getTransportLayerFactory().createTransportLayer(new PacketHandlers<>(new SMBPacketSerializer(), this, converter), config); + this.hostResolver = config.getHostResolver(); this.bus = bus; this.serverList = serverList; init(); @@ -127,6 +131,7 @@ public Connection(Connection connection) { this.client = connection.client; this.config = connection.config; this.transport = connection.transport; + this.hostResolver = connection.hostResolver; this.bus = connection.bus; this.serverList = connection.serverList; init(); @@ -136,7 +141,27 @@ public void connect(String hostname, int port) throws IOException { if (isConnected()) { throw new IllegalStateException(format("This connection is already connected to %s", getRemoteHostname())); } - transport.connect(new InetSocketAddress(hostname, port)); + try { + InetAddress[] address = hostResolver.resolveHost(hostname); + for (InetAddress inetAddress : address) { + try { + transport.connect(new InetSocketAddress(inetAddress, port)); + onConnect(hostname, port); + return; + } catch (IOException e) { + logger.debug("Failed to connect to {} on address {}", hostname, inetAddress.getHostAddress(), e); + } + } + transport.connect(new InetSocketAddress(hostname, port)); + } catch (IOException e) { + throw new IOException("Failed to resolve hostname " + hostname, e); + } + + logger.error("Failed to connect to {}", hostname); + throw new SMBRuntimeException("Failed to connect to " + hostname); + } + + private void onConnect(String hostname, int port) throws IOException { this.connectionContext = new ConnectionContext(config.getClientGuid(), hostname, port, config); new SMBProtocolNegotiator(this, config, connectionContext).negotiateDialect(); this.signatory.init(); diff --git a/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy b/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy index 35a32344..702937d8 100644 --- a/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/SMBClientSpec.groovy @@ -17,12 +17,13 @@ package com.hierynomus.smbj import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor import com.hierynomus.smbj.testing.StubTransportLayerFactory +import com.hierynomus.smbj.testing.StubHostResolver import spock.lang.Specification class SMBClientSpec extends Specification { def processor = new DefaultPacketProcessor() - def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor)).build() + def config = SmbConfig.builder().withTransportLayerFactory(new StubTransportLayerFactory(processor)).withHostResolver(StubHostResolver.INSTANCE).build() def "should return same connection for same host/port combo"() { given: diff --git a/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy b/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy index 197f54ab..532f4031 100644 --- a/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy +++ b/src/test/groovy/com/hierynomus/smbj/connection/ConnectionSpec.groovy @@ -39,6 +39,7 @@ import com.hierynomus.smbj.testing.PacketProcessor.NoOpPacketProcessor import com.hierynomus.smbj.testing.PacketProcessor.DefaultPacketProcessor import com.hierynomus.smbj.testing.StubAuthenticator import com.hierynomus.smbj.testing.StubTransportLayerFactory +import com.hierynomus.smbj.testing.StubHostResolver import net.engio.mbassy.listener.Handler import spock.lang.Specification @@ -51,6 +52,7 @@ class ConnectionSpec extends Specification { private SmbConfig smbConfig(packetProcessor) { SmbConfig.builder() .withTransportLayerFactory(new StubTransportLayerFactory(new DefaultPacketProcessor().wrap(packetProcessor))) + .withHostResolver(new StubHostResolver()) .withAuthenticators(new StubAuthenticator.Factory()) .build() } diff --git a/src/test/java/com/hierynomus/smbj/testing/StubHostResolver.java b/src/test/java/com/hierynomus/smbj/testing/StubHostResolver.java new file mode 100644 index 00000000..681269a1 --- /dev/null +++ b/src/test/java/com/hierynomus/smbj/testing/StubHostResolver.java @@ -0,0 +1,30 @@ +/* + * Copyright (C)2016 - SMBJ Contributors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.hierynomus.smbj.testing; + +import java.net.InetAddress; + +import com.hierynomus.protocol.commons.HostResolver; + +public class StubHostResolver implements HostResolver { + public static final StubHostResolver INSTANCE = new StubHostResolver(); + + public InetAddress[] resolveHost(String host) { + return new InetAddress[]{ + InetAddress.getLoopbackAddress() + }; + } +} diff --git a/src/test/java/com/hierynomus/smbj/testing/Utils.java b/src/test/java/com/hierynomus/smbj/testing/Utils.java index 58fc96d1..c8db5a0c 100644 --- a/src/test/java/com/hierynomus/smbj/testing/Utils.java +++ b/src/test/java/com/hierynomus/smbj/testing/Utils.java @@ -21,7 +21,9 @@ public class Utils { public static SmbConfig config(PacketProcessor processor) { return SmbConfig.builder() - .withTransportLayerFactory(new StubTransportLayerFactory<>(new DefaultPacketProcessor().wrap(processor))) + .withTransportLayerFactory( + new StubTransportLayerFactory<>(new DefaultPacketProcessor().wrap(processor))) + .withHostResolver(StubHostResolver.INSTANCE) .withAuthenticators(new StubAuthenticator.Factory()).build(); } }