diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ToStH1Utils.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ToStH1Utils.java index 9b8aaa4874..d4e050b139 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ToStH1Utils.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/H2ToStH1Utils.java @@ -159,16 +159,20 @@ static Http2Headers h1HeadersToH2Headers(HttpHeaders h1Headers) { Iterator connectionItr = h1Headers.valuesIterator(CONNECTION); if (connectionItr.hasNext()) { do { - String connectionHeader = connectionItr.next().toString(); + CharSequence connectionHeader = connectionItr.next(); connectionItr.remove(); - int i = connectionHeader.indexOf(','); + int i = indexOf(connectionHeader, ',', 0); if (i != -1) { int start = 0; do { - h1Headers.remove(connectionHeader.substring(start, i)); + h1Headers.remove(connectionHeader.subSequence(start, i)); start = i + 1; - } while (start < connectionHeader.length() && (i = connectionHeader.indexOf(',', start)) != -1); - h1Headers.remove(connectionHeader.substring(start)); + // Skip OWS + if (start < connectionHeader.length() && connectionHeader.charAt(start) == ' ') { + ++start; + } + } while (start < connectionHeader.length() && (i = indexOf(connectionHeader, ',', start)) != -1); + h1Headers.remove(connectionHeader.subSequence(start, connectionHeader.length())); } else { h1Headers.remove(connectionHeader); } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java index c12c230050..23eb50ef1e 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/H2PriorKnowledgeFeatureParityTest.java @@ -197,9 +197,14 @@ import static org.junit.jupiter.api.Assumptions.assumeTrue; class H2PriorKnowledgeFeatureParityTest { - private static final CharSequence[] PROHIBITED_HEADERS = {CONNECTION, KEEP_ALIVE, TRANSFER_ENCODING, UPGRADE, PROXY_CONNECTION}; + private static final String CONNECTION_HEADER1 = "conn1"; + private static final String CONNECTION_HEADER2 = "conn2"; + private static final String CONNECTION_HEADER3 = "conn3"; + private static final String CONNECTION_HEADER4 = "conn4"; + private static final CharSequence[] CONNECTION_HEADERS = {CONNECTION_HEADER1, CONNECTION_HEADER2, + CONNECTION_HEADER3, CONNECTION_HEADER4}; private static final String EXPECT_FAIL_HEADER = "please_fail_expect"; private static final ContextMap.Key K1 = newKey("k1", String.class); private static final ContextMap.Key K2 = newKey("k2", String.class); @@ -1439,7 +1444,7 @@ protected void initChannel(final Channel ch) { }, __ -> { }, identity()); InetSocketAddress serverAddress = (InetSocketAddress) serverAcceptorChannel.localAddress(); try (BlockingHttpClient client = forSingleAddress(HostAndPort.of(serverAddress)) - .protocols(HttpProtocol.HTTP_2.config) + .protocols(HttpProtocol.HTTP_2.configOtherHeaderFactory) .enableWireLogging("servicetalk-tests-wire-logger", LogLevel.TRACE, () -> true) .executionStrategy(clientExecutionStrategy) .buildBlocking()) { @@ -1453,7 +1458,7 @@ protected void initChannel(final Channel ch) { void h2LayerFiltersOutProhibitedH1HeadersOnServerSide() throws Exception { setUp(DEFAULT, true); try (ServerContext serverContext = HttpServers.forAddress(localAddress(0)) - .protocols(HttpProtocol.HTTP_2.config) + .protocols(HttpProtocol.HTTP_2.configOtherHeaderFactory) .enableWireLogging("servicetalk-tests-wire-logger", LogLevel.TRACE, () -> true) .listenBlockingAndAwait((ctx, request, responseFactory) -> addProhibitedHeaders(responseFactory.ok())); BlockingHttpClient client = forSingleAddress(serverHostAndPort(serverContext)) @@ -1467,11 +1472,21 @@ void h2LayerFiltersOutProhibitedH1HeadersOnServerSide() throws Exception { assertThat("Unexpected headerName: " + headerName, response.headers().contains(headerName), is(false)); } + for (CharSequence headerName : CONNECTION_HEADERS) { + assertThat("Unexpected headerName: " + headerName, + response.headers().contains(headerName), is(false)); + } } } private static T addProhibitedHeaders(T metaData) { metaData.addHeader(CONNECTION, UPGRADE) + .addHeader(CONNECTION, CONNECTION_HEADER1 + "," + CONNECTION_HEADER2) + .addHeader(CONNECTION, CONNECTION_HEADER3 + ", " + CONNECTION_HEADER4) + .addHeader(CONNECTION_HEADER1, "foo") + .addHeader(CONNECTION_HEADER2, "bar") + .addHeader(CONNECTION_HEADER3, "baz") + .addHeader(CONNECTION_HEADER4, "boo") .addHeader(KEEP_ALIVE, "timeout=5") .addHeader(TRANSFER_ENCODING, CHUNKED) .addHeader(UPGRADE, "foo/2") @@ -2072,7 +2087,17 @@ private static boolean allHeadersSanitized(Http2Headers headers) { return !headers.contains(HttpHeaderNames.CONNECTION) && !headers.contains(HttpHeaderNames.KEEP_ALIVE) && !headers.contains(HttpHeaderNames.TRANSFER_ENCODING) && !headers.contains(HttpHeaderNames.UPGRADE) - && !headers.contains(HttpHeaderNames.PROXY_CONNECTION); + && !headers.contains(HttpHeaderNames.PROXY_CONNECTION) && + allConnHeadersSanitized(headers); + } + + private static boolean allConnHeadersSanitized(Http2Headers headers) { + for (CharSequence headerName : CONNECTION_HEADERS) { + if (headers.contains(headerName)) { + return false; + } + } + return true; } } diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpProtocol.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpProtocol.java index d31e1b5558..4e8556fe8f 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpProtocol.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/HttpProtocol.java @@ -21,21 +21,25 @@ import java.util.Arrays; import java.util.Collection; +import static io.servicetalk.http.api.DefaultHttpHeadersFactory.INSTANCE; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_2_0; +import static io.servicetalk.http.netty.HttpProtocolConfigs.h1; import static io.servicetalk.http.netty.HttpProtocolConfigs.h1Default; import static io.servicetalk.http.netty.HttpProtocolConfigs.h2; import static io.servicetalk.logging.api.LogLevel.TRACE; enum HttpProtocol { - HTTP_1(h1Default(), HTTP_1_1), - HTTP_2(h2().enableFrameLogging("servicetalk-tests-h2-frame-logger", TRACE, () -> true).build(), HTTP_2_0); + HTTP_1(h1Default(), h1().headersFactory(new H2HeadersFactory(true, true, false)).build(), HTTP_1_1), + HTTP_2(applyFrameLogger(h2()).build(), applyFrameLogger(h2()).headersFactory(INSTANCE).build(), HTTP_2_0); + final HttpProtocolConfig configOtherHeaderFactory; final HttpProtocolConfig config; final HttpProtocolVersion version; - HttpProtocol(HttpProtocolConfig config, HttpProtocolVersion version) { + HttpProtocol(HttpProtocolConfig config, HttpProtocolConfig configOtherHeaderFactory, HttpProtocolVersion version) { this.config = config; + this.configOtherHeaderFactory = configOtherHeaderFactory; this.version = version; } @@ -46,4 +50,8 @@ static HttpProtocolConfig[] toConfigs(Collection protocols) { static HttpProtocolConfig[] toConfigs(HttpProtocol[] protocols) { return Arrays.stream(protocols).map(p -> p.config).toArray(HttpProtocolConfig[]::new); } + + private static H2ProtocolConfigBuilder applyFrameLogger(H2ProtocolConfigBuilder builder) { + return builder.enableFrameLogging("servicetalk-tests-h2-frame-logger", TRACE, () -> true); + } }