diff --git a/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java b/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java index 96450e0008a..99c012e2a4d 100644 --- a/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java +++ b/src/main/java/com/linecorp/armeria/server/HttpServerHandler.java @@ -126,32 +126,38 @@ final class HttpServerHandler extends ChannelInboundHandlerAdapter { @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { - isReading = true; + isReading = true; // Cleared in channelReadComplete() if (msg instanceof Http2Settings) { - logger.debug("{} HTTP/2 settings: {}", ctx.channel(), msg); + handleHttp2Settings(ctx, (Http2Settings) msg); + } else { + handleRequest(ctx, (FullHttpRequest) msg); + } + } - useHeadOfLineBlocking = false; - switch (sessionProtocol) { - case H1: - sessionProtocol = SessionProtocol.H2; - break; - case H1C: - sessionProtocol = SessionProtocol.H2C; - break; - } + private void handleHttp2Settings(ChannelHandlerContext ctx, Http2Settings h2settings) { + logger.debug("{} HTTP/2 settings: {}", ctx.channel(), h2settings); + + useHeadOfLineBlocking = false; + switch (sessionProtocol) { + case H1: + sessionProtocol = SessionProtocol.H2; + break; + case H1C: + sessionProtocol = SessionProtocol.H2C; + break; + } + } + + private void handleRequest(ChannelHandlerContext ctx, FullHttpRequest req) throws Exception { + // Ignore the request received after the last request, + // because we are going to close the connection after sending the last response. + if (handledLastRequest) { return; } - final FullHttpRequest req = (FullHttpRequest) msg; boolean invoked = false; try { - // Ignore the request received after the last request, - // because we are going to close the connection after sending the last response. - if (handledLastRequest) { - return; - } - // If we received the message with keep-alive disabled, // we should not accept a request anymore. if (!HttpUtil.isKeepAlive(req)) { @@ -177,11 +183,12 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception // Find the service that matches the path. final PathMapped mapped = host.findServiceConfig(path); if (!mapped.isPresent()) { - respond(ctx, reqSeq, req, HttpResponseStatus.NOT_FOUND); + // No services matched the path. + handleNonExistentMapping(ctx, reqSeq, req, host, path); return; } - // Decode the request and create a new invocation context from it. + // Decode the request and create a new invocation context from it to perform an invocation. final String mappedPath = mapped.mappedPath(); final ServiceConfig serviceCfg = mapped.value(); final Service service = serviceCfg.service(); @@ -193,55 +200,22 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception switch (decodeResult.type()) { case SUCCESS: { + // A successful decode; perform the invocation. final ServiceInvocationContext iCtx = decodeResult.invocationContext(); - final long timeoutMillis = config.requestTimeoutPolicy().timeout(iCtx); - - // Perform the actual invocation. invoke(iCtx, service.handler(), promise); invoked = true; - if (promise.isDone()) { - // If the invocation has been finished immediately, - // there's no need to schedule a timeout nor to add a listener to the promise. - handleInvocationResult(ctx, reqSeq, req, serviceCfg, iCtx, codec, promise, null); - } else { - final ScheduledFuture timeoutFuture; - if (timeoutMillis > 0) { - timeoutFuture = ctx.executor().schedule( - () -> promise.tryFailure(new RequestTimeoutException( - "request timed out after " + timeoutMillis + "ms: " + iCtx)), - timeoutMillis, TimeUnit.MILLISECONDS); - } else { - timeoutFuture = null; - } - - promise.addListener((Future future) -> { - try { - handleInvocationResult(ctx, reqSeq, req, serviceCfg, iCtx, codec, future, timeoutFuture); - } catch (Exception e) { - respond(ctx, reqSeq, req, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); - } - }); - } + // Do the post-invocation tasks such as scheduling a timeout. + handleInvocationPromise(ctx, reqSeq, req, codec, iCtx, promise); break; } case FAILURE: { // Could not create an invocation context. - final Object errorResponse = decodeResult.errorResponse(); - if (errorResponse instanceof FullHttpResponse) { - FullHttpResponse httpResponse = (FullHttpResponse) errorResponse; - promise.tryFailure(new RequestDecodeException( - decodeResult.cause(), httpResponse.content().readableBytes())); - respond(ctx, reqSeq, req, (FullHttpResponse) errorResponse); - } else { - ReferenceCountUtil.safeRelease(errorResponse); - promise.tryFailure(new RequestDecodeException(decodeResult.cause(), 0)); - respond(ctx, reqSeq, req, HttpResponseStatus.BAD_REQUEST, decodeResult.cause()); - } + handleDecodeFailure(ctx, reqSeq, req, decodeResult, promise); break; } case NOT_FOUND: - // Turned out that the request wasn't accepted by the service. + // Turned out that the request wasn't accepted by the matching service. promise.tryFailure(SERVICE_NOT_FOUND); respond(ctx, reqSeq, req, HttpResponseStatus.NOT_FOUND); break; @@ -255,6 +229,27 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception } } + private void handleNonExistentMapping(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, + VirtualHost host, String path) { + + if (path.charAt(path.length() - 1) != '/') { + // Handle the case where /path doesn't exist but /path/ exists. + final String pathWithSlash = path + '/'; + if (host.findServiceConfig(pathWithSlash).isPresent()) { + final String location; + if (path.length() == req.uri().length()) { + location = pathWithSlash; + } else { + location = pathWithSlash + req.uri().substring(path.length()); + } + redirect(ctx, reqSeq, req, location); + return; + } + } + + respond(ctx, reqSeq, req, HttpResponseStatus.NOT_FOUND); + } + private void invoke(ServiceInvocationContext iCtx, ServiceInvocationHandler handler, Promise promise) { @@ -270,28 +265,38 @@ private void invoke(ServiceInvocationContext iCtx, ServiceInvocationHandler hand } } - private static String hostname(FullHttpRequest req) { - final String hostname = req.headers().getAsString(HttpHeaderNames.HOST); - if (hostname == null) { - return ""; - } + private void handleInvocationPromise(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, + ServiceCodec codec, ServiceInvocationContext iCtx, + Promise promise) throws Exception { + if (promise.isDone()) { + // If the invocation has been finished immediately, + // there's no need to schedule a timeout nor to add a listener to the promise. + handleInvocationResult(ctx, reqSeq, req, iCtx, codec, promise, null); + } else { + final long timeoutMillis = config.requestTimeoutPolicy().timeout(iCtx); + final ScheduledFuture timeoutFuture; + if (timeoutMillis > 0) { + timeoutFuture = ctx.executor().schedule( + () -> promise.tryFailure(new RequestTimeoutException( + "request timed out after " + timeoutMillis + "ms: " + iCtx)), + timeoutMillis, TimeUnit.MILLISECONDS); + } else { + timeoutFuture = null; + } - final int hostnameColonIdx = hostname.lastIndexOf(':'); - if (hostnameColonIdx < 0) { - return hostname; + promise.addListener((Future future) -> { + try { + handleInvocationResult(ctx, reqSeq, req, iCtx, codec, future, timeoutFuture); + } catch (Exception e) { + respond(ctx, reqSeq, req, HttpResponseStatus.INTERNAL_SERVER_ERROR, e); + } + }); } - - return hostname.substring(0, hostnameColonIdx); - } - - private static String stripQuery(String uri) { - final int queryStart = uri.indexOf('?'); - return queryStart < 0 ? uri : uri.substring(0, queryStart); } private void handleInvocationResult( ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, - ServiceConfig sCfg, ServiceInvocationContext iCtx, ServiceCodec codec, Future future, + ServiceInvocationContext iCtx, ServiceCodec codec, Future future, ScheduledFuture timeoutFuture) throws Exception { // Release the original request which was retained before the invocation. @@ -321,6 +326,40 @@ private void handleInvocationResult( } } + private void handleDecodeFailure(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, + DecodeResult decodeResult, Promise promise) { + final Object errorResponse = decodeResult.errorResponse(); + if (errorResponse instanceof FullHttpResponse) { + FullHttpResponse httpResponse = (FullHttpResponse) errorResponse; + promise.tryFailure(new RequestDecodeException( + decodeResult.cause(), httpResponse.content().readableBytes())); + respond(ctx, reqSeq, req, (FullHttpResponse) errorResponse); + } else { + ReferenceCountUtil.safeRelease(errorResponse); + promise.tryFailure(new RequestDecodeException(decodeResult.cause(), 0)); + respond(ctx, reqSeq, req, HttpResponseStatus.BAD_REQUEST, decodeResult.cause()); + } + } + + private static String hostname(FullHttpRequest req) { + final String hostname = req.headers().getAsString(HttpHeaderNames.HOST); + if (hostname == null) { + return ""; + } + + final int hostnameColonIdx = hostname.lastIndexOf(':'); + if (hostnameColonIdx < 0) { + return hostname; + } + + return hostname.substring(0, hostnameColonIdx); + } + + private static String stripQuery(String uri) { + final int queryStart = uri.indexOf('?'); + return queryStart < 0 ? uri : uri.substring(0, queryStart); + } + private static HttpResponseStatus toHttpResponseStatus(Throwable cause) { if (cause instanceof RequestTimeoutException) { return HttpResponseStatus.SERVICE_UNAVAILABLE; @@ -342,7 +381,9 @@ private void respond(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, respond(ctx, reqSeq, req, new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, content)); } - private void respond(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, HttpResponseStatus status) { + private void respond(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, + HttpResponseStatus status) { + if (status.code() < 400) { respond(ctx, reqSeq, req, status, Unpooled.EMPTY_BUFFER); } else { @@ -367,6 +408,14 @@ private void respond(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, respond(ctx, reqSeq, req, res); } + private void redirect(ChannelHandlerContext ctx, int reqSeq, FullHttpRequest req, String location) { + final DefaultFullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.TEMPORARY_REDIRECT, Unpooled.EMPTY_BUFFER); + res.headers().set(HttpHeaderNames.LOCATION, location); + + respond(ctx, reqSeq, req, res); + } + private static String errorMessage(HttpResponseStatus status) { String reasonPhrase = status.reasonPhrase(); StringBuilder buf = new StringBuilder(reasonPhrase.length() + 4);