From 270567d8873a1819a3a58cbe3319447f8808ef1c Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Wed, 19 Mar 2025 06:46:21 +0000 Subject: [PATCH 01/27] core: Added changes to DelayedStream.setStream() should cancel the provided stream if not using it --- .../grpc/internal/DelayedClientTransport.java | 21 +++-- .../java/io/grpc/internal/DelayedStream.java | 13 ++- .../internal/DelayedClientTransportTest.java | 91 ++++++++++++++----- .../io/grpc/internal/DelayedStreamTest.java | 51 ++++++++++- .../grpc/internal/ManagedChannelImplTest.java | 11 ++- 5 files changed, 148 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 8ff755af3eb..e919a47ae2e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -178,13 +178,6 @@ private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamT if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) { pendingStream.lastPickStatus = pickResult.getStatus(); } - pendingStreams.add(pendingStream); - if (getPendingStreamsCount() == 1) { - syncContext.executeLater(reportTransportInUse); - } - for (ClientStreamTracer streamTracer : tracers) { - streamTracer.createPendingStream(); - } return pendingStream; } @@ -363,6 +356,20 @@ private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.tracers = tracers; } + @Override + public void start(ClientStreamListener listener) { + super.start(listener); + synchronized (lock) { + pendingStreams.add(this); + if (getPendingStreamsCount() == 1) { + syncContext.executeLater(reportTransportInUse); + } + for (ClientStreamTracer streamTracer : tracers) { + streamTracer.createPendingStream(); + } + } + } + /** Runnable may be null. */ private Runnable createRealStream(ClientTransport transport, String authorityOverride) { ClientStream realStream; diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 2ca4630d6a1..0bb6372ead0 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -125,11 +125,22 @@ public void appendTimeoutInsight(InsightBuilder insight) { @CheckReturnValue final Runnable setStream(ClientStream stream) { ClientStreamListener savedListener; + ClientStream oldStream = null; + boolean cancelOldStream = false; + synchronized (this) { - // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { + oldStream = realStream; + cancelOldStream = listener != null; + } + if (oldStream != null && !cancelOldStream) { return null; } + + if (cancelOldStream) { + oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream")); + } + setRealStream(checkNotNull(stream, "stream")); savedListener = listener; if (savedListener == null) { diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 902c2835a92..0128f1fdbd1 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -170,6 +170,7 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void newStreamThenAssignTransportThenShutdown() { ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); + stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -177,12 +178,12 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); + fakeExecutor.runDueTasks(); assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream( same(method), same(headers), same(callOptions), ArgumentMatchers.any()); - stream.start(streamListener); - verify(mockRealStream).start(same(streamListener)); + verify(mockRealStream).start(any(ClientStreamListener.class)); } @Test public void transportTerminatedThenAssignTransport() { @@ -271,14 +272,41 @@ public void uncaughtException(Thread t, Throwable e) { verifyNoMoreInteractions(mockRealStream); } + @Test + public void newStreamThenShutDownNow() { + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); + assertEquals(1,delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportTerminated(); + verify(streamListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertEquals(0,delayedTransport.getPendingStreamsCount()); + assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); + } + + @Test + public void testDelayedClientTransportPendingStreamsOnShutDown() { + ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers); + ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers); + assertEquals(0, delayedTransport.getPendingStreamsCount()); + clientStream.start(streamListener); + clientStream1.start(streamListener); + assertEquals(2, delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + assertEquals(0, delayedTransport.getPendingStreamsCount()); + } + @Test public void newStreamThenShutdownTransportThenCancelStream() { ClientStream stream = delayedTransport.newStream( - method, new Metadata(), CallOptions.DEFAULT, tracers); + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); - stream.start(streamListener); stream.cancel(Status.CANCELLED); verify(transportListener).transportTerminated(); assertEquals(0, delayedTransport.getPendingStreamsCount()); @@ -348,33 +376,39 @@ public void uncaughtException(Thread t, Throwable e) { ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers, - failFastCallOptions); + failFastCallOptions); + transportListener.transportInUse(true); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions, tracers); + method, headers, failFastCallOptions, tracers); + ff3.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr1.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( - wfr3Executor.getScheduledExecutorService()); + wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); @@ -382,7 +416,8 @@ public void uncaughtException(Thread t, Throwable e) { PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); @@ -478,7 +513,8 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr5.start(mock(ClientStreamListener.class)); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( eqPickSubchannelArgs(method, headers, waitForReadyCallOptions)); @@ -626,12 +662,14 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { verify(picker, never()).pickSubchannel(any(PickSubchannelArgs.class)); Thread sideThread = new Thread("sideThread") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pick Subchannel and wait on barrier + ClientStream clientStream = + delayedTransport.newStream(method, headers, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread.start(); PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions); @@ -659,12 +697,14 @@ public void run() { ////////// Phase 2: reprocess() with a different picker // Create the second stream Thread sideThread2 = new Thread("sideThread2") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pickSubchannel and wait on barrier + ClientStream clientStream = delayedTransport + .newStream(method, headers2, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread2.start(); // The second stream will see the first picker verify(picker, timeout(5000)).pickSubchannel(argThat(args2)); @@ -730,6 +770,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); stream.start(streamListener); + transportListener.transportInUse(true); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); } diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index a47bea9f4ab..2902be027a9 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -46,6 +46,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,6 +85,39 @@ public void setStream_setAuthority() { inOrder.verify(realStream).start(any(ClientStreamListener.class)); } + @Test + public void testSetStreamReplaceOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + callMeMaybe(stream.setStream(newStream)); + // Verify old stream was canceled + verify(oldStream,never()).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + + @Test + public void testSetStreamStartCancelsOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + stream.start(listener); + try { + callMeMaybe(stream.setStream(newStream)); + } catch (IllegalStateException e) { + assertNotNull(e.getMessage()); + } + // Verify old stream was canceled + verify(oldStream).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + @Test(expected = IllegalStateException.class) public void start_afterStart() { stream.start(listener); @@ -333,17 +367,28 @@ public void setStreamTwice() { stream.start(listener); callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - callMeMaybe(stream.setStream(mock(ClientStream.class))); + try { + callMeMaybe(stream.setStream(mock(ClientStream.class))); + } catch (IllegalStateException e) { + assertEquals("realStream already set to realStream",e.getMessage()); + } stream.flush(); verify(realStream).flush(); } @Test public void cancelThenSetStream() { - stream.start(listener); - stream.cancel(Status.CANCELLED); + try { + stream.cancel(Status.CANCELLED); + Assert.fail("Should have thrown"); + } catch (IllegalStateException e) { + assertEquals("May only be called after start", e.getMessage()); + } callMeMaybe(stream.setStream(realStream)); + stream.start(listener); stream.isReady(); + verify(realStream).start(same(listener)); + verify(realStream).isReady(); verifyNoMoreInteractions(realStream); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 21ccf1095df..d1bf205205a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -2920,8 +2920,13 @@ public void idleMode_resetsDelayedTransportPicker() { // Move channel to idle timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); + executor.runDueTasks(); assertEquals(IDLE, channel.getState(false)); + //Force transport re-creation explicitly + channel.getState(true); + executor.runDueTasks(); + // This call should be buffered, but will move the channel out of idle ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, new Metadata()); @@ -2947,15 +2952,15 @@ public void idleMode_resetsDelayedTransportPicker() { transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) - .thenReturn(PickResult.withSubchannel(subchannel)); + .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper2, READY, mockPicker); assertEquals(READY, channel.getState(false)); executor.runDueTasks(); // Verify the buffered call was drained verify(mockTransport).newStream( - same(method), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } From 8ac96783c8a51d3c623a6a326129fdae40e96f7a Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 20 Mar 2025 06:23:05 +0000 Subject: [PATCH 02/27] core: Added changes to DelayedStream.setStream() should cancel the provided stream if not using it --- .../internal/DelayedClientTransportTest.java | 13 +++++++++++-- .../io/grpc/internal/DelayedStreamTest.java | 19 ++++++++----------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 0128f1fdbd1..394f8e2da86 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -226,8 +226,10 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); + assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); + assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(streamListener).closed( same(Status.CANCELLED), same(RpcProgress.PROCESSED), any(Metadata.class)); @@ -273,7 +275,7 @@ public void uncaughtException(Thread t, Throwable e) { } @Test - public void newStreamThenShutDownNow() { + public void testNewStreamThenShutDownNow() { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); @@ -283,6 +285,7 @@ public void newStreamThenShutDownNow() { verify(transportListener).transportTerminated(); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertEquals(0,delayedTransport.getPendingStreamsCount()); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } @@ -291,11 +294,14 @@ public void newStreamThenShutDownNow() { public void testDelayedClientTransportPendingStreamsOnShutDown() { ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers); ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers); + assertEquals(0, delayedTransport.getPendingStreamsCount()); clientStream.start(streamListener); clientStream1.start(streamListener); + assertEquals(2, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdownNow(Status.UNAVAILABLE); + assertEquals(0, delayedTransport.getPendingStreamsCount()); } @@ -350,7 +356,9 @@ public void testDelayedClientTransportPendingStreamsOnShutDown() { assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } - @Test public void reprocessSemantics() { + @Test + @SuppressWarnings("DirectInvocationOnMock") + public void reprocessSemantics() { CallOptions failFastCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 1); CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 2) .withWaitForReady(); @@ -754,6 +762,7 @@ public void reprocess_addOptionalLabelCallsTracer() throws Exception { } @Test + @SuppressWarnings("DirectInvocationOnMock") public void newStream_racesWithReprocessIdleMode() throws Exception { SubchannelPicker picker = new SubchannelPicker() { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 2902be027a9..bcc0b7f8675 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -107,11 +108,8 @@ public void testSetStreamStartCancelsOldStreamProperly() { // First stream set, but never started callMeMaybe(stream.setStream(oldStream)); stream.start(listener); - try { - callMeMaybe(stream.setStream(newStream)); - } catch (IllegalStateException e) { - assertNotNull(e.getMessage()); - } + assertThrows(IllegalStateException.class, + () -> callMeMaybe(stream.setStream(mock(ClientStream.class)))); // Verify old stream was canceled verify(oldStream).cancel(any(Status.class)); // Ensure new stream is properly set @@ -363,15 +361,14 @@ public void setStreamThenStartThenCancelled() { } @Test - public void setStreamTwice() { + public void testSetStreamTwice() { stream.start(listener); callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - try { - callMeMaybe(stream.setStream(mock(ClientStream.class))); - } catch (IllegalStateException e) { - assertEquals("realStream already set to realStream",e.getMessage()); - } + IllegalStateException e = assertThrows(IllegalStateException.class, () -> + callMeMaybe(stream.setStream(mock(ClientStream.class))) + ); + assertEquals("realStream already set to realStream", e.getMessage()); stream.flush(); verify(realStream).flush(); } From cac52d4e9e6e5e4252ae0bb81622e94d36b29f0f Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Fri, 21 Mar 2025 07:01:18 +0000 Subject: [PATCH 03/27] core: Fixed internal review points --- core/src/main/java/io/grpc/internal/DelayedStream.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 0bb6372ead0..15b45ea5d3b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -136,11 +136,9 @@ final Runnable setStream(ClientStream stream) { if (oldStream != null && !cancelOldStream) { return null; } - if (cancelOldStream) { oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream")); } - setRealStream(checkNotNull(stream, "stream")); savedListener = listener; if (savedListener == null) { From 00eb166591d52e1d5229159f40c9db19f3055188 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Tue, 18 Mar 2025 15:17:13 -0700 Subject: [PATCH 04/27] xds: Include XdsConfig as a CallOption This allows Filters to access the xds configuration for their own processing. From gRFC A83: > This data is available via the XdsConfig attribute introduced in A74. > If the xDS ConfigSelector is not already passing that attribute to the > filters, it will need to be changed to do so. --- xds/src/main/java/io/grpc/xds/XdsNameResolver.java | 14 +++++++++++--- .../test/java/io/grpc/xds/XdsNameResolverTest.java | 4 ++++ 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 5c1b3105c45..123d3a77172 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -94,6 +94,8 @@ final class XdsNameResolver extends NameResolver { static final CallOptions.Key CLUSTER_SELECTION_KEY = CallOptions.Key.create("io.grpc.xds.CLUSTER_SELECTION_KEY"); + static final CallOptions.Key XDS_CONFIG_CALL_OPTION_KEY = + CallOptions.Key.create("io.grpc.xds.XDS_CONFIG_CALL_OPTION_KEY"); static final CallOptions.Key RPC_HASH_KEY = CallOptions.Key.create("io.grpc.xds.RPC_HASH_KEY"); static final CallOptions.Key AUTO_HOST_REWRITE_KEY = @@ -467,6 +469,7 @@ public Result selectConfig(PickSubchannelArgs args) { "Failed to parse service config (method config)")); } final String finalCluster = cluster; + final XdsConfig xdsConfig = routingCfg.xdsConfig; final long hash = generateHash(routeAction.hashPolicies(), headers); class ClusterSelectionInterceptor implements ClientInterceptor { @Override @@ -475,6 +478,7 @@ public ClientCall interceptCall( final Channel next) { CallOptions callOptionsForCluster = callOptions.withOption(CLUSTER_SELECTION_KEY, finalCluster) + .withOption(XDS_CONFIG_CALL_OPTION_KEY, xdsConfig) .withOption(RPC_HASH_KEY, hash); if (routeAction.autoHostRewrite()) { callOptionsForCluster = callOptionsForCluster.withOption(AUTO_HOST_REWRITE_KEY, true); @@ -801,7 +805,7 @@ private void updateRoutes( } // Make newly added clusters selectable by config selector and deleted clusters no longer // selectable. - routingConfig = new RoutingConfig(httpMaxStreamDurationNano, routesData.build()); + routingConfig = new RoutingConfig(xdsConfig, httpMaxStreamDurationNano, routesData.build()); for (String cluster : deletedClusters) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); if (count == 0) { @@ -879,17 +883,21 @@ private void cleanUpRoutes(Status error) { * VirtualHost-level configuration for request routing. */ private static class RoutingConfig { - private final long fallbackTimeoutNano; + final XdsConfig xdsConfig; + final long fallbackTimeoutNano; final ImmutableList routes; final Status errorStatus; - private RoutingConfig(long fallbackTimeoutNano, ImmutableList routes) { + private RoutingConfig( + XdsConfig xdsConfig, long fallbackTimeoutNano, ImmutableList routes) { + this.xdsConfig = checkNotNull(xdsConfig, "xdsConfig"); this.fallbackTimeoutNano = fallbackTimeoutNano; this.routes = checkNotNull(routes, "routes"); this.errorStatus = null; } private RoutingConfig(Status errorStatus) { + this.xdsConfig = null; this.fallbackTimeoutNano = 0; this.routes = null; this.errorStatus = checkNotNull(errorStatus, "errorStatus"); diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 7425e3e31de..371c4213738 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -1672,6 +1672,10 @@ private void assertCallSelectClusterResult( clientCall.start(new NoopClientCallListener<>(), new Metadata()); assertThat(testCall.callOptions.getOption(XdsNameResolver.CLUSTER_SELECTION_KEY)) .isEqualTo("cluster:" + expectedCluster); + XdsConfig xdsConfig = + testCall.callOptions.getOption(XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY); + assertThat(xdsConfig).isNotNull(); + assertThat(xdsConfig.getClusters()).containsKey(expectedCluster); // Without "cluster:" prefix @SuppressWarnings("unchecked") Map config = (Map) result.getConfig(); if (expectedTimeoutSec != null) { From 8cd4d5a5a03abe640ffbcd8494867378725f67b2 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 5 Mar 2025 13:29:55 -0800 Subject: [PATCH 05/27] xds: Assert XdsNR's cluster ref counting is consistent It is much harder to debug refcounting problems when we ignore impossible situations. So make such impossible cases complain loudly so the bug is obvious. --- .../main/java/io/grpc/xds/XdsNameResolver.java | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index 123d3a77172..bbe36bdd744 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -537,17 +537,21 @@ private boolean retainCluster(String cluster) { private void releaseCluster(final String cluster) { int count = clusterRefs.get(cluster).refCount.decrementAndGet(); + if (count < 0) { + throw new AssertionError(); + } if (count == 0) { syncContext.execute(new Runnable() { @Override public void run() { - if (clusterRefs.get(cluster).refCount.get() == 0) { - clusterRefs.remove(cluster); - if (resolveState.lastConfigOrStatus.hasValue()) { - updateResolutionResult(resolveState.lastConfigOrStatus.getValue()); - } else { - resolveState.cleanUpRoutes(resolveState.lastConfigOrStatus.getStatus()); - } + if (clusterRefs.get(cluster).refCount.get() != 0) { + throw new AssertionError(); + } + clusterRefs.remove(cluster); + if (resolveState.lastConfigOrStatus.hasValue()) { + updateResolutionResult(resolveState.lastConfigOrStatus.getValue()); + } else { + resolveState.cleanUpRoutes(resolveState.lastConfigOrStatus.getStatus()); } } }); From 2295bbe1f7e02a747dc31fb080868d7773abe715 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Thu, 20 Mar 2025 22:31:16 -0700 Subject: [PATCH 06/27] xds: Expose filter names to filter instances (#11971) This is to support gRFC A83 xDS GCP Authentication Filter: > Otherwise, the filter will look in the CDS resource's metadata for a > key corresponding to the filter's instance name. --- xds/src/main/java/io/grpc/xds/FaultFilter.java | 2 +- xds/src/main/java/io/grpc/xds/Filter.java | 2 +- .../main/java/io/grpc/xds/GcpAuthenticationFilter.java | 2 +- xds/src/main/java/io/grpc/xds/InternalRbacFilter.java | 2 +- xds/src/main/java/io/grpc/xds/RbacFilter.java | 2 +- xds/src/main/java/io/grpc/xds/RouterFilter.java | 2 +- xds/src/main/java/io/grpc/xds/XdsNameResolver.java | 3 ++- xds/src/main/java/io/grpc/xds/XdsServerWrapper.java | 3 ++- .../java/io/grpc/xds/GrpcXdsClientImplDataTest.java | 2 +- xds/src/test/java/io/grpc/xds/RbacFilterTest.java | 10 ++++++---- xds/src/test/java/io/grpc/xds/StatefulFilter.java | 2 +- xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java | 2 +- .../test/java/io/grpc/xds/XdsServerWrapperTest.java | 4 ++-- 13 files changed, 21 insertions(+), 17 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/FaultFilter.java b/xds/src/main/java/io/grpc/xds/FaultFilter.java index 2012fd36b62..0f3bb5b0557 100644 --- a/xds/src/main/java/io/grpc/xds/FaultFilter.java +++ b/xds/src/main/java/io/grpc/xds/FaultFilter.java @@ -99,7 +99,7 @@ public boolean isClientFilter() { } @Override - public FaultFilter newInstance() { + public FaultFilter newInstance(String name) { return INSTANCE; } diff --git a/xds/src/main/java/io/grpc/xds/Filter.java b/xds/src/main/java/io/grpc/xds/Filter.java index aa326b55ad7..416d929becf 100644 --- a/xds/src/main/java/io/grpc/xds/Filter.java +++ b/xds/src/main/java/io/grpc/xds/Filter.java @@ -87,7 +87,7 @@ default boolean isServerFilter() { *
  • Filter name+typeUrl in FilterChain's HCM.http_filters.
  • * */ - Filter newInstance(); + Filter newInstance(String name); /** * Parses the top-level filter config from raw proto message. The message may be either a {@link diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index 41687817c47..add885c6416 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -64,7 +64,7 @@ public boolean isClientFilter() { } @Override - public GcpAuthenticationFilter newInstance() { + public GcpAuthenticationFilter newInstance(String name) { return new GcpAuthenticationFilter(); } diff --git a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java index cedb3f4c85b..476adbf9cfd 100644 --- a/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/InternalRbacFilter.java @@ -33,7 +33,7 @@ public static ServerInterceptor createInterceptor(RBAC rbac) { throw new IllegalArgumentException( String.format("Failed to parse Rbac policy: %s", filterConfig.errorDetail)); } - return new RbacFilter.Provider().newInstance() + return new RbacFilter.Provider().newInstance("internalRbacFilter") .buildServerInterceptor(filterConfig.config, null); } } diff --git a/xds/src/main/java/io/grpc/xds/RbacFilter.java b/xds/src/main/java/io/grpc/xds/RbacFilter.java index 2bc4eeb846b..d91884735e9 100644 --- a/xds/src/main/java/io/grpc/xds/RbacFilter.java +++ b/xds/src/main/java/io/grpc/xds/RbacFilter.java @@ -89,7 +89,7 @@ public boolean isServerFilter() { } @Override - public RbacFilter newInstance() { + public RbacFilter newInstance(String name) { return INSTANCE; } diff --git a/xds/src/main/java/io/grpc/xds/RouterFilter.java b/xds/src/main/java/io/grpc/xds/RouterFilter.java index 939bd0b12ab..504c4213149 100644 --- a/xds/src/main/java/io/grpc/xds/RouterFilter.java +++ b/xds/src/main/java/io/grpc/xds/RouterFilter.java @@ -56,7 +56,7 @@ public boolean isServerFilter() { } @Override - public RouterFilter newInstance() { + public RouterFilter newInstance(String name) { return INSTANCE; } diff --git a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java index bbe36bdd744..7704a4a09db 100644 --- a/xds/src/main/java/io/grpc/xds/XdsNameResolver.java +++ b/xds/src/main/java/io/grpc/xds/XdsNameResolver.java @@ -704,7 +704,8 @@ private void updateActiveFilters(@Nullable List filterConfigs Filter.Provider provider = filterRegistry.get(typeUrl); checkNotNull(provider, "provider %s", typeUrl); - Filter filter = activeFilters.computeIfAbsent(filterKey, k -> provider.newInstance()); + Filter filter = activeFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); checkNotNull(filter, "filter %s", filterKey); filtersToShutdown.remove(filterKey); } diff --git a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java index e0185974861..6625bd8178a 100644 --- a/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java +++ b/xds/src/main/java/io/grpc/xds/XdsServerWrapper.java @@ -560,7 +560,8 @@ private void updateActiveFiltersForChain( Filter.Provider provider = filterRegistry.get(typeUrl); checkNotNull(provider, "provider %s", typeUrl); - Filter filter = chainFilters.computeIfAbsent(filterKey, k -> provider.newInstance()); + Filter filter = chainFilters.computeIfAbsent( + filterKey, k -> provider.newInstance(namedFilter.name)); checkNotNull(filter, "filter %s", filterKey); filtersToShutdown.remove(filterKey); } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index 90b83320d63..bfaa17245cf 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -1267,7 +1267,7 @@ public boolean isClientFilter() { } @Override - public TestFilter newInstance() { + public TestFilter newInstance(String name) { return new TestFilter(); } diff --git a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java index 7f195693d84..334e159dd1d 100644 --- a/xds/src/test/java/io/grpc/xds/RbacFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/RbacFilterTest.java @@ -80,6 +80,8 @@ public class RbacFilterTest { StringMatcher.newBuilder().setExact("/" + PATH).setIgnoreCase(true).build(); private static final RbacFilter.Provider FILTER_PROVIDER = new RbacFilter.Provider(); + private final String name = "theFilterName"; + @Test public void filterType_serverOnly() { assertThat(FILTER_PROVIDER.isClientFilter()).isFalse(); @@ -259,7 +261,7 @@ public void testAuthorizationInterceptor() { OrMatcher.create(AlwaysTrueMatcher.INSTANCE)); AuthConfig authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.ALLOW); - FILTER_PROVIDER.newInstance().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler, never()).startCall(eq(mockServerCall), any(Metadata.class)); ArgumentCaptor captor = ArgumentCaptor.forClass(Status.class); @@ -271,7 +273,7 @@ public void testAuthorizationInterceptor() { authconfig = AuthConfig.create(Collections.singletonList(policyMatcher), GrpcAuthorizationEngine.Action.DENY); - FILTER_PROVIDER.newInstance().buildServerInterceptor(RbacConfig.create(authconfig), null) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(RbacConfig.create(authconfig), null) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); } @@ -322,7 +324,7 @@ public void overrideConfig() { RbacConfig override = FILTER_PROVIDER.parseFilterConfigOverride(Any.pack(rbacPerRoute)).config; assertThat(override).isEqualTo(RbacConfig.create(null)); ServerInterceptor interceptor = - FILTER_PROVIDER.newInstance().buildServerInterceptor(original, override); + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override); assertThat(interceptor).isNull(); policyMatcher = PolicyMatcher.create("policy-matcher-override", @@ -332,7 +334,7 @@ public void overrideConfig() { GrpcAuthorizationEngine.Action.ALLOW); override = RbacConfig.create(authconfig); - FILTER_PROVIDER.newInstance().buildServerInterceptor(original, override) + FILTER_PROVIDER.newInstance(name).buildServerInterceptor(original, override) .interceptCall(mockServerCall, new Metadata(), mockHandler); verify(mockHandler).startCall(eq(mockServerCall), any(Metadata.class)); verify(mockServerCall).getAttributes(); diff --git a/xds/src/test/java/io/grpc/xds/StatefulFilter.java b/xds/src/test/java/io/grpc/xds/StatefulFilter.java index 162dd380daf..4ef662c7ccd 100644 --- a/xds/src/test/java/io/grpc/xds/StatefulFilter.java +++ b/xds/src/test/java/io/grpc/xds/StatefulFilter.java @@ -108,7 +108,7 @@ public boolean isServerFilter() { } @Override - public synchronized StatefulFilter newInstance() { + public synchronized StatefulFilter newInstance(String name) { StatefulFilter filter = new StatefulFilter(counter++); instances.put(filter.idx, filter); return filter; diff --git a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java index 371c4213738..622084d4306 100644 --- a/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsNameResolverTest.java @@ -219,7 +219,7 @@ public void setUp() { // Lenient: suppress [MockitoHint] Unused warning, only used in resolved_fault* tests. lenient() .doReturn(new FaultFilter(mockRandom, new AtomicLong())) - .when(faultFilterProvider).newInstance(); + .when(faultFilterProvider).newInstance(any(String.class)); FilterRegistry filterRegistry = FilterRegistry.newRegistry().register( ROUTER_FILTER_PROVIDER, diff --git a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java index b866e10c559..e5f0f44cbae 100644 --- a/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsServerWrapperTest.java @@ -1135,7 +1135,7 @@ public void run() { Filter.Provider filterProvider = mock(Filter.Provider.class); when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); when(filterProvider.isServerFilter()).thenReturn(true); - when(filterProvider.newInstance()).thenReturn(filter); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); filterRegistry.register(filterProvider); FilterConfig f0 = mock(FilterConfig.class); @@ -1208,7 +1208,7 @@ public void run() { Filter.Provider filterProvider = mock(Filter.Provider.class); when(filterProvider.typeUrls()).thenReturn(new String[]{"filter-type-url"}); when(filterProvider.isServerFilter()).thenReturn(true); - when(filterProvider.newInstance()).thenReturn(filter); + when(filterProvider.newInstance(any(String.class))).thenReturn(filter); filterRegistry.register(filterProvider); FilterConfig f0 = mock(FilterConfig.class); From 9e6ece3527ae92eed2218acc87b15955de577452 Mon Sep 17 00:00:00 2001 From: Alex Panchenko <440271+panchenko@users.noreply.github.com> Date: Fri, 21 Mar 2025 09:30:24 +0200 Subject: [PATCH 07/27] Replace usages of deprecated ExpectedException in grpc-api and grpc-core (#11962) --- api/build.gradle | 1 + api/src/test/java/io/grpc/MetadataTest.java | 35 +++++-------- .../java/io/grpc/MethodDescriptorTest.java | 6 --- .../java/io/grpc/ServerInterceptorsTest.java | 18 +++---- .../io/grpc/ServerServiceDefinitionTest.java | 19 ++----- .../java/io/grpc/ServiceDescriptorTest.java | 51 +++++++++---------- .../internal/AbstractClientStreamTest.java | 23 +++------ .../internal/AbstractServerStreamTest.java | 30 ++++------- .../ConnectivityStateManagerTest.java | 8 +-- .../io/grpc/internal/DnsNameResolverTest.java | 28 +++++----- .../java/io/grpc/internal/GrpcUtilTest.java | 23 ++++----- .../grpc/internal/InternalSubchannelTest.java | 13 ++--- .../java/io/grpc/internal/JsonParserTest.java | 51 ++++++------------- .../ManagedChannelImplBuilderTest.java | 33 +++++------- .../ManagedChannelServiceConfigTest.java | 49 +++++++----------- .../io/grpc/internal/MessageDeframerTest.java | 32 +++++------- .../io/grpc/internal/ServerCallImplTest.java | 42 ++++++++------- .../java/io/grpc/internal/ServerImplTest.java | 16 +++--- .../grpc/internal/AbstractTransportTest.java | 10 +--- 19 files changed, 185 insertions(+), 303 deletions(-) diff --git a/api/build.gradle b/api/build.gradle index dc3eaea3f4e..415a17f61f8 100644 --- a/api/build.gradle +++ b/api/build.gradle @@ -47,6 +47,7 @@ dependencies { testImplementation project(':grpc-core') testImplementation project(':grpc-testing') testImplementation libraries.guava.testlib + testImplementation libraries.truth signature (libraries.signature.java) { artifact { diff --git a/api/src/test/java/io/grpc/MetadataTest.java b/api/src/test/java/io/grpc/MetadataTest.java index 14ba8ca9b23..a858fff5e5a 100644 --- a/api/src/test/java/io/grpc/MetadataTest.java +++ b/api/src/test/java/io/grpc/MetadataTest.java @@ -16,6 +16,7 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; import static java.nio.charset.StandardCharsets.US_ASCII; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertArrayEquals; @@ -24,6 +25,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -37,9 +39,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.Locale; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -49,9 +49,6 @@ @RunWith(JUnit4.class) public class MetadataTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private static final Metadata.BinaryMarshaller FISH_MARSHALLER = new Metadata.BinaryMarshaller() { @Override @@ -65,7 +62,7 @@ public Fish parseBytes(byte[] serialized) { } }; - private static class FishStreamMarsaller implements Metadata.BinaryStreamMarshaller { + private static class FishStreamMarshaller implements Metadata.BinaryStreamMarshaller { @Override public InputStream toStream(Fish fish) { return new ByteArrayInputStream(FISH_MARSHALLER.toBytes(fish)); @@ -82,7 +79,7 @@ public Fish parseStream(InputStream stream) { } private static final Metadata.BinaryStreamMarshaller FISH_STREAM_MARSHALLER = - new FishStreamMarsaller(); + new FishStreamMarshaller(); /** A pattern commonly used to avoid unnecessary serialization of immutable objects. */ private static final class FakeFishStream extends InputStream { @@ -121,10 +118,9 @@ public Fish parseStream(InputStream stream) { @Test public void noPseudoHeaders() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid character"); - - Metadata.Key.of(":test-bin", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of(":test-bin", FISH_MARSHALLER)); + assertThat(e).hasMessageThat().isEqualTo("Invalid character ':' in key name ':test-bin'"); } @Test @@ -186,8 +182,7 @@ public void testGetAllNoRemove() { Iterator i = metadata.getAll(KEY).iterator(); assertEquals(lance, i.next()); - thrown.expect(UnsupportedOperationException.class); - i.remove(); + assertThrows(UnsupportedOperationException.class, i::remove); } @Test @@ -271,17 +266,15 @@ public void mergeExpands() { @Test public void shortBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - - Metadata.Key.of("-bin", FISH_MARSHALLER); + assertThrows(IllegalArgumentException.class, () -> Metadata.Key.of("-bin", FISH_MARSHALLER)); } @Test public void invalidSuffixBinaryKeyName() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Binary header is named"); - - Metadata.Key.of("nonbinary", FISH_MARSHALLER); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> Metadata.Key.of("nonbinary", FISH_MARSHALLER)); + assertThat(e).hasMessageThat() + .isEqualTo("Binary header is named nonbinary. It must end with -bin"); } @Test @@ -415,7 +408,7 @@ public void streamedValueDifferentMarshaller() { h.put(KEY_STREAMED, salmon); // Get using a different marshaller instance. - Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarsaller())); + Fish fish = h.get(copyKey(KEY_STREAMED, new FishStreamMarshaller())); assertEquals(salmon, fish); } diff --git a/api/src/test/java/io/grpc/MethodDescriptorTest.java b/api/src/test/java/io/grpc/MethodDescriptorTest.java index 9431190984b..e068e0c1108 100644 --- a/api/src/test/java/io/grpc/MethodDescriptorTest.java +++ b/api/src/test/java/io/grpc/MethodDescriptorTest.java @@ -26,9 +26,7 @@ import io.grpc.MethodDescriptor.Marshaller; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -37,10 +35,6 @@ */ @RunWith(JUnit4.class) public class MethodDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void createMethodDescriptor() { MethodDescriptor descriptor = MethodDescriptor.newBuilder() diff --git a/api/src/test/java/io/grpc/ServerInterceptorsTest.java b/api/src/test/java/io/grpc/ServerInterceptorsTest.java index abfb3540fe4..b84b3838afa 100644 --- a/api/src/test/java/io/grpc/ServerInterceptorsTest.java +++ b/api/src/test/java/io/grpc/ServerInterceptorsTest.java @@ -19,6 +19,7 @@ import static com.google.common.collect.Iterables.getOnlyElement; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.same; @@ -40,7 +41,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentMatchers; @@ -55,10 +55,6 @@ public class ServerInterceptorsTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Mock private Marshaller requestMarshaller; @@ -111,21 +107,21 @@ public void makeSureExpectedMocksUnused() { public void npeForNullServiceDefinition() { ServerServiceDefinition serviceDef = null; List interceptors = Arrays.asList(); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDef, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDef, interceptors)); } @Test public void npeForNullInterceptorList() { - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, (List) null); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, (List) null)); } @Test public void npeForNullInterceptor() { List interceptors = Arrays.asList((ServerInterceptor) null); - thrown.expect(NullPointerException.class); - ServerInterceptors.intercept(serviceDefinition, interceptors); + assertThrows(NullPointerException.class, + () -> ServerInterceptors.intercept(serviceDefinition, interceptors)); } @Test diff --git a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java index 6a84d640d78..9e43302e210 100644 --- a/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java +++ b/api/src/test/java/io/grpc/ServerServiceDefinitionTest.java @@ -18,14 +18,13 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import java.util.ArrayList; import java.util.Collections; import java.util.HashSet; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -52,9 +51,6 @@ public class ServerServiceDefinitionTest { = ServerMethodDefinition.create(method1, methodHandler1); private ServerMethodDefinition methodDef2 = ServerMethodDefinition.create(method2, methodHandler2); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); @Test public void noMethods() { @@ -91,9 +87,7 @@ public void addMethod_duplicateName() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(method1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.addMethod(diffMethod1, methodHandler2) - .build(); + assertThrows(IllegalStateException.class, () -> ssd.addMethod(diffMethod1, methodHandler2)); } @Test @@ -101,8 +95,7 @@ public void buildMisaligned_extraMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(methodDef1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test @@ -110,16 +103,14 @@ public void buildMisaligned_diffMethodInstance() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd) .addMethod(diffMethod1, methodHandler1); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test public void buildMisaligned_missingMethod() { ServiceDescriptor sd = new ServiceDescriptor(serviceName, method1); ServerServiceDefinition.Builder ssd = ServerServiceDefinition.builder(sd); - thrown.expect(IllegalStateException.class); - ssd.build(); + assertThrows(IllegalStateException.class, ssd::build); } @Test diff --git a/api/src/test/java/io/grpc/ServiceDescriptorTest.java b/api/src/test/java/io/grpc/ServiceDescriptorTest.java index a05858680d5..89bdead3632 100644 --- a/api/src/test/java/io/grpc/ServiceDescriptorTest.java +++ b/api/src/test/java/io/grpc/ServiceDescriptorTest.java @@ -16,17 +16,18 @@ package io.grpc; +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; +import com.google.common.truth.StringSubject; import io.grpc.MethodDescriptor.MethodType; import io.grpc.testing.TestMethodDescriptors; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.List; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -36,32 +37,27 @@ @RunWith(JUnit4.class) public class ServiceDescriptorTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void failsOnNullName() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("name"); - - new ServiceDescriptor(null, Collections.>emptyList()); + List> methods = Collections.emptyList(); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor(null, methods)); + assertThat(e).hasMessageThat().isEqualTo("name"); } @Test public void failsOnNullMethods() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("methods"); - - new ServiceDescriptor("name", (Collection>) null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", (Collection>) null)); + assertThat(e).hasMessageThat().isEqualTo("methods"); } @Test public void failsOnNullMethod() { - thrown.expect(NullPointerException.class); - thrown.expectMessage("method"); - - new ServiceDescriptor("name", Collections.>singletonList(null)); + List> methods = Collections.singletonList(null); + NullPointerException e = assertThrows(NullPointerException.class, + () -> new ServiceDescriptor("name", methods)); + assertThat(e).hasMessageThat().isEqualTo("method"); } @Test @@ -69,15 +65,17 @@ public void failsOnNonMatchingNames() { List> descriptors = Collections.>singletonList( MethodDescriptor.newBuilder() .setType(MethodType.UNARY) - .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongservice", "method")) + .setFullMethodName(MethodDescriptor.generateFullMethodName("wrongService", "method")) .setRequestMarshaller(TestMethodDescriptors.voidMarshaller()) .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("service names"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("fooService", descriptors)); + StringSubject error = assertThat(e).hasMessageThat(); + error.contains("service names"); + error.contains("fooService"); + error.contains("wrongService"); } @Test @@ -96,10 +94,9 @@ public void failsOnNonDuplicateNames() { .setResponseMarshaller(TestMethodDescriptors.voidMarshaller()) .build()); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("duplicate"); - - new ServiceDescriptor("name", descriptors); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> new ServiceDescriptor("name", descriptors)); + assertThat(e).hasMessageThat().isEqualTo("duplicate name name/method"); } @Test diff --git a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java index ad3b59030d7..18fafe6557d 100644 --- a/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractClientStreamTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -57,7 +58,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -76,8 +76,6 @@ public class AbstractClientStreamTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final StatsTraceContext statsTraceCtx = StatsTraceContext.NOOP; private final TransportTracer transportTracer = new TransportTracer(); @@ -136,9 +134,7 @@ public void cancel_failsOnNull() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(listener); - thrown.expect(NullPointerException.class); - - stream.cancel(null); + assertThrows(NullPointerException.class, () -> stream.cancel(null)); } @Test @@ -164,9 +160,7 @@ public void startFailsOnNullListener() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); - thrown.expect(NullPointerException.class); - - stream.start(null); + assertThrows(NullPointerException.class, () -> stream.start(null)); } @Test @@ -174,9 +168,7 @@ public void cantCallStartTwice() { AbstractClientStream stream = new BaseAbstractClientStream(allocator, statsTraceCtx, transportTracer); stream.start(mockListener); - thrown.expect(IllegalStateException.class); - - stream.start(mockListener); + assertThrows(IllegalStateException.class, () -> stream.start(mockListener)); } @Test @@ -188,8 +180,7 @@ public void inboundDataReceived_failsOnNullFrame() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.inboundDataReceived(null); + assertThrows(NullPointerException.class, () -> state.inboundDataReceived(null)); } @Test @@ -212,8 +203,8 @@ public void inboundHeadersReceived_failsIfStatusReported() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.inboundHeadersReceived(new Metadata()); + Metadata headers = new Metadata(); + assertThrows(IllegalStateException.class, () -> state.inboundHeadersReceived(headers)); } @Test diff --git a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java index b41d45e972e..137ba19bfea 100644 --- a/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java +++ b/core/src/test/java/io/grpc/internal/AbstractServerStreamTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; @@ -45,9 +46,7 @@ import java.util.Queue; import java.util.concurrent.TimeUnit; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -60,9 +59,6 @@ public class AbstractServerStreamTest { private static final int TIMEOUT_MS = 1000; private static final int MAX_MESSAGE_SIZE = 100; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); - private final WritableBufferAllocator allocator = new WritableBufferAllocator() { @Override public WritableBuffer allocate(int capacityHint) { @@ -226,9 +222,9 @@ public void completeWithoutClose() { public void setListener_setOnlyOnce() { TransportState state = stream.transportState(); state.setListener(new ServerStreamListenerBase()); - thrown.expect(IllegalStateException.class); - state.setListener(new ServerStreamListenerBase()); + ServerStreamListenerBase listener2 = new ServerStreamListenerBase(); + assertThrows(IllegalStateException.class, () -> state.setListener(listener2)); } @Test @@ -238,8 +234,7 @@ public void listenerReady_onlyOnce() { TransportState state = stream.transportState(); - thrown.expect(IllegalStateException.class); - state.onStreamAllocated(); + assertThrows(IllegalStateException.class, state::onStreamAllocated); } @Test @@ -255,8 +250,7 @@ public void listenerReady_readyCalled() { public void setListener_failsOnNull() { TransportState state = stream.transportState(); - thrown.expect(NullPointerException.class); - state.setListener(null); + assertThrows(NullPointerException.class, () -> state.setListener(null)); } // TODO(ericgribkoff) This test is only valid if deframeInTransportThread=true, as otherwise the @@ -284,9 +278,7 @@ public void messagesAvailable(MessageProducer producer) { @Test public void writeHeaders_failsOnNullHeaders() { - thrown.expect(NullPointerException.class); - - stream.writeHeaders(null, true); + assertThrows(NullPointerException.class, () -> stream.writeHeaders(null, true)); } @Test @@ -336,16 +328,13 @@ public void writeMessage_closesStream() throws Exception { @Test public void close_failsOnNullStatus() { - thrown.expect(NullPointerException.class); - - stream.close(null, new Metadata()); + Metadata trailers = new Metadata(); + assertThrows(NullPointerException.class, () -> stream.close(null, trailers)); } @Test public void close_failsOnNullMetadata() { - thrown.expect(NullPointerException.class); - - stream.close(Status.INTERNAL, null); + assertThrows(NullPointerException.class, () -> stream.close(Status.INTERNAL, null)); } @Test @@ -451,4 +440,3 @@ public int streamId() { } } } - diff --git a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java index 2a759a4f386..dfd6ed56a1e 100644 --- a/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java +++ b/core/src/test/java/io/grpc/internal/ConnectivityStateManagerTest.java @@ -27,9 +27,7 @@ import io.grpc.ConnectivityState; import java.util.LinkedList; import java.util.concurrent.Executor; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -38,10 +36,6 @@ */ @RunWith(JUnit4.class) public class ConnectivityStateManagerTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - private final FakeClock executor = new FakeClock(); private final ConnectivityStateManager state = new ConnectivityStateManager(); private final LinkedList sink = new LinkedList<>(); @@ -75,7 +69,7 @@ public void run() { assertEquals(1, sink.size()); assertEquals(TRANSIENT_FAILURE, sink.poll()); } - + @Test public void registerCallbackAfterStateChanged() { state.gotoState(CONNECTING); diff --git a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java index be304ad326b..130c01d1c04 100644 --- a/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java +++ b/core/src/test/java/io/grpc/internal/DnsNameResolverTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -35,6 +36,7 @@ import static org.mockito.Mockito.when; import com.google.common.base.Stopwatch; +import com.google.common.base.VerifyException; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.net.InetAddresses; @@ -82,7 +84,6 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.DisableOnDebug; -import org.junit.rules.ExpectedException; import org.junit.rules.TestRule; import org.junit.rules.Timeout; import org.junit.runner.RunWith; @@ -99,8 +100,6 @@ public class DnsNameResolverTest { @Rule public final TestRule globalTimeout = new DisableOnDebug(Timeout.seconds(10)); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); private final Map serviceConfig = new LinkedHashMap<>(); @@ -914,9 +913,10 @@ public HttpConnectProxiedSocketAddress proxyFor(SocketAddress targetAddress) { public void maybeChooseServiceConfig_failsOnMisspelling() { Map bad = new LinkedHashMap<>(); bad.put("parcentage", 1.0); - thrown.expectMessage("Bad key"); - - DnsNameResolver.maybeChooseServiceConfig(bad, new Random(), "host"); + Random random = new Random(); + VerifyException e = assertThrows(VerifyException.class, + () -> DnsNameResolver.maybeChooseServiceConfig(bad, random, "host")); + assertThat(e).hasMessageThat().isEqualTo("Bad key: parcentage=1.0"); } @Test @@ -1155,25 +1155,25 @@ public void parseTxtResults_misspelledName() throws Exception { } @Test - public void parseTxtResults_badTypeFails() throws Exception { + public void parseTxtResults_badTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config={}"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("wrong type"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("wrong type {}"); } @Test - public void parseTxtResults_badInnerTypeFails() throws Exception { + public void parseTxtResults_badInnerTypeFails() { List txtRecords = new ArrayList<>(); txtRecords.add("some_record"); txtRecords.add("grpc_config=[\"bogus\"]"); - thrown.expect(ClassCastException.class); - thrown.expectMessage("not object"); - DnsNameResolver.parseTxtResults(txtRecords); + ClassCastException e = assertThrows(ClassCastException.class, + () -> DnsNameResolver.parseTxtResults(txtRecords)); + assertThat(e).hasMessageThat().isEqualTo("value bogus for idx 0 in [bogus] is not object"); } @Test diff --git a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java index 39acb582d28..229c593ef80 100644 --- a/core/src/test/java/io/grpc/internal/GrpcUtilTest.java +++ b/core/src/test/java/io/grpc/internal/GrpcUtilTest.java @@ -22,6 +22,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -41,7 +42,6 @@ import java.util.ArrayList; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -57,8 +57,6 @@ public class GrpcUtilTest { new ClientStreamTracer() {} }; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Captor @@ -201,9 +199,7 @@ public void urlAuthorityEscape_unicodeAreNotEncoded() { @Test public void checkAuthority_failsOnNull() { - thrown.expect(NullPointerException.class); - - GrpcUtil.checkAuthority(null); + assertThrows(NullPointerException.class, () -> GrpcUtil.checkAuthority(null)); } @Test @@ -229,19 +225,18 @@ public void checkAuthority_succeedsOnIpV6() { @Test public void checkAuthority_failsOnInvalidAuthority() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - GrpcUtil.checkAuthority("[ : : 1]"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("[ : : 1]")); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test public void checkAuthority_userInfoNotAllowed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Userinfo"); - - GrpcUtil.checkAuthority("foo@valid"); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> GrpcUtil.checkAuthority("foo@valid")); + assertThat(e).hasMessageThat() + .isEqualTo("Userinfo must not be present on authority: 'foo@valid'"); } @Test diff --git a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java index b75fd43a743..bed722f5f3a 100644 --- a/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java +++ b/core/src/test/java/io/grpc/internal/InternalSubchannelTest.java @@ -27,6 +27,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -65,7 +66,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -79,9 +79,6 @@ public class InternalSubchannelTest { @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private static final String AUTHORITY = "fakeauthority"; private static final String USER_AGENT = "mosaic"; @@ -544,8 +541,9 @@ public void constructor_eagListWithNull_throws() { public void updateAddresses_emptyEagList_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); - thrown.expect(IllegalArgumentException.class); - internalSubchannel.updateAddresses(Arrays.asList()); + List newAddressGroups = Collections.emptyList(); + assertThrows(IllegalArgumentException.class, + () -> internalSubchannel.updateAddresses(newAddressGroups)); } @Test @@ -553,8 +551,7 @@ public void updateAddresses_eagListWithNull_throws() { SocketAddress addr = new FakeSocketAddress(); createInternalSubchannel(addr); List eags = Arrays.asList((EquivalentAddressGroup) null); - thrown.expect(NullPointerException.class); - internalSubchannel.updateAddresses(eags); + assertThrows(NullPointerException.class, () -> internalSubchannel.updateAddresses(eags)); } @Test public void updateAddresses_intersecting_ready() { diff --git a/core/src/test/java/io/grpc/internal/JsonParserTest.java b/core/src/test/java/io/grpc/internal/JsonParserTest.java index cfee566fa4a..a0dd81c20ce 100644 --- a/core/src/test/java/io/grpc/internal/JsonParserTest.java +++ b/core/src/test/java/io/grpc/internal/JsonParserTest.java @@ -17,15 +17,14 @@ package io.grpc.internal; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import com.google.gson.stream.MalformedJsonException; import java.io.EOFException; import java.io.IOException; import java.util.ArrayList; import java.util.LinkedHashMap; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -35,10 +34,6 @@ @RunWith(JUnit4.class) public class JsonParserTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void emptyObject() throws IOException { assertEquals(new LinkedHashMap(), JsonParser.parse("{}")); @@ -75,45 +70,33 @@ public void nullValue() throws IOException { } @Test - public void nanFails() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("NaN"); + public void nanFails() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("NaN")); } @Test - public void objectEarlyEnd() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{foo:}"); + public void objectEarlyEnd() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{foo:}")); } @Test - public void earlyEndArray() throws IOException { - thrown.expect(EOFException.class); - - JsonParser.parse("[1, 2, "); + public void earlyEndArray() { + assertThrows(EOFException.class, () -> JsonParser.parse("[1, 2, ")); } @Test - public void arrayMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("[1, 2, ]"); + public void arrayMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("[1, 2, ]")); } @Test - public void objectMissingElement() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{1: "); + public void objectMissingElement() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{1: ")); } @Test - public void objectNoName() throws IOException { - thrown.expect(MalformedJsonException.class); - - JsonParser.parse("{: 1"); + public void objectNoName() { + assertThrows(MalformedJsonException.class, () -> JsonParser.parse("{: 1")); } @Test @@ -125,9 +108,7 @@ public void objectStringName() throws IOException { } @Test - public void duplicate() throws IOException { - thrown.expect(IllegalArgumentException.class); - - JsonParser.parse("{\"hi\": 2, \"hi\": 3}"); + public void duplicate() { + assertThrows(IllegalArgumentException.class, () -> JsonParser.parse("{\"hi\": 2, \"hi\": 3}")); } -} \ No newline at end of file +} diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java index cf131a79d87..861412653fb 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplBuilderTest.java @@ -23,6 +23,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.doReturn; @@ -67,7 +68,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.Mock; @@ -99,8 +99,6 @@ public ClientCall interceptCall( }; @Rule public final MockitoRule mocks = MockitoJUnit.rule(); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @Mock private ClientTransportFactory mockClientTransportFactory; @@ -424,10 +422,9 @@ public void checkAuthority_validAuthorityAllowed() { @Test public void checkAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -450,11 +447,10 @@ public void enableCheckAuthority_validAuthorityAllowed() { @Test public void disableCheckAuthority_invalidAuthorityFailed() { - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Invalid authority"); - builder.disableCheckAuthority().enableCheckAuthority(); - builder.checkAuthority(DUMMY_AUTHORITY_INVALID); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> builder.checkAuthority(DUMMY_AUTHORITY_INVALID)); + assertThat(e).hasMessageThat().isEqualTo("Invalid authority: [ : : 1]"); } @Test @@ -680,14 +676,12 @@ public void perRpcBufferLimit() { @Test public void retryBufferSizeInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.retryBufferSize(0L); + assertThrows(IllegalArgumentException.class, () -> builder.retryBufferSize(0L)); } @Test public void perRpcBufferLimitInvalidArg() { - thrown.expect(IllegalArgumentException.class); - builder.perRpcBufferLimit(0L); + assertThrows(IllegalArgumentException.class, () -> builder.perRpcBufferLimit(0L)); } @Test @@ -710,8 +704,7 @@ public void defaultServiceConfig_nullKey() { Map config = new HashMap<>(); config.put(null, "val"); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -721,8 +714,7 @@ public void defaultServiceConfig_intKey() { Map config = new HashMap<>(); config.put("key", subConfig); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test @@ -730,8 +722,7 @@ public void defaultServiceConfig_intValue() { Map config = new HashMap<>(); config.put("key", 3); - thrown.expect(IllegalArgumentException.class); - builder.defaultServiceConfig(config); + assertThrows(IllegalArgumentException.class, () -> builder.defaultServiceConfig(config)); } @Test diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java index 493714dfd41..fefc37e4fdc 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelServiceConfigTest.java @@ -20,6 +20,7 @@ import static io.grpc.MethodDescriptor.MethodType.UNARY; import static io.grpc.Status.Code.UNAVAILABLE; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; import com.google.common.collect.ImmutableList; @@ -34,19 +35,13 @@ import io.grpc.testing.TestMethodDescriptors; import java.util.Collections; import java.util.Map; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @RunWith(JUnit4.class) public class ManagedChannelServiceConfigTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); - @Test public void managedChannelServiceConfig_shouldParseHealthCheckingConfig() throws Exception { Map rawServiceConfig = @@ -79,10 +74,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateMethod() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate method"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate method name service/method"); } @Test @@ -92,10 +86,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateService() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name1, name2)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -107,10 +100,9 @@ public void createManagedChannelServiceConfig_failsOnDuplicateServiceMultipleCon Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig1, methodConfig2)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("Duplicate service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("Duplicate service service"); } @Test @@ -119,10 +111,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithEmptyServiceN Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -131,10 +122,9 @@ public void createManagedChannelServiceConfig_failsOnMethodNameWithoutServiceNam Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service name for method method1"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method1"); } @Test @@ -143,10 +133,9 @@ public void createManagedChannelServiceConfig_failsOnMissingServiceName() { Map methodConfig = ImmutableMap.of("name", ImmutableList.of(name)); Map serviceConfig = ImmutableMap.of("methodConfig", ImmutableList.of(methodConfig)); - thrown.expect(IllegalArgumentException.class); - thrown.expectMessage("missing service"); - - ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null); + IllegalArgumentException e = assertThrows(IllegalArgumentException.class, + () -> ManagedChannelServiceConfig.fromServiceConfig(serviceConfig, true, 3, 4, null)); + assertThat(e).hasMessageThat().isEqualTo("missing service name for method method"); } @Test diff --git a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java index 8f1b908e999..54758bc096f 100644 --- a/core/src/test/java/io/grpc/internal/MessageDeframerTest.java +++ b/core/src/test/java/io/grpc/internal/MessageDeframerTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.GrpcUtil.DEFAULT_MAX_MESSAGE_SIZE; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assume.assumeTrue; import static org.mockito.ArgumentMatchers.anyInt; @@ -53,10 +54,8 @@ import java.util.concurrent.TimeUnit; import java.util.zip.GZIPOutputStream; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; import org.junit.experimental.runners.Enclosed; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.junit.runners.Parameterized; @@ -341,9 +340,6 @@ public Void answer(InvocationOnMock invocation) throws Throwable { @RunWith(JUnit4.class) public static class SizeEnforcingInputStreamTests { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public final ExpectedException thrown = ExpectedException.none(); private TestBaseStreamTracer tracer = new TestBaseStreamTracer(); private StatsTraceContext statsTraceCtx = new StatsTraceContext(new StreamTracer[]{tracer}); @@ -381,11 +377,12 @@ public void sizeEnforcingInputStream_readByteAboveLimit() throws IOException { new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - while (stream.read() != -1) { - } + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> { + while (stream.read() != -1) { + } + }); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -427,10 +424,10 @@ public void sizeEnforcingInputStream_readAboveLimit() throws IOException { byte[] buf = new byte[10]; try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.read(buf, 0, buf.length); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, + () -> stream.read(buf, 0, buf.length)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } @@ -470,10 +467,9 @@ public void sizeEnforcingInputStream_skipAboveLimit() throws IOException { new MessageDeframer.SizeEnforcingInputStream(in, 2, statsTraceCtx); try { - thrown.expect(StatusRuntimeException.class); - thrown.expectMessage("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds"); - - stream.skip(4); + StatusRuntimeException e = assertThrows(StatusRuntimeException.class, () -> stream.skip(4)); + assertThat(e).hasMessageThat() + .isEqualTo("RESOURCE_EXHAUSTED: Decompressed gRPC message exceeds maximum size 2"); } finally { stream.close(); } diff --git a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java index 652c94a4640..7394c83eab2 100644 --- a/core/src/test/java/io/grpc/internal/ServerCallImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerCallImplTest.java @@ -16,12 +16,14 @@ package io.grpc.internal; +import static com.google.common.truth.Truth.assertThat; import static io.grpc.internal.GrpcUtil.CONTENT_LENGTH_KEY; import static java.nio.charset.StandardCharsets.UTF_8; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; @@ -54,7 +56,6 @@ import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -64,8 +65,6 @@ @RunWith(JUnit4.class) public class ServerCallImplTest { - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @Mock private ServerStream stream; @@ -175,20 +174,20 @@ public void sendHeader_contentLengthDiscarded() { @Test public void sendHeader_failsOnSecondCall() { call.sendHeaders(new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has already been called"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has already been called"); } @Test public void sendHeader_failsOnClosed() { call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendHeaders(new Metadata()); + Metadata headers = new Metadata(); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendHeaders(headers)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test @@ -204,18 +203,16 @@ public void sendMessage_failsOnClosed() { call.sendHeaders(new Metadata()); call.close(Status.CANCELLED, new Metadata()); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("call is closed"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("call is closed"); } @Test public void sendMessage_failsIfheadersUnsent() { - thrown.expect(IllegalStateException.class); - thrown.expectMessage("sendHeaders has not been called"); - - call.sendMessage(1234L); + IllegalStateException e = assertThrows(IllegalStateException.class, + () -> call.sendMessage(1234L)); + assertThat(e).hasMessageThat().isEqualTo("sendHeaders has not been called"); } @Test @@ -490,9 +487,10 @@ public void streamListener_unexpectedRuntimeException() { InputStream inputStream = UNARY_METHOD.streamRequest(1234L); - thrown.expect(RuntimeException.class); - thrown.expectMessage("unexpected exception"); - streamListener.messagesAvailable(new SingleMessageProducer(inputStream)); + SingleMessageProducer producer = new SingleMessageProducer(inputStream); + RuntimeException e = assertThrows(RuntimeException.class, + () -> streamListener.messagesAvailable(producer)); + assertThat(e).hasMessageThat().isEqualTo("unexpected exception"); } private static class LongMarshaller implements Marshaller { diff --git a/core/src/test/java/io/grpc/internal/ServerImplTest.java b/core/src/test/java/io/grpc/internal/ServerImplTest.java index 3125edca1e6..2ddaba751e4 100644 --- a/core/src/test/java/io/grpc/internal/ServerImplTest.java +++ b/core/src/test/java/io/grpc/internal/ServerImplTest.java @@ -26,6 +26,7 @@ import static org.junit.Assert.assertNotSame; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.AdditionalAnswers.delegatesTo; @@ -104,7 +105,6 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -140,8 +140,6 @@ public boolean shouldAccept(Runnable runnable) { }; private static final String AUTHORITY = "some_authority"; - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule public final ExpectedException thrown = ExpectedException.none(); @Rule public final MockitoRule mocks = MockitoJUnit.rule(); @BeforeClass @@ -1228,7 +1226,7 @@ public void testStreamClose_deadlineExceededTriggersImmediateCancellation() thro assertFalse(context.get().isCancelled()); assertEquals(1, timer.forwardNanos(1)); - + assertTrue(callReference.get().isCancelled()); assertTrue(context.get().isCancelled()); assertThat(context.get().cancellationCause()).isNotNull(); @@ -1260,9 +1258,8 @@ public List getListenSocketAddresses() { public void getPortBeforeStartedFails() { transportServer = new SimpleServer(); createServer(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("started"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Not started"); } @Test @@ -1271,9 +1268,8 @@ public void getPortAfterTerminationFails() throws Exception { createAndStartServer(); server.shutdown(); server.awaitTermination(); - thrown.expect(IllegalStateException.class); - thrown.expectMessage("terminated"); - server.getPort(); + IllegalStateException e = assertThrows(IllegalStateException.class, () -> server.getPort()); + assertThat(e).hasMessageThat().isEqualTo("Already terminated"); } @Test diff --git a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java index 4a518895db6..1f4c2b41f15 100644 --- a/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java +++ b/core/src/testFixtures/java/io/grpc/internal/AbstractTransportTest.java @@ -24,6 +24,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.junit.Assume.assumeTrue; @@ -76,9 +77,7 @@ import java.util.concurrent.TimeoutException; import org.junit.After; import org.junit.Before; -import org.junit.Rule; import org.junit.Test; -import org.junit.rules.ExpectedException; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; import org.mockito.ArgumentCaptor; @@ -209,10 +208,6 @@ public ServerStreamTracer newServerStreamTracer(String fullMethodName, Metadata } })); - @SuppressWarnings("deprecation") // https://github.com/grpc/grpc-java/issues/7467 - @Rule - public ExpectedException thrown = ExpectedException.none(); - @Before public void setUp() { server = newServer(Arrays.asList(serverStreamTracerFactory)); @@ -396,8 +391,7 @@ public void serverAlreadyListening() throws Exception { port = ((InetSocketAddress) addr).getPort(); } InternalServer server2 = newServer(port, Arrays.asList(serverStreamTracerFactory)); - thrown.expect(IOException.class); - server2.start(new MockServerListener()); + assertThrows(IOException.class, () -> server2.start(new MockServerListener())); } @Test From 77216bf876b689f0ea37e643d1995f62bb13931c Mon Sep 17 00:00:00 2001 From: yifeizhuang Date: Fri, 21 Mar 2025 15:19:25 -0700 Subject: [PATCH 08/27] otel tracing: fix span names (#11974) --- .../OpenTelemetryTracingModule.java | 4 ++-- .../OpenTelemetryTracingModuleTest.java | 18 +++++++++--------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java index 838ee0797a7..8c42a189ac2 100644 --- a/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java +++ b/opentelemetry/src/main/java/io/grpc/opentelemetry/OpenTelemetryTracingModule.java @@ -446,7 +446,7 @@ private void recordOutboundMessageSentEvent(Span span, if (optionalWireSize != -1 && optionalWireSize != optionalUncompressedSize) { attributesBuilder.put("message-size-compressed", optionalWireSize); } - span.addEvent("Outbound message sent", attributesBuilder.build()); + span.addEvent("Outbound message", attributesBuilder.build()); } private void recordInboundCompressedMessage(Span span, int seqNo, long optionalWireSize) { @@ -460,7 +460,7 @@ private void recordInboundMessageSize(Span span, int seqNo, long bytes) { AttributesBuilder attributesBuilder = io.opentelemetry.api.common.Attributes.builder(); attributesBuilder.put("sequence-number", seqNo); attributesBuilder.put("message-size", bytes); - span.addEvent("Inbound message received", attributesBuilder.build()); + span.addEvent("Inbound message", attributesBuilder.build()); } private String generateErrorStatusDescription(io.grpc.Status status) { diff --git a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java index b4486bcf2e4..bca6be94b9f 100644 --- a/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java +++ b/opentelemetry/src/test/java/io/grpc/opentelemetry/OpenTelemetryTracingModuleTest.java @@ -231,7 +231,7 @@ public void clientBasicTracingMocking() { List events = eventNameCaptor.getAllValues(); List attributes = attributesCaptor.getAllValues(); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(0)); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -241,7 +241,7 @@ public void clientBasicTracingMocking() { attributes.get(0)); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(1)); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -313,7 +313,7 @@ public void clientBasicTracingRule() { assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); assertEquals( - "Inbound message received" , + "Inbound message" , clientSpanEvents.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -323,7 +323,7 @@ public void clientBasicTracingRule() { clientSpanEvents.get(1).getAttributes()); assertEquals( - "Inbound message received" , + "Inbound message" , clientSpanEvents.get(2).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -342,7 +342,7 @@ public void clientBasicTracingRule() { assertTrue(clientSpanEvents.get(0).getAttributes().isEmpty()); assertEquals( - "Outbound message sent" , + "Outbound message" , attemptSpanEvents.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -352,7 +352,7 @@ public void clientBasicTracingRule() { attemptSpanEvents.get(1).getAttributes()); assertEquals( - "Outbound message sent" , + "Outbound message" , attemptSpanEvents.get(2).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -518,7 +518,7 @@ public void serverBasicTracingNoHeaders() { List events = spans.get(0).getEvents(); assertEquals(events.size(), 4); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(0).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -529,7 +529,7 @@ public void serverBasicTracingNoHeaders() { events.get(0).getAttributes()); assertEquals( - "Outbound message sent" , + "Outbound message" , events.get(1).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() @@ -549,7 +549,7 @@ public void serverBasicTracingNoHeaders() { events.get(2).getAttributes()); assertEquals( - "Inbound message received" , + "Inbound message" , events.get(3).getName()); assertEquals( io.opentelemetry.api.common.Attributes.builder() From d28fcb266529b656c2bd50bd2a061f05c861a675 Mon Sep 17 00:00:00 2001 From: Ashley Zhang Date: Fri, 21 Mar 2025 15:19:40 -0700 Subject: [PATCH 09/27] xds: add support for custom per-target credentials on the transport (#11951) --- .../io/grpc/xds/GrpcXdsTransportFactory.java | 54 +++++++++---- .../InternalSharedXdsClientPoolProvider.java | 10 ++- .../grpc/xds/SharedXdsClientPoolProvider.java | 50 ++++++++---- .../grpc/xds/GrpcXdsClientImplTestBase.java | 3 +- .../grpc/xds/GrpcXdsTransportFactoryTest.java | 7 +- .../io/grpc/xds/LoadReportClientTest.java | 14 ++-- .../xds/SharedXdsClientPoolProviderTest.java | 77 +++++++++++++++++++ .../io/grpc/xds/XdsClientFallbackTest.java | 34 +++++--- 8 files changed, 198 insertions(+), 51 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java index 74c28ba2d2d..0da51bf47f7 100644 --- a/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java +++ b/xds/src/main/java/io/grpc/xds/GrpcXdsTransportFactory.java @@ -19,6 +19,7 @@ import static com.google.common.base.Preconditions.checkNotNull; import com.google.common.annotations.VisibleForTesting; +import io.grpc.CallCredentials; import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ClientCall; @@ -34,35 +35,50 @@ final class GrpcXdsTransportFactory implements XdsTransportFactory { - static final GrpcXdsTransportFactory DEFAULT_XDS_TRANSPORT_FACTORY = - new GrpcXdsTransportFactory(); + private final CallCredentials callCredentials; + + GrpcXdsTransportFactory(CallCredentials callCredentials) { + this.callCredentials = callCredentials; + } @Override public XdsTransport create(Bootstrapper.ServerInfo serverInfo) { - return new GrpcXdsTransport(serverInfo); + return new GrpcXdsTransport(serverInfo, callCredentials); } @VisibleForTesting public XdsTransport createForTest(ManagedChannel channel) { - return new GrpcXdsTransport(channel); + return new GrpcXdsTransport(channel, callCredentials); } @VisibleForTesting static class GrpcXdsTransport implements XdsTransport { private final ManagedChannel channel; + private final CallCredentials callCredentials; public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo) { + this(serverInfo, null); + } + + @VisibleForTesting + public GrpcXdsTransport(ManagedChannel channel) { + this(channel, null); + } + + public GrpcXdsTransport(Bootstrapper.ServerInfo serverInfo, CallCredentials callCredentials) { String target = serverInfo.target(); ChannelCredentials channelCredentials = (ChannelCredentials) serverInfo.implSpecificConfig(); this.channel = Grpc.newChannelBuilder(target, channelCredentials) .keepAliveTime(5, TimeUnit.MINUTES) .build(); + this.callCredentials = callCredentials; } @VisibleForTesting - public GrpcXdsTransport(ManagedChannel channel) { + public GrpcXdsTransport(ManagedChannel channel, CallCredentials callCredentials) { this.channel = checkNotNull(channel, "channel"); + this.callCredentials = callCredentials; } @Override @@ -72,7 +88,8 @@ public StreamingCall createStreamingCall( MethodDescriptor.Marshaller respMarshaller) { Context prevContext = Context.ROOT.attach(); try { - return new XdsStreamingCall<>(fullMethodName, reqMarshaller, respMarshaller); + return new XdsStreamingCall<>( + fullMethodName, reqMarshaller, respMarshaller, callCredentials); } finally { Context.ROOT.detach(prevContext); } @@ -89,16 +106,21 @@ private class XdsStreamingCall implements private final ClientCall call; - public XdsStreamingCall(String methodName, MethodDescriptor.Marshaller reqMarshaller, - MethodDescriptor.Marshaller respMarshaller) { - this.call = channel.newCall( - MethodDescriptor.newBuilder() - .setFullMethodName(methodName) - .setType(MethodDescriptor.MethodType.BIDI_STREAMING) - .setRequestMarshaller(reqMarshaller) - .setResponseMarshaller(respMarshaller) - .build(), - CallOptions.DEFAULT); // TODO(zivy): support waitForReady + public XdsStreamingCall( + String methodName, + MethodDescriptor.Marshaller reqMarshaller, + MethodDescriptor.Marshaller respMarshaller, + CallCredentials callCredentials) { + this.call = + channel.newCall( + MethodDescriptor.newBuilder() + .setFullMethodName(methodName) + .setType(MethodDescriptor.MethodType.BIDI_STREAMING) + .setRequestMarshaller(reqMarshaller) + .setResponseMarshaller(respMarshaller) + .build(), + CallOptions.DEFAULT.withCallCredentials( + callCredentials)); // TODO(zivy): support waitForReady } @Override diff --git a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java index 85b59fabfa0..9c98bba93cf 100644 --- a/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/InternalSharedXdsClientPoolProvider.java @@ -16,6 +16,7 @@ package io.grpc.xds; +import io.grpc.CallCredentials; import io.grpc.Internal; import io.grpc.MetricRecorder; import io.grpc.internal.ObjectPool; @@ -42,6 +43,13 @@ public static ObjectPool getOrCreate(String target) public static ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { - return SharedXdsClientPoolProvider.getDefaultProvider().getOrCreate(target, metricRecorder); + return getOrCreate(target, metricRecorder, null); + } + + public static ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { + return SharedXdsClientPoolProvider.getDefaultProvider() + .getOrCreate(target, metricRecorder, transportCallCredentials); } } diff --git a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java index 2bc7be4a014..5302880d48c 100644 --- a/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java +++ b/xds/src/main/java/io/grpc/xds/SharedXdsClientPoolProvider.java @@ -17,11 +17,11 @@ package io.grpc.xds; import static com.google.common.base.Preconditions.checkNotNull; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableList; import com.google.errorprone.annotations.concurrent.GuardedBy; +import io.grpc.CallCredentials; import io.grpc.MetricRecorder; import io.grpc.internal.ExponentialBackoffPolicy; import io.grpc.internal.GrpcUtil; @@ -87,6 +87,12 @@ public ObjectPool get(String target) { @Override public ObjectPool getOrCreate(String target, MetricRecorder metricRecorder) throws XdsInitializationException { + return getOrCreate(target, metricRecorder, null); + } + + public ObjectPool getOrCreate( + String target, MetricRecorder metricRecorder, CallCredentials transportCallCredentials) + throws XdsInitializationException { ObjectPool ref = targetToXdsClientMap.get(target); if (ref == null) { synchronized (lock) { @@ -102,7 +108,9 @@ public ObjectPool getOrCreate(String target, MetricRecorder metricRec if (bootstrapInfo.servers().isEmpty()) { throw new XdsInitializationException("No xDS server provided"); } - ref = new RefCountedXdsClientObjectPool(bootstrapInfo, target, metricRecorder); + ref = + new RefCountedXdsClientObjectPool( + bootstrapInfo, target, metricRecorder, transportCallCredentials); targetToXdsClientMap.put(target, ref); } } @@ -126,6 +134,7 @@ class RefCountedXdsClientObjectPool implements ObjectPool { private final BootstrapInfo bootstrapInfo; private final String target; // The target associated with the xDS client. private final MetricRecorder metricRecorder; + private final CallCredentials transportCallCredentials; private final Object lock = new Object(); @GuardedBy("lock") private ScheduledExecutorService scheduler; @@ -137,11 +146,21 @@ class RefCountedXdsClientObjectPool implements ObjectPool { private XdsClientMetricReporterImpl metricReporter; @VisibleForTesting - RefCountedXdsClientObjectPool(BootstrapInfo bootstrapInfo, String target, - MetricRecorder metricRecorder) { + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, String target, MetricRecorder metricRecorder) { + this(bootstrapInfo, target, metricRecorder, null); + } + + @VisibleForTesting + RefCountedXdsClientObjectPool( + BootstrapInfo bootstrapInfo, + String target, + MetricRecorder metricRecorder, + CallCredentials transportCallCredentials) { this.bootstrapInfo = checkNotNull(bootstrapInfo); this.target = target; this.metricRecorder = metricRecorder; + this.transportCallCredentials = transportCallCredentials; } @Override @@ -153,16 +172,19 @@ public XdsClient getObject() { } scheduler = SharedResourceHolder.get(GrpcUtil.TIMER_SERVICE); metricReporter = new XdsClientMetricReporterImpl(metricRecorder, target); - xdsClient = new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, - bootstrapInfo, - scheduler, - BACKOFF_POLICY_PROVIDER, - GrpcUtil.STOPWATCH_SUPPLIER, - TimeProvider.SYSTEM_TIME_PROVIDER, - MessagePrinter.INSTANCE, - new TlsContextManagerImpl(bootstrapInfo), - metricReporter); + GrpcXdsTransportFactory xdsTransportFactory = + new GrpcXdsTransportFactory(transportCallCredentials); + xdsClient = + new XdsClientImpl( + xdsTransportFactory, + bootstrapInfo, + scheduler, + BACKOFF_POLICY_PROVIDER, + GrpcUtil.STOPWATCH_SUPPLIER, + TimeProvider.SYSTEM_TIME_PROVIDER, + MessagePrinter.INSTANCE, + new TlsContextManagerImpl(bootstrapInfo), + metricReporter); metricReporter.setXdsClient(xdsClient); } refCount++; diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java index 51c07cb3537..36131464d08 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplTestBase.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; @@ -4193,7 +4192,7 @@ public void serverFailureMetricReport_forRetryAndBackoff() { private XdsClientImpl createXdsClient(String serverUri) { BootstrapInfo bootstrapInfo = buildBootStrap(serverUri); return new XdsClientImpl( - DEFAULT_XDS_TRANSPORT_FACTORY, + new GrpcXdsTransportFactory(null), bootstrapInfo, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java index 703e429fa23..66e0d4b3198 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsTransportFactoryTest.java @@ -92,9 +92,10 @@ public void onCompleted() { @Test public void callApis() throws Exception { XdsTransportFactory.XdsTransport xdsTransport = - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.create( - Bootstrapper.ServerInfo.create("localhost:" + server.getPort(), - InsecureChannelCredentials.create())); + new GrpcXdsTransportFactory(null) + .create( + Bootstrapper.ServerInfo.create( + "localhost:" + server.getPort(), InsecureChannelCredentials.create())); MethodDescriptor methodDescriptor = AggregatedDiscoveryServiceGrpc.getStreamAggregatedResourcesMethod(); XdsTransportFactory.StreamingCall streamingCall = diff --git a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java index c11a3a6e0d2..9bdf86132b6 100644 --- a/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java +++ b/xds/src/test/java/io/grpc/xds/LoadReportClientTest.java @@ -178,11 +178,15 @@ public void cancelled(Context context) { when(backoffPolicy2.nextBackoffNanos()) .thenReturn(TimeUnit.SECONDS.toNanos(2L), TimeUnit.SECONDS.toNanos(20L)); addFakeStatsData(); - lrsClient = new LoadReportClient(loadStatsManager, - GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY.createForTest(channel), - NODE, - syncContext, fakeClock.getScheduledExecutorService(), backoffPolicyProvider, - fakeClock.getStopwatchSupplier()); + lrsClient = + new LoadReportClient( + loadStatsManager, + new GrpcXdsTransportFactory(null).createForTest(channel), + NODE, + syncContext, + fakeClock.getScheduledExecutorService(), + backoffPolicyProvider, + fakeClock.getStopwatchSupplier()); syncContext.execute(new Runnable() { @Override public void run() { diff --git a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java index 4fb77f0be42..86e4fc83a8c 100644 --- a/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java +++ b/xds/src/test/java/io/grpc/xds/SharedXdsClientPoolProviderTest.java @@ -18,20 +18,36 @@ import static com.google.common.truth.Truth.assertThat; +import static io.grpc.Metadata.ASCII_STRING_MARSHALLER; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyNoMoreInteractions; import static org.mockito.Mockito.when; +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.common.util.concurrent.SettableFuture; +import io.grpc.CallCredentials; +import io.grpc.Grpc; import io.grpc.InsecureChannelCredentials; +import io.grpc.InsecureServerCredentials; +import io.grpc.Metadata; import io.grpc.MetricRecorder; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.auth.MoreCallCredentials; import io.grpc.internal.ObjectPool; import io.grpc.xds.SharedXdsClientPoolProvider.RefCountedXdsClientObjectPool; +import io.grpc.xds.XdsListenerResource.LdsUpdate; import io.grpc.xds.client.Bootstrapper.BootstrapInfo; import io.grpc.xds.client.Bootstrapper.ServerInfo; import io.grpc.xds.client.EnvoyProtoData.Node; import io.grpc.xds.client.XdsClient; +import io.grpc.xds.client.XdsClient.ResourceWatcher; import io.grpc.xds.client.XdsInitializationException; import java.util.Collections; +import java.util.concurrent.TimeUnit; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -54,9 +70,12 @@ public class SharedXdsClientPoolProviderTest { private final Node node = Node.newBuilder().setId("SharedXdsClientPoolProviderTest").build(); private final MetricRecorder metricRecorder = new MetricRecorder() {}; private static final String DUMMY_TARGET = "dummy"; + static final Metadata.Key AUTHORIZATION_METADATA_KEY = + Metadata.Key.of("Authorization", ASCII_STRING_MARSHALLER); @Mock private GrpcBootstrapperImpl bootstrapper; + @Mock private ResourceWatcher ldsResourceWatcher; @Test public void noServer() throws XdsInitializationException { @@ -138,4 +157,62 @@ public void refCountedXdsClientObjectPool_getObjectCreatesNewInstanceIfAlreadySh assertThat(xdsClient2).isNotSameInstanceAs(xdsClient1); xdsClientPool.returnObject(xdsClient2); } + + private class CallCredsServerInterceptor implements ServerInterceptor { + private SettableFuture tokenFuture = SettableFuture.create(); + + @Override + public ServerCall.Listener interceptCall( + ServerCall serverCall, + Metadata metadata, + ServerCallHandler next) { + tokenFuture.set(metadata.get(AUTHORIZATION_METADATA_KEY)); + return next.startCall(serverCall, metadata); + } + + public String getTokenWithTimeout(long timeout, TimeUnit unit) throws Exception { + return tokenFuture.get(timeout, unit); + } + } + + @Test + public void xdsClient_usesCallCredentials() throws Exception { + // Set up fake xDS server + XdsTestControlPlaneService fakeXdsService = new XdsTestControlPlaneService(); + CallCredsServerInterceptor callCredentialsInterceptor = new CallCredsServerInterceptor(); + Server xdsServer = + Grpc.newServerBuilderForPort(0, InsecureServerCredentials.create()) + .addService(fakeXdsService) + .intercept(callCredentialsInterceptor) + .build() + .start(); + String xdsServerUri = "localhost:" + xdsServer.getPort(); + + // Set up bootstrap & xDS client pool provider + ServerInfo server = ServerInfo.create(xdsServerUri, InsecureChannelCredentials.create()); + BootstrapInfo bootstrapInfo = + BootstrapInfo.builder().servers(Collections.singletonList(server)).node(node).build(); + when(bootstrapper.bootstrap()).thenReturn(bootstrapInfo); + SharedXdsClientPoolProvider provider = new SharedXdsClientPoolProvider(bootstrapper); + + // Create custom xDS transport CallCredentials + CallCredentials sampleCreds = + MoreCallCredentials.from( + OAuth2Credentials.create(new AccessToken("token", /* expirationTime= */ null))); + + // Create xDS client that uses the CallCredentials on the transport + ObjectPool xdsClientPool = + provider.getOrCreate("target", metricRecorder, sampleCreds); + XdsClient xdsClient = xdsClientPool.getObject(); + xdsClient.watchXdsResource( + XdsListenerResource.getInstance(), "someLDSresource", ldsResourceWatcher); + + // Wait for xDS server to get the request and verify that it received the CallCredentials + assertThat(callCredentialsInterceptor.getTokenWithTimeout(5, TimeUnit.SECONDS)) + .isEqualTo("Bearer token"); + + // Clean up + xdsClientPool.returnObject(xdsClient); + xdsServer.shutdownNow(); + } } diff --git a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java index 97c2695f209..036b9f6f55d 100644 --- a/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsClientFallbackTest.java @@ -18,7 +18,6 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; -import static io.grpc.xds.GrpcXdsTransportFactory.DEFAULT_XDS_TRANSPORT_FACTORY; import static org.mockito.AdditionalAnswers.delegatesTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -442,9 +441,14 @@ public void fallbackFromBadUrlToGoodOne() { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri, validUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri, validUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); @@ -462,9 +466,14 @@ public void testGoodUrlFollowedByBadUrl() { String garbageUri = "some. garbage"; String validUri = "localhost:" + mainXdsServer.getServer().getPort(); - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(validUri, garbageUri), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(validUri, garbageUri), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); verify(ldsWatcher, timeout(5000)).onChanged( @@ -481,9 +490,14 @@ public void testTwoBadUrl() { String garbageUri1 = "some. garbage"; String garbageUri2 = "other garbage"; - XdsClientImpl client = CommonBootstrapperTestUtils.createXdsClient( - Arrays.asList(garbageUri1, garbageUri2), DEFAULT_XDS_TRANSPORT_FACTORY, fakeClock, - new ExponentialBackoffPolicy.Provider(), MessagePrinter.INSTANCE, xdsClientMetricReporter); + XdsClientImpl client = + CommonBootstrapperTestUtils.createXdsClient( + Arrays.asList(garbageUri1, garbageUri2), + new GrpcXdsTransportFactory(null), + fakeClock, + new ExponentialBackoffPolicy.Provider(), + MessagePrinter.INSTANCE, + xdsClientMetricReporter); client.watchXdsResource(XdsListenerResource.getInstance(), MAIN_SERVER, ldsWatcher); fakeClock.forwardTime(20, TimeUnit.SECONDS); From eeded6fcc9b657adedc5a4d75651453955bf291b Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Mon, 24 Mar 2025 21:32:53 +0000 Subject: [PATCH 10/27] core: Log any exception during panic because of exception panic() calls a good amount of code, so it could get another exception. The SynchronizationContext is running on an arbitrary thread and we don't want to propagate this secondary exception up its stack (to be handled by its UncaughtExceptionHandler); it we wanted that we'd propagate the original exception. This second exception will only be seen in the logs; the first exception was logged and will be used to fail RPCs. Also related to http://yaqs/8493785598685872128 and b692b9d26 --- .../src/main/java/io/grpc/internal/ManagedChannelImpl.java | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java index dd4d0aef5be..1b51c2dbb32 100644 --- a/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java +++ b/core/src/main/java/io/grpc/internal/ManagedChannelImpl.java @@ -187,7 +187,12 @@ public void uncaughtException(Thread t, Throwable e) { Level.SEVERE, "[" + getLogId() + "] Uncaught exception in the SynchronizationContext. Panic!", e); - panic(e); + try { + panic(e); + } catch (Throwable anotherT) { + logger.log( + Level.SEVERE, "[" + getLogId() + "] Uncaught exception while panicking", anotherT); + } } }); From 036cd413430c8ef3c522934c5a4b8499aa01d0d3 Mon Sep 17 00:00:00 2001 From: jiangyuan Date: Tue, 25 Mar 2025 20:12:28 +0800 Subject: [PATCH 11/27] services: Avoid cancellation exceptions when notifying watchers that already have their connections cancelled (#11934) Some clients watching health status can cancel their watch and `HealthService` when trying to notify these watchers were getting CANCELLED exception because there was no cancellation handler set on the `StreamObserver`. This change sets the cancellation handler that removes the watcher from the set of watcher clients to be notified of the health status. --- .../protobuf/services/HealthServiceImpl.java | 30 ++++++++++++------- .../services/HealthStatusManagerTest.java | 18 +++++++++++ 2 files changed, 38 insertions(+), 10 deletions(-) diff --git a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java index 2efe4b3951a..5cd294b4fbe 100644 --- a/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java +++ b/services/src/main/java/io/grpc/protobuf/services/HealthServiceImpl.java @@ -27,6 +27,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import java.util.HashMap; import java.util.IdentityHashMap; @@ -83,6 +84,11 @@ public void watch(HealthCheckRequest request, final StreamObserver responseObserver) { final String service = request.getService(); synchronized (watchLock) { + if (responseObserver instanceof ServerCallStreamObserver) { + ((ServerCallStreamObserver) responseObserver).setOnCancelHandler(() -> { + removeWatcher(service, responseObserver); + }); + } ServingStatus status = statusMap.get(service); responseObserver.onNext(getResponseForWatch(status)); IdentityHashMap, Boolean> serviceWatchers = @@ -98,21 +104,25 @@ public void watch(HealthCheckRequest request, @Override // Called when the client has closed the stream public void cancelled(Context context) { - synchronized (watchLock) { - IdentityHashMap, Boolean> serviceWatchers = - watchers.get(service); - if (serviceWatchers != null) { - serviceWatchers.remove(responseObserver); - if (serviceWatchers.isEmpty()) { - watchers.remove(service); - } - } - } + removeWatcher(service, responseObserver); } }, MoreExecutors.directExecutor()); } + void removeWatcher(String service, StreamObserver responseObserver) { + synchronized (watchLock) { + IdentityHashMap, Boolean> serviceWatchers = + watchers.get(service); + if (serviceWatchers != null) { + serviceWatchers.remove(responseObserver); + if (serviceWatchers.isEmpty()) { + watchers.remove(service); + } + } + } + } + void setStatus(String service, ServingStatus status) { synchronized (watchLock) { if (terminal) { diff --git a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java index 87d4ac29be8..b2652e92771 100644 --- a/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java +++ b/services/src/test/java/io/grpc/protobuf/services/HealthStatusManagerTest.java @@ -18,6 +18,11 @@ import static com.google.common.truth.Truth.assertThat; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import io.grpc.BindableService; import io.grpc.Context; @@ -28,6 +33,7 @@ import io.grpc.health.v1.HealthCheckResponse; import io.grpc.health.v1.HealthCheckResponse.ServingStatus; import io.grpc.health.v1.HealthGrpc; +import io.grpc.stub.ServerCallStreamObserver; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcServerRule; import java.util.ArrayDeque; @@ -109,6 +115,18 @@ public void enterTerminalState_watch() throws Exception { assertThat(obs.responses).isEmpty(); } + @Test + @SuppressWarnings("unchecked") + public void serverCallStreamObserver_watch() throws Exception { + manager.setStatus(SERVICE1, ServingStatus.SERVING); + ServerCallStreamObserver observer = mock(ServerCallStreamObserver.class); + service.watch(HealthCheckRequest.newBuilder().setService(SERVICE1).build(), observer); + + verify(observer, times(1)) + .onNext(eq(HealthCheckResponse.newBuilder().setStatus(ServingStatus.SERVING).build())); + verify(observer, times(1)).setOnCancelHandler(any(Runnable.class)); + } + @Test public void enterTerminalState_ignoreClear() throws Exception { manager.setStatus(SERVICE1, ServingStatus.SERVING); From 61bb8784f64f32419ab3149f5a30e566e366eccb Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal <81427947+AgraVator@users.noreply.github.com> Date: Wed, 26 Mar 2025 06:13:05 +0000 Subject: [PATCH 12/27] fix: cleans up FileWatcherCertificateProvider in XdsSecurityClientServerTest --- .../io/grpc/xds/XdsSecurityClientServerTest.java | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java index cd3ef293369..380c0591812 100644 --- a/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java +++ b/xds/src/test/java/io/grpc/xds/XdsSecurityClientServerTest.java @@ -130,6 +130,7 @@ public class XdsSecurityClientServerTest { private FakeXdsClient xdsClient = new FakeXdsClient(); private FakeXdsClientPoolFactory fakePoolFactory = new FakeXdsClientPoolFactory(xdsClient); private static final String OVERRIDE_AUTHORITY = "foo.test.google.fr"; + private Attributes sslContextAttributes; @Parameters(name = "enableSpiffe={0}") public static Collection data() { @@ -152,6 +153,14 @@ public void tearDown() throws IOException { NameResolverRegistry.getDefaultRegistry().deregister(fakeNameResolverFactory); } FileWatcherCertificateProviderProvider.enableSpiffe = originalEnableSpiffe; + if (sslContextAttributes != null) { + SslContextProviderSupplier sslContextProviderSupplier = sslContextAttributes.get( + SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER); + if (sslContextProviderSupplier != null) { + sslContextProviderSupplier.close(); + } + sslContextAttributes = null; + } } @Test @@ -651,7 +660,7 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( InetSocketAddress socketAddress = new InetSocketAddress(Inet4Address.getLoopbackAddress(), port); tlsContextManagerForClient = new TlsContextManagerImpl(bootstrapInfoForClient); - Attributes attrs = + sslContextAttributes = (upstreamTlsContext != null) ? Attributes.newBuilder() .set(SecurityProtocolNegotiators.ATTR_SSL_CONTEXT_PROVIDER_SUPPLIER, @@ -660,7 +669,7 @@ private SimpleServiceGrpc.SimpleServiceBlockingStub getBlockingStub( .build() : Attributes.EMPTY; fakeNameResolverFactory.setServers( - ImmutableList.of(new EquivalentAddressGroup(socketAddress, attrs))); + ImmutableList.of(new EquivalentAddressGroup(socketAddress, sslContextAttributes))); return SimpleServiceGrpc.newBlockingStub(cleanupRule.register(channelBuilder.build())); } From 3a35a76acf42f24bb3c0be1f891b326152f1a21f Mon Sep 17 00:00:00 2001 From: Alex Panchenko <440271+panchenko@users.noreply.github.com> Date: Wed, 26 Mar 2025 10:19:21 +0200 Subject: [PATCH 13/27] core: Use java.time.Time.getNano in InstantTimeProvider without reflection (#11977) Fixes #11975 --- .../io/grpc/internal/InstantTimeProvider.java | 30 ++++--------------- .../internal/TimeProviderResolverFactory.java | 4 +-- .../internal/InstantTimeProviderTest.java | 3 +- 3 files changed, 9 insertions(+), 28 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/InstantTimeProvider.java b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java index 38c840d2594..12996163753 100644 --- a/core/src/main/java/io/grpc/internal/InstantTimeProvider.java +++ b/core/src/main/java/io/grpc/internal/InstantTimeProvider.java @@ -18,37 +18,19 @@ import static com.google.common.math.LongMath.saturatedAdd; -import java.lang.reflect.InvocationTargetException; -import java.lang.reflect.Method; +import java.time.Instant; import java.util.concurrent.TimeUnit; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; /** * {@link InstantTimeProvider} resolves InstantTimeProvider which implements {@link TimeProvider}. */ final class InstantTimeProvider implements TimeProvider { - private Method now; - private Method getNano; - private Method getEpochSecond; - - public InstantTimeProvider(Class instantClass) { - try { - this.now = instantClass.getMethod("now"); - this.getNano = instantClass.getMethod("getNano"); - this.getEpochSecond = instantClass.getMethod("getEpochSecond"); - } catch (NoSuchMethodException ex) { - throw new AssertionError(ex); - } - } - @Override + @IgnoreJRERequirement public long currentTimeNanos() { - try { - Object instant = now.invoke(null); - int nanos = (int) getNano.invoke(instant); - long epochSeconds = (long) getEpochSecond.invoke(instant); - return saturatedAdd(TimeUnit.SECONDS.toNanos(epochSeconds), nanos); - } catch (IllegalAccessException | InvocationTargetException ex) { - throw new RuntimeException(ex); - } + Instant now = Instant.now(); + long epochSeconds = now.getEpochSecond(); + return saturatedAdd(TimeUnit.SECONDS.toNanos(epochSeconds), now.getNano()); } } diff --git a/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java index d88d9bb9eb5..04272034ce9 100644 --- a/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java +++ b/core/src/main/java/io/grpc/internal/TimeProviderResolverFactory.java @@ -23,8 +23,8 @@ final class TimeProviderResolverFactory { static TimeProvider resolveTimeProvider() { try { - Class instantClass = Class.forName("java.time.Instant"); - return new InstantTimeProvider(instantClass); + Class.forName("java.time.Instant"); + return new InstantTimeProvider(); } catch (ClassNotFoundException ex) { return new ConcurrentTimeProvider(); } diff --git a/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java index ac9a02fa936..6702bc421a5 100644 --- a/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java +++ b/core/src/test/java/io/grpc/internal/InstantTimeProviderTest.java @@ -34,8 +34,7 @@ public class InstantTimeProviderTest { @Test public void testInstantCurrentTimeNanos() throws Exception { - InstantTimeProvider instantTimeProvider = new InstantTimeProvider( - Class.forName("java.time.Instant")); + InstantTimeProvider instantTimeProvider = new InstantTimeProvider(); // Get the current time from the InstantTimeProvider long actualTimeNanos = instantTimeProvider.currentTimeNanos(); From fa77210a65d0e64ba253dcf182aa4b5d58985e77 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Thu, 27 Mar 2025 13:52:30 -0700 Subject: [PATCH 14/27] util: Graceful switch to new LB when leaving CONNECTING Previously it would wait for the new LB to enter READY. However, that prevents there being an upper-bound on how long the old policy will continue to be used. The point of graceful switch is to avoid RPCs seeing increased latency when we swap config. We don't want it to prevent the system from becoming eventually consistent. --- .../grpc/util/GracefulSwitchLoadBalancer.java | 5 ++-- .../util/GracefulSwitchLoadBalancerTest.java | 26 +++++++++++++++---- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java index 72c41886ad4..e36eec1ff25 100644 --- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java @@ -38,7 +38,8 @@ /** * A load balancer that gracefully swaps to a new lb policy. If the channel is currently in a state * other than READY, the new policy will be swapped into place immediately. Otherwise, the channel - * will keep using the old policy until the new policy reports READY or the old policy exits READY. + * will keep using the old policy until the new policy leaves CONNECTING or the old policy exits + * READY. * *

    The child balancer and configuration is specified using service config. Config objects are * generally created by calling {@link #parseLoadBalancingPolicyConfig(List)} from a @@ -147,7 +148,7 @@ public void updateBalancingState(ConnectivityState newState, SubchannelPicker ne checkState(currentLbIsReady, "there's pending lb while current lb has been out of READY"); pendingState = newState; pendingPicker = newPicker; - if (newState == ConnectivityState.READY) { + if (newState != ConnectivityState.CONNECTING) { swap(); } } else if (lb == currentLb) { diff --git a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java index 9a4f569c144..5192d6a2a64 100644 --- a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java @@ -18,6 +18,7 @@ import static com.google.common.truth.Truth.assertThat; import static io.grpc.ConnectivityState.CONNECTING; +import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; import static io.grpc.util.GracefulSwitchLoadBalancer.BUFFER_PICKER; @@ -32,6 +33,7 @@ import static org.mockito.Mockito.when; import com.google.common.testing.EqualsTester; +import io.grpc.ConnectivityState; import io.grpc.ConnectivityStateInfo; import io.grpc.EquivalentAddressGroup; import io.grpc.LoadBalancer; @@ -363,7 +365,21 @@ public void createSubchannelForwarded() { } @Test - public void updateBalancingStateIsGraceful() { + public void updateBalancingStateIsGraceful_Ready() { + updateBalancingStateIsGraceful(READY); + } + + @Test + public void updateBalancingStateIsGraceful_TransientFailure() { + updateBalancingStateIsGraceful(TRANSIENT_FAILURE); + } + + @Test + public void updateBalancingStateIsGraceful_Idle() { + updateBalancingStateIsGraceful(IDLE); + } + + public void updateBalancingStateIsGraceful(ConnectivityState swapsOnState) { assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() .setLoadBalancingPolicyConfig(createConfig(lbPolicies[0], new Object())) .build())); @@ -392,11 +408,11 @@ public void updateBalancingStateIsGraceful() { helper2.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - // lb2 reports READY + // lb2 reports swapsOnState SubchannelPicker picker2 = mock(SubchannelPicker.class); - helper2.updateBalancingState(READY, picker2); + helper2.updateBalancingState(swapsOnState, picker2); verify(lb0).shutdown(); - verify(mockHelper).updateBalancingState(READY, picker2); + verify(mockHelper).updateBalancingState(swapsOnState, picker2); assertIsOk(gracefulSwitchLb.acceptResolvedAddresses(addressesBuilder() .setLoadBalancingPolicyConfig(createConfig(lbPolicies[3], new Object())) @@ -407,7 +423,7 @@ public void updateBalancingStateIsGraceful() { helper3.updateBalancingState(CONNECTING, picker3); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker3); - // lb2 out of READY + // lb2 out of swapsOnState picker2 = mock(SubchannelPicker.class); helper2.updateBalancingState(CONNECTING, picker2); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker2); From 129e7479de7e252d89a42b346f8032b02fa71ea5 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Fri, 28 Mar 2025 19:49:36 +0000 Subject: [PATCH 15/27] util: Replace BUFFER_PICKER with FixedResultPicker I think at some point there were more usages in the tests. But now it is pretty easy. PriorityLb.ChildLbState.picker is initialized to FixedResultPicker(NoResult). So now that GracefulSwitchLb is using the same picker, equals() is able to de-dup an update. --- .../io/grpc/util/GracefulSwitchLoadBalancer.java | 16 +--------------- .../util/GracefulSwitchLoadBalancerTest.java | 7 +++++-- .../io/grpc/xds/PriorityLoadBalancerTest.java | 9 ++++----- 3 files changed, 10 insertions(+), 22 deletions(-) diff --git a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java index e36eec1ff25..1dc4fb6750a 100644 --- a/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/GracefulSwitchLoadBalancer.java @@ -19,7 +19,6 @@ import static com.google.common.base.Preconditions.checkNotNull; import static com.google.common.base.Preconditions.checkState; -import com.google.common.annotations.VisibleForTesting; import com.google.common.base.MoreObjects; import com.google.common.base.Objects; import io.grpc.ConnectivityState; @@ -66,19 +65,6 @@ public void handleNameResolutionError(final Status error) { public void shutdown() {} }; - @VisibleForTesting - static final SubchannelPicker BUFFER_PICKER = new SubchannelPicker() { - @Override - public PickResult pickSubchannel(PickSubchannelArgs args) { - return PickResult.withNoResult(); - } - - @Override - public String toString() { - return "BUFFER_PICKER"; - } - }; - private final Helper helper; // While the new policy is not fully switched on, the pendingLb is handling new updates from name @@ -128,7 +114,7 @@ private void switchToInternal(LoadBalancer.Factory newBalancerFactory) { pendingLb = defaultBalancer; pendingBalancerFactory = null; pendingState = ConnectivityState.CONNECTING; - pendingPicker = BUFFER_PICKER; + pendingPicker = new FixedResultPicker(PickResult.withNoResult()); if (newBalancerFactory.equals(currentBalancerFactory)) { return; diff --git a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java index 5192d6a2a64..843e16194c5 100644 --- a/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/GracefulSwitchLoadBalancerTest.java @@ -21,7 +21,6 @@ import static io.grpc.ConnectivityState.IDLE; import static io.grpc.ConnectivityState.READY; import static io.grpc.ConnectivityState.TRANSIENT_FAILURE; -import static io.grpc.util.GracefulSwitchLoadBalancer.BUFFER_PICKER; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.inOrder; @@ -521,7 +520,11 @@ public void switchWhileOldPolicyGoesFromReadyToNotReadyWhileNewPolicyStillIdle() helper0.updateBalancingState(CONNECTING, picker); verify(mockHelper, never()).updateBalancingState(CONNECTING, picker); - inOrder.verify(mockHelper).updateBalancingState(CONNECTING, BUFFER_PICKER); + ArgumentCaptor pickerCaptor = ArgumentCaptor.forClass(SubchannelPicker.class); + inOrder.verify(mockHelper).updateBalancingState(eq(CONNECTING), pickerCaptor.capture()); + assertThat(pickerCaptor.getValue().pickSubchannel(mock(PickSubchannelArgs.class)).hasResult()) + .isFalse(); + inOrder.verify(lb0).shutdown(); // shutdown after update picker = mock(SubchannelPicker.class); diff --git a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java index 9823501dcd9..08d4863d194 100644 --- a/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java +++ b/xds/src/test/java/io/grpc/xds/PriorityLoadBalancerTest.java @@ -522,8 +522,7 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); // Nothing important about this verify, other than to provide a baseline - verify(helper, times(2)) - .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); + verify(helper).updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); assertThat(fooBalancers).hasSize(1); assertThat(fooHelpers).hasSize(1); Helper helper0 = Iterables.getOnlyElement(fooHelpers); @@ -539,7 +538,7 @@ public void connectingResetFailOverIfSeenReadyOrIdleSinceTransientFailure() { helper0.updateBalancingState( CONNECTING, EMPTY_PICKER); - verify(helper, times(3)) + verify(helper, times(2)) .updateBalancingState(eq(CONNECTING), pickerReturns(PickResult.withNoResult())); // failover happens @@ -805,7 +804,7 @@ public void raceBetweenShutdownAndChildLbBalancingStateUpdate() { .setAddresses(ImmutableList.of()) .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(2)).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); + verify(helper).updateBalancingState(eq(CONNECTING), isA(SubchannelPicker.class)); // LB shutdown and subchannel state change can happen simultaneously. If shutdown runs first, // any further balancing state update should be ignored. @@ -843,7 +842,7 @@ public void noDuplicateOverallBalancingStateUpdate() { .setLoadBalancingPolicyConfig(priorityLbConfig) .build()); - verify(helper, times(6)).updateBalancingState(any(), any()); + verify(helper, times(4)).updateBalancingState(any(), any()); } private void assertLatestConnectivityState(ConnectivityState expectedState) { From f828a3c74de58ed0a0f924827562320a381b6c87 Mon Sep 17 00:00:00 2001 From: Abhishek Agrawal <81427947+AgraVator@users.noreply.github.com> Date: Tue, 1 Apr 2025 10:57:39 +0000 Subject: [PATCH 16/27] Start 1.73.0 development cycle (#11987) --- MODULE.bazel | 2 +- build.gradle | 2 +- .../src/test/golden/TestDeprecatedService.java.txt | 2 +- compiler/src/test/golden/TestService.java.txt | 2 +- core/src/main/java/io/grpc/internal/GrpcUtil.java | 2 +- examples/MODULE.bazel | 2 +- examples/android/clientcache/app/build.gradle | 10 +++++----- examples/android/helloworld/app/build.gradle | 8 ++++---- examples/android/routeguide/app/build.gradle | 8 ++++---- examples/android/strictmode/app/build.gradle | 8 ++++---- examples/build.gradle | 2 +- examples/example-alts/build.gradle | 2 +- examples/example-debug/build.gradle | 2 +- examples/example-debug/pom.xml | 4 ++-- examples/example-dualstack/build.gradle | 2 +- examples/example-dualstack/pom.xml | 4 ++-- examples/example-gauth/build.gradle | 2 +- examples/example-gauth/pom.xml | 4 ++-- examples/example-gcp-csm-observability/build.gradle | 2 +- examples/example-gcp-observability/build.gradle | 2 +- examples/example-hostname/build.gradle | 2 +- examples/example-hostname/pom.xml | 4 ++-- examples/example-jwt-auth/build.gradle | 2 +- examples/example-jwt-auth/pom.xml | 4 ++-- examples/example-oauth/build.gradle | 2 +- examples/example-oauth/pom.xml | 4 ++-- examples/example-opentelemetry/build.gradle | 2 +- examples/example-orca/build.gradle | 2 +- examples/example-reflection/build.gradle | 2 +- examples/example-servlet/build.gradle | 2 +- examples/example-tls/build.gradle | 2 +- examples/example-tls/pom.xml | 4 ++-- examples/example-xds/build.gradle | 2 +- examples/pom.xml | 4 ++-- 34 files changed, 55 insertions(+), 55 deletions(-) diff --git a/MODULE.bazel b/MODULE.bazel index 88f3a524060..d69bb1927aa 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -2,7 +2,7 @@ module( name = "grpc-java", compatibility_level = 0, repo_name = "io_grpc_grpc_java", - version = "1.72.0-SNAPSHOT", # CURRENT_GRPC_VERSION + version = "1.73.0-SNAPSHOT", # CURRENT_GRPC_VERSION ) # GRPC_DEPS_START diff --git a/build.gradle b/build.gradle index 93ce60054bc..42e16fbfef9 100644 --- a/build.gradle +++ b/build.gradle @@ -21,7 +21,7 @@ subprojects { apply plugin: "net.ltgt.errorprone" group = "io.grpc" - version = "1.72.0-SNAPSHOT" // CURRENT_GRPC_VERSION + version = "1.73.0-SNAPSHOT" // CURRENT_GRPC_VERSION repositories { maven { // The google mirror is less flaky than mavenCentral() diff --git a/compiler/src/test/golden/TestDeprecatedService.java.txt b/compiler/src/test/golden/TestDeprecatedService.java.txt index 82fe6a81d18..a870a5a82e1 100644 --- a/compiler/src/test/golden/TestDeprecatedService.java.txt +++ b/compiler/src/test/golden/TestDeprecatedService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.72.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.73.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated @java.lang.Deprecated diff --git a/compiler/src/test/golden/TestService.java.txt b/compiler/src/test/golden/TestService.java.txt index 912bd50da12..1f62ca26718 100644 --- a/compiler/src/test/golden/TestService.java.txt +++ b/compiler/src/test/golden/TestService.java.txt @@ -8,7 +8,7 @@ import static io.grpc.MethodDescriptor.generateFullMethodName; * */ @javax.annotation.Generated( - value = "by gRPC proto compiler (version 1.72.0-SNAPSHOT)", + value = "by gRPC proto compiler (version 1.73.0-SNAPSHOT)", comments = "Source: grpc/testing/compiler/test.proto") @io.grpc.stub.annotations.GrpcGenerated public final class TestServiceGrpc { diff --git a/core/src/main/java/io/grpc/internal/GrpcUtil.java b/core/src/main/java/io/grpc/internal/GrpcUtil.java index 937854ac3ff..15fcdfb6300 100644 --- a/core/src/main/java/io/grpc/internal/GrpcUtil.java +++ b/core/src/main/java/io/grpc/internal/GrpcUtil.java @@ -219,7 +219,7 @@ public byte[] parseAsciiString(byte[] serialized) { public static final Splitter ACCEPT_ENCODING_SPLITTER = Splitter.on(',').trimResults(); - public static final String IMPLEMENTATION_VERSION = "1.72.0-SNAPSHOT"; // CURRENT_GRPC_VERSION + public static final String IMPLEMENTATION_VERSION = "1.73.0-SNAPSHOT"; // CURRENT_GRPC_VERSION /** * The default timeout in nanos for a keepalive ping request. diff --git a/examples/MODULE.bazel b/examples/MODULE.bazel index 4d72ae9c395..b960555e8c5 100644 --- a/examples/MODULE.bazel +++ b/examples/MODULE.bazel @@ -1,5 +1,5 @@ bazel_dep(name = "googleapis", repo_name = "com_google_googleapis", version = "0.0.0-20240326-1c8d509c5") -bazel_dep(name = "grpc-java", repo_name = "io_grpc_grpc_java", version = "1.72.0-SNAPSHOT") # CURRENT_GRPC_VERSION +bazel_dep(name = "grpc-java", repo_name = "io_grpc_grpc_java", version = "1.73.0-SNAPSHOT") # CURRENT_GRPC_VERSION bazel_dep(name = "grpc-proto", repo_name = "io_grpc_grpc_proto", version = "0.0.0-20240627-ec30f58") bazel_dep(name = "protobuf", repo_name = "com_google_protobuf", version = "23.1") bazel_dep(name = "rules_jvm_external", version = "6.0") diff --git a/examples/android/clientcache/app/build.gradle b/examples/android/clientcache/app/build.gradle index 1f2a17ae6bb..670193167fe 100644 --- a/examples/android/clientcache/app/build.gradle +++ b/examples/android/clientcache/app/build.gradle @@ -34,7 +34,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -54,12 +54,12 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' testImplementation 'junit:junit:4.13.2' testImplementation 'com.google.truth:truth:1.1.5' - testImplementation 'io.grpc:grpc-testing:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + testImplementation 'io.grpc:grpc-testing:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION } diff --git a/examples/android/helloworld/app/build.gradle b/examples/android/helloworld/app/build.gradle index 09b994a4954..81f79c67440 100644 --- a/examples/android/helloworld/app/build.gradle +++ b/examples/android/helloworld/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/routeguide/app/build.gradle b/examples/android/routeguide/app/build.gradle index bdad129845b..12e6430d2ad 100644 --- a/examples/android/routeguide/app/build.gradle +++ b/examples/android/routeguide/app/build.gradle @@ -32,7 +32,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -52,8 +52,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/android/strictmode/app/build.gradle b/examples/android/strictmode/app/build.gradle index f38110a741b..f37e970ed7b 100644 --- a/examples/android/strictmode/app/build.gradle +++ b/examples/android/strictmode/app/build.gradle @@ -33,7 +33,7 @@ android { protobuf { protoc { artifact = 'com.google.protobuf:protoc:3.25.1' } plugins { - grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + grpc { artifact = 'io.grpc:protoc-gen-grpc-java:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION } } generateProtoTasks { @@ -53,8 +53,8 @@ dependencies { implementation 'androidx.appcompat:appcompat:1.0.0' // You need to build grpc-java to obtain these libraries below. - implementation 'io.grpc:grpc-okhttp:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-protobuf-lite:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION - implementation 'io.grpc:grpc-stub:1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-okhttp:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-protobuf-lite:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION + implementation 'io.grpc:grpc-stub:1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION implementation 'org.apache.tomcat:annotations-api:6.0.53' } diff --git a/examples/build.gradle b/examples/build.gradle index e807d09f407..206fd38e0e3 100644 --- a/examples/build.gradle +++ b/examples/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' def protocVersion = protobufVersion diff --git a/examples/example-alts/build.gradle b/examples/example-alts/build.gradle index e7142f3fb5a..c7d804973a5 100644 --- a/examples/example-alts/build.gradle +++ b/examples/example-alts/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-debug/build.gradle b/examples/example-debug/build.gradle index 1d07a7cb8ec..c2449833cdc 100644 --- a/examples/example-debug/build.gradle +++ b/examples/example-debug/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' dependencies { diff --git a/examples/example-debug/pom.xml b/examples/example-debug/pom.xml index 00fdecdb6c4..ce46b13c019 100644 --- a/examples/example-debug/pom.xml +++ b/examples/example-debug/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-debug https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 1.8 diff --git a/examples/example-dualstack/build.gradle b/examples/example-dualstack/build.gradle index b73902095a1..7a37c46b536 100644 --- a/examples/example-dualstack/build.gradle +++ b/examples/example-dualstack/build.gradle @@ -23,7 +23,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' dependencies { diff --git a/examples/example-dualstack/pom.xml b/examples/example-dualstack/pom.xml index 28539851934..4af44beae46 100644 --- a/examples/example-dualstack/pom.xml +++ b/examples/example-dualstack/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-dualstack https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 1.8 diff --git a/examples/example-gauth/build.gradle b/examples/example-gauth/build.gradle index 0ef0dcaefe2..404cab907de 100644 --- a/examples/example-gauth/build.gradle +++ b/examples/example-gauth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' def protocVersion = protobufVersion diff --git a/examples/example-gauth/pom.xml b/examples/example-gauth/pom.xml index d2a32578550..b8f8d4d930e 100644 --- a/examples/example-gauth/pom.xml +++ b/examples/example-gauth/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-gauth https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 1.8 diff --git a/examples/example-gcp-csm-observability/build.gradle b/examples/example-gcp-csm-observability/build.gradle index e16e32e3bc1..dcb1f254263 100644 --- a/examples/example-gcp-csm-observability/build.gradle +++ b/examples/example-gcp-csm-observability/build.gradle @@ -22,7 +22,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' def openTelemetryVersion = '1.40.0' def openTelemetryPrometheusVersion = '1.40.0-alpha' diff --git a/examples/example-gcp-observability/build.gradle b/examples/example-gcp-observability/build.gradle index 1cad10bbb87..673d3f4461d 100644 --- a/examples/example-gcp-observability/build.gradle +++ b/examples/example-gcp-observability/build.gradle @@ -22,7 +22,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-hostname/build.gradle b/examples/example-hostname/build.gradle index 86aa42f8ed0..a0b9153785e 100644 --- a/examples/example-hostname/build.gradle +++ b/examples/example-hostname/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' dependencies { diff --git a/examples/example-hostname/pom.xml b/examples/example-hostname/pom.xml index 6993c40a1ac..f0aab1c8edc 100644 --- a/examples/example-hostname/pom.xml +++ b/examples/example-hostname/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-hostname https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 1.8 diff --git a/examples/example-jwt-auth/build.gradle b/examples/example-jwt-auth/build.gradle index ca12c3f7872..0d6d095bf1d 100644 --- a/examples/example-jwt-auth/build.gradle +++ b/examples/example-jwt-auth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' def protocVersion = protobufVersion diff --git a/examples/example-jwt-auth/pom.xml b/examples/example-jwt-auth/pom.xml index 7e9a915bfbd..cc7171910df 100644 --- a/examples/example-jwt-auth/pom.xml +++ b/examples/example-jwt-auth/pom.xml @@ -7,13 +7,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-jwt-auth https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 3.25.5 diff --git a/examples/example-oauth/build.gradle b/examples/example-oauth/build.gradle index 6d06f097ccb..9e40c3e33f9 100644 --- a/examples/example-oauth/build.gradle +++ b/examples/example-oauth/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protobufVersion = '3.25.5' def protocVersion = protobufVersion diff --git a/examples/example-oauth/pom.xml b/examples/example-oauth/pom.xml index a9fea928a34..36cf63dd602 100644 --- a/examples/example-oauth/pom.xml +++ b/examples/example-oauth/pom.xml @@ -7,13 +7,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-oauth https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 3.25.5 diff --git a/examples/example-opentelemetry/build.gradle b/examples/example-opentelemetry/build.gradle index f575d24d19b..5f98b32be60 100644 --- a/examples/example-opentelemetry/build.gradle +++ b/examples/example-opentelemetry/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' def openTelemetryVersion = '1.40.0' def openTelemetryPrometheusVersion = '1.40.0-alpha' diff --git a/examples/example-orca/build.gradle b/examples/example-orca/build.gradle index 45235fa1e08..edb28e1573b 100644 --- a/examples/example-orca/build.gradle +++ b/examples/example-orca/build.gradle @@ -16,7 +16,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-reflection/build.gradle b/examples/example-reflection/build.gradle index ad68e891436..18157b0eed1 100644 --- a/examples/example-reflection/build.gradle +++ b/examples/example-reflection/build.gradle @@ -16,7 +16,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-servlet/build.gradle b/examples/example-servlet/build.gradle index 2176df5afc5..e66fda59e4e 100644 --- a/examples/example-servlet/build.gradle +++ b/examples/example-servlet/build.gradle @@ -15,7 +15,7 @@ java { targetCompatibility = JavaVersion.VERSION_1_8 } -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-tls/build.gradle b/examples/example-tls/build.gradle index e741bfe1c3f..9603e04e417 100644 --- a/examples/example-tls/build.gradle +++ b/examples/example-tls/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/example-tls/pom.xml b/examples/example-tls/pom.xml index de9298cf0e1..c8b87f54cd0 100644 --- a/examples/example-tls/pom.xml +++ b/examples/example-tls/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT example-tls https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 1.8 diff --git a/examples/example-xds/build.gradle b/examples/example-xds/build.gradle index c0159dce258..1e55f182d3a 100644 --- a/examples/example-xds/build.gradle +++ b/examples/example-xds/build.gradle @@ -21,7 +21,7 @@ java { // Feel free to delete the comment at the next line. It is just for safely // updating the version in our release process. -def grpcVersion = '1.72.0-SNAPSHOT' // CURRENT_GRPC_VERSION +def grpcVersion = '1.73.0-SNAPSHOT' // CURRENT_GRPC_VERSION def protocVersion = '3.25.5' dependencies { diff --git a/examples/pom.xml b/examples/pom.xml index edc9c4cda14..ff32936f2a2 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -6,13 +6,13 @@ jar - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT examples https://github.com/grpc/grpc-java UTF-8 - 1.72.0-SNAPSHOT + 1.73.0-SNAPSHOT 3.25.5 3.25.5 From 8c7cf53939a561f0350957fe39df87536cde853f Mon Sep 17 00:00:00 2001 From: Kannan J Date: Wed, 2 Apr 2025 04:40:41 +0000 Subject: [PATCH 17/27] okhttp: Per-rpc call option authority verification (#11754) --- .../io/grpc/internal/CertificateUtils.java | 20 +- .../java/io/grpc/okhttp/NoopSslSocket.java | 117 +++++ .../io/grpc/okhttp/OkHttpChannelBuilder.java | 18 +- .../io/grpc/okhttp/OkHttpClientStream.java | 2 +- .../io/grpc/okhttp/OkHttpClientTransport.java | 271 +++++++++-- .../io/grpc/okhttp/OkHttpServerBuilder.java | 3 +- .../io/grpc/okhttp/OkHttpTlsUpgrader.java | 7 +- .../grpc/okhttp/OkHttpClientStreamTest.java | 6 +- .../okhttp/OkHttpClientTransportTest.java | 361 +++++++++++--- .../src/test/java/io/grpc/okhttp/TlsTest.java | 459 ++++++++++++++++++ 10 files changed, 1141 insertions(+), 123 deletions(-) create mode 100644 okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java diff --git a/core/src/main/java/io/grpc/internal/CertificateUtils.java b/core/src/main/java/io/grpc/internal/CertificateUtils.java index cc26cedb274..91d17de93cb 100644 --- a/core/src/main/java/io/grpc/internal/CertificateUtils.java +++ b/core/src/main/java/io/grpc/internal/CertificateUtils.java @@ -16,6 +16,7 @@ package io.grpc.internal; +import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; import java.security.GeneralSecurityException; @@ -36,8 +37,21 @@ public final class CertificateUtils { /** * Creates X509TrustManagers using the provided CA certs. */ - public static TrustManager[] createTrustManager(InputStream rootCerts) + public static TrustManager[] createTrustManager(byte[] rootCerts) throws GeneralSecurityException { + InputStream rootCertsStream = new ByteArrayInputStream(rootCerts); + try { + return CertificateUtils.createTrustManager(rootCertsStream); + } finally { + GrpcUtil.closeQuietly(rootCertsStream); + } + } + + /** + * Creates X509TrustManagers using the provided input stream of CA certs. + */ + public static TrustManager[] createTrustManager(InputStream rootCerts) + throws GeneralSecurityException { KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); try { ks.load(null, null); @@ -52,13 +66,13 @@ public static TrustManager[] createTrustManager(InputStream rootCerts) } TrustManagerFactory trustManagerFactory = - TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); trustManagerFactory.init(ks); return trustManagerFactory.getTrustManagers(); } private static X509Certificate[] getX509Certificates(InputStream inputStream) - throws CertificateException { + throws CertificateException { CertificateFactory factory = CertificateFactory.getInstance("X.509"); Collection certs = factory.generateCertificates(inputStream); return certs.toArray(new X509Certificate[0]); diff --git a/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java new file mode 100644 index 00000000000..6e6a6f12a39 --- /dev/null +++ b/okhttp/src/main/java/io/grpc/okhttp/NoopSslSocket.java @@ -0,0 +1,117 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.okhttp; + +import java.io.IOException; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; + +/** A no-op ssl socket, to facilitate overriding only the required methods in specific + * implementations. + */ +class NoopSslSocket extends SSLSocket { + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } +} diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java index 7eaaa6fd763..98f764132fe 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpChannelBuilder.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.internal.CertificateUtils.createTrustManager; import static io.grpc.internal.GrpcUtil.DEFAULT_KEEPALIVE_TIMEOUT_NANOS; import static io.grpc.internal.GrpcUtil.KEEPALIVE_TIME_NANOS_DISABLED; @@ -89,6 +90,7 @@ public final class OkHttpChannelBuilder extends ForwardingChannelBuilder2 ERROR_CODE_TO_STATUS = buildErrorCodeToStatusMap(); private static final Logger log = Logger.getLogger(OkHttpClientTransport.class.getName()); + private static final String GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK = + "GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK"; + static boolean enablePerRpcAuthorityCheck = + GrpcUtil.getFlag(GRPC_ENABLE_PER_RPC_AUTHORITY_CHECK, false); + private Socket sock; + private SSLSession sslSession; private static Map buildErrorCodeToStatusMap() { Map errorToStatus = new EnumMap<>(ErrorCode.class); @@ -144,6 +168,26 @@ private static Map buildErrorCodeToStatusMap() { return Collections.unmodifiableMap(errorToStatus); } + private static final Class x509ExtendedTrustManagerClass; + private static final Method checkServerTrustedMethod; + + static { + Class x509ExtendedTrustManagerClass1 = null; + Method checkServerTrustedMethod1 = null; + try { + x509ExtendedTrustManagerClass1 = Class.forName("javax.net.ssl.X509ExtendedTrustManager"); + checkServerTrustedMethod1 = x509ExtendedTrustManagerClass1.getMethod("checkServerTrusted", + X509Certificate[].class, String.class, Socket.class); + } catch (ClassNotFoundException e) { + // Per-rpc authority override via call options will be disallowed. + } catch (NoSuchMethodException e) { + // Should never happen since X509ExtendedTrustManager was introduced in Android API level 24 + // along with checkServerTrusted. + } + x509ExtendedTrustManagerClass = x509ExtendedTrustManagerClass1; + checkServerTrustedMethod = checkServerTrustedMethod1; + } + private final InetSocketAddress address; private final String defaultAuthority; private final String userAgent; @@ -205,6 +249,19 @@ private static Map buildErrorCodeToStatusMap() { private final boolean useGetForSafeMethods; @GuardedBy("lock") private final TransportTracer transportTracer; + private final TrustManager x509TrustManager; + + @SuppressWarnings("serial") + private static class LruCache extends LinkedHashMap { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > 100; + } + } + + @GuardedBy("lock") + private final Map authorityVerificationResults = new LruCache<>(); + @GuardedBy("lock") private final InUseStateAggregator inUseState = new InUseStateAggregator() { @@ -233,13 +290,14 @@ protected void handleNotInUse() { SettableFuture connectedFuture; public OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this( transportFactory, address, @@ -249,19 +307,21 @@ public OkHttpClientTransport( GrpcUtil.STOPWATCH_SUPPLIER, new Http2(), proxiedAddr, - tooManyPingsRunnable); + tooManyPingsRunnable, + channelCredentials); } private OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, - InetSocketAddress address, - String authority, - @Nullable String userAgent, - Attributes eagAttrs, - Supplier stopwatchFactory, - Variant variant, - @Nullable HttpConnectProxiedSocketAddress proxiedAddr, - Runnable tooManyPingsRunnable) { + OkHttpTransportFactory transportFactory, + InetSocketAddress address, + String authority, + @Nullable String userAgent, + Attributes eagAttrs, + Supplier stopwatchFactory, + Variant variant, + @Nullable HttpConnectProxiedSocketAddress proxiedAddr, + Runnable tooManyPingsRunnable, + ChannelCredentials channelCredentials) { this.address = Preconditions.checkNotNull(address, "address"); this.defaultAuthority = authority; this.maxMessageSize = transportFactory.maxMessageSize; @@ -276,7 +336,8 @@ private OkHttpClientTransport( this.socketFactory = transportFactory.socketFactory == null ? SocketFactory.getDefault() : transportFactory.socketFactory; this.sslSocketFactory = transportFactory.sslSocketFactory; - this.hostnameVerifier = transportFactory.hostnameVerifier; + this.hostnameVerifier = transportFactory.hostnameVerifier != null + ? transportFactory.hostnameVerifier : OkHostnameVerifier.INSTANCE; this.connectionSpec = Preconditions.checkNotNull( transportFactory.connectionSpec, "connectionSpec"); this.stopwatchFactory = Preconditions.checkNotNull(stopwatchFactory, "stopwatchFactory"); @@ -292,6 +353,21 @@ private OkHttpClientTransport( .set(GrpcAttributes.ATTR_CLIENT_EAG_ATTRS, eagAttrs).build(); this.useGetForSafeMethods = transportFactory.useGetForSafeMethods; initTransportTracer(); + TrustManager tempX509TrustManager; + if (channelCredentials instanceof TlsChannelCredentials + && x509ExtendedTrustManagerClass != null) { + try { + tempX509TrustManager = getTrustManager( + (TlsChannelCredentials) channelCredentials); + } catch (GeneralSecurityException e) { + tempX509TrustManager = null; + log.log(Level.WARNING, "Obtaining X509ExtendedTrustManager for the transport failed." + + "Per-rpc authority overrides will be disallowed.", e); + } + } else { + tempX509TrustManager = null; + } + x509TrustManager = tempX509TrustManager; } /** @@ -300,7 +376,7 @@ private OkHttpClientTransport( @SuppressWarnings("AddressSelection") // An IP address always returns one address @VisibleForTesting OkHttpClientTransport( - OkHttpChannelBuilder.OkHttpTransportFactory transportFactory, + OkHttpTransportFactory transportFactory, String userAgent, Supplier stopwatchFactory, Variant variant, @@ -316,7 +392,8 @@ private OkHttpClientTransport( stopwatchFactory, variant, null, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); this.connectingCallback = connectingCallback; this.connectedFuture = Preconditions.checkNotNull(connectedFuture, "connectedFuture"); } @@ -396,6 +473,7 @@ public OkHttpClientStream newStream( Preconditions.checkNotNull(headers, "headers"); StatsTraceContext statsTraceContext = StatsTraceContext.newClientContext(tracers, getAttributes(), headers); + // FIXME: it is likely wrong to pass the transportTracer here as it'll exit the lock's scope synchronized (lock) { // to make @GuardedBy linter happy return new OkHttpClientStream( @@ -416,23 +494,116 @@ public OkHttpClientStream newStream( } } + private TrustManager getTrustManager(TlsChannelCredentials tlsCreds) + throws GeneralSecurityException { + TrustManager[] tm; + // Using the same way of creating TrustManager from OkHttpChannelBuilder.sslSocketFactoryFrom() + if (tlsCreds.getTrustManagers() != null) { + tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); + } else if (tlsCreds.getRootCertificates() != null) { + tm = CertificateUtils.createTrustManager(tlsCreds.getRootCertificates()); + } else { // else use system default + TrustManagerFactory tmf = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + tmf.init((KeyStore) null); + tm = tmf.getTrustManagers(); + } + for (TrustManager trustManager: tm) { + if (trustManager instanceof X509TrustManager) { + return trustManager; + } + } + return null; + } + @GuardedBy("lock") - void streamReadyToStart(OkHttpClientStream clientStream) { + void streamReadyToStart(OkHttpClientStream clientStream, String authority) { if (goAwayStatus != null) { clientStream.transportState().transportReportStatus( goAwayStatus, RpcProgress.MISCARRIED, true, new Metadata()); - } else if (streams.size() >= maxConcurrentStreams) { - pendingStreams.add(clientStream); - setInUse(clientStream); } else { - startStream(clientStream); + if (socket instanceof SSLSocket && !authority.equals(defaultAuthority)) { + Status authorityVerificationResult; + if (authorityVerificationResults.containsKey(authority)) { + authorityVerificationResult = authorityVerificationResults.get(authority); + } else { + authorityVerificationResult = verifyAuthority(authority); + authorityVerificationResults.put(authority, authorityVerificationResult); + } + if (!authorityVerificationResult.isOk()) { + if (enablePerRpcAuthorityCheck) { + clientStream.transportState().transportReportStatus( + authorityVerificationResult, RpcProgress.PROCESSED, true, new Metadata()); + return; + } + } + } + if (streams.size() >= maxConcurrentStreams) { + pendingStreams.add(clientStream); + setInUse(clientStream); + } else { + startStream(clientStream); + } } } + private Status verifyAuthority(String authority) { + Status authorityVerificationResult; + if (hostnameVerifier.verify(authority, ((SSLSocket) socket).getSession())) { + authorityVerificationResult = Status.OK; + } else { + authorityVerificationResult = Status.UNAVAILABLE.withDescription(String.format( + "HostNameVerifier verification failed for authority '%s'", + authority)); + } + if (!authorityVerificationResult.isOk() && !enablePerRpcAuthorityCheck) { + log.log(Level.WARNING, String.format("HostNameVerifier verification failed for " + + "authority '%s'. This will be an error in the future.", + authority)); + } + if (authorityVerificationResult.isOk()) { + // The status is trivially assigned in this case, but we are still making use of the + // cache to keep track that a warning log had been logged for the authority when + // enablePerRpcAuthorityCheck is false. When we permanently enable the feature, the + // status won't need to be cached for case when x509TrustManager is null. + if (x509TrustManager == null) { + authorityVerificationResult = Status.UNAVAILABLE.withDescription( + String.format("Could not verify authority '%s' for the rpc with no " + + "X509TrustManager available", + authority)); + } else if (x509ExtendedTrustManagerClass.isInstance(x509TrustManager)) { + try { + Certificate[] peerCertificates = sslSession.getPeerCertificates(); + X509Certificate[] x509PeerCertificates = + new X509Certificate[peerCertificates.length]; + for (int i = 0; i < peerCertificates.length; i++) { + x509PeerCertificates[i] = (X509Certificate) peerCertificates[i]; + } + checkServerTrustedMethod.invoke(x509TrustManager, x509PeerCertificates, + "RSA", new SslSocketWrapper((SSLSocket) socket, authority)); + authorityVerificationResult = Status.OK; + } catch (SSLPeerUnverifiedException | InvocationTargetException + | IllegalAccessException e) { + authorityVerificationResult = Status.UNAVAILABLE.withCause(e).withDescription( + "Peer verification failed"); + } + if (authorityVerificationResult.getCause() != null) { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future.", + authorityVerificationResult.getCause()); + } else { + log.log(Level.WARNING, authorityVerificationResult.getDescription() + + ". This will be an error in the future."); + } + } + } + return authorityVerificationResult; + } + @SuppressWarnings("GuardedBy") @GuardedBy("lock") private void startStream(OkHttpClientStream stream) { - Preconditions.checkState( + checkState( stream.transportState().id() == OkHttpClientStream.ABSENT_ID, "StreamId already assigned"); streams.put(nextStreamId, stream); setInUse(stream); @@ -531,8 +702,6 @@ public Timeout timeout() { public void close() { } }); - Socket sock; - SSLSession sslSession = null; try { // This is a hack to make sure the connection preface and initial settings to be sent out // without blocking the start. By doing this essentially prevents potential deadlock when @@ -1460,4 +1629,50 @@ public void alternateService(int streamId, String origin, ByteString protocol, S // TODO(madongfly): Deal with alternateService propagation } } + + /** + * SSLSocket wrapper that provides a fake SSLSession for handshake session. + */ + static final class SslSocketWrapper extends NoopSslSocket { + + private final SSLSession sslSession; + private final SSLSocket sslSocket; + + SslSocketWrapper(SSLSocket sslSocket, String peerHost) { + this.sslSocket = sslSocket; + this.sslSession = new FakeSslSession(peerHost); + } + + @Override + public SSLSession getHandshakeSession() { + return this.sslSession; + } + + @Override + public boolean isConnected() { + return sslSocket.isConnected(); + } + + @Override + public SSLParameters getSSLParameters() { + return sslSocket.getSSLParameters(); + } + } + + /** + * Fake SSLSession instance that provides the peer host name to verify for per-rpc check. + */ + static class FakeSslSession extends NoopSslSession { + + private final String peerHost; + + FakeSslSession(String peerHost) { + this.peerHost = peerHost; + } + + @Override + public String getPeerHost() { + return peerHost; + } + } } diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java index 068474d70bc..8daeed42a8c 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpServerBuilder.java @@ -17,6 +17,7 @@ package io.grpc.okhttp; import static com.google.common.base.Preconditions.checkArgument; +import static io.grpc.internal.CertificateUtils.createTrustManager; import com.google.common.base.Preconditions; import com.google.errorprone.annotations.CanIgnoreReturnValue; @@ -425,7 +426,7 @@ static HandshakerSocketFactoryResult handshakerSocketFactoryFrom(ServerCredentia tm = tlsCreds.getTrustManagers().toArray(new TrustManager[0]); } else if (tlsCreds.getRootCertificates() != null) { try { - tm = OkHttpChannelBuilder.createTrustManager(tlsCreds.getRootCertificates()); + tm = createTrustManager(tlsCreds.getRootCertificates()); } catch (GeneralSecurityException gse) { log.log(Level.FINE, "Exception loading root certificates from credential", gse); return HandshakerSocketFactoryResult.error( diff --git a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java index 1004dcd93f9..a8b038c91f4 100644 --- a/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java +++ b/okhttp/src/main/java/io/grpc/okhttp/OkHttpTlsUpgrader.java @@ -19,13 +19,13 @@ import com.google.common.annotations.VisibleForTesting; import com.google.common.base.Preconditions; import io.grpc.okhttp.internal.ConnectionSpec; -import io.grpc.okhttp.internal.OkHostnameVerifier; import io.grpc.okhttp.internal.Protocol; import java.io.IOException; import java.net.Socket; import java.util.Arrays; import java.util.Collections; import java.util.List; +import javax.annotation.Nonnull; import javax.net.ssl.HostnameVerifier; import javax.net.ssl.SSLPeerUnverifiedException; import javax.net.ssl.SSLSocket; @@ -52,7 +52,7 @@ final class OkHttpTlsUpgrader { * @throws RuntimeException if the upgrade negotiation failed. */ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, - HostnameVerifier hostnameVerifier, Socket socket, String host, int port, + @Nonnull HostnameVerifier hostnameVerifier, Socket socket, String host, int port, ConnectionSpec spec) throws IOException { Preconditions.checkNotNull(sslSocketFactory, "sslSocketFactory"); Preconditions.checkNotNull(socket, "socket"); @@ -67,9 +67,6 @@ public static SSLSocket upgrade(SSLSocketFactory sslSocketFactory, "Only " + TLS_PROTOCOLS + " are supported, but negotiated protocol is %s", negotiatedProtocol); - if (hostnameVerifier == null) { - hostnameVerifier = OkHostnameVerifier.INSTANCE; - } if (!hostnameVerifier.verify(canonicalizeHost(host), sslSocket.getSession())) { throw new SSLPeerUnverifiedException("Cannot verify hostname: " + host); } diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java index 1f716705968..1c98d6ee30d 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientStreamTest.java @@ -20,6 +20,7 @@ import static io.grpc.internal.ClientStreamListener.RpcProgress.PROCESSED; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.ArgumentMatchers.isA; import static org.mockito.Mockito.times; @@ -244,12 +245,13 @@ public void getUnaryRequest() throws IOException { // GET streams send headers after halfClose is called. verify(mockedFrameWriter, times(0)).synStream( eq(false), eq(false), eq(3), eq(0), headersCaptor.capture()); - verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class)); + verify(transport, times(0)).streamReadyToStart(isA(OkHttpClientStream.class), + isA(String.class)); byte[] msg = "request".getBytes(Charset.forName("UTF-8")); stream.writeMessage(new ByteArrayInputStream(msg)); stream.halfClose(); - verify(transport).streamReadyToStart(eq(stream)); + verify(transport).streamReadyToStart(eq(stream), any(String.class)); stream.transportState().start(3); verify(mockedFrameWriter) diff --git a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java index 826dee8e2b4..99f430be009 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/OkHttpClientTransportTest.java @@ -68,6 +68,7 @@ import io.grpc.Status; import io.grpc.Status.Code; import io.grpc.internal.AbstractStream; +import io.grpc.internal.ClientStream; import io.grpc.internal.ClientStreamListener; import io.grpc.internal.ClientTransport; import io.grpc.internal.FakeClock; @@ -115,6 +116,10 @@ import java.util.logging.Logger; import javax.annotation.Nullable; import javax.net.SocketFactory; +import javax.net.ssl.HandshakeCompletedListener; +import javax.net.ssl.HostnameVerifier; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSocket; import okio.Buffer; import okio.BufferedSink; import okio.BufferedSource; @@ -189,16 +194,24 @@ public void tearDown() { private void initTransport() throws Exception { startTransport( - DEFAULT_START_STREAM_ID, null, true, null); + DEFAULT_START_STREAM_ID, null, true, null, null); } private void initTransport(int startId) throws Exception { - startTransport(startId, null, true, null); + startTransport(startId, null, true, null, null); } private void startTransport(int startId, @Nullable Runnable connectingCallback, - boolean waitingForConnected, String userAgent) - throws Exception { + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier) throws Exception { + startTransport(startId, connectingCallback, waitingForConnected, userAgent, hostnameVerifier, + false); + } + + private void startTransport(int startId, @Nullable Runnable connectingCallback, + boolean waitingForConnected, String userAgent, + HostnameVerifier hostnameVerifier, boolean useSslSocket) + throws Exception { connectedFuture = SettableFuture.create(); final Ticker ticker = new Ticker() { @Override @@ -212,7 +225,11 @@ public Stopwatch get() { return Stopwatch.createUnstarted(ticker); } }; - channelBuilder.socketFactory(new FakeSocketFactory(socket)); + channelBuilder.socketFactory( + new FakeSocketFactory(useSslSocket ? new MockSslSocket(socket) : socket)); + if (hostnameVerifier != null) { + channelBuilder = channelBuilder.hostnameVerifier(hostnameVerifier); + } clientTransport = new OkHttpClientTransport( channelBuilder.buildTransportFactory(), userAgent, @@ -240,7 +257,8 @@ public void testToString() throws Exception { /*userAgent=*/ null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String s = clientTransport.toString(); assertTrue("Unexpected: " + s, s.contains("OkHttpClientTransport")); assertTrue("Unexpected: " + s, s.contains(address.toString())); @@ -258,7 +276,8 @@ public void testTransportExecutorWithTooFewThreads() throws Exception { null, EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); ArgumentCaptor statusCaptor = ArgumentCaptor.forClass(Status.class); verify(transportListener, timeout(TIME_OUT_MS)).transportShutdown(statusCaptor.capture()); @@ -299,7 +318,7 @@ public void close() throws SecurityException { assertThat(log.getLevel()).isEqualTo(Level.FINE); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -389,7 +408,7 @@ public void maxMessageSizeShouldBeEnforced() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -442,11 +461,11 @@ public void nextFrameThrowIoException() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -476,7 +495,7 @@ public void nextFrameThrowIoException() throws Exception { public void nextFrameThrowsError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -497,7 +516,7 @@ public void nextFrameThrowsError() throws Exception { public void nextFrameReturnFalse() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -515,7 +534,7 @@ public void readMessages() throws Exception { final int numMessages = 10; final String message = "Hello Client"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(numMessages); @@ -567,7 +586,7 @@ public void receivedDataForInvalidStreamShouldKillConnection() throws Exception public void invalidInboundHeadersCancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -592,7 +611,7 @@ public void invalidInboundHeadersCancelStream() throws Exception { public void invalidInboundTrailersPropagateToMetadata() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -612,7 +631,7 @@ public void invalidInboundTrailersPropagateToMetadata() throws Exception { public void readStatus() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -626,7 +645,7 @@ public void readStatus() throws Exception { public void receiveReset() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -643,7 +662,7 @@ public void receiveReset() throws Exception { public void receiveResetNoError() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertContainStream(3); @@ -664,7 +683,7 @@ public void receiveResetNoError() throws Exception { public void cancelStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.CANCELLED); @@ -679,7 +698,7 @@ public void cancelStream() throws Exception { public void addDefaultUserAgent() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); Header userAgentHeader = new Header(GrpcUtil.USER_AGENT_KEY.name(), @@ -696,9 +715,9 @@ public void addDefaultUserAgent() throws Exception { @Test public void overrideDefaultUserAgent() throws Exception { - startTransport(3, null, true, "fakeUserAgent"); + startTransport(3, null, true, "fakeUserAgent", null); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); List

    expectedHeaders = Arrays.asList(HTTP_SCHEME_HEADER, METHOD_HEADER, @@ -717,7 +736,7 @@ public void overrideDefaultUserAgent() throws Exception { public void cancelStreamForDeadlineExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); getStream(3).cancel(Status.DEADLINE_EXCEEDED); @@ -731,7 +750,7 @@ public void writeMessage() throws Exception { initTransport(); final String message = "Hello Server"; MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); @@ -746,6 +765,65 @@ public void writeMessage() throws Exception { shutdownAndVerify(); } + @Test + public void perRpcAuthoritySpecified_verificationSkippedInPlainTextConnection() + throws Exception { + initTransport(); + final String message = "Hello Server"; + MockStreamListener listener = new MockStreamListener(); + ClientStream stream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + stream.start(listener); + InputStream input = new ByteArrayInputStream(message.getBytes(UTF_8)); + assertEquals(12, input.available()); + stream.writeMessage(input); + stream.flush(); + verify(frameWriter, timeout(TIME_OUT_MS)) + .data(eq(false), eq(3), any(Buffer.class), eq(12 + HEADER_LENGTH)); + Buffer sentFrame = capturedBuffer.poll(); + assertEquals(createMessageFrame(message), sentFrame); + stream.cancel(Status.CANCELLED); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_ignoredForNonSslSocket() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, false); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_successCase() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> true, true); + ClientStream unused = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + shutdownAndVerify(); + } + + @Test + public void perRpcAuthoritySpecified_hostnameVerification_SslSocket_flagDisabled() + throws Exception { + startTransport( + DEFAULT_START_STREAM_ID, null, true, null, + (hostname, session) -> false, true); + ClientStream clientStream = + clientTransport.newStream(method, new Metadata(), + CallOptions.DEFAULT.withAuthority("some-authority"), tracers); + assertThat(clientStream).isInstanceOf(OkHttpClientStream.class); + shutdownAndVerify(); + } + @Test public void transportTracer_windowSizeDefault() throws Exception { initTransport(); @@ -772,12 +850,12 @@ public void windowUpdate() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(2); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(2); @@ -842,7 +920,7 @@ public void windowUpdate() throws Exception { public void windowUpdateWithInboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = INITIAL_WINDOW_SIZE / 2 + 1; @@ -879,7 +957,7 @@ public void windowUpdateWithInboundFlowControl() throws Exception { public void outboundFlowControl() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -925,7 +1003,7 @@ public void outboundFlowControl_smallWindowSize() throws Exception { setInitialWindowSize(initialOutboundWindowSize); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -968,7 +1046,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { frameHandler().windowUpdate(0, 65535); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); @@ -1004,7 +1082,7 @@ public void outboundFlowControl_bigWindowSize() throws Exception { public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1050,7 +1128,7 @@ public void outboundFlowControlWithInitialWindowSizeChange() throws Exception { public void outboundFlowControlWithInitialWindowSizeChangeInMiddleOfStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); int messageLength = 20; @@ -1085,10 +1163,10 @@ public void stopNormally() throws Exception { initTransport(); MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); assertEquals(2, activeStreamCount()); @@ -1115,11 +1193,11 @@ public void receiveGoAway() throws Exception { // start 2 streams. MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); stream1.request(1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); stream2.request(1); @@ -1142,7 +1220,7 @@ public void receiveGoAway() throws Exception { // But stream 1 should be able to send. final String sentMessage = "Should I also go away?"; - OkHttpClientStream stream = getStream(3); + ClientStream stream = getStream(3); InputStream input = new ByteArrayInputStream(sentMessage.getBytes(UTF_8)); assertEquals(22, input.available()); stream.writeMessage(input); @@ -1174,7 +1252,7 @@ public void streamIdExhausted() throws Exception { initTransport(startId); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1210,11 +1288,11 @@ public void pendingStreamSucceed() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); String sentMessage = "hello"; @@ -1247,7 +1325,7 @@ public void pendingStreamCancelled() throws Exception { initTransport(); setMaxConcurrentStreams(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1266,11 +1344,11 @@ public void pendingStreamFailedByGoAway() throws Exception { setMaxConcurrentStreams(1); final MockStreamListener listener1 = new MockStreamListener(); final MockStreamListener listener2 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); @@ -1296,7 +1374,7 @@ public void pendingStreamSucceedAfterShutdown() throws Exception { setMaxConcurrentStreams(0); final MockStreamListener listener = new MockStreamListener(); // The second stream should be pending. - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); waitForStreamPending(1); @@ -1320,15 +1398,15 @@ public void pendingStreamFailedByIdExhausted() throws Exception { final MockStreamListener listener2 = new MockStreamListener(); final MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); // The second and third stream should be pending. - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -1352,7 +1430,7 @@ public void pendingStreamFailedByIdExhausted() throws Exception { public void receivingWindowExceeded() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1404,7 +1482,7 @@ public void duplexStreamingHeadersShouldNotBeFlushed() throws Exception { private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); verify(frameWriter, timeout(TIME_OUT_MS)).synStream( @@ -1421,7 +1499,7 @@ private void shouldHeadersBeFlushed(boolean shouldBeFlushed) throws Exception { public void receiveDataWithoutHeader() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1444,7 +1522,7 @@ public void receiveDataWithoutHeader() throws Exception { public void receiveDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1468,7 +1546,7 @@ public void receiveDataWithoutHeaderAndTrailer() throws Exception { public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.request(1); @@ -1490,7 +1568,7 @@ public void receiveLongEnoughDataWithoutHeaderAndTrailer() throws Exception { public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1519,7 +1597,7 @@ public void receiveDataForUnknownStreamUpdateConnectionWindow() throws Exception public void receiveWindowUpdateForUnknownStream() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); stream.cancel(Status.CANCELLED); @@ -1539,7 +1617,7 @@ public void receiveWindowUpdateForUnknownStream() throws Exception { public void shouldBeInitiallyReady() throws Exception { initTransport(); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1557,7 +1635,7 @@ public void notifyOnReady() throws Exception { AbstractStream.TransportState.DEFAULT_ONREADY_THRESHOLD - HEADER_LENGTH - 1; setInitialWindowSize(0); MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); assertTrue(stream.isReady()); @@ -1704,7 +1782,7 @@ public void shutdownDuringConnecting() throws Exception { DEFAULT_START_STREAM_ID, connectingCallback, false, - null); + null, null); clientTransport.shutdown(SHUTDOWN_REASON); delayed.set(null); shutdownAndVerify(); @@ -1719,7 +1797,8 @@ public void invalidAuthorityPropagates() { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); String host = clientTransport.getOverridenHost(); int port = clientTransport.getOverridenPort(); @@ -1737,7 +1816,8 @@ public void unreachableServer() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1767,7 +1847,8 @@ public void customSocketFactory() throws Exception { "userAgent", EAG_ATTRS, NO_PROXY, - tooManyPingsRunnable); + tooManyPingsRunnable, + null); ManagedClientTransport.Listener listener = mock(ManagedClientTransport.Listener.class); clientTransport.start(listener); @@ -1792,7 +1873,8 @@ public void proxy_200() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1841,7 +1923,8 @@ public void proxy_500() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1889,7 +1972,8 @@ public void proxy_immediateServerClose() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.start(transportListener); Socket sock = serverSocket.accept(); @@ -1920,7 +2004,8 @@ public void proxy_serverHangs() throws Exception { .setTargetAddress(targetAddress) .setProxyAddress(new InetSocketAddress("localhost", serverSocket.getLocalPort())) .build(), - tooManyPingsRunnable); + tooManyPingsRunnable, + null); clientTransport.proxySocketTimeout = 10; clientTransport.start(transportListener); @@ -1987,13 +2072,13 @@ public void goAway_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2027,13 +2112,13 @@ public void reset_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); @@ -2069,13 +2154,13 @@ public void shutdownNow_streamListenerRpcProgress() throws Exception { MockStreamListener listener1 = new MockStreamListener(); MockStreamListener listener2 = new MockStreamListener(); MockStreamListener listener3 = new MockStreamListener(); - OkHttpClientStream stream1 = + ClientStream stream1 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream1.start(listener1); - OkHttpClientStream stream2 = + ClientStream stream2 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream2.start(listener2); - OkHttpClientStream stream3 = + ClientStream stream3 = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream3.start(listener3); waitForStreamPending(1); @@ -2100,11 +2185,11 @@ public void finishedStreamRemovedFromInUseState() throws Exception { initTransport(); setMaxConcurrentStreams(1); final MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream stream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); - OkHttpClientStream pendingStream = - clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); + OkHttpClientStream pendingStream = clientTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); pendingStream.start(listener); waitForStreamPending(1); clientTransport.finishStream(stream.transportState().id(), Status.OK, PROCESSED, @@ -2144,7 +2229,7 @@ private void waitForStreamPending(int expected) throws Exception { private void assertNewStreamFail() throws Exception { MockStreamListener listener = new MockStreamListener(); - OkHttpClientStream stream = + ClientStream stream = clientTransport.newStream(method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(listener); listener.waitUntilStreamClosed(); @@ -2375,6 +2460,124 @@ public InputStream getInputStream() { } } + private static class MockSslSocket extends SSLSocket { + private Socket delegate; + + MockSslSocket(Socket socket) { + delegate = socket; + } + + @Override + public String[] getSupportedCipherSuites() { + return new String[0]; + } + + @Override + public String[] getEnabledCipherSuites() { + return new String[0]; + } + + @Override + public void setEnabledCipherSuites(String[] suites) { + + } + + @Override + public String[] getSupportedProtocols() { + return new String[0]; + } + + @Override + public String[] getEnabledProtocols() { + return new String[0]; + } + + @Override + public void setEnabledProtocols(String[] protocols) { + + } + + @Override + public SSLSession getSession() { + return null; + } + + @Override + public void addHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void removeHandshakeCompletedListener(HandshakeCompletedListener listener) { + + } + + @Override + public void startHandshake() throws IOException { + + } + + @Override + public void setUseClientMode(boolean mode) { + + } + + @Override + public boolean getUseClientMode() { + return false; + } + + @Override + public void setNeedClientAuth(boolean need) { + + } + + @Override + public boolean getNeedClientAuth() { + return false; + } + + @Override + public void setWantClientAuth(boolean want) { + + } + + @Override + public boolean getWantClientAuth() { + return false; + } + + @Override + public void setEnableSessionCreation(boolean flag) { + + } + + @Override + public boolean getEnableSessionCreation() { + return false; + } + + @Override + public synchronized void close() throws IOException { + delegate.close(); + } + + @Override + public SocketAddress getLocalSocketAddress() { + return delegate.getLocalSocketAddress(); + } + + @Override + public OutputStream getOutputStream() throws IOException { + return delegate.getOutputStream(); + } + + @Override + public InputStream getInputStream() throws IOException { + return delegate.getInputStream(); + } + } + static class PingCallbackImpl implements ClientTransport.PingCallback { int invocationCount; long roundTripTime; diff --git a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java index a21360a89ba..20a2f1a5ca7 100644 --- a/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java +++ b/okhttp/src/test/java/io/grpc/okhttp/TlsTest.java @@ -18,8 +18,10 @@ import static com.google.common.truth.Truth.assertThat; import static com.google.common.truth.Truth.assertWithMessage; +import static org.junit.Assert.fail; import com.google.common.base.Throwables; +import io.grpc.CallOptions; import io.grpc.ChannelCredentials; import io.grpc.ConnectivityState; import io.grpc.ManagedChannel; @@ -32,18 +34,34 @@ import io.grpc.TlsServerCredentials; import io.grpc.internal.testing.TestUtils; import io.grpc.okhttp.internal.Platform; +import io.grpc.stub.ClientCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.GrpcCleanupRule; import io.grpc.testing.TlsTesting; import io.grpc.testing.protobuf.SimpleRequest; import io.grpc.testing.protobuf.SimpleResponse; import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.util.CertificateUtils; import java.io.IOException; import java.io.InputStream; +import java.net.Socket; +import java.security.GeneralSecurityException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Optional; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSocket; import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import org.codehaus.mojo.animal_sniffer.IgnoreJRERequirement; import org.junit.Assume; import org.junit.Before; import org.junit.Rule; @@ -53,6 +71,7 @@ /** Verify OkHttp's TLS integration. */ @RunWith(JUnit4.class) +@IgnoreJRERequirement public class TlsTest { @Rule public final GrpcCleanupRule grpcCleanupRule = new GrpcCleanupRule(); @@ -92,6 +111,325 @@ public void basicTls_succeeds() throws Exception { SimpleServiceGrpc.newBlockingStub(channel).unaryRpc(SimpleRequest.getDefaultInstance()); } + @Test + public void perRpcAuthorityOverride_hostnameVerifier_goodAuthority_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for hostname verifier failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_hostnameVerifier_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(caCert) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + + @Test + public void perRpcAuthorityOverride_noTlsCredentialsUsedToBuildChannel_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + SSLSocketFactory sslSocketFactory = TestUtils.newSslSocketFactoryForCa( + Platform.get().getProvider(), TestUtils.loadCert("ca.pem")); + ManagedChannel channel = grpcCleanupRule.register( + OkHttpChannelBuilder.forAddress("localhost", server.getPort()) + .overrideAuthority(TestUtils.TEST_SERVER_HOST) + .directExecutor() + .sslSocketFactory(sslSocketFactory) + .build()); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bar.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()).isEqualTo( + "Could not verify authority 'bar.test.google.fr' for the rpc with no " + + "X509TrustManager available"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_permitted_succeeds() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("good.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_fails() throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getCause().getCause()).isInstanceOf(CertificateException.class); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_trustManager_denied_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509ExtendedTrustManager regularTrustManager = + (X509ExtendedTrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new HostnameCheckingX509ExtendedTrustManager(regularTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("bad.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } + + /** + * This test simulates the absence of X509ExtendedTrustManager while still using the + * real trust manager for the connection handshake to happen. When the TrustManager is not an + * X509ExtendedTrustManager, the per-rpc check ignores the trust manager. However, the + * HostnameVerifier is still used, so only valid authorities are permitted. + */ + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_goodAuthority_succeeds() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("foo.test.google.fr"), + SimpleRequest.getDefaultInstance()); + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_fails() + throws Exception { + OkHttpClientTransport.enablePerRpcAuthorityCheck = true; + try { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + try { + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + fail("Expected exception for authority verification failure."); + } catch (StatusRuntimeException ex) { + assertThat(ex.getStatus().getCode()).isEqualTo(Status.Code.UNAVAILABLE); + assertThat(ex.getStatus().getDescription()) + .isEqualTo("HostNameVerifier verification failed for authority 'disallowed.name.com'"); + } + } finally { + OkHttpClientTransport.enablePerRpcAuthorityCheck = false; + } + } + + @Test + public void + perRpcAuthorityOverride_notX509ExtendedTrustManager_badAuthority_flagDisabled_succeeds() + throws Exception { + ServerCredentials serverCreds; + try (InputStream serverCert = TlsTesting.loadCert("server1.pem"); + InputStream serverPrivateKey = TlsTesting.loadCert("server1.key")) { + serverCreds = TlsServerCredentials.newBuilder() + .keyManager(serverCert, serverPrivateKey) + .build(); + } + ChannelCredentials channelCreds; + try (InputStream caCert = TlsTesting.loadCert("ca.pem")) { + X509TrustManager x509ExtendedTrustManager = + (X509TrustManager) getX509ExtendedTrustManager(caCert).get(); + channelCreds = TlsChannelCredentials.newBuilder() + .trustManager(new FakeTrustManager(x509ExtendedTrustManager)) + .build(); + } + Server server = grpcCleanupRule.register(server(serverCreds)); + ManagedChannel channel = grpcCleanupRule.register(clientChannel(server, channelCreds)); + + ClientCalls.blockingUnaryCall(channel, SimpleServiceGrpc.getUnaryRpcMethod(), + CallOptions.DEFAULT.withAuthority("disallowed.name.com"), + SimpleRequest.getDefaultInstance()); + } + @Test public void mtls_succeeds() throws Exception { ServerCredentials serverCreds; @@ -282,6 +620,127 @@ public void hostnameVerifierFails_fails() assertThat(status.getCause()).isInstanceOf(SSLPeerUnverifiedException.class); } + /** Used to simulate the case of X509ExtendedTrustManager not present. */ + private static class FakeTrustManager implements X509TrustManager { + + private final X509TrustManager delegate; + + public FakeTrustManager(X509TrustManager x509ExtendedTrustManager) { + this.delegate = x509ExtendedTrustManager; + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkClientTrusted(x509Certificates, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + delegate.checkServerTrusted(x509Certificates, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + /** + * Checks against a limited set of hostnames. In production, EndpointIdentificationAlgorithm is + * unset so the default trust manager will not fail based on the hostname. This class is used to + * test user-provided trust managers that may have their own behavior. + */ + private static class HostnameCheckingX509ExtendedTrustManager + extends ForwardingX509ExtendedTrustManager { + public HostnameCheckingX509ExtendedTrustManager(X509ExtendedTrustManager tm) { + super(tm); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + String peer = ((SSLSocket) socket).getHandshakeSession().getPeerHost(); + if (!"foo.test.google.fr".equals(peer) && !"good.test.google.fr".equals(peer)) { + throw new CertificateException("Peer verification failed."); + } + super.checkServerTrusted(chain, authType, socket); + } + } + + @IgnoreJRERequirement + private static class ForwardingX509ExtendedTrustManager extends X509ExtendedTrustManager { + private final X509ExtendedTrustManager delegate; + + private ForwardingX509ExtendedTrustManager(X509ExtendedTrustManager delegate) { + this.delegate = delegate; + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkServerTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkServerTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, engine); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + delegate.checkClientTrusted(chain, authType); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + delegate.checkClientTrusted(chain, authType, socket); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } + } + + private static Optional getX509ExtendedTrustManager(InputStream rootCerts) + throws GeneralSecurityException { + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + try { + ks.load(null, null); + } catch (IOException ex) { + // Shouldn't really happen, as we're not loading any data. + throw new GeneralSecurityException(ex); + } + X509Certificate[] certs = CertificateUtils.getX509Certificates(rootCerts); + for (X509Certificate cert : certs) { + X500Principal principal = cert.getSubjectX500Principal(); + ks.setCertificateEntry(principal.getName("RFC2253"), cert); + } + + TrustManagerFactory trustManagerFactory = + TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init(ks); + return Arrays.stream(trustManagerFactory.getTrustManagers()) + .filter(trustManager -> trustManager instanceof X509ExtendedTrustManager).findFirst(); + } + private static Server server(ServerCredentials creds) throws IOException { return OkHttpServerBuilder.forPort(0, creds) .directExecutor() From 8879e9442719eb92ef0e0034e646dbb139066b27 Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 2 Apr 2025 03:52:00 -0700 Subject: [PATCH 18/27] core: Delete stale SuppressWarnings("deprecated") for ATTR_LOAD_BALANCING_CONFIG (#11982) ATTR_LOAD_BALANCING_CONFIG was deleted in bf7a42dbd. --- .../io/grpc/internal/AutoConfiguredLoadBalancerFactory.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java index a382227fd6c..a257637de22 100644 --- a/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java +++ b/core/src/main/java/io/grpc/internal/AutoConfiguredLoadBalancerFactory.java @@ -40,8 +40,6 @@ import java.util.Map; import javax.annotation.Nullable; -// TODO(creamsoup) fully deprecate LoadBalancer.ATTR_LOAD_BALANCING_CONFIG -@SuppressWarnings("deprecation") public final class AutoConfiguredLoadBalancerFactory { private final LoadBalancerRegistry registry; From 920c38479a5ac54602f7bba5d6dac6222da63c1b Mon Sep 17 00:00:00 2001 From: Eric Anderson Date: Wed, 2 Apr 2025 03:54:32 -0700 Subject: [PATCH 19/27] core: Delete the long-deprecated GRPC_PROXY_EXP (#11988) "EXP" stood for experimental and all documentation that referenced it made it clear it was experimental. It's been some years since we started logging a message when it was used to say it will be deleted. There's no time like the present to delete it. --- .../io/grpc/internal/ProxyDetectorImpl.java | 46 +------------------ .../grpc/internal/ProxyDetectorImplTest.java | 44 +----------------- 2 files changed, 4 insertions(+), 86 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java index b3f646d6099..58c7803346f 100644 --- a/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java +++ b/core/src/main/java/io/grpc/internal/ProxyDetectorImpl.java @@ -147,18 +147,9 @@ public ProxySelector get() { } }; - /** - * Experimental environment variable name for enabling proxy support. - * - * @deprecated Use the standard Java proxy configuration instead with flags such as: - * -Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT - */ - @Deprecated - private static final String GRPC_PROXY_ENV_VAR = "GRPC_PROXY_EXP"; // Do not hard code a ProxySelector because the global default ProxySelector can change private final Supplier proxySelector; private final AuthenticationProvider authenticationProvider; - private final InetSocketAddress overrideProxyAddress; // We want an HTTPS proxy, which operates on the entire data stream (See IETF rfc2817). static final String PROXY_SCHEME = "https"; @@ -168,21 +159,15 @@ public ProxySelector get() { * {@link ProxyDetectorImpl.AuthenticationProvider} to detect proxy parameters. */ public ProxyDetectorImpl() { - this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR, System.getenv(GRPC_PROXY_ENV_VAR)); + this(DEFAULT_PROXY_SELECTOR, DEFAULT_AUTHENTICATOR); } @VisibleForTesting ProxyDetectorImpl( Supplier proxySelector, - AuthenticationProvider authenticationProvider, - @Nullable String proxyEnvString) { + AuthenticationProvider authenticationProvider) { this.proxySelector = checkNotNull(proxySelector); this.authenticationProvider = checkNotNull(authenticationProvider); - if (proxyEnvString != null) { - overrideProxyAddress = overrideProxy(proxyEnvString); - } else { - overrideProxyAddress = null; - } } @Nullable @@ -191,12 +176,6 @@ public ProxiedSocketAddress proxyFor(SocketAddress targetServerAddress) throws I if (!(targetServerAddress instanceof InetSocketAddress)) { return null; } - if (overrideProxyAddress != null) { - return HttpConnectProxiedSocketAddress.newBuilder() - .setProxyAddress(overrideProxyAddress) - .setTargetAddress((InetSocketAddress) targetServerAddress) - .build(); - } return detectProxy((InetSocketAddress) targetServerAddress); } @@ -272,27 +251,6 @@ private ProxiedSocketAddress detectProxy(InetSocketAddress targetAddr) throws IO .build(); } - /** - * GRPC_PROXY_EXP is deprecated but let's maintain compatibility for now. - */ - private static InetSocketAddress overrideProxy(String proxyHostPort) { - if (proxyHostPort == null) { - return null; - } - - String[] parts = proxyHostPort.split(":", 2); - int port = 80; - if (parts.length > 1) { - port = Integer.parseInt(parts[1]); - } - log.warning( - "Detected GRPC_PROXY_EXP and will honor it, but this feature will " - + "be removed in a future release. Use the JVM flags " - + "\"-Dhttps.proxyHost=HOST -Dhttps.proxyPort=PORT\" to set the https proxy for " - + "this JVM."); - return new InetSocketAddress(parts[0], port); - } - /** * This interface makes unit testing easier by avoiding direct calls to static methods. */ diff --git a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java index 0432a474ac5..771050f119d 100644 --- a/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java +++ b/core/src/test/java/io/grpc/internal/ProxyDetectorImplTest.java @@ -73,7 +73,7 @@ public ProxySelector get() { return proxySelector; } }; - proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator, null); + proxyDetector = new ProxyDetectorImpl(proxySelectorSupplier, authenticator); unresolvedProxy = InetSocketAddress.createUnresolved("10.0.0.1", proxyPort); proxySocketAddress = HttpConnectProxiedSocketAddress.newBuilder() .setTargetAddress(destination) @@ -82,45 +82,6 @@ public ProxySelector get() { .build(); } - @Test - public void override_hostPort() throws Exception { - final String overrideHost = "10.99.99.99"; - final int overridePort = 1234; - final String overrideHostWithPort = overrideHost + ":" + overridePort; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHost), overridePort)) - .build(), - detected); - } - - @Test - public void override_hostOnly() throws Exception { - final String overrideHostWithoutPort = "10.99.99.99"; - final int defaultPort = 80; - ProxyDetectorImpl proxyDetector = new ProxyDetectorImpl( - proxySelectorSupplier, - authenticator, - overrideHostWithoutPort); - ProxiedSocketAddress detected = proxyDetector.proxyFor(destination); - assertNotNull(detected); - assertEquals( - HttpConnectProxiedSocketAddress.newBuilder() - .setTargetAddress(destination) - .setProxyAddress( - new InetSocketAddress(InetAddress.getByName(overrideHostWithoutPort), defaultPort)) - .build(), - detected); - } - @Test public void returnNullWhenNoProxy() throws Exception { when(proxySelector.select(any(URI.class))) @@ -227,8 +188,7 @@ public ProxySelector get() { return null; } }, - authenticator, - null); + authenticator); assertNull(proxyDetector.proxyFor(destination)); } } From 0ab9e11ed808d090804d41b18e10dbefc15f9abd Mon Sep 17 00:00:00 2001 From: MV Shiva Date: Wed, 2 Apr 2025 16:29:55 +0530 Subject: [PATCH 20/27] xds: propagate audience from cluster resource in gcp auth filter (#11972) --- .../io/grpc/xds/GcpAuthenticationFilter.java | 108 +++++-- .../grpc/xds/GcpAuthenticationFilterTest.java | 301 +++++++++++++++++- .../grpc/xds/GrpcXdsClientImplDataTest.java | 15 +- .../test/java/io/grpc/xds/XdsTestUtils.java | 2 +- 4 files changed, 372 insertions(+), 54 deletions(-) diff --git a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java index add885c6416..b5568efe400 100644 --- a/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java +++ b/xds/src/main/java/io/grpc/xds/GcpAuthenticationFilter.java @@ -16,8 +16,13 @@ package io.grpc.xds; +import static com.google.common.base.Preconditions.checkNotNull; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; + import com.google.auth.oauth2.ComputeEngineCredentials; import com.google.auth.oauth2.IdTokenCredentials; +import com.google.common.annotations.VisibleForTesting; import com.google.common.primitives.UnsignedLongs; import com.google.protobuf.Any; import com.google.protobuf.InvalidProtocolBufferException; @@ -34,8 +39,11 @@ import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.Status; +import io.grpc.StatusOr; import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; +import io.grpc.xds.XdsConfig.XdsClusterConfig; import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; import java.util.LinkedHashMap; import java.util.Map; @@ -52,6 +60,13 @@ final class GcpAuthenticationFilter implements Filter { static final String TYPE_URL = "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig"; + final String filterInstanceName; + + GcpAuthenticationFilter(String name) { + filterInstanceName = checkNotNull(name, "name"); + } + + static final class Provider implements Filter.Provider { @Override public String[] typeUrls() { @@ -65,7 +80,7 @@ public boolean isClientFilter() { @Override public GcpAuthenticationFilter newInstance(String name) { - return new GcpAuthenticationFilter(); + return new GcpAuthenticationFilter(name); } @Override @@ -119,34 +134,57 @@ public ClientInterceptor buildClientInterceptor(FilterConfig config, public ClientCall interceptCall( MethodDescriptor method, CallOptions callOptions, Channel next) { - /*String clusterName = callOptions.getOption(XdsAttributes.ATTR_CLUSTER_NAME); + String clusterName = callOptions.getOption(CLUSTER_SELECTION_KEY); if (clusterName == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s does not contain cluster resource", filterInstanceName))); + } + + if (!clusterName.startsWith("cluster:")) { return next.newCall(method, callOptions); - }*/ - - // TODO: Fetch the CDS resource for the cluster. - // If the CDS resource is not available, fail the RPC with Status.UNAVAILABLE. - - // TODO: Extract the audience from the CDS resource metadata. - // If the audience is not found or is in the wrong format, fail the RPC. - String audience = "TEST_AUDIENCE"; - - try { - CallCredentials existingCallCredentials = callOptions.getCredentials(); - CallCredentials newCallCredentials = - getCallCredentials(callCredentialsCache, audience, credentials); - if (existingCallCredentials != null) { - callOptions = callOptions.withCallCredentials( - new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); - } else { - callOptions = callOptions.withCallCredentials(newCallCredentials); - } } - catch (Exception e) { - // If we fail to attach CallCredentials due to any reason, return a FailingClientCall - return new FailingClientCall<>(Status.UNAUTHENTICATED - .withDescription("Failed to attach CallCredentials.") - .withCause(e)); + XdsConfig xdsConfig = callOptions.getOption(XDS_CONFIG_CALL_OPTION_KEY); + if (xdsConfig == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s does not contain xds configuration", + filterInstanceName, clusterName))); + } + StatusOr xdsCluster = + xdsConfig.getClusters().get(clusterName.substring("cluster:".length())); + if (xdsCluster == null) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format( + "GCP Authn for %s with %s - xds cluster config does not contain xds cluster", + filterInstanceName, clusterName))); + } + if (!xdsCluster.hasValue()) { + return new FailingClientCall<>(xdsCluster.getStatus()); + } + Object audienceObj = + xdsCluster.getValue().getClusterResource().parsedMetadata().get(filterInstanceName); + if (audienceObj == null) { + return next.newCall(method, callOptions); + } + if (!(audienceObj instanceof AudienceWrapper)) { + return new FailingClientCall<>( + Status.UNAVAILABLE.withDescription( + String.format("GCP Authn found wrong type in %s metadata: %s=%s", + clusterName, filterInstanceName, audienceObj.getClass()))); + } + AudienceWrapper audience = (AudienceWrapper) audienceObj; + CallCredentials existingCallCredentials = callOptions.getCredentials(); + CallCredentials newCallCredentials = + getCallCredentials(callCredentialsCache, audience.audience, credentials); + if (existingCallCredentials != null) { + callOptions = callOptions.withCallCredentials( + new CompositeCallCredentials(existingCallCredentials, newCallCredentials)); + } else { + callOptions = callOptions.withCallCredentials(newCallCredentials); } return next.newCall(method, callOptions); } @@ -186,9 +224,11 @@ public String typeUrl() { } /** An implementation of {@link ClientCall} that fails when started. */ - private static final class FailingClientCall extends ClientCall { + @VisibleForTesting + static final class FailingClientCall extends ClientCall { - private final Status error; + @VisibleForTesting + final Status error; public FailingClientCall(Status error) { this.error = error; @@ -235,13 +275,21 @@ V getOrInsert(K key, Function create) { static class AudienceMetadataParser implements MetadataValueParser { + static final class AudienceWrapper { + final String audience; + + AudienceWrapper(String audience) { + this.audience = checkNotNull(audience); + } + } + @Override public String getTypeUrl() { return "type.googleapis.com/envoy.extensions.filters.http.gcp_authn.v3.Audience"; } @Override - public String parse(Any any) throws ResourceInvalidException { + public AudienceWrapper parse(Any any) throws ResourceInvalidException { Audience audience; try { audience = any.unpack(Audience.class); @@ -253,7 +301,7 @@ public String parse(Any any) throws ResourceInvalidException { throw new ResourceInvalidException( "Audience URL is empty. Metadata value must contain a valid URL."); } - return url; + return new AudienceWrapper(url); } } } diff --git a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java index 52efaf9bd7b..a5e142b4094 100644 --- a/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java +++ b/xds/src/test/java/io/grpc/xds/GcpAuthenticationFilterTest.java @@ -17,25 +17,60 @@ package io.grpc.xds; import static com.google.common.truth.Truth.assertThat; +import static io.grpc.xds.XdsNameResolver.CLUSTER_SELECTION_KEY; +import static io.grpc.xds.XdsNameResolver.XDS_CONFIG_CALL_OPTION_KEY; +import static io.grpc.xds.XdsTestUtils.CLUSTER_NAME; +import static io.grpc.xds.XdsTestUtils.EDS_NAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_HOSTNAME; +import static io.grpc.xds.XdsTestUtils.ENDPOINT_PORT; +import static io.grpc.xds.XdsTestUtils.RDS_NAME; +import static io.grpc.xds.XdsTestUtils.buildRouteConfiguration; +import static io.grpc.xds.XdsTestUtils.getWrrLbConfigAsMap; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.protobuf.Any; import com.google.protobuf.Empty; import com.google.protobuf.Message; import com.google.protobuf.UInt64Value; +import io.envoyproxy.envoy.config.route.v3.RouteConfiguration; import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.GcpAuthnFilterConfig; import io.envoyproxy.envoy.extensions.filters.http.gcp_authn.v3.TokenCacheConfig; import io.grpc.CallOptions; import io.grpc.Channel; +import io.grpc.ClientCall; import io.grpc.ClientInterceptor; import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.StatusOr; +import io.grpc.inprocess.InProcessServerBuilder; import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.Endpoints.LbEndpoint; +import io.grpc.xds.Endpoints.LocalityLbEndpoints; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; +import io.grpc.xds.GcpAuthenticationFilter.FailingClientCall; import io.grpc.xds.GcpAuthenticationFilter.GcpAuthenticationConfig; +import io.grpc.xds.XdsClusterResource.CdsUpdate; +import io.grpc.xds.XdsConfig.XdsClusterConfig; +import io.grpc.xds.XdsConfig.XdsClusterConfig.EndpointConfig; +import io.grpc.xds.XdsEndpointResource.EdsUpdate; +import io.grpc.xds.XdsListenerResource.LdsUpdate; +import io.grpc.xds.XdsRouteConfigureResource.RdsUpdate; +import io.grpc.xds.client.Locality; +import io.grpc.xds.client.XdsResourceType; +import io.grpc.xds.client.XdsResourceType.ResourceInvalidException; +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.JUnit4; @@ -46,6 +81,17 @@ public class GcpAuthenticationFilterTest { private static final GcpAuthenticationFilter.Provider FILTER_PROVIDER = new GcpAuthenticationFilter.Provider(); + private static final String serverName = InProcessServerBuilder.generateName(); + private static final LdsUpdate ldsUpdate = getLdsUpdate(); + private static final EdsUpdate edsUpdate = getEdsUpdate(); + private static final RdsUpdate rdsUpdate = getRdsUpdate(); + private static final CdsUpdate cdsUpdate = getCdsUpdate(); + + @Test + public void testNewFilterInstancesPerFilterName() { + assertThat(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")) + .isNotEqualTo(new GcpAuthenticationFilter("FILTER_INSTANCE_NAME1")); + } @Test public void filterType_clientOnly() { @@ -92,35 +138,258 @@ public void testParseFilterConfig_withInvalidMessageType() { } @Test - public void testClientInterceptor_createsAndReusesCachedCredentials() { + public void testClientInterceptor_success() throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); - GcpAuthenticationFilter filter = new GcpAuthenticationFilter(); - - // Create interceptor + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); + + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); - // Mock channel and capture CallOptions + verify(mockChannel).newCall(eq(methodDescriptor), callOptionsCaptor.capture()); + CallOptions capturedOptions = callOptionsCaptor.getAllValues().get(0); + assertNotNull(capturedOptions.getCredentials()); + } + + @Test + public void testClientInterceptor_createsAndReusesCachedCredentials() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); Channel mockChannel = Mockito.mock(Channel.class); ArgumentCaptor callOptionsCaptor = ArgumentCaptor.forClass(CallOptions.class); - // Execute interception twice to check caching - interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); - interceptor.interceptCall(methodDescriptor, CallOptions.DEFAULT, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); - // Capture and verify CallOptions for CallCredentials presence - Mockito.verify(mockChannel, Mockito.times(2)) + verify(mockChannel, Mockito.times(2)) .newCall(eq(methodDescriptor), callOptionsCaptor.capture()); - - // Retrieve the CallOptions captured from both calls CallOptions firstCapturedOptions = callOptionsCaptor.getAllValues().get(0); CallOptions secondCapturedOptions = callOptionsCaptor.getAllValues().get(1); - - // Ensure that CallCredentials was added assertNotNull(firstCapturedOptions.getCredentials()); assertNotNull(secondCapturedOptions.getCredentials()); - - // Ensure that the CallCredentials from both calls are the same, indicating caching assertSame(firstCapturedOptions.getCredentials(), secondCapturedOptions.getCredentials()); } + + @Test + public void testClientInterceptor_withoutClusterSelectionKey() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT; + + ClientCall call = interceptor.interceptCall( + methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain cluster resource"); + } + + @Test + public void testClientInterceptor_clusterSelectionKeyWithoutPrefix() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + Channel mockChannel = mock(Channel.class); + + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + verify(mockChannel).newCall(methodDescriptor, callOptionsWithXds); + } + + @Test + public void testClientInterceptor_xdsConfigDoesNotExist() throws Exception { + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0"); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds configuration"); + } + + @Test + public void testClientInterceptor_incorrectClusterName() throws Exception { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + cdsUpdate, + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster("custer0", StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("does not contain xds cluster"); + } + + @Test + public void testClientInterceptor_statusOrError() throws Exception { + StatusOr errorCluster = + StatusOr.fromStatus(Status.NOT_FOUND.withDescription("Cluster resource not found")); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, errorCluster).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("Cluster resource not found"); + } + + @Test + public void testClientInterceptor_notAudienceWrapper() + throws IOException, ResourceInvalidException { + XdsConfig.XdsClusterConfig clusterConfig = new XdsConfig.XdsClusterConfig( + CLUSTER_NAME, + getCdsUpdateWithIncorrectAudienceWrapper(), + new EndpointConfig(StatusOr.fromValue(edsUpdate))); + XdsConfig defaultXdsConfig = new XdsConfig.XdsConfigBuilder() + .setListener(ldsUpdate) + .setRoute(rdsUpdate) + .setVirtualHost(rdsUpdate.virtualHosts.get(0)) + .addCluster(CLUSTER_NAME, StatusOr.fromValue(clusterConfig)).build(); + CallOptions callOptionsWithXds = CallOptions.DEFAULT + .withOption(CLUSTER_SELECTION_KEY, "cluster:cluster0") + .withOption(XDS_CONFIG_CALL_OPTION_KEY, defaultXdsConfig); + GcpAuthenticationConfig config = new GcpAuthenticationConfig(10); + GcpAuthenticationFilter filter = new GcpAuthenticationFilter("FILTER_INSTANCE_NAME"); + ClientInterceptor interceptor = filter.buildClientInterceptor(config, null, null); + MethodDescriptor methodDescriptor = TestMethodDescriptors.voidMethod(); + Channel mockChannel = Mockito.mock(Channel.class); + + ClientCall call = + interceptor.interceptCall(methodDescriptor, callOptionsWithXds, mockChannel); + + assertTrue(call instanceof FailingClientCall); + FailingClientCall clientCall = (FailingClientCall) call; + assertThat(clientCall.error.getDescription()).contains("GCP Authn found wrong type"); + } + + private static LdsUpdate getLdsUpdate() { + Filter.NamedFilterConfig routerFilterConfig = new Filter.NamedFilterConfig( + serverName, RouterFilter.ROUTER_CONFIG); + HttpConnectionManager httpConnectionManager = HttpConnectionManager.forRdsName( + 0L, RDS_NAME, Collections.singletonList(routerFilterConfig)); + return XdsListenerResource.LdsUpdate.forApiListener(httpConnectionManager); + } + + private static RdsUpdate getRdsUpdate() { + RouteConfiguration routeConfiguration = + buildRouteConfiguration(serverName, RDS_NAME, CLUSTER_NAME); + XdsResourceType.Args args = new XdsResourceType.Args(null, "0", "0", null, null, null); + try { + return XdsRouteConfigureResource.getInstance().doParse(args, routeConfiguration); + } catch (ResourceInvalidException ex) { + return null; + } + } + + private static EdsUpdate getEdsUpdate() { + Map lbEndpointsMap = new HashMap<>(); + LbEndpoint lbEndpoint = LbEndpoint.create( + serverName, ENDPOINT_PORT, 0, true, ENDPOINT_HOSTNAME, ImmutableMap.of()); + lbEndpointsMap.put( + Locality.create("", "", ""), + LocalityLbEndpoints.create(ImmutableList.of(lbEndpoint), 10, 0, ImmutableMap.of())); + return new XdsEndpointResource.EdsUpdate(EDS_NAME, lbEndpointsMap, Collections.emptyList()); + } + + private static CdsUpdate getCdsUpdate() { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", new AudienceWrapper("TEST_AUDIENCE")); + try { + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } catch (IOException ex) { + return null; + } + } + + private static CdsUpdate getCdsUpdateWithIncorrectAudienceWrapper() throws IOException { + ImmutableMap.Builder parsedMetadata = ImmutableMap.builder(); + parsedMetadata.put("FILTER_INSTANCE_NAME", "TEST_AUDIENCE"); + CdsUpdate.Builder cdsUpdate = CdsUpdate.forEds( + CLUSTER_NAME, EDS_NAME, null, null, null, null, false) + .lbPolicyConfig(getWrrLbConfigAsMap()); + return cdsUpdate.parsedMetadata(parsedMetadata.build()).build(); + } } diff --git a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java index bfaa17245cf..e53ed9047ca 100644 --- a/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java +++ b/xds/src/test/java/io/grpc/xds/GrpcXdsClientImplDataTest.java @@ -129,6 +129,7 @@ import io.grpc.xds.Endpoints.LbEndpoint; import io.grpc.xds.Endpoints.LocalityLbEndpoints; import io.grpc.xds.Filter.FilterConfig; +import io.grpc.xds.GcpAuthenticationFilter.AudienceMetadataParser.AudienceWrapper; import io.grpc.xds.MetadataRegistry.MetadataValueParser; import io.grpc.xds.RouteLookupServiceClusterSpecifierPlugin.RlsPluginConfig; import io.grpc.xds.VirtualHost.Route; @@ -2417,8 +2418,7 @@ public Object parse(Any value) { } @Test - public void processCluster_parsesAudienceMetadata() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void processCluster_parsesAudienceMetadata() throws Exception { MetadataRegistry.getInstance(); Audience audience = Audience.newBuilder() @@ -2462,7 +2462,10 @@ public void processCluster_parsesAudienceMetadata() "FILTER_METADATA", ImmutableMap.of( "key1", "value1", "key2", 42.0)); - assertThat(update.parsedMetadata()).isEqualTo(expectedParsedMetadata); + assertThat(update.parsedMetadata().get("FILTER_METADATA")) + .isEqualTo(expectedParsedMetadata.get("FILTER_METADATA")); + assertThat(update.parsedMetadata().get("AUDIENCE_METADATA")) + .isInstanceOf(AudienceWrapper.class); } @Test @@ -2519,8 +2522,7 @@ public void processCluster_parsesAddressMetadata() throws Exception { } @Test - public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void processCluster_metadataKeyCollision_resolvesToTypedMetadata() throws Exception { MetadataRegistry metadataRegistry = MetadataRegistry.getInstance(); MetadataValueParser testParser = @@ -2575,8 +2577,7 @@ public Object parse(Any value) { } @Test - public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() - throws ResourceInvalidException, InvalidProtocolBufferException { + public void parseNonAggregateCluster_withHttp11ProxyTransportSocket() throws Exception { XdsClusterResource.isEnabledXdsHttpConnect = true; Http11ProxyUpstreamTransport http11ProxyUpstreamTransport = diff --git a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java index 9f90777be3d..52953ef5407 100644 --- a/xds/src/test/java/io/grpc/xds/XdsTestUtils.java +++ b/xds/src/test/java/io/grpc/xds/XdsTestUtils.java @@ -291,7 +291,7 @@ static Map createMinimalLbEndpointsMap(String ser } @SuppressWarnings("unchecked") - private static ImmutableMap getWrrLbConfigAsMap() throws IOException { + static ImmutableMap getWrrLbConfigAsMap() throws IOException { String lbConfigStr = "{\"wrr_locality_experimental\" : " + "{ \"childPolicy\" : [{\"round_robin\" : {}}]}}"; From b7df168b5c83c37e5c408a780ab97bce745f2cba Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 3 Apr 2025 06:48:37 +0000 Subject: [PATCH 21/27] core: Fixed internal review points --- core/src/main/java/io/grpc/internal/DelayedStream.java | 1 + .../java/io/grpc/internal/DelayedClientTransportTest.java | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 15b45ea5d3b..27dd84733b2 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -129,6 +129,7 @@ final Runnable setStream(ClientStream stream) { boolean cancelOldStream = false; synchronized (this) { + // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { oldStream = realStream; cancelOldStream = listener != null; diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 394f8e2da86..dd506ec1a24 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -279,14 +279,14 @@ public void testNewStreamThenShutDownNow() { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); - assertEquals(1,delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); - assertEquals(0,delayedTransport.getPendingStreamsCount()); + assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } From 58f39fa661d1020a3cdc3a178f485d0049ee5863 Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Wed, 19 Mar 2025 06:46:21 +0000 Subject: [PATCH 22/27] core: Added changes to DelayedStream.setStream() should cancel the provided stream if not using it --- .../grpc/internal/DelayedClientTransport.java | 21 +++-- .../java/io/grpc/internal/DelayedStream.java | 13 ++- .../internal/DelayedClientTransportTest.java | 91 ++++++++++++++----- .../io/grpc/internal/DelayedStreamTest.java | 51 ++++++++++- .../grpc/internal/ManagedChannelImplTest.java | 11 ++- 5 files changed, 148 insertions(+), 39 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java index 8ff755af3eb..e919a47ae2e 100644 --- a/core/src/main/java/io/grpc/internal/DelayedClientTransport.java +++ b/core/src/main/java/io/grpc/internal/DelayedClientTransport.java @@ -178,13 +178,6 @@ private PendingStream createPendingStream(PickSubchannelArgs args, ClientStreamT if (args.getCallOptions().isWaitForReady() && pickResult != null && pickResult.hasResult()) { pendingStream.lastPickStatus = pickResult.getStatus(); } - pendingStreams.add(pendingStream); - if (getPendingStreamsCount() == 1) { - syncContext.executeLater(reportTransportInUse); - } - for (ClientStreamTracer streamTracer : tracers) { - streamTracer.createPendingStream(); - } return pendingStream; } @@ -363,6 +356,20 @@ private PendingStream(PickSubchannelArgs args, ClientStreamTracer[] tracers) { this.tracers = tracers; } + @Override + public void start(ClientStreamListener listener) { + super.start(listener); + synchronized (lock) { + pendingStreams.add(this); + if (getPendingStreamsCount() == 1) { + syncContext.executeLater(reportTransportInUse); + } + for (ClientStreamTracer streamTracer : tracers) { + streamTracer.createPendingStream(); + } + } + } + /** Runnable may be null. */ private Runnable createRealStream(ClientTransport transport, String authorityOverride) { ClientStream realStream; diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 2ca4630d6a1..0bb6372ead0 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -125,11 +125,22 @@ public void appendTimeoutInsight(InsightBuilder insight) { @CheckReturnValue final Runnable setStream(ClientStream stream) { ClientStreamListener savedListener; + ClientStream oldStream = null; + boolean cancelOldStream = false; + synchronized (this) { - // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { + oldStream = realStream; + cancelOldStream = listener != null; + } + if (oldStream != null && !cancelOldStream) { return null; } + + if (cancelOldStream) { + oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream")); + } + setRealStream(checkNotNull(stream, "stream")); savedListener = listener; if (savedListener == null) { diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 902c2835a92..0128f1fdbd1 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -170,6 +170,7 @@ public void uncaughtException(Thread t, Throwable e) { @Test public void newStreamThenAssignTransportThenShutdown() { ClientStream stream = delayedTransport.newStream(method, headers, callOptions, tracers); + stream.start(streamListener); assertEquals(1, delayedTransport.getPendingStreamsCount()); assertTrue(stream instanceof DelayedStream); delayedTransport.reprocess(mockPicker); @@ -177,12 +178,12 @@ public void uncaughtException(Thread t, Throwable e) { delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener).transportTerminated(); + fakeExecutor.runDueTasks(); assertEquals(0, fakeExecutor.runDueTasks()); verify(mockRealTransport).newStream( same(method), same(headers), same(callOptions), ArgumentMatchers.any()); - stream.start(streamListener); - verify(mockRealStream).start(same(streamListener)); + verify(mockRealStream).start(any(ClientStreamListener.class)); } @Test public void transportTerminatedThenAssignTransport() { @@ -271,14 +272,41 @@ public void uncaughtException(Thread t, Throwable e) { verifyNoMoreInteractions(mockRealStream); } + @Test + public void newStreamThenShutDownNow() { + ClientStream stream = delayedTransport.newStream( + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); + assertEquals(1,delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + verify(transportListener).transportShutdown(any(Status.class)); + verify(transportListener).transportTerminated(); + verify(streamListener).closed( + statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertEquals(0,delayedTransport.getPendingStreamsCount()); + assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); + } + + @Test + public void testDelayedClientTransportPendingStreamsOnShutDown() { + ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers); + ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers); + assertEquals(0, delayedTransport.getPendingStreamsCount()); + clientStream.start(streamListener); + clientStream1.start(streamListener); + assertEquals(2, delayedTransport.getPendingStreamsCount()); + delayedTransport.shutdownNow(Status.UNAVAILABLE); + assertEquals(0, delayedTransport.getPendingStreamsCount()); + } + @Test public void newStreamThenShutdownTransportThenCancelStream() { ClientStream stream = delayedTransport.newStream( - method, new Metadata(), CallOptions.DEFAULT, tracers); + method, new Metadata(), CallOptions.DEFAULT, tracers); + stream.start(streamListener); delayedTransport.shutdown(SHUTDOWN_STATUS); verify(transportListener).transportShutdown(same(SHUTDOWN_STATUS)); verify(transportListener, times(0)).transportTerminated(); assertEquals(1, delayedTransport.getPendingStreamsCount()); - stream.start(streamListener); stream.cancel(Status.CANCELLED); verify(transportListener).transportTerminated(); assertEquals(0, delayedTransport.getPendingStreamsCount()); @@ -348,33 +376,39 @@ public void uncaughtException(Thread t, Throwable e) { ff1.start(mock(ClientStreamListener.class)); ff1.halfClose(); PickSubchannelArgsMatcher ff1args = new PickSubchannelArgsMatcher(method, headers, - failFastCallOptions); + failFastCallOptions); + transportListener.transportInUse(true); verify(transportListener).transportInUse(true); DelayedStream ff2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff2args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); DelayedStream ff3 = (DelayedStream) delayedTransport.newStream( - method, headers, failFastCallOptions, tracers); + method, headers, failFastCallOptions, tracers); + ff3.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff3args = new PickSubchannelArgsMatcher(method, headers, failFastCallOptions); DelayedStream ff4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, failFastCallOptions, tracers); + method2, headers2, failFastCallOptions, tracers); + ff4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher ff4args = new PickSubchannelArgsMatcher(method2, headers2, failFastCallOptions); // Wait-for-ready streams FakeClock wfr3Executor = new FakeClock(); DelayedStream wfr1 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr1.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr1args = new PickSubchannelArgsMatcher(method, headers, waitForReadyCallOptions); DelayedStream wfr2 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr2.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr2args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); CallOptions wfr3callOptions = waitForReadyCallOptions.withExecutor( - wfr3Executor.getScheduledExecutorService()); + wfr3Executor.getScheduledExecutorService()); DelayedStream wfr3 = (DelayedStream) delayedTransport.newStream( method, headers, wfr3callOptions, tracers); wfr3.start(mock(ClientStreamListener.class)); @@ -382,7 +416,8 @@ public void uncaughtException(Thread t, Throwable e) { PickSubchannelArgsMatcher wfr3args = new PickSubchannelArgsMatcher(method, headers, wfr3callOptions); DelayedStream wfr4 = (DelayedStream) delayedTransport.newStream( - method2, headers2, waitForReadyCallOptions, tracers); + method2, headers2, waitForReadyCallOptions, tracers); + wfr4.start(mock(ClientStreamListener.class)); PickSubchannelArgsMatcher wfr4args = new PickSubchannelArgsMatcher(method2, headers2, waitForReadyCallOptions); @@ -478,7 +513,8 @@ public void uncaughtException(Thread t, Throwable e) { // New streams will use the last picker DelayedStream wfr5 = (DelayedStream) delayedTransport.newStream( - method, headers, waitForReadyCallOptions, tracers); + method, headers, waitForReadyCallOptions, tracers); + wfr5.start(mock(ClientStreamListener.class)); assertNull(wfr5.getRealStream()); inOrder.verify(picker).pickSubchannel( eqPickSubchannelArgs(method, headers, waitForReadyCallOptions)); @@ -626,12 +662,14 @@ public PickResult answer(InvocationOnMock invocation) throws Throwable { verify(picker, never()).pickSubchannel(any(PickSubchannelArgs.class)); Thread sideThread = new Thread("sideThread") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pick Subchannel and wait on barrier + ClientStream clientStream = + delayedTransport.newStream(method, headers, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread.start(); PickSubchannelArgsMatcher args = new PickSubchannelArgsMatcher(method, headers, callOptions); @@ -659,12 +697,14 @@ public void run() { ////////// Phase 2: reprocess() with a different picker // Create the second stream Thread sideThread2 = new Thread("sideThread2") { - @Override - public void run() { - // Will call pickSubchannel and wait on barrier - delayedTransport.newStream(method, headers2, callOptions, tracers); - } - }; + @Override + public void run() { + // Will call pickSubchannel and wait on barrier + ClientStream clientStream = delayedTransport + .newStream(method, headers2, callOptions, tracers); + clientStream.start(streamListener); + } + }; sideThread2.start(); // The second stream will see the first picker verify(picker, timeout(5000)).pickSubchannel(argThat(args2)); @@ -730,6 +770,7 @@ public void newStream_racesWithReprocessIdleMode() throws Exception { ClientStream stream = delayedTransport.newStream( method, headers, callOptions, tracers); stream.start(streamListener); + transportListener.transportInUse(true); assertTrue(delayedTransport.hasPendingStreams()); verify(transportListener).transportInUse(true); } diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index a47bea9f4ab..2902be027a9 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -46,6 +46,7 @@ import java.io.ByteArrayInputStream; import java.io.InputStream; import java.util.concurrent.TimeUnit; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.runner.RunWith; @@ -84,6 +85,39 @@ public void setStream_setAuthority() { inOrder.verify(realStream).start(any(ClientStreamListener.class)); } + @Test + public void testSetStreamReplaceOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + callMeMaybe(stream.setStream(newStream)); + // Verify old stream was canceled + verify(oldStream,never()).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + + @Test + public void testSetStreamStartCancelsOldStreamProperly() { + ClientStream oldStream = mock(ClientStream.class); + ClientStream newStream = mock(ClientStream.class); + + // First stream set, but never started + callMeMaybe(stream.setStream(oldStream)); + stream.start(listener); + try { + callMeMaybe(stream.setStream(newStream)); + } catch (IllegalStateException e) { + assertNotNull(e.getMessage()); + } + // Verify old stream was canceled + verify(oldStream).cancel(any(Status.class)); + // Ensure new stream is properly set + verifyNoMoreInteractions(newStream); + } + @Test(expected = IllegalStateException.class) public void start_afterStart() { stream.start(listener); @@ -333,17 +367,28 @@ public void setStreamTwice() { stream.start(listener); callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - callMeMaybe(stream.setStream(mock(ClientStream.class))); + try { + callMeMaybe(stream.setStream(mock(ClientStream.class))); + } catch (IllegalStateException e) { + assertEquals("realStream already set to realStream",e.getMessage()); + } stream.flush(); verify(realStream).flush(); } @Test public void cancelThenSetStream() { - stream.start(listener); - stream.cancel(Status.CANCELLED); + try { + stream.cancel(Status.CANCELLED); + Assert.fail("Should have thrown"); + } catch (IllegalStateException e) { + assertEquals("May only be called after start", e.getMessage()); + } callMeMaybe(stream.setStream(realStream)); + stream.start(listener); stream.isReady(); + verify(realStream).start(same(listener)); + verify(realStream).isReady(); verifyNoMoreInteractions(realStream); } diff --git a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java index 21ccf1095df..d1bf205205a 100644 --- a/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java +++ b/core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java @@ -2920,8 +2920,13 @@ public void idleMode_resetsDelayedTransportPicker() { // Move channel to idle timer.forwardNanos(TimeUnit.MILLISECONDS.toNanos(idleTimeoutMillis)); + executor.runDueTasks(); assertEquals(IDLE, channel.getState(false)); + //Force transport re-creation explicitly + channel.getState(true); + executor.runDueTasks(); + // This call should be buffered, but will move the channel out of idle ClientCall call2 = channel.newCall(method, CallOptions.DEFAULT); call2.start(mockCallListener2, new Metadata()); @@ -2947,15 +2952,15 @@ public void idleMode_resetsDelayedTransportPicker() { transportListener.transportReady(); when(mockPicker.pickSubchannel(any(PickSubchannelArgs.class))) - .thenReturn(PickResult.withSubchannel(subchannel)); + .thenReturn(PickResult.withSubchannel(subchannel)); updateBalancingStateSafely(helper2, READY, mockPicker); assertEquals(READY, channel.getState(false)); executor.runDueTasks(); // Verify the buffered call was drained verify(mockTransport).newStream( - same(method), any(Metadata.class), any(CallOptions.class), - ArgumentMatchers.any()); + same(method), any(Metadata.class), any(CallOptions.class), + ArgumentMatchers.any()); verify(mockStream).start(any(ClientStreamListener.class)); } From 013948debb707c5b1872c755445c43373428532a Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 20 Mar 2025 06:23:05 +0000 Subject: [PATCH 23/27] core: Added changes to DelayedStream.setStream() should cancel the provided stream if not using it --- .../internal/DelayedClientTransportTest.java | 13 +++++++++++-- .../io/grpc/internal/DelayedStreamTest.java | 19 ++++++++----------- 2 files changed, 19 insertions(+), 13 deletions(-) diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 0128f1fdbd1..394f8e2da86 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -226,8 +226,10 @@ public void uncaughtException(Thread t, Throwable e) { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); + assertEquals(1, delayedTransport.getPendingStreamsCount()); stream.cancel(Status.CANCELLED); + assertEquals(0, delayedTransport.getPendingStreamsCount()); verify(streamListener).closed( same(Status.CANCELLED), same(RpcProgress.PROCESSED), any(Metadata.class)); @@ -273,7 +275,7 @@ public void uncaughtException(Thread t, Throwable e) { } @Test - public void newStreamThenShutDownNow() { + public void testNewStreamThenShutDownNow() { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); @@ -283,6 +285,7 @@ public void newStreamThenShutDownNow() { verify(transportListener).transportTerminated(); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); + assertEquals(0,delayedTransport.getPendingStreamsCount()); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } @@ -291,11 +294,14 @@ public void newStreamThenShutDownNow() { public void testDelayedClientTransportPendingStreamsOnShutDown() { ClientStream clientStream = delayedTransport.newStream(method, headers, callOptions, tracers); ClientStream clientStream1 = delayedTransport.newStream(method, headers, callOptions, tracers); + assertEquals(0, delayedTransport.getPendingStreamsCount()); clientStream.start(streamListener); clientStream1.start(streamListener); + assertEquals(2, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdownNow(Status.UNAVAILABLE); + assertEquals(0, delayedTransport.getPendingStreamsCount()); } @@ -350,7 +356,9 @@ public void testDelayedClientTransportPendingStreamsOnShutDown() { assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } - @Test public void reprocessSemantics() { + @Test + @SuppressWarnings("DirectInvocationOnMock") + public void reprocessSemantics() { CallOptions failFastCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 1); CallOptions waitForReadyCallOptions = CallOptions.DEFAULT.withOption(SHARD_ID, 2) .withWaitForReady(); @@ -754,6 +762,7 @@ public void reprocess_addOptionalLabelCallsTracer() throws Exception { } @Test + @SuppressWarnings("DirectInvocationOnMock") public void newStream_racesWithReprocessIdleMode() throws Exception { SubchannelPicker picker = new SubchannelPicker() { @Override public PickResult pickSubchannel(PickSubchannelArgs args) { diff --git a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java index 2902be027a9..bcc0b7f8675 100644 --- a/core/src/test/java/io/grpc/internal/DelayedStreamTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedStreamTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -107,11 +108,8 @@ public void testSetStreamStartCancelsOldStreamProperly() { // First stream set, but never started callMeMaybe(stream.setStream(oldStream)); stream.start(listener); - try { - callMeMaybe(stream.setStream(newStream)); - } catch (IllegalStateException e) { - assertNotNull(e.getMessage()); - } + assertThrows(IllegalStateException.class, + () -> callMeMaybe(stream.setStream(mock(ClientStream.class)))); // Verify old stream was canceled verify(oldStream).cancel(any(Status.class)); // Ensure new stream is properly set @@ -363,15 +361,14 @@ public void setStreamThenStartThenCancelled() { } @Test - public void setStreamTwice() { + public void testSetStreamTwice() { stream.start(listener); callMeMaybe(stream.setStream(realStream)); verify(realStream).start(any(ClientStreamListener.class)); - try { - callMeMaybe(stream.setStream(mock(ClientStream.class))); - } catch (IllegalStateException e) { - assertEquals("realStream already set to realStream",e.getMessage()); - } + IllegalStateException e = assertThrows(IllegalStateException.class, () -> + callMeMaybe(stream.setStream(mock(ClientStream.class))) + ); + assertEquals("realStream already set to realStream", e.getMessage()); stream.flush(); verify(realStream).flush(); } From a2220ebbc8600806aa86f9e29f964b847cd19b08 Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Fri, 21 Mar 2025 07:01:18 +0000 Subject: [PATCH 24/27] core: Fixed internal review points --- core/src/main/java/io/grpc/internal/DelayedStream.java | 2 -- 1 file changed, 2 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 0bb6372ead0..15b45ea5d3b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -136,11 +136,9 @@ final Runnable setStream(ClientStream stream) { if (oldStream != null && !cancelOldStream) { return null; } - if (cancelOldStream) { oldStream.cancel(Status.CANCELLED.withDescription("Replaced by a new Stream")); } - setRealStream(checkNotNull(stream, "stream")); savedListener = listener; if (savedListener == null) { From 67b2f3aef0592379e2a0d1b51ff3262d156e2804 Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 3 Apr 2025 06:48:37 +0000 Subject: [PATCH 25/27] core: Fixed internal review points --- core/src/main/java/io/grpc/internal/DelayedStream.java | 1 + .../java/io/grpc/internal/DelayedClientTransportTest.java | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 15b45ea5d3b..27dd84733b2 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -129,6 +129,7 @@ final Runnable setStream(ClientStream stream) { boolean cancelOldStream = false; synchronized (this) { + // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { oldStream = realStream; cancelOldStream = listener != null; diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index 394f8e2da86..dd506ec1a24 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -279,14 +279,14 @@ public void testNewStreamThenShutDownNow() { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); - assertEquals(1,delayedTransport.getPendingStreamsCount()); + assertEquals(1, delayedTransport.getPendingStreamsCount()); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); - assertEquals(0,delayedTransport.getPendingStreamsCount()); + assertEquals(0, delayedTransport.getPendingStreamsCount()); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } From 696dd52454676c4b196b2a0f3d50bcaeb956d571 Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 3 Apr 2025 06:59:30 +0000 Subject: [PATCH 26/27] Revert "core: Fixed internal review points" This reverts commit 67b2f3aef0592379e2a0d1b51ff3262d156e2804. --- core/src/main/java/io/grpc/internal/DelayedStream.java | 1 - .../java/io/grpc/internal/DelayedClientTransportTest.java | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 27dd84733b2..15b45ea5d3b 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -129,7 +129,6 @@ final Runnable setStream(ClientStream stream) { boolean cancelOldStream = false; synchronized (this) { - // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { oldStream = realStream; cancelOldStream = listener != null; diff --git a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java index dd506ec1a24..394f8e2da86 100644 --- a/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java +++ b/core/src/test/java/io/grpc/internal/DelayedClientTransportTest.java @@ -279,14 +279,14 @@ public void testNewStreamThenShutDownNow() { ClientStream stream = delayedTransport.newStream( method, new Metadata(), CallOptions.DEFAULT, tracers); stream.start(streamListener); - assertEquals(1, delayedTransport.getPendingStreamsCount()); + assertEquals(1,delayedTransport.getPendingStreamsCount()); delayedTransport.shutdownNow(Status.UNAVAILABLE); verify(transportListener).transportShutdown(any(Status.class)); verify(transportListener).transportTerminated(); verify(streamListener).closed( statusCaptor.capture(), any(RpcProgress.class), any(Metadata.class)); - assertEquals(0, delayedTransport.getPendingStreamsCount()); + assertEquals(0,delayedTransport.getPendingStreamsCount()); assertEquals(Status.Code.UNAVAILABLE, statusCaptor.getValue().getCode()); } From d1602adbfcec1d8e5fb9738b8fc08d42cb3e738b Mon Sep 17 00:00:00 2001 From: ssangamesh Date: Thu, 3 Apr 2025 08:38:39 +0000 Subject: [PATCH 27/27] core: Fixed internal review points --- core/src/main/java/io/grpc/internal/DelayedStream.java | 1 + 1 file changed, 1 insertion(+) diff --git a/core/src/main/java/io/grpc/internal/DelayedStream.java b/core/src/main/java/io/grpc/internal/DelayedStream.java index 15b45ea5d3b..27dd84733b2 100644 --- a/core/src/main/java/io/grpc/internal/DelayedStream.java +++ b/core/src/main/java/io/grpc/internal/DelayedStream.java @@ -129,6 +129,7 @@ final Runnable setStream(ClientStream stream) { boolean cancelOldStream = false; synchronized (this) { + // If realStream != null, then either setStream() or cancel() has been called. if (realStream != null) { oldStream = realStream; cancelOldStream = listener != null;