diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java index 428a2902cc..dfa9b89601 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java @@ -60,6 +60,7 @@ import io.servicetalk.transport.api.HostAndPort; import io.servicetalk.transport.api.ServerContext; +import io.netty.bootstrap.Bootstrap; import io.netty.bootstrap.ServerBootstrap; import io.netty.channel.Channel; import io.netty.channel.ChannelDuplexHandler; @@ -69,24 +70,32 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; import io.netty.handler.codec.http.HttpHeaderNames; import io.netty.handler.codec.http2.DefaultHttp2DataFrame; +import io.netty.handler.codec.http2.DefaultHttp2GoAwayFrame; import io.netty.handler.codec.http2.DefaultHttp2Headers; import io.netty.handler.codec.http2.DefaultHttp2HeadersFrame; import io.netty.handler.codec.http2.DefaultHttp2ResetFrame; import io.netty.handler.codec.http2.DefaultHttp2SettingsFrame; import io.netty.handler.codec.http2.Http2DataFrame; +import io.netty.handler.codec.http2.Http2Frame; import io.netty.handler.codec.http2.Http2FrameCodecBuilder; import io.netty.handler.codec.http2.Http2Headers; import io.netty.handler.codec.http2.Http2HeadersFrame; import io.netty.handler.codec.http2.Http2MultiplexHandler; +import io.netty.handler.codec.http2.Http2ResetFrame; import io.netty.handler.codec.http2.Http2Settings; import io.netty.handler.codec.http2.Http2SettingsAckFrame; +import io.netty.handler.codec.http2.Http2StreamChannel; +import io.netty.handler.codec.http2.Http2StreamChannelBootstrap; +import io.netty.handler.codec.http2.Http2StreamFrame; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.CsvSource; import org.junit.jupiter.params.provider.MethodSource; import java.io.IOException; @@ -101,6 +110,7 @@ import java.util.concurrent.BlockingQueue; import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingDeque; import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -151,16 +161,19 @@ import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; import static io.servicetalk.transport.netty.internal.BuilderUtils.serverChannel; +import static io.servicetalk.transport.netty.internal.BuilderUtils.socketChannel; import static io.servicetalk.transport.netty.internal.NettyIoExecutors.createIoExecutor; import static java.lang.String.valueOf; import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.function.UnaryOperator.identity; import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.emptyIterable; import static org.hamcrest.Matchers.emptyString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.hasItems; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; @@ -1330,6 +1343,94 @@ private static T addProhibitedHeaders(T metaData) { return metaData; } + @ParameterizedTest(name = "{displayName} [{index}] headerName={0}, headerValue={1}") + @CsvSource({"connection,upgrade", + "keep-alive,timeout=5", + "transfer-encoding,chunked", + "upgrade,foo/2", + "proxy-connection,close"}) + void h2FailsRequestsWithMalformedHeaders(String headerName, String headerValue) throws Exception { + setUp(DEFAULT, true); + try (ServerContext serverContext = HttpServers.forAddress(localAddress(0)) + .protocols(HttpProtocol.HTTP_2.config) + .enableWireLogging("servicetalk-tests-wire-logger", LogLevel.TRACE, () -> true) + .listenBlockingAndAwait((ctx, request, responseFactory) -> responseFactory.ok())) { + + Channel channel = null; + Http2StreamChannel stream = null; + try { + Bootstrap b = new Bootstrap(); + b.group(serverEventLoopGroup); + b.channel(socketChannel(serverEventLoopGroup, InetSocketAddress.class)); + b.remoteAddress(serverContext.listenAddress()); + b.handler(new ChannelInitializer() { + @Override + protected void initChannel(final Channel ch) { + Http2FrameCodecBuilder builder = Http2FrameCodecBuilder.forClient(); + builder.initialSettings().pushEnabled(false).maxConcurrentStreams(0L); + ch.pipeline().addLast(builder.build(), + new Http2MultiplexHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2GoAwayFrame(PROTOCOL_ERROR)); + } + }), + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Http2Frame msg) { + // ignore all frames on the parent channel + } + }); + } + }); + channel = b.connect().sync().channel(); + + BlockingQueue frames = new LinkedBlockingDeque<>(); + Http2StreamChannelBootstrap bs = new Http2StreamChannelBootstrap(channel); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(final Http2StreamChannel ch) { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Http2StreamFrame frame) { + frames.add(frame); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) evt); + } + ctx.fireUserEventTriggered(evt); + } + }); + } + }); + stream = bs.open().sync().get(); + + Http2Headers headers = new DefaultHttp2Headers() + .method("POST") + .path("/") + .scheme("http") + .authority("localhost") + .add(headerName, headerValue); + stream.writeAndFlush(new DefaultHttp2HeadersFrame(headers)).sync(); + + Http2StreamFrame resetFrame = frames.take(); + assertThat(resetFrame, instanceOf(Http2ResetFrame.class)); + assertThat(((Http2ResetFrame) resetFrame).errorCode(), is(PROTOCOL_ERROR.code())); + assertThat("Received unexpected frames", frames, empty()); + } finally { + if (stream != null) { + stream.close().await(); + } + if (channel != null) { + channel.close().await(); + } + } + } + } + @ParameterizedTest(name = "{displayName} [{index}] client={0}, h2PriorKnowledge={1}") @MethodSource("clientExecutors") void clientRespectsSettingsFrame(HttpTestExecutionStrategy strategy,