Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package wvlet.airframe.http.netty

import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter}
import io.netty.handler.codec.http.{HttpContent, HttpRequest, HttpUtil, LastHttpContent}
import io.netty.util.{AttributeKey, ReferenceCountUtil}
import wvlet.airframe.http.{Http, HttpMethod, InputStreamMessage, RPCException, RPCStatus, ServerAddress}
import wvlet.log.LogSupport

import java.io.{ByteArrayOutputStream, File, FileInputStream, FileOutputStream}
import java.net.InetSocketAddress
import java.nio.file.Files
import scala.jdk.CollectionConverters.*

/**
* A Netty channel handler that replaces HttpObjectAggregator. It assembles incoming HTTP chunks into an
* HttpMessage.Request, buffering small bodies in memory and spilling large bodies to a temp file to reduce heap usage.
*
* @param bodyBufferThresholdBytes
* Bodies larger than this threshold are written to a temp file instead of held in memory.
*/
class NettyBodyHandler(bodyBufferThresholdBytes: Long) extends ChannelInboundHandlerAdapter with LogSupport {
import NettyBodyHandler.*

override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
msg match {
case httpRequest: HttpRequest =>
try {
handleHttpRequest(ctx, httpRequest)
} catch {
case e: RPCException =>
ReferenceCountUtil.release(msg)
val resp = NettyRequestHandler.toNettyResponse(e.toResponse)
ctx.writeAndFlush(resp).addListener(ChannelFutureListener.CLOSE)
}

case content: HttpContent =>
handleContent(ctx, content)

case other =>
ctx.fireChannelRead(other)
}
}

private def handleHttpRequest(ctx: ChannelHandlerContext, httpRequest: HttpRequest): Unit = {
// Start of a new request: build the airframe Request with method, URI, headers, remote address
var req = httpRequest.method().name().toUpperCase match {
case HttpMethod.GET => Http.GET(httpRequest.uri())
case HttpMethod.POST => Http.POST(httpRequest.uri())
case HttpMethod.PUT => Http.PUT(httpRequest.uri())
case HttpMethod.DELETE => Http.DELETE(httpRequest.uri())
case HttpMethod.PATCH => Http.PATCH(httpRequest.uri())
case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, httpRequest.uri())
case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, httpRequest.uri())
case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, httpRequest.uri())
case _ =>
throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${httpRequest.method()}")
}

// Set remote address
ctx.channel().remoteAddress() match {
case x: InetSocketAddress =>
req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}"))
case _ =>
}

// Read request headers
httpRequest.headers().names().asScala.foreach { name =>
req = req.withHeader(name, httpRequest.headers().get(name))
}

// Determine buffering strategy based on Content-Length header
val contentLength = HttpUtil.getContentLength(httpRequest, -1L)
val useFile = contentLength > bodyBufferThresholdBytes

val state = if (useFile) {
val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
new RequestState(req, bodyBuf = null, fileBuf = new FileOutputStream(tmpFile), tmpFile = Some(tmpFile))
} else {
new RequestState(req, bodyBuf = null, fileBuf = null, tmpFile = None)
}
Comment on lines +88 to +99

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential resource leak here. Files.createTempFile creates a file on disk. If the subsequent new FileOutputStream(tmpFile) constructor throws an exception (e.g., due to permissions or other I/O errors), the created tmpFile will not be deleted. The cleanup logic in exceptionCaught will not handle this case because the RequestState has not yet been associated with the channel. This can lead to an accumulation of temporary files on the server. To prevent this, you should wrap the resource allocation in a try-catch block and ensure the temporary file is deleted on failure.

    val state = if (useFile) {
      val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
      try {
        RequestState(req, bodyBuf = null, fileBuf = new FileOutputStream(tmpFile), tmpFile = Some(tmpFile))
      } catch {
        case e: Throwable =>
          tmpFile.delete()
          throw e
      }
    } else {
      RequestState(req, bodyBuf = null, fileBuf = null, tmpFile = None)
    }

Comment on lines +88 to +99

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential resource leak here. If new FileOutputStream(tmpFile) throws an exception, the tmpFile that was just created will not be deleted. To prevent this file leak, you should wrap the resource allocation in a try-catch block to ensure the temporary file is deleted on failure.

Suggested change
val state = if (useFile) {
val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
new RequestState(req, bodyBuf = null, fileBuf = new FileOutputStream(tmpFile), tmpFile = Some(tmpFile))
} else {
new RequestState(req, bodyBuf = null, fileBuf = null, tmpFile = None)
}
val state = if (useFile) {
val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
try {
new RequestState(req, bodyBuf = null, fileBuf = new FileOutputStream(tmpFile), tmpFile = Some(tmpFile))
} catch {
case e: Throwable =>
// Ensure the temp file is deleted if FileOutputStream creation fails
tmpFile.delete()
throw e
}
} else {
new RequestState(req, bodyBuf = null, fileBuf = null, tmpFile = None)
}


ctx.channel().attr(REQUEST_STATE_KEY).set(state)

// If this HttpRequest also carries content (e.g., non-chunked), process the content part
if (httpRequest.isInstanceOf[HttpContent]) {
handleContent(ctx, httpRequest.asInstanceOf[HttpContent])
}
}

private def handleContent(ctx: ChannelHandlerContext, content: HttpContent): Unit = {
val state = ctx.channel().attr(REQUEST_STATE_KEY).get()
if (state == null) {
ReferenceCountUtil.release(content)
return
}

try {
val buf = content.content()
if (buf.isReadable) {
val size = buf.readableBytes()

if (state.fileBuf != null) {
// File-backed path: write directly to file using Netty's internal pooled buffer
buf.readBytes(state.fileBuf, size)
} else {
// In-memory path
if (state.bodyBuf == null) {
state.bodyBuf = new ByteArrayOutputStream(size)
}
buf.readBytes(state.bodyBuf, size)

// Check if we should spill to file (Content-Length was unknown or chunked transfer)
if (state.bodyBuf.size() > bodyBufferThresholdBytes) {
val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
val fos = new FileOutputStream(tmpFile)
state.bodyBuf.writeTo(fos)
state.fileBuf = fos
state.tmpFile = Some(tmpFile)
state.bodyBuf = null // Release in-memory buffer

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This block has a potential resource leak. If new FileOutputStream(tmpFile) or state.bodyBuf.writeTo(fos) throws an exception, the created tmpFile and fos stream will not be cleaned up. This should be wrapped in a try-catch block to ensure resources are released on failure.

            val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
            var fos: FileOutputStream = null
            try {
              fos = new FileOutputStream(tmpFile)
              state.bodyBuf.writeTo(fos)
              state.fileBuf = fos
              state.tmpFile = Some(tmpFile)
              state.bodyBuf = null // Release in-memory buffer
            } catch {
              case e: Throwable =>
                if (fos != null) {
                  try { fos.close() } catch { case _: Throwable => /* ignore */ }
                }
                tmpFile.delete()
                throw e
            }

}
Comment on lines +132 to +145

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is a potential resource leak when spilling the in-memory buffer to a file. If new FileOutputStream(tmpFile) throws an exception, the tmpFile created by Files.createTempFile will be leaked because the exception is not caught here and the cleanup logic will not be aware of this partially-initialized state. You should use a try-catch block to ensure the file is deleted if an error occurs during the spill operation.

          if (state.bodyBuf.size() > bodyBufferThresholdBytes) {
            val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile
            try {
              val fos     = new FileOutputStream(tmpFile)
              state.bodyBuf.writeTo(fos)
              state.fileBuf = fos
              state.tmpFile = Some(tmpFile)
              state.bodyBuf = null // Release in-memory buffer
            } catch {
              case e: Throwable =>
                tmpFile.delete()
                throw e
            }
          }

}
}

if (content.isInstanceOf[LastHttpContent]) {
// Request is complete — build final HttpMessage.Request and fire downstream
var req = state.request

if (state.fileBuf != null) {
state.fileBuf.close()
state.tmpFile.foreach { tmpFile =>
req = req.withContent(new InputStreamMessage(new FileInputStream(tmpFile)))
}
} else if (state.bodyBuf != null && state.bodyBuf.size() > 0) {
req = req.withContent(state.bodyBuf.toByteArray)
}

// Store the temp file path for cleanup after response
state.tmpFile.foreach { tmpFile =>
ctx.channel().attr(TEMP_FILE_KEY).set(tmpFile)
}

ctx.channel().attr(REQUEST_STATE_KEY).set(null)
ctx.fireChannelRead(req)
}
} finally {
ReferenceCountUtil.release(content)
}
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
// Clean up any in-progress state
val state = ctx.channel().attr(REQUEST_STATE_KEY).getAndSet(null)
if (state != null) {
if (state.fileBuf != null) {
try { state.fileBuf.close() }
catch { case _: Exception => }

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It's good practice to log exceptions that are caught and ignored, even in cleanup code, as this can help debug rare issues. Consider logging the exception at a warn or debug level.

Suggested change
catch { case _: Exception => }
catch { case e: Exception => warn("Failed to close file buffer in exception handler", e) }

}
state.tmpFile.foreach(_.delete())
}
ctx.fireExceptionCaught(cause)
}
}

object NettyBodyHandler {
private[netty] val REQUEST_STATE_KEY: AttributeKey[RequestState] =
AttributeKey.valueOf("airframe.requestState")

private[netty] val TEMP_FILE_KEY: AttributeKey[File] =
AttributeKey.valueOf("airframe.tempFile")

/**
* Delete the temp file associated with the channel, if any.
*/
def cleanupTempFile(ctx: ChannelHandlerContext): Unit = {
val tmpFile = ctx.channel().attr(TEMP_FILE_KEY).getAndSet(null)
if (tmpFile != null) {
tmpFile.delete()
}
}

private[netty] class RequestState(
val request: wvlet.airframe.http.HttpMessage.Request,
var bodyBuf: ByteArrayOutputStream,
var fileBuf: FileOutputStream,
var tmpFile: Option[File]
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,36 +14,25 @@
package wvlet.airframe.http.netty

import io.netty.buffer.Unpooled
import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter}
import io.netty.handler.codec.http.*
import wvlet.airframe.http.HttpMessage.{Request, Response}
import wvlet.airframe.http.internal.{HttpLogs, RPCResponseFilter}
import wvlet.airframe.http.{
Http,
HttpHeader,
HttpLogger,
HttpMethod,
HttpServerException,
HttpStatus,
RPCException,
RPCStatus,
ServerAddress,
ServerSentEvent
}
import wvlet.airframe.rx.{Cancelable, OnCompletion, OnError, OnNext, Rx, RxRunner}
import wvlet.airframe.http.internal.HttpLogs
import wvlet.airframe.http.{Http, HttpHeader, HttpLogger, HttpStatus, RPCException, RPCStatus, ServerSentEvent}
import wvlet.airframe.rx.{OnCompletion, OnError, OnNext, Rx, RxRunner}
import wvlet.log.LogSupport

import java.net.InetSocketAddress
import java.util.concurrent.{SynchronousQueue, ThreadFactory, ThreadPoolExecutor, TimeUnit}
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.immutable.ListMap
import scala.jdk.CollectionConverters.*
import NettyRequestHandler.toNettyResponse

import java.io.ByteArrayOutputStream

/**
* Handles fully-assembled HttpMessage.Request objects from NettyBodyHandler. The request body is already available as
* a Message (either ByteArrayMessage for small bodies or InputStreamMessage for large bodies backed by a temp file).
*/
class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Filter, httpStreamLogger: HttpLogger)
extends SimpleChannelInboundHandler[FullHttpRequest]
extends ChannelInboundHandlerAdapter
with LogSupport {

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
Expand All @@ -60,49 +49,17 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
ctx.flush()
}

override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpRequest): Unit = {
try {
var req: wvlet.airframe.http.HttpMessage.Request = msg.method().name().toUpperCase match {
case HttpMethod.GET => Http.GET(msg.uri())
case HttpMethod.POST => Http.POST(msg.uri())
case HttpMethod.PUT => Http.PUT(msg.uri())
case HttpMethod.DELETE => Http.DELETE(msg.uri())
case HttpMethod.PATCH => Http.PATCH(msg.uri())
case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, msg.uri())
case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, msg.uri())
case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, msg.uri())
case _ =>
throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${msg.method()}")
}

// Set remote address for logging purpose
ctx.channel().remoteAddress() match {
case x: InetSocketAddress =>
// TODO This address might be IPv6
req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}"))
case _ =>
}

// Read request headers
msg.headers().names().asScala.map { x =>
req = req.withHeader(x, msg.headers().get(x))
}

// Read request body
var bodyBuf: ByteArrayOutputStream = null
val requestBody = msg.content()
while (requestBody.isReadable) {
// the returned size is greater than 0 when isReadable = true
val size = requestBody.readableBytes()
if (bodyBuf == null) {
bodyBuf = new ByteArrayOutputStream(size)
}
requestBody.readBytes(bodyBuf, size)
}
if (bodyBuf != null && bodyBuf.size() > 0) {
req = req.withContent(bodyBuf.toByteArray)
}
override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = {
msg match {
case req: Request =>
handleRequest(ctx, req)
case _ =>
ctx.fireChannelRead(msg)
}
}

private def handleRequest(ctx: ChannelHandlerContext, req: Request): Unit = {
try {
// Dispatch the request and get an async response, Rx[Response]
val rxResponse: Rx[Response] = dispatcher.apply(
req,
Expand All @@ -111,11 +68,16 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
}
)

// HTTP/1.1 defaults to keep-alive; only close if client explicitly sends "Connection: close"
val requestKeepAlive = !req.header
.get(HttpHeader.Connection)
.exists(_.toLowerCase().contains("close"))

RxRunner.run(rxResponse) {
case OnNext(v) =>
val resp = v.asInstanceOf[Response]
val nettyResponse = toNettyResponse(resp)
writeResponse(msg, ctx, nettyResponse)
writeResponse(req, requestKeepAlive, ctx, nettyResponse)

if (resp.isContentTypeEventStream && resp.message.isEmpty) {
// Capture request context and timing before handing off to SSE executor thread
Expand Down Expand Up @@ -182,32 +144,36 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi
// This path manages unhandled exceptions
val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse
val nettyResponse = toNettyResponse(resp)
writeResponse(msg, ctx, nettyResponse)
writeResponse(req, requestKeepAlive, ctx, nettyResponse)
case OnCompletion =>
}
} catch {
case e: RPCException =>
writeResponse(msg, ctx, toNettyResponse(e.toResponse))
writeResponse(req, requestKeepAlive = false, ctx, toNettyResponse(e.toResponse))
} finally {
// Clean up temp file if body was file-backed
NettyBodyHandler.cleanupTempFile(ctx)
// Need to clean up the TLS in case the same thread is reused for the next request
NettyBackend.clearThreadLocal()
}
Comment on lines 153 to 158

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The finally block cleans up the temporary file used for large request bodies. However, RxRunner.run is non-blocking and dispatches the request for asynchronous processing. This means the finally block will execute immediately, deleting the temporary file before the request handler has a chance to read it. This is a critical bug that will cause I/O exceptions for any request with a body large enough to be spooled to disk.

The cleanup logic must be deferred until after the response has been fully sent to the client. A robust way to achieve this is by attaching a listener to the ChannelFuture of the final write operation for the response.

Comment on lines 153 to 158

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The temp file is deleted in the finally block of handleRequest. Since RxRunner.run is used to dispatch the request, the processing may be asynchronous. If the controller or any downstream filter processes the InputStream asynchronously (e.g., offloaded to another thread pool), the temp file will be deleted before the stream is fully read, resulting in an IOException. Cleanup should be deferred until the response is fully sent or the InputStream is closed.

}

private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = {
private def writeResponse(
req: Request,
requestKeepAlive: Boolean,
ctx: ChannelHandlerContext,
resp: DefaultHttpResponse
): Unit = {
val isEventStream =
Option(resp.headers())
.flatMap(h => Option(h.get(HttpHeader.ContentType)))
.exists(_.contains("text/event-stream"))

// Respect client's Connection header. HTTP/1.1 defaults to keep-alive; HTTP/1.0 defaults to close.
val keepAlive: Boolean =
HttpStatus.ofCode(resp.status().code()).isSuccessful && (HttpUtil.isKeepAlive(req) || isEventStream)
HttpStatus.ofCode(resp.status().code()).isSuccessful && (requestKeepAlive || isEventStream)

if (keepAlive) {
if (!req.protocolVersion().isKeepAliveDefault) {
resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE)
}
} else {
if (!keepAlive) {
resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE)
}
val f = ctx.writeAndFlush(resp)
Expand Down
Loading
Loading