From dfe9f99f151b5de029a43b9923ca88055d0e2a9c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Vill=C5=91=20Sz=C5=B1cs?= Date: Fri, 13 Sep 2024 13:26:11 +0200 Subject: [PATCH] [CALCITE-6529] Use persistent sessionContext in AvaticaCommonsHttpClientImpl --- .../remote/AvaticaCommonsHttpClientImpl.java | 30 ++++---- .../AvaticaCommonsHttpClientImplTest.java | 75 ++++++++++++++++++- .../HttpServerSpnegoWithoutJaasTest.java | 2 +- 3 files changed, 90 insertions(+), 17 deletions(-) diff --git a/core/src/main/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImpl.java b/core/src/main/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImpl.java index 39217c05c0..dcbfe123f0 100644 --- a/core/src/main/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImpl.java +++ b/core/src/main/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImpl.java @@ -91,6 +91,7 @@ public class AvaticaCommonsHttpClientImpl implements AvaticaHttpClient, HttpClie protected CredentialsProvider credentialsProvider = null; protected Lookup authRegistry = null; protected Object userToken; + protected HttpClientContext context; public AvaticaCommonsHttpClientImpl(URL url) { this.uri = toURI(Objects.requireNonNull(url)); @@ -109,23 +110,22 @@ protected void initializeClient(PoolingHttpClientConnectionManager pool, HttpClientBuilder httpClientBuilder = HttpClients.custom().setConnectionManager(pool) .setDefaultRequestConfig(requestConfig); this.client = httpClientBuilder.build(); + + this.context = HttpClientContext.create(); + // Set the credentials if they were provided. + if (null != this.credentialsProvider) { + context.setCredentialsProvider(credentialsProvider); + context.setAuthSchemeRegistry(authRegistry); + context.setAuthCache(authCache); + } + if (null != userToken) { + context.setUserToken(userToken); + } + } @Override public byte[] send(byte[] request) { while (true) { - HttpClientContext context = HttpClientContext.create(); - - // Set the credentials if they were provided. - if (null != this.credentialsProvider) { - context.setCredentialsProvider(credentialsProvider); - context.setAuthSchemeRegistry(authRegistry); - context.setAuthCache(authCache); - } - - if (null != userToken) { - context.setUserToken(userToken); - } - ByteArrayEntity entity = new ByteArrayEntity(request, ContentType.APPLICATION_OCTET_STREAM); // Create the client with the AuthSchemeRegistry and manager @@ -184,6 +184,8 @@ CloseableHttpResponse execute(HttpPost post, HttpClientContext context) throw new IllegalArgumentException("Unsupported authentiation type: " + authType); } this.authRegistry = authRegistryBuilder.build(); + context.setCredentialsProvider(credentialsProvider); + context.setAuthSchemeRegistry(authRegistry); } @Override public void setGSSCredential(GSSCredential credential) { @@ -205,6 +207,8 @@ CloseableHttpResponse execute(HttpPost post, HttpClientContext context) ((BasicCredentialsProvider) this.credentialsProvider) .setCredentials(anyAuthScope, EmptyCredentials.INSTANCE); } + context.setCredentialsProvider(credentialsProvider); + context.setAuthSchemeRegistry(authRegistry); } /** diff --git a/core/src/test/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImplTest.java b/core/src/test/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImplTest.java index d4dacdf441..2b89f16e26 100644 --- a/core/src/test/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImplTest.java +++ b/core/src/test/java/org/apache/calcite/avatica/remote/AvaticaCommonsHttpClientImplTest.java @@ -17,25 +17,33 @@ package org.apache.calcite.avatica.remote; import org.apache.calcite.avatica.AvaticaUtils; +import org.apache.calcite.avatica.ConnectionConfig; import org.apache.hc.client5.http.classic.methods.HttpPost; import org.apache.hc.client5.http.impl.classic.CloseableHttpResponse; -import org.apache.hc.client5.http.protocol.HttpClientContext; +import org.apache.hc.client5.http.impl.io.PoolingHttpClientConnectionManager; import org.apache.hc.core5.http.NoHttpResponseException; +import org.apache.hc.core5.http.io.entity.ByteArrayEntity; import org.apache.hc.core5.http.io.entity.StringEntity; import org.junit.Test; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; +import java.io.ByteArrayInputStream; import java.net.HttpURLConnection; import java.net.URL; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.any; import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doReturn; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static java.nio.charset.StandardCharsets.UTF_8; @@ -64,9 +72,11 @@ public class AvaticaCommonsHttpClientImplTest { final AvaticaCommonsHttpClientImpl client = spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1"))); + client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock( + ConnectionConfig.class)); doAnswer(failThenSucceed).when(client) - .execute(any(HttpPost.class), any(HttpClientContext.class)); + .execute(any(HttpPost.class), eq(client.context)); when(badResponse.getCode()).thenReturn(HttpURLConnection.HTTP_UNAVAILABLE); @@ -96,9 +106,11 @@ public class AvaticaCommonsHttpClientImplTest { final AvaticaCommonsHttpClientImpl client = spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1"))); + client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock( + ConnectionConfig.class)); doAnswer(failThenSucceed).when(client) - .execute(any(HttpPost.class), any(HttpClientContext.class)); + .execute(any(HttpPost.class), eq(client.context)); when(badResponse.getCode()).thenReturn(HttpURLConnection.HTTP_UNAVAILABLE); @@ -109,6 +121,63 @@ public class AvaticaCommonsHttpClientImplTest { assertEquals("success", AvaticaUtils.newStringUtf8(responseBytes)); } + @Test + public void testPersistentContextReusedAcrossRequests() throws Exception { + final AvaticaCommonsHttpClientImpl client = + spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1"))); + client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock( + ConnectionConfig.class)); + + CloseableHttpResponse response = mock(CloseableHttpResponse.class); + when(response.getCode()).thenReturn(HttpURLConnection.HTTP_OK); + + ByteArrayEntity entity = mock(ByteArrayEntity.class); + when(entity.getContent()).thenReturn(new ByteArrayInputStream(new byte[0])); + when(response.getEntity()).thenReturn(entity); + + doReturn(response).when(client) + .execute(any(HttpPost.class), eq(client.context)); + + client.send(new byte[0]); + client.send(new byte[0]); + + // Verify that the persistent context was reused and not created again + verify(client, times(2)).execute(any(HttpPost.class), + eq(client.context)); + } + + @Test + public void testPersistentContextThreadSafety() throws Exception { + final AvaticaCommonsHttpClientImpl client = + spy(new AvaticaCommonsHttpClientImpl(new URL("http://127.0.0.1"))); + client.setHttpClientPool(mock(PoolingHttpClientConnectionManager.class), mock( + ConnectionConfig.class)); + + doReturn(mock(CloseableHttpResponse.class)).when(client) + .execute(any(HttpPost.class), eq(client.context)); + + Runnable requestTask = () -> { + try { + client.send(new byte[0]); + } catch (Exception e) { + fail("Threaded request failed with exception: " + e.getMessage()); + } + }; + + int threadCount = 5; + Thread[] threads = new Thread[threadCount]; + for (int i = 0; i < threadCount; i++) { + threads[i] = new Thread(requestTask); + threads[i].start(); + } + + for (Thread thread : threads) { + thread.join(); + } + + verify(client, times(threadCount)).execute(any(HttpPost.class), eq(client.context)); + } + } // End AvaticaCommonsHttpClientImplTest.java diff --git a/server/src/test/java/org/apache/calcite/avatica/server/HttpServerSpnegoWithoutJaasTest.java b/server/src/test/java/org/apache/calcite/avatica/server/HttpServerSpnegoWithoutJaasTest.java index f7b56809c6..ad09b6f7db 100644 --- a/server/src/test/java/org/apache/calcite/avatica/server/HttpServerSpnegoWithoutJaasTest.java +++ b/server/src/test/java/org/apache/calcite/avatica/server/HttpServerSpnegoWithoutJaasTest.java @@ -233,8 +233,8 @@ private static void setupUsers(File keytabDir) throws KrbException { // Passes the GSSCredential into the HTTP client implementation final AvaticaCommonsHttpClientImpl httpClient = new AvaticaCommonsHttpClientImpl(httpServerUrl); - httpClient.setGSSCredential(credential); httpClient.setHttpClientPool(pool, config); + httpClient.setGSSCredential(credential); return httpClient.send(new byte[0]); }