Skip to content

Commit

Permalink
Prevent concurrent execution of the same mutable request object (#3197)
Browse files Browse the repository at this point in the history
Motivation:

Our `HttpRequestMetaData` object is mutable, and we expect users to create a new request every time they need to make a new call. Sequential retries are acceptable, but concurrent execution can corrupt internal state. While these expectations are more clear for HTTP users, with gRPC it gets less obvious that they can not subscribe to the same returned `Single<Message>` concurrently.

Modifications:
- Enhance `FilterableClientToClient` to protect users from concurrent execution of the same request, while still allowing sequential retries.
- Verify concurrent execution is not allowed for HTTP and gRPC.

Result:

Users get `RejectedSubscribeException` if they subscribe to the same Single that shares underlying meta-data object concurrently. This is the best effort to let users know they misused the client.

Risk:

The change may unexpectedly break existing use-cases. To give users some time to adjust their code, we temporarily introduce a system property to opt-out from this new behavior: `-Dio.servicetalk.http.netty.skipConcurrentRequestCheck=true`.
  • Loading branch information
idelpivnitskiy authored Feb 28, 2025
1 parent 57fac32 commit 7a7341a
Show file tree
Hide file tree
Showing 6 changed files with 627 additions and 21 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Copyright © 2025 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.grpc.netty;

import io.servicetalk.concurrent.CompletableSource;
import io.servicetalk.concurrent.api.Executor;
import io.servicetalk.concurrent.api.ExecutorExtension;
import io.servicetalk.concurrent.api.Processors;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.SourceAdapters;
import io.servicetalk.concurrent.internal.RejectedSubscribeException;
import io.servicetalk.grpc.api.DefaultGrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcClientMetadata;
import io.servicetalk.grpc.api.GrpcServiceContext;
import io.servicetalk.grpc.api.GrpcStatusException;
import io.servicetalk.grpc.netty.TesterProto.TestRequest;
import io.servicetalk.grpc.netty.TesterProto.TestResponse;
import io.servicetalk.grpc.netty.TesterProto.Tester.ClientFactory;
import io.servicetalk.grpc.netty.TesterProto.Tester.TesterClient;
import io.servicetalk.grpc.netty.TesterProto.Tester.TesterService;
import io.servicetalk.http.api.HttpServiceContext;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.api.StreamingHttpResponseFactory;
import io.servicetalk.http.api.StreamingHttpServiceFilter;
import io.servicetalk.transport.api.ServerContext;

import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.api.ExecutorExtension.withCachedExecutor;
import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress;
import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertThrows;

class ConcurrentGrpcRequestTest {

private enum AsyncVariant {
TEST,
TEST_REQUEST_STREAM,
TEST_RESPONSE_STREAM,
TEST_BI_DI_STREAM,
BLOCKING_TEST,
BLOCKING_TEST_REQUEST_STREAM,
BLOCKING_TEST_RESPONSE_STREAM,
BLOCKING_TEST_BI_DI_STREAM
}

@RegisterExtension
static final ExecutorExtension<Executor> executorExtension = withCachedExecutor().setClassLevel(true);

private final CountDownLatch receivedFirstRequest = new CountDownLatch(1);
private final AtomicInteger receivedRequests = new AtomicInteger();
private final CompletableSource.Processor responseProcessor = Processors.newCompletableProcessor();
private final ServerContext serverCtx;

ConcurrentGrpcRequestTest() throws Exception {
serverCtx = GrpcServers.forAddress(localAddress(0))
.initializeHttp(builder -> builder.appendServiceFilter(s -> new StreamingHttpServiceFilter(s) {
@Override
public Single<StreamingHttpResponse> handle(HttpServiceContext ctx,
StreamingHttpRequest request,
StreamingHttpResponseFactory responseFactory) {
receivedFirstRequest.countDown();
Single<StreamingHttpResponse> response = delegate().handle(ctx, request, responseFactory);
if (receivedRequests.incrementAndGet() == 1) {
return response.concat(SourceAdapters.fromSource(responseProcessor));
}
return response;
}
}))
.listenAndAwait(new TesterService() {

@Override
public Single<TestResponse> test(GrpcServiceContext ctx, TestRequest request) {
return newResponse();
}

@Override
public Single<TestResponse> testRequestStream(GrpcServiceContext ctx,
Publisher<TestRequest> request) {
return newResponse();
}

@Override
public Publisher<TestResponse> testResponseStream(GrpcServiceContext ctx, TestRequest request) {
return newResponse().toPublisher();
}

@Override
public Publisher<TestResponse> testBiDiStream(GrpcServiceContext ctx,
Publisher<TestRequest> request) {
return newResponse().toPublisher();
}

private Single<TestResponse> newResponse() {
return Single.succeeded(TestResponse.newBuilder().setMessage("msg").build());
}
});
}

@AfterEach
void tearDown() throws Exception {
serverCtx.close();
}

private static List<Arguments> asyncVariants() {
List<Arguments> arguments = new ArrayList<>();
for (AsyncVariant variant : AsyncVariant.values()) {
arguments.add(Arguments.of(true, variant));
// Blocking calls without metadata always create a new underlying request, there is no risk
if (!variant.name().startsWith("BLOCKING")) {
arguments.add(Arguments.of(false, variant));
}
}
return arguments;
}

@ParameterizedTest(name = "{displayName} [{index}] withMetadata={0} variant={1}")
@MethodSource("asyncVariants")
void test(boolean withMetadata, AsyncVariant variant) throws Exception {
GrpcClientMetadata metadata = withMetadata ? new DefaultGrpcClientMetadata() : null;
try (TesterClient client = GrpcClients.forAddress(serverHostAndPort(serverCtx)).build(new ClientFactory())) {
Single<TestResponse> firstSingle = newSingle(variant, client, metadata);
Future<TestResponse> first = firstSingle.toFuture();
receivedFirstRequest.await();
Future<TestResponse> firstConcurrent = firstSingle.toFuture();
Future<TestResponse> secondConcurrent = newSingle(variant, client, metadata).toFuture();

assertRejected(firstConcurrent);
if (metadata != null) {
assertRejected(secondConcurrent);
} else {
// Requests are independent when metadata is not shared between them
assertResponse(secondConcurrent);
}
responseProcessor.onComplete();
assertResponse(first);

// Sequential requests should be successful:
assertResponse(firstSingle.toFuture());
assertResponse(newSingle(variant, client, metadata).toFuture());
}
assertThat(receivedRequests.get(), is(metadata != null ? 3 : 4));
}

private static Single<TestResponse> newSingle(AsyncVariant variant, TesterClient client,
@Nullable GrpcClientMetadata metadata) {
switch (variant) {
case TEST:
return metadata == null ?
client.test(newRequest()) :
client.test(metadata, newRequest());
case TEST_REQUEST_STREAM:
return metadata == null ?
client.testRequestStream(newStreamingRequest()) :
client.testRequestStream(metadata, newStreamingRequest());
case TEST_RESPONSE_STREAM:
return (metadata == null ?
client.testResponseStream(newRequest()) :
client.testResponseStream(metadata, newRequest()))
.firstOrError();
case TEST_BI_DI_STREAM:
return (metadata == null ?
client.testBiDiStream(newStreamingRequest()) :
client.testBiDiStream(metadata, newStreamingRequest()))
.firstOrError();
case BLOCKING_TEST:
return executorExtension.executor().submit(() -> metadata == null ?
client.asBlockingClient().test(newRequest()) :
client.asBlockingClient().test(metadata, newRequest()));
case BLOCKING_TEST_REQUEST_STREAM:
return executorExtension.executor().submit(() -> metadata == null ?
client.asBlockingClient().testRequestStream(newIterableRequest()) :
client.asBlockingClient().testRequestStream(metadata, newIterableRequest()));
case BLOCKING_TEST_RESPONSE_STREAM:
return executorExtension.executor().submit(() -> (metadata == null ?
client.asBlockingClient().testResponseStream(newRequest()) :
client.asBlockingClient().testResponseStream(metadata, newRequest()))
.iterator().next());
case BLOCKING_TEST_BI_DI_STREAM:
return executorExtension.executor().submit(() -> (metadata == null ?
client.asBlockingClient().testBiDiStream(newIterableRequest()) :
client.asBlockingClient().testBiDiStream(metadata, newIterableRequest()))
.iterator().next());
default:
throw new AssertionError("Unexpected variant: " + variant);
}
}

private static TestRequest newRequest() {
return TestRequest.newBuilder().setName("foo").build();
}

private static Publisher<TestRequest> newStreamingRequest() {
return Publisher.from(newRequest());
}

private static Iterable<TestRequest> newIterableRequest() {
return Collections.singletonList(newRequest());
}

private static void assertRejected(Future<?> future) {
ExecutionException ee = assertThrows(ExecutionException.class, future::get);
assertThat(ee.getCause(), is(instanceOf(GrpcStatusException.class)));
GrpcStatusException gse = (GrpcStatusException) ee.getCause();
assertThat(gse.getCause(), is(instanceOf(RejectedSubscribeException.class)));
}

private static void assertResponse(Future<TestResponse> future) throws Exception {
assertThat(future.get().getMessage(), is("msg"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ final ContextMap context0() {
@Override
public final ContextMap context() {
if (context == null) {
// If this implementation ever changes to a concurrent one, remove external synchronization from
// FilterableClientToClient.executeRequest(...) and make it consistent with DefaultGrpcMetadata.
context = new DefaultContextMap();
}
return context;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import io.servicetalk.concurrent.api.Completable;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.internal.RejectedSubscribeException;
import io.servicetalk.context.api.ContextMap;
import io.servicetalk.http.api.BlockingHttpClient;
import io.servicetalk.http.api.BlockingStreamingHttpClient;
import io.servicetalk.http.api.FilterableStreamingHttpClient;
Expand All @@ -34,18 +36,29 @@
import io.servicetalk.http.api.ReservedStreamingHttpConnection;
import io.servicetalk.http.api.StreamingHttpClient;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpRequester;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.http.api.StreamingHttpResponseFactory;

import static io.servicetalk.context.api.ContextMap.Key.newKey;
import static io.servicetalk.http.api.HttpApiConversions.toBlockingClient;
import static io.servicetalk.http.api.HttpApiConversions.toBlockingStreamingClient;
import static io.servicetalk.http.api.HttpApiConversions.toClient;
import static io.servicetalk.http.api.HttpApiConversions.toReservedBlockingConnection;
import static io.servicetalk.http.api.HttpApiConversions.toReservedBlockingStreamingConnection;
import static io.servicetalk.http.api.HttpApiConversions.toReservedConnection;
import static io.servicetalk.http.api.HttpContextKeys.HTTP_EXECUTION_STRATEGY_KEY;
import static java.lang.Boolean.getBoolean;

final class FilterableClientToClient implements StreamingHttpClient {

// FIXME: 0.43 - remove this temporary system property
private static final boolean SKIP_CONCURRENT_REQUEST_CHECK =
getBoolean("io.servicetalk.http.netty.skipConcurrentRequestCheck");
private static final ContextMap.Key<Object> HTTP_IN_FLIGHT_REQUEST =
newKey("HTTP_IN_FLIGHT_REQUEST", Object.class);

private final Object lock = new Object();
private final FilterableStreamingHttpClient client;
private final HttpExecutionContext executionContext;

Expand All @@ -71,17 +84,14 @@ public BlockingHttpClient asBlockingClient() {

@Override
public Single<StreamingHttpResponse> request(final StreamingHttpRequest request) {
return Single.defer(() -> {
request.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, executionContext().executionStrategy());
return client.request(request).shareContextOnSubscribe();
});
return executeRequest(client, request, executionContext().executionStrategy(), lock);
}

@Override
public Single<ReservedStreamingHttpConnection> reserveConnection(final HttpRequestMetaData metaData) {
return Single.defer(() -> {
HttpExecutionStrategy clientstrategy = executionContext().executionStrategy();
metaData.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, clientstrategy);
setExecutionStrategy(metaData, clientstrategy);
return client.reserveConnection(metaData).map(rc -> new ReservedStreamingHttpConnection() {
@Override
public ReservedHttpConnection asConnection() {
Expand All @@ -108,10 +118,7 @@ public Single<StreamingHttpResponse> request(final StreamingHttpRequest request)
// Use the strategy from the client as the underlying ReservedStreamingHttpConnection may be user
// created and hence could have an incorrect default strategy. Doing this makes sure we never call
// the method without strategy just as we do for the regular connection.
return Single.defer(() -> {
request.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, clientstrategy);
return rc.request(request).shareContextOnSubscribe();
});
return executeRequest(rc, request, clientstrategy, lock);
}

@Override
Expand Down Expand Up @@ -196,4 +203,64 @@ public Completable closeAsyncGracefully() {
public StreamingHttpRequest newRequest(final HttpRequestMethod method, final String requestTarget) {
return client.newRequest(method, requestTarget);
}

private static Single<StreamingHttpResponse> executeRequest(final StreamingHttpRequester requester,
final StreamingHttpRequest request,
final HttpExecutionStrategy strategy,
final Object lock) {
return Single.defer(() -> {
if (SKIP_CONCURRENT_REQUEST_CHECK) {
return setStrategyAndExecute(requester, request, strategy).shareContextOnSubscribe();
}

// Prevent concurrent execution of the same request through the same layer.
// In general, we do not expect users to execute the same mutable request concurrently. Therefore, the cost
// of synchronized block should be negligible for most requests, unless they messed up with reactive streams
// chain and accidentally subscribed to the same request concurrently. This protection helps them avoid
// ambiguous runtime behavior caused by a corrupted mutable request state.
final Object inFlight;
// Note that because request.context() may lazily allocate a new ContextMap, there is a risk that
// synchronization will happen on two different contexts. However, this is acceptable compromise because:
// - Most likely users subscribe 2+ times to the same request from the same thread.
// - This is the best effort protection, giving users at least one rejection should be enough to let them
// know their code is incorrect and should be rewritten.
final ContextMap context = request.context();
synchronized (context) {
// We do not override lock because other layers may already set their own one.
inFlight = context.putIfAbsent(HTTP_IN_FLIGHT_REQUEST, lock);
}
if (lock.equals(inFlight)) {
return Single.<StreamingHttpResponse>failed(new RejectedSubscribeException(
"Concurrent execution is detected for the same mutable request. Only a single execution is " +
"allowed at any point of time. Otherwise, request data structures can be corrupted. " +
"To avoid this error, create a new request for every execution or wrap every request " +
"creation with Single.defer() operator."))
.shareContextOnSubscribe();
}

Single<StreamingHttpResponse> response = setStrategyAndExecute(requester, request, strategy);
if (inFlight == null) {
// Remove only if we are the one who set the lock.
response = response.beforeFinally(() -> {
synchronized (context) {
final Object removedLock = context.remove(HTTP_IN_FLIGHT_REQUEST);
assert removedLock == lock;
}
});
}
return response.shareContextOnSubscribe();
});
}

private static Single<StreamingHttpResponse> setStrategyAndExecute(final StreamingHttpRequester requester,
final StreamingHttpRequest request,
final HttpExecutionStrategy strategy) {
setExecutionStrategy(request, strategy);
return requester.request(request);
}

private static void setExecutionStrategy(final HttpRequestMetaData request, final HttpExecutionStrategy strategy) {
// We do not override HttpExecutionStrategy because users may prefer to use their own.
request.context().putIfAbsent(HTTP_EXECUTION_STRATEGY_KEY, strategy);
}
}
Loading

0 comments on commit 7a7341a

Please sign in to comment.