diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index 837a7ab43..1b6e90866 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -14,8 +14,14 @@ package wvlet.airframe.http.netty import io.netty.buffer.Unpooled -import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.channel.{ChannelFuture, ChannelFutureListener, ChannelHandlerContext, SimpleChannelInboundHandler} import io.netty.handler.codec.http.* +import io.netty.handler.codec.http.websocketx.{ + WebSocketFrameAggregator, + WebSocketServerHandshaker, + WebSocketServerHandshakerFactory +} +import io.netty.util.concurrent.EventExecutorGroup import wvlet.airframe.http.HttpMessage.{Request, Response} import wvlet.airframe.http.internal.{HttpLogs, RPCResponseFilter} import wvlet.airframe.http.{ @@ -27,6 +33,7 @@ import wvlet.airframe.http.{ HttpStatus, RPCException, RPCStatus, + RxHttpFilter, ServerAddress, ServerSentEvent } @@ -38,12 +45,20 @@ import java.util.concurrent.{SynchronousQueue, ThreadFactory, ThreadPoolExecutor import java.util.concurrent.atomic.AtomicInteger import scala.collection.immutable.ListMap import scala.jdk.CollectionConverters.* +import scala.util.control.NonFatal import NettyRequestHandler.toNettyResponse import java.io.ByteArrayOutputStream -class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Filter, httpStreamLogger: HttpLogger) - extends SimpleChannelInboundHandler[FullHttpRequest] +class NettyRequestHandler( + config: NettyServerConfig, + dispatcher: NettyBackend.Filter, + httpStreamLogger: HttpLogger, + // Filter applied to the WebSocket upgrade request (e.g. installs the RPC context) before the route's own filter + wsUpgradeFilter: RxHttpFilter, + // Executor for offloading WebSocket frame callbacks off the event loop, when handlerExecutorThreads is configured + wsHandlerExecutor: Option[EventExecutorGroup] +) extends SimpleChannelInboundHandler[FullHttpRequest] with LogSupport { override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { @@ -62,135 +77,235 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpRequest): Unit = { try { - var req: wvlet.airframe.http.HttpMessage.Request = msg.method().name().toUpperCase match { - case HttpMethod.GET => Http.GET(msg.uri()) - case HttpMethod.POST => Http.POST(msg.uri()) - case HttpMethod.PUT => Http.PUT(msg.uri()) - case HttpMethod.DELETE => Http.DELETE(msg.uri()) - case HttpMethod.PATCH => Http.PATCH(msg.uri()) - case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, msg.uri()) - case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, msg.uri()) - case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, msg.uri()) - case _ => - throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${msg.method()}") + val req = NettyRequestHandler.toAirframeRequest(ctx, msg) + webSocketRouteFor(req, msg) match { + case Some(route) => + handleWebSocketUpgrade(ctx, msg, req, route) + case None => + dispatchHttp(ctx, msg, req) } + } catch { + case e: RPCException => + writeResponse(msg, ctx, toNettyResponse(e.toResponse)) + } finally { + // Need to clean up the TLS in case the same thread is reused for the next request + NettyBackend.clearThreadLocal() + } + } - // Set remote address for logging purpose - ctx.channel().remoteAddress() match { - case x: InetSocketAddress => - // TODO This address might be IPv6 - req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}")) - case _ => - } + /** + * Find a registered WebSocket route matching this request, only if it is a WebSocket upgrade request. + */ + private def webSocketRouteFor(req: Request, msg: FullHttpRequest): Option[WebSocketRoute] = { + if (config.webSocketRoutes.isEmpty || !NettyRequestHandler.isWebSocketUpgrade(msg)) { + None + } else { + config.webSocketRoutes.find(_.path == req.path) + } + } - // Read request headers - msg.headers().names().asScala.map { x => - req = req.withHeader(x, msg.headers().get(x)) + /** + * Run the upgrade request through the route's filter chain and, if it is allowed, perform the WebSocket handshake. + * The filter chain lets auth/logging/metrics filters apply to the handshake; a non-2xx response rejects the upgrade. + */ + private def handleWebSocketUpgrade( + ctx: ChannelHandlerContext, + msg: FullHttpRequest, + req: Request, + route: WebSocketRoute + ): Unit = { + // Build the filter chain before retaining the request, so a synchronous failure here does not leak the buffer. + // The chain terminates in a 200 marker that signals "upgrade allowed". + val filtered: Rx[Response] = + try { + wsUpgradeFilter + .andThen(route.filter).andThen { (_: Request) => Rx.single(Http.response(HttpStatus.Ok_200)) } + .apply(req) + } catch { + case NonFatal(ex) => + writeResponse( + msg, + ctx, + toNettyResponse(RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse) + ) + return } - // Read request body - var bodyBuf: ByteArrayOutputStream = null - val requestBody = msg.content() - while (requestBody.isReadable) { - // the returned size is greater than 0 when isReadable = true - val size = requestBody.readableBytes() - if (bodyBuf == null) { - bodyBuf = new ByteArrayOutputStream(size) + // The handshake may run later (async filter and/or event loop hop), so keep the request alive + msg.retain() + // Act on the first terminal outcome only: guards against a filter that emits multiple responses (double + // handshake/release) or completes without emitting one (leaking the retained request) + val handled = new java.util.concurrent.atomic.AtomicBoolean(false) + RxRunner.run(filtered) { + case OnNext(v) => + if (handled.compareAndSet(false, true)) { + val resp = v.asInstanceOf[Response] + if (resp.status.isSuccessful) { + doWebSocketHandshake(ctx, msg, req, route) + } else { + // A filter rejected the upgrade (e.g. failed auth); return its response without upgrading + writeResponse(msg, ctx, toNettyResponse(resp)) + msg.release() + } } - requestBody.readBytes(bodyBuf, size) - } - if (bodyBuf != null && bodyBuf.size() > 0) { - req = req.withContent(bodyBuf.toByteArray) - } + case OnError(ex) => + if (handled.compareAndSet(false, true)) { + writeResponse( + msg, + ctx, + toNettyResponse(RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse) + ) + msg.release() + } + case OnCompletion => + // The filter chain completed without emitting a response: release the retained request and close the + // connection so the client does not hang waiting for a handshake response that will never come. + if (handled.compareAndSet(false, true)) { + msg.release() + ctx.close() + } + } + } - // Dispatch the request and get an async response, Rx[Response] - val rxResponse: Rx[Response] = dispatcher.apply( - req, - NettyBackend.newContext { (request: Request) => - Rx.single(Http.response(HttpStatus.NotFound_404)) + /** + * Perform the Netty WebSocket handshake on the channel's event loop and wire the connection to a fresh + * [[NettyWebSocketHandler]]. Pipeline mutation must happen on the event loop thread. + */ + private def doWebSocketHandshake( + ctx: ChannelHandlerContext, + msg: FullHttpRequest, + req: Request, + route: WebSocketRoute + ): Unit = { + ctx.channel().eventLoop().execute { () => + try { + val location = NettyRequestHandler.webSocketLocation(ctx, msg) + val wsFactory = new WebSocketServerHandshakerFactory(location, null, true, config.webSocketMaxFrameSize) + val handshaker: WebSocketServerHandshaker = wsFactory.newHandshaker(msg) + if (handshaker == null) { + WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()) + } else { + val wsContext = new NettyWebSocketContext(ctx.channel(), req, handshaker) + val userHandler = route.handlerFactory(req) + val wsHandler = new NettyWebSocketHandler(userHandler, wsContext) + val future = handshaker.handshake(ctx.channel(), msg) + // The handshake reconfigures the pipeline with WebSocket codecs and drops the HTTP aggregator/codec. + // Add a frame aggregator so fragmented (continuation) frames are coalesced into whole text/binary messages, + // then install the frame handler and remove this HTTP request handler. Offload frame callbacks to the + // handler executor when configured, so blocking user callbacks do not stall the event loop (as for HTTP). + val pipeline = ctx.pipeline() + pipeline.addLast("wsAggregator", new WebSocketFrameAggregator(config.webSocketMaxFrameSize)) + wsHandlerExecutor match { + case Some(executor) => pipeline.addLast(executor, "wsHandler", wsHandler) + case None => pipeline.addLast("wsHandler", wsHandler) + } + pipeline.remove(NettyRequestHandler.this) + // Notify onOpen so it always precedes any inbound frame. Run it on the frame handler's own executor (the + // handler executor when configured, else the event loop); Netty serializes tasks per handler, so onOpen is + // enqueued before any frame dispatch and never blocks the event loop. Outbound writes from onOpen are still + // queued after the handshake response, so frame ordering on the wire is preserved. + val handlerExecutor = pipeline.context("wsHandler").executor() + handlerExecutor.execute { () => wsHandler.notifyOpen() } + future.addListener { (f: ChannelFuture) => + // A failed handshake-response write is typically a benign client disconnect; close the channel and let + // channelInactive deliver onClose rather than surfacing a noisy onError. + if (!f.isSuccess) { + ctx.close() + } + } } - ) + } catch { + case NonFatal(e) => + warn("Failed to perform WebSocket handshake", e) + ctx.close() + } finally { + msg.release() + } + } + } - RxRunner.run(rxResponse) { - case OnNext(v) => - val resp = v.asInstanceOf[Response] - val nettyResponse = toNettyResponse(resp) - writeResponse(msg, ctx, nettyResponse) + private def dispatchHttp(ctx: ChannelHandlerContext, msg: FullHttpRequest, req: Request): Unit = { + // Dispatch the request and get an async response, Rx[Response] + val rxResponse: Rx[Response] = dispatcher.apply( + req, + NettyBackend.newContext { (request: Request) => + Rx.single(Http.response(HttpStatus.NotFound_404)) + } + ) - if (resp.isContentTypeEventStream && resp.message.isEmpty) { - // Capture request context and timing before handing off to SSE executor thread - val streamStartTime = System.currentTimeMillis() - val streamStartNano = System.nanoTime() - val requestMethod = req.method.toString - val requestPath = req.path - val requestUri = req.uri - val remoteAddr = req.remoteAddress.map(_.hostAndPort) - val responseStatus = resp.status - val eventCounter = new AtomicInteger(0) + RxRunner.run(rxResponse) { + case OnNext(v) => + val resp = v.asInstanceOf[Response] + val nettyResponse = toNettyResponse(resp) + writeResponse(msg, ctx, nettyResponse) - def writeStreamLog(status: HttpStatus, error: Option[Throwable]): Unit = { - val m = ListMap.newBuilder[String, Any] - m ++= HttpLogs.unixTimeLogs(streamStartTime) - m += "method" -> requestMethod - m += "path" -> requestPath - m += "uri" -> requestUri - remoteAddr.foreach(a => m += "remote_address" -> a) - m ++= HttpLogs.durationLogs(streamStartTime, streamStartNano) - m += "event_count" -> eventCounter.get() - m += "status_code" -> status.code - m += "status_code_name" -> status.reason - error.foreach { e => - m += "error_message" -> e.getMessage - m += "exception" -> e - } - httpStreamLogger.write(m.result()) + if (resp.isContentTypeEventStream && resp.message.isEmpty) { + // Capture request context and timing before handing off to SSE executor thread + val streamStartTime = System.currentTimeMillis() + val streamStartNano = System.nanoTime() + val requestMethod = req.method.toString + val requestPath = req.path + val requestUri = req.uri + val remoteAddr = req.remoteAddress.map(_.hostAndPort) + val responseStatus = resp.status + val eventCounter = new AtomicInteger(0) + + def writeStreamLog(status: HttpStatus, error: Option[Throwable]): Unit = { + val m = ListMap.newBuilder[String, Any] + m ++= HttpLogs.unixTimeLogs(streamStartTime) + m += "method" -> requestMethod + m += "path" -> requestPath + m += "uri" -> requestUri + remoteAddr.foreach(a => m += "remote_address" -> a) + m ++= HttpLogs.durationLogs(streamStartTime, streamStartNano) + m += "event_count" -> eventCounter.get() + m += "status_code" -> status.code + m += "status_code_name" -> status.reason + error.foreach { e => + m += "error_message" -> e.getMessage + m += "exception" -> e } + httpStreamLogger.write(m.result()) + } - // Run SSE stream consumption in a separate thread to avoid blocking the Netty worker. - // ctx.writeAndFlush() is thread-safe in Netty and can be called from any thread. - try { - NettyRequestHandler.sseExecutor.execute { () => - RxRunner.run(resp.events) { - case OnNext(e: ServerSentEvent) => - eventCounter.incrementAndGet() - val event = e.toContent - val buf = Unpooled.copiedBuffer(event.getBytes("UTF-8")) - ctx.writeAndFlush(new DefaultHttpContent(buf)) - case OnError(e) => - writeStreamLog(HttpStatus.InternalServerError_500, Some(e)) - if (ctx.channel().isActive) { - ctx - .writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) - .addListener(ChannelFutureListener.CLOSE) - } - case _ => - writeStreamLog(responseStatus, None) - val f = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) - f.addListener(ChannelFutureListener.CLOSE) - } + // Run SSE stream consumption in a separate thread to avoid blocking the Netty worker. + // ctx.writeAndFlush() is thread-safe in Netty and can be called from any thread. + try { + NettyRequestHandler.sseExecutor.execute { () => + RxRunner.run(resp.events) { + case OnNext(e: ServerSentEvent) => + eventCounter.incrementAndGet() + val event = e.toContent + val buf = Unpooled.copiedBuffer(event.getBytes("UTF-8")) + ctx.writeAndFlush(new DefaultHttpContent(buf)) + case OnError(e) => + writeStreamLog(HttpStatus.InternalServerError_500, Some(e)) + if (ctx.channel().isActive) { + ctx + .writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + .addListener(ChannelFutureListener.CLOSE) + } + case _ => + writeStreamLog(responseStatus, None) + val f = ctx.writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + f.addListener(ChannelFutureListener.CLOSE) } - } catch { - case e: java.util.concurrent.RejectedExecutionException => - warn(s"SSE executor is saturated; closing stream", e) - writeStreamLog(HttpStatus.ServiceUnavailable_503, Some(e)) - ctx - .writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) - .addListener(ChannelFutureListener.CLOSE) } + } catch { + case e: java.util.concurrent.RejectedExecutionException => + warn(s"SSE executor is saturated; closing stream", e) + writeStreamLog(HttpStatus.ServiceUnavailable_503, Some(e)) + ctx + .writeAndFlush(LastHttpContent.EMPTY_LAST_CONTENT) + .addListener(ChannelFutureListener.CLOSE) } - case OnError(ex) => - // This path manages unhandled exceptions - val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse - val nettyResponse = toNettyResponse(resp) - writeResponse(msg, ctx, nettyResponse) - case OnCompletion => - } - } catch { - case e: RPCException => - writeResponse(msg, ctx, toNettyResponse(e.toResponse)) - } finally { - // Need to clean up the TLS in case the same thread is reused for the next request - NettyBackend.clearThreadLocal() + } + case OnError(ex) => + // This path manages unhandled exceptions + val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse + val nettyResponse = toNettyResponse(resp) + writeResponse(msg, ctx, nettyResponse) + case OnCompletion => } } @@ -220,6 +335,75 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi object NettyRequestHandler extends LogSupport { + /** + * Convert a Netty [[FullHttpRequest]] into an Airframe [[Request]]. Shared by the HTTP dispatch and WebSocket + * upgrade paths. + */ + private[netty] def toAirframeRequest(ctx: ChannelHandlerContext, msg: FullHttpRequest): Request = { + var req: Request = msg.method().name().toUpperCase match { + case HttpMethod.GET => Http.GET(msg.uri()) + case HttpMethod.POST => Http.POST(msg.uri()) + case HttpMethod.PUT => Http.PUT(msg.uri()) + case HttpMethod.DELETE => Http.DELETE(msg.uri()) + case HttpMethod.PATCH => Http.PATCH(msg.uri()) + case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, msg.uri()) + case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, msg.uri()) + case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, msg.uri()) + case _ => + throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${msg.method()}") + } + + // Set remote address for logging purpose + ctx.channel().remoteAddress() match { + case x: InetSocketAddress => + // TODO This address might be IPv6 + req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}")) + case _ => + } + + // Read request headers + msg.headers().names().asScala.foreach { x => + req = req.withHeader(x, msg.headers().get(x)) + } + + // Read request body + var bodyBuf: ByteArrayOutputStream = null + val requestBody = msg.content() + while (requestBody.isReadable) { + // the returned size is greater than 0 when isReadable = true + val size = requestBody.readableBytes() + if (bodyBuf == null) { + bodyBuf = new ByteArrayOutputStream(size) + } + requestBody.readBytes(bodyBuf, size) + } + if (bodyBuf != null && bodyBuf.size() > 0) { + req = req.withContent(bodyBuf.toByteArray) + } + req + } + + /** + * Check whether the request is a WebSocket upgrade request (Connection: Upgrade and Upgrade: websocket). + */ + private[netty] def isWebSocketUpgrade(msg: HttpRequest): Boolean = { + val headers = msg.headers() + // containsValue with ignoreCase handles comma-separated lists (e.g. "Connection: keep-alive, Upgrade"), + // so each token is matched individually - do not replace with an exact-match comparison. + headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) && + headers.containsValue(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true) + } + + /** + * Build the WebSocket URL used by Netty's handshaker, using wss:// when the channel is TLS-secured (an SslHandler is + * present in the pipeline) and ws:// otherwise. + */ + private[netty] def webSocketLocation(ctx: ChannelHandlerContext, req: HttpRequest): String = { + val scheme = if (ctx.pipeline().get(classOf[io.netty.handler.ssl.SslHandler]) != null) "wss" else "ws" + val host = Option(req.headers().get(HttpHeaderNames.HOST)).getOrElse("localhost") + s"${scheme}://${host}${req.uri()}" + } + // Thread pool for SSE stream consumption to avoid blocking Netty worker threads. // Bounded to 64 threads to prevent unbounded creation under high SSE load. // Idle threads are reclaimed after 60 seconds. Daemon threads don't prevent JVM shutdown. diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala index 4bc6dadff..42c6d3916 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala @@ -79,7 +79,12 @@ case class NettyServerConfig( // slow or blocking request handlers from starving the event loop and blocking other // connections (e.g., health check endpoints). // None (default) = handlers run on Netty worker threads (event loop). - handlerExecutorThreads: Option[Int] = None + handlerExecutorThreads: Option[Int] = None, + // WebSocket routes. An incoming HTTP upgrade request whose path matches a route is + // upgraded to a WebSocket connection after passing through the route's filter. + webSocketRoutes: List[WebSocketRoute] = Nil, + // Maximum WebSocket frame payload size in bytes (default: 1MB) + webSocketMaxFrameSize: Int = 1024 * 1024 ) { lazy val port = serverPort.getOrElse(IOUtil.unusedPort) @@ -101,6 +106,32 @@ case class NettyServerConfig( def withRouter(rxRouter: RxRouter): NettyServerConfig = { this.copy(router = Router.fromRxRouter(rxRouter)) } + + /** + * Register a WebSocket route at the given path. An incoming HTTP upgrade request whose path matches will be upgraded + * to a WebSocket connection, and a new [[WebSocketHandler]] is created per connection via the given factory. + */ + def withWebSocketRoute(path: String)(handlerFactory: HttpMessage.Request => WebSocketHandler): NettyServerConfig = { + withWebSocketRoute(path, RxHttpFilter.identity)(handlerFactory) + } + + /** + * Register a WebSocket route at the given path with a filter applied to the upgrade request. The filter can enforce + * auth, logging, or metrics on the handshake. If the filter returns a non-2xx response, the upgrade is rejected with + * that response and no WebSocket connection is established. + */ + def withWebSocketRoute(path: String, filter: RxHttpFilter)( + handlerFactory: HttpMessage.Request => WebSocketHandler + ): NettyServerConfig = { + this.copy(webSocketRoutes = webSocketRoutes :+ WebSocketRoute(path, handlerFactory, filter)) + } + + /** + * Set the maximum WebSocket frame payload size in bytes (default: 1MB). + */ + def withWebSocketMaxFrameSize(sizeInBytes: Int): NettyServerConfig = { + this.copy(webSocketMaxFrameSize = sizeInBytes) + } def withHttpLoggerConfig(f: HttpLoggerConfig => HttpLoggerConfig): NettyServerConfig = { this.copy(httpLoggerConfig = f(httpLoggerConfig)) } @@ -412,7 +443,8 @@ class NettyServer(config: NettyServerConfig, session: Session) extends HttpServe pipeline.addLast(new HttpContentCompressor()) pipeline.addLast(new HttpServerExpectContinueHandler) pipeline.addLast(new ChunkedWriteHandler()) - val handler = new NettyRequestHandler(config, dispatcher, httpStreamLogger) + val handler = + new NettyRequestHandler(config, dispatcher, httpStreamLogger, attachContextFilter, handlerExecutorGroup) handlerExecutorGroup match { case Some(executor) => // Offload request handling to a separate thread pool so that diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyWebSocketHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyWebSocketHandler.scala new file mode 100644 index 000000000..64c897bc7 --- /dev/null +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyWebSocketHandler.scala @@ -0,0 +1,146 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.netty + +import io.netty.buffer.{ByteBufUtil, Unpooled} +import io.netty.channel.{Channel, ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.handler.codec.http.websocketx.* +import wvlet.airframe.http.HttpMessage.Request +import wvlet.airframe.http.{RxHttpFilter, WebSocketContext, WebSocketHandler} +import wvlet.log.LogSupport + +import java.util.concurrent.atomic.AtomicBoolean +import scala.util.control.NonFatal + +/** + * A WebSocket route registered on the Netty server. An incoming HTTP upgrade request whose path matches [[path]] is + * upgraded to a WebSocket connection, after passing through [[filter]] (so auth/logging/metrics filters apply to the + * handshake). A fresh [[WebSocketHandler]] is created per connection via [[handlerFactory]]. + */ +case class WebSocketRoute( + path: String, + handlerFactory: Request => WebSocketHandler, + filter: RxHttpFilter = RxHttpFilter.identity +) + +/** + * A [[WebSocketContext]] backed by a Netty [[Channel]]. Netty's `writeAndFlush` is thread-safe, so send/close may be + * called from any thread. + */ +private[netty] class NettyWebSocketContext( + channel: Channel, + override val request: Request, + private[netty] val handshaker: WebSocketServerHandshaker +) extends WebSocketContext { + + override def send(text: String): Unit = { + channel.writeAndFlush(new TextWebSocketFrame(text)) + } + + override def send(data: Array[Byte]): Unit = { + channel.writeAndFlush(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data))) + } + + override def close(): Unit = { + close(WebSocketCloseStatus.NORMAL_CLOSURE.code(), WebSocketCloseStatus.NORMAL_CLOSURE.reasonText()) + } + + override def close(statusCode: Int, reason: String): Unit = { + handshaker.close(channel, new CloseWebSocketFrame(statusCode, reason)) + } +} + +/** + * A Netty inbound handler that bridges WebSocket frames to a user-provided [[WebSocketHandler]]. It is installed into + * the channel pipeline after a successful handshake (see [[NettyRequestHandler]]). + * + * Control frames (close/ping/pong) are handled here rather than by Netty's `WebSocketServerProtocolHandler` on + * purpose: that handler performs its own handshake for a fixed path and cannot gate the upgrade on an + * [[RxHttpFilter]], which is the whole point of the filter-aware upgrade in [[NettyRequestHandler]]. + */ +private[netty] class NettyWebSocketHandler(handler: WebSocketHandler, wsContext: NettyWebSocketContext) + extends SimpleChannelInboundHandler[WebSocketFrame] + with LogSupport { + + // Ensure onClose is delivered exactly once (either via a Close frame or channel inactivation) + private val closeNotified = new AtomicBoolean(false) + + /** + * Notify the handler that the connection is open. Called by [[NettyRequestHandler]] once the handshake completes. + */ + private[netty] def notifyOpen(): Unit = { + safeInvoke(handler.onOpen(wsContext)) + } + + override def channelRead0(ctx: ChannelHandlerContext, frame: WebSocketFrame): Unit = { + frame match { + case t: TextWebSocketFrame => + safeInvoke(handler.onTextMessage(wsContext, t.text())) + case b: BinaryWebSocketFrame => + safeInvoke(handler.onBinaryMessage(wsContext, ByteBufUtil.getBytes(b.content()))) + case c: CloseWebSocketFrame => + notifyClose() + // Echo the close frame back and close the connection (retain to balance the auto-release) + wsContext.handshaker.close(ctx.channel(), c.retain()) + case _: PingWebSocketFrame => + // Respond to ping with a pong carrying the same payload + ctx.writeAndFlush(new PongWebSocketFrame(frame.content().retain())) + case _: PongWebSocketFrame => + // Ignore unsolicited pongs + case other => + debug(s"Ignoring unsupported WebSocket frame: ${other.getClass.getName}") + } + } + + override def channelInactive(ctx: ChannelHandlerContext): Unit = { + notifyClose() + super.channelInactive(ctx) + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + if (NettyRequestHandler.isBenignIOException(cause)) { + debug(cause) + } else { + safeOnError(cause) + } + ctx.close() + } + + private def notifyClose(): Unit = { + if (closeNotified.compareAndSet(false, true)) { + try { + handler.onClose(wsContext) + } catch { + case NonFatal(e) => warn(e) + } + } + } + + // Invoke a user callback, routing any non-fatal exception to onError + private def safeInvoke(body: => Unit): Unit = { + try { + body + } catch { + case NonFatal(e) => safeOnError(e) + } + } + + private def safeOnError(e: Throwable): Unit = { + try { + handler.onError(wsContext, e) + } catch { + case NonFatal(e2) => warn(e2) + } + } +} diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/WebSocketTest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/WebSocketTest.scala new file mode 100644 index 000000000..98a977d5d --- /dev/null +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/WebSocketTest.scala @@ -0,0 +1,285 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.netty + +import wvlet.airframe.http.HttpMessage.{Request, Response} +import wvlet.airframe.http.{ + Endpoint, + Http, + HttpStatus, + RxHttpEndpoint, + RxHttpFilter, + RxRouter, + WebSocketContext, + WebSocketHandler +} +import wvlet.airframe.rx.Rx +import wvlet.airspec.AirSpec + +import java.io.ByteArrayOutputStream +import java.net.URI +import java.net.http.{HttpClient, WebSocket} +import java.nio.ByteBuffer +import java.util.concurrent.{CompletionStage, CountDownLatch, LinkedBlockingQueue, TimeUnit} + +class WebSocketTestApi { + @Endpoint(path = "/v1/hello") + def hello: String = "hello" +} + +/** + * A JDK11 WebSocket listener that collects received text/binary messages into blocking queues and exposes latches for + * lifecycle assertions. + */ +private class CollectingListener extends WebSocket.Listener { + private val textBuffer = new StringBuilder + private val binaryBuffer = new ByteArrayOutputStream() + val textMessages = new LinkedBlockingQueue[String]() + val binaryMessages = new LinkedBlockingQueue[Array[Byte]]() + val openLatch = new CountDownLatch(1) + val closeLatch = new CountDownLatch(1) + + override def onOpen(ws: WebSocket): Unit = { + openLatch.countDown() + ws.request(Long.MaxValue) + } + + override def onText(ws: WebSocket, data: CharSequence, last: Boolean): CompletionStage[?] = { + textBuffer.append(data) + if (last) { + textMessages.put(textBuffer.toString) + textBuffer.setLength(0) + } + null + } + + override def onBinary(ws: WebSocket, data: ByteBuffer, last: Boolean): CompletionStage[?] = { + val bytes = new Array[Byte](data.remaining()) + data.get(bytes) + binaryBuffer.write(bytes) + if (last) { + binaryMessages.put(binaryBuffer.toByteArray) + binaryBuffer.reset() + } + null + } + + override def onClose(ws: WebSocket, statusCode: Int, reason: String): CompletionStage[?] = { + closeLatch.countDown() + null + } + + def nextText: String = textMessages.poll(10, TimeUnit.SECONDS) + def nextBinary: Array[Byte] = binaryMessages.poll(10, TimeUnit.SECONDS) +} + +class WebSocketTest extends AirSpec { + + private def connect(server: NettyServer, path: String, listener: WebSocket.Listener): WebSocket = { + val client = HttpClient.newHttpClient() + val uri = URI.create(s"ws://${server.localAddress}${path}") + client.newWebSocketBuilder().buildAsync(uri, listener).get(10, TimeUnit.SECONDS) + } + + test("echo text messages") { + Netty.server + .withWebSocketRoute("/ws/echo") { _ => + new WebSocketHandler { + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = { + ctx.send(s"echo:${message}") + } + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/echo", listener) + try { + ws.sendText("hello", true) + listener.nextText shouldBe "echo:hello" + ws.sendText("world", true) + listener.nextText shouldBe "echo:world" + } finally { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + } + } + } + + test("echo binary messages") { + Netty.server + .withWebSocketRoute("/ws/echo") { _ => + new WebSocketHandler { + override def onBinaryMessage(ctx: WebSocketContext, message: Array[Byte]): Unit = { + ctx.send(message) + } + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/echo", listener) + val payload = Array[Byte](1, 2, 3, 4, 5) + try { + ws.sendBinary(ByteBuffer.wrap(payload), true) + listener.nextBinary shouldBe payload + } finally { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + } + } + } + + test("fragmented text messages are aggregated") { + Netty.server + .withWebSocketRoute("/ws/echo") { _ => + new WebSocketHandler { + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = ctx.send(s"echo:${message}") + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/echo", listener) + try { + // Send the message in two fragments (last=false, then last=true); the server should see one whole message + ws.sendText("foo", false).get(10, TimeUnit.SECONDS) + ws.sendText("bar", true).get(10, TimeUnit.SECONDS) + listener.nextText shouldBe "echo:foobar" + } finally { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + } + } + } + + test("server can push messages on open") { + Netty.server + .withWebSocketRoute("/ws/push") { _ => + new WebSocketHandler { + override def onOpen(ctx: WebSocketContext): Unit = { + ctx.send("welcome") + } + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/push", listener) + try { + listener.nextText shouldBe "welcome" + } finally { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + } + } + } + + test("onOpen and onClose lifecycle callbacks fire") { + val openLatch = new CountDownLatch(1) + val closeLatch = new CountDownLatch(1) + Netty.server + .withWebSocketRoute("/ws/lifecycle") { _ => + new WebSocketHandler { + override def onOpen(ctx: WebSocketContext): Unit = openLatch.countDown() + override def onClose(ctx: WebSocketContext): Unit = closeLatch.countDown() + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/lifecycle", listener) + openLatch.await(10, TimeUnit.SECONDS) shouldBe true + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + closeLatch.await(10, TimeUnit.SECONDS) shouldBe true + } + } + + test("filter can reject the upgrade") { + val denyFilter = new RxHttpFilter { + override def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = { + Rx.single(Http.response(HttpStatus.Forbidden_403)) + } + } + Netty.server + .withWebSocketRoute("/ws/secure", denyFilter) { _ => + new WebSocketHandler {} + } + .noLogging + .design + .build[NettyServer] { server => + // The handshake should fail because the filter returns a non-2xx response + intercept[Exception] { + connect(server, "/ws/secure", new CollectingListener) + } + } + } + + test("filter returning an empty response rejects the upgrade") { + val emptyFilter = new RxHttpFilter { + override def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = Rx.empty + } + Netty.server + .withWebSocketRoute("/ws/empty", emptyFilter) { _ => + new WebSocketHandler {} + } + .noLogging + .design + .build[NettyServer] { server => + intercept[Exception] { + connect(server, "/ws/empty", new CollectingListener) + } + } + } + + test("WebSocket works with a handler executor thread pool") { + Netty.server + .withHandlerExecutorThreads(4) + .withWebSocketRoute("/ws/echo") { _ => + new WebSocketHandler { + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = ctx.send(s"echo:${message}") + } + } + .noLogging + .design + .build[NettyServer] { server => + val listener = new CollectingListener + val ws = connect(server, "/ws/echo", listener) + try { + ws.sendText("hello", true) + listener.nextText shouldBe "echo:hello" + } finally { + ws.sendClose(WebSocket.NORMAL_CLOSURE, "bye") + } + } + } + + test("normal HTTP endpoints still work alongside WebSocket routes") { + Netty.server + .withRouter(RxRouter.of[WebSocketTestApi]) + .withWebSocketRoute("/ws/echo") { _ => + new WebSocketHandler { + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = ctx.send(message) + } + } + .noLogging + .designWithSyncClient + .build[wvlet.airframe.http.client.SyncClient] { client => + val resp = client.send(Http.GET("/v1/hello")) + resp.status shouldBe HttpStatus.Ok_200 + resp.contentString shouldBe "hello" + } + } +} diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/WebSocket.scala b/airframe-http/src/main/scala/wvlet/airframe/http/WebSocket.scala new file mode 100644 index 000000000..a9ffafbbe --- /dev/null +++ b/airframe-http/src/main/scala/wvlet/airframe/http/WebSocket.scala @@ -0,0 +1,89 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http + +import wvlet.airframe.http.HttpMessage.Request + +/** + * A per-connection handle for interacting with an established WebSocket connection. It is passed to the callbacks of + * [[WebSocketHandler]] so that the handler can send outbound frames or close the connection. + * + * Implementations (e.g. the Netty backend) are thread-safe: `send` and `close` may be called from any thread. + */ +trait WebSocketContext { + + /** + * The original HTTP upgrade request that initiated this WebSocket connection. Useful for reading headers, query + * parameters, or attachments set by upstream filters (e.g. an authenticated user). + */ + def request: Request + + /** + * Send a UTF-8 text frame to the client. + */ + def send(text: String): Unit + + /** + * Send a binary frame to the client. + */ + def send(data: Array[Byte]): Unit + + /** + * Close the connection with a normal closure status (1000). + */ + def close(): Unit + + /** + * Close the connection with the given WebSocket close status code and reason. + */ + def close(statusCode: Int, reason: String): Unit +} + +/** + * A callback interface for handling the lifecycle of a server-side WebSocket connection. + * + * All callbacks have no-op defaults, so an implementation only needs to override the events it cares about. A new + * handler instance is created per connection (see `withWebSocketRoute`), so handlers may hold per-connection mutable + * state. + * + * Note: callbacks are invoked on the server's I/O thread (or the configured handler executor thread pool). Avoid + * blocking on the I/O thread; sending via [[WebSocketContext]] is thread-safe and may be done from any thread. + */ +trait WebSocketHandler { + + /** + * Called once after the WebSocket handshake completes and the connection is ready to send/receive frames. + */ + def onOpen(ctx: WebSocketContext): Unit = {} + + /** + * Called when a text frame is received from the client. + */ + def onTextMessage(ctx: WebSocketContext, message: String): Unit = {} + + /** + * Called when a binary frame is received from the client. + */ + def onBinaryMessage(ctx: WebSocketContext, message: Array[Byte]): Unit = {} + + /** + * Called once when the connection is closed, either by the client or the server. + */ + def onClose(ctx: WebSocketContext): Unit = {} + + /** + * Called when an error occurs while processing the connection. The connection is typically closed afterward. + */ + def onError(ctx: WebSocketContext, e: Throwable): Unit = {} +} diff --git a/docs/airframe-http.md b/docs/airframe-http.md index ded5d3d47..eaea19519 100644 --- a/docs/airframe-http.md +++ b/docs/airframe-http.md @@ -437,6 +437,84 @@ val router = RxRouter ``` +## WebSocket + +The Netty backend (`airframe-http-netty`) can serve WebSocket connections. Register a WebSocket route by +path with `Netty.server.withWebSocketRoute(...)`. A fresh `WebSocketHandler` is created per connection, so +your handler can hold per-connection state: + +```scala +import wvlet.airframe.http.* +import wvlet.airframe.http.netty.Netty + +Netty.server + .withPort(8080) + .withRouter(RxRouter.of[MyApi]) // [optional] regular HTTP/RPC routes can coexist + .withWebSocketRoute("/ws/echo") { request => + new WebSocketHandler { + // Called once after the handshake completes + override def onOpen(ctx: WebSocketContext): Unit = { + ctx.send("welcome") + } + // Called for each text frame from the client + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = { + ctx.send(s"echo:${message}") + } + // Called for each binary frame from the client + override def onBinaryMessage(ctx: WebSocketContext, message: Array[Byte]): Unit = { + ctx.send(message) + } + // Called once when the connection is closed (by either side) + override def onClose(ctx: WebSocketContext): Unit = { + // release per-connection resources here + } + override def onError(ctx: WebSocketContext, e: Throwable): Unit = { + warn(e) + } + } + } + .start { server => + server.awaitTermination + } +``` + +`WebSocketContext` is the per-connection handle for interacting with the client. It is thread-safe, so you +can call `send`/`close` from any thread (e.g. to push messages from a background task): + +- `send(text: String)` / `send(data: Array[Byte])` — send a text or binary frame +- `close()` / `close(statusCode, reason)` — close the connection +- `request` — the original HTTP upgrade request (headers, query parameters, attachments set by filters) + +All `WebSocketHandler` callbacks have no-op defaults, so override only the events you need. Callbacks run on +the server's I/O thread (or the configured handler executor thread pool), so avoid blocking in them; offload +long-running work to a separate thread. + +### Applying filters to the handshake + +WebSocket upgrades flow through the same [`RxHttpFilter`](#filters) chain as HTTP requests, so you can reuse +auth/logging/metrics filters on the handshake. Pass a filter as the second argument; if it returns a non-2xx +response, the upgrade is rejected with that response and no WebSocket connection is established: + +```scala +val authFilter = new RxHttpFilter { + override def apply(request: Request, next: RxHttpEndpoint): Rx[Response] = { + if (isAuthorized(request)) next(request) + else Rx.single(Http.response(HttpStatus.Unauthorized_401)) + } +} + +Netty.server + .withWebSocketRoute("/ws/secure", authFilter) { request => + new WebSocketHandler { + override def onTextMessage(ctx: WebSocketContext, message: String): Unit = ctx.send(message) + } + } +``` + +Fragmented frames are aggregated into whole text/binary messages before reaching your handler, and the +maximum frame payload size (default 1MB) can be tuned with `.withWebSocketMaxFrameSize(sizeInBytes)`. + + ## Access Logs airframe-http stores HTTP access logs at `log/http-server.json` by default in JSON format. When the log file becomes large, it will be compressed with gz and rotated automatically.