diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBodyHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBodyHandler.scala new file mode 100644 index 0000000000..39c161dad9 --- /dev/null +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyBodyHandler.scala @@ -0,0 +1,212 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.netty + +import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter} +import io.netty.handler.codec.http.{HttpContent, HttpRequest, HttpUtil, LastHttpContent} +import io.netty.util.{AttributeKey, ReferenceCountUtil} +import wvlet.airframe.http.{Http, HttpMethod, InputStreamMessage, RPCException, RPCStatus, ServerAddress} +import wvlet.log.LogSupport + +import java.io.{ByteArrayOutputStream, File, FileInputStream, FileOutputStream} +import java.net.InetSocketAddress +import java.nio.file.Files +import scala.jdk.CollectionConverters.* + +/** + * A Netty channel handler that replaces HttpObjectAggregator. It assembles incoming HTTP chunks into an + * HttpMessage.Request, buffering small bodies in memory and spilling large bodies to a temp file to reduce heap usage. + * + * @param bodyBufferThresholdBytes + * Bodies larger than this threshold are written to a temp file instead of held in memory. + */ +class NettyBodyHandler(bodyBufferThresholdBytes: Long) extends ChannelInboundHandlerAdapter with LogSupport { + import NettyBodyHandler.* + + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + msg match { + case httpRequest: HttpRequest => + try { + handleHttpRequest(ctx, httpRequest) + } catch { + case e: RPCException => + ReferenceCountUtil.release(msg) + val resp = NettyRequestHandler.toNettyResponse(e.toResponse) + ctx.writeAndFlush(resp).addListener(ChannelFutureListener.CLOSE) + } + + case content: HttpContent => + handleContent(ctx, content) + + case other => + ctx.fireChannelRead(other) + } + } + + private def handleHttpRequest(ctx: ChannelHandlerContext, httpRequest: HttpRequest): Unit = { + // Start of a new request: build the airframe Request with method, URI, headers, remote address + var req = httpRequest.method().name().toUpperCase match { + case HttpMethod.GET => Http.GET(httpRequest.uri()) + case HttpMethod.POST => Http.POST(httpRequest.uri()) + case HttpMethod.PUT => Http.PUT(httpRequest.uri()) + case HttpMethod.DELETE => Http.DELETE(httpRequest.uri()) + case HttpMethod.PATCH => Http.PATCH(httpRequest.uri()) + case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, httpRequest.uri()) + case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, httpRequest.uri()) + case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, httpRequest.uri()) + case _ => + throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${httpRequest.method()}") + } + + // Set remote address + ctx.channel().remoteAddress() match { + case x: InetSocketAddress => + req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}")) + case _ => + } + + // Read request headers + httpRequest.headers().names().asScala.foreach { name => + req = req.withHeader(name, httpRequest.headers().get(name)) + } + + // Determine buffering strategy based on Content-Length header + val contentLength = HttpUtil.getContentLength(httpRequest, -1L) + val useFile = contentLength > bodyBufferThresholdBytes + + val state = if (useFile) { + val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile + try { + new RequestState(req, bodyBuf = null, fileBuf = new FileOutputStream(tmpFile), tmpFile = Some(tmpFile)) + } catch { + case e: Exception => + tmpFile.delete() + throw e + } + } else { + new RequestState(req, bodyBuf = null, fileBuf = null, tmpFile = None) + } + + ctx.channel().attr(REQUEST_STATE_KEY).set(state) + + // If this HttpRequest also carries content (e.g., non-chunked), process the content part + if (httpRequest.isInstanceOf[HttpContent]) { + handleContent(ctx, httpRequest.asInstanceOf[HttpContent]) + } + } + + private def handleContent(ctx: ChannelHandlerContext, content: HttpContent): Unit = { + val state = ctx.channel().attr(REQUEST_STATE_KEY).get() + if (state == null) { + ReferenceCountUtil.release(content) + return + } + + try { + val buf = content.content() + if (buf.isReadable) { + val size = buf.readableBytes() + + if (state.fileBuf != null) { + // File-backed path: write directly to file using Netty's internal pooled buffer + buf.readBytes(state.fileBuf, size) + } else { + // In-memory path + if (state.bodyBuf == null) { + state.bodyBuf = new ByteArrayOutputStream(size) + } + buf.readBytes(state.bodyBuf, size) + + // Check if we should spill to file (Content-Length was unknown or chunked transfer) + if (state.bodyBuf.size() > bodyBufferThresholdBytes) { + val tmpFile = Files.createTempFile("airframe-body-", ".tmp").toFile + try { + val fos = new FileOutputStream(tmpFile) + state.bodyBuf.writeTo(fos) + state.fileBuf = fos + state.tmpFile = Some(tmpFile) + state.bodyBuf = null // Release in-memory buffer + } catch { + case e: Exception => + tmpFile.delete() + throw e + } + } + } + } + + if (content.isInstanceOf[LastHttpContent]) { + // Request is complete — build final HttpMessage.Request and fire downstream + var req = state.request + + if (state.fileBuf != null) { + state.fileBuf.close() + state.tmpFile.foreach { tmpFile => + req = req.withContent(new InputStreamMessage(new FileInputStream(tmpFile))) + } + } else if (state.bodyBuf != null && state.bodyBuf.size() > 0) { + req = req.withContent(state.bodyBuf.toByteArray) + } + + // Store the temp file path for cleanup after response + state.tmpFile.foreach { tmpFile => + ctx.channel().attr(TEMP_FILE_KEY).set(tmpFile) + } + + ctx.channel().attr(REQUEST_STATE_KEY).set(null) + ctx.fireChannelRead(req) + } + } finally { + ReferenceCountUtil.release(content) + } + } + + override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { + // Clean up any in-progress state + val state = ctx.channel().attr(REQUEST_STATE_KEY).getAndSet(null) + if (state != null) { + if (state.fileBuf != null) { + try { state.fileBuf.close() } + catch { case _: Exception => } + } + state.tmpFile.foreach(_.delete()) + } + ctx.fireExceptionCaught(cause) + } +} + +object NettyBodyHandler { + private[netty] val REQUEST_STATE_KEY: AttributeKey[RequestState] = + AttributeKey.valueOf("airframe.requestState") + + private[netty] val TEMP_FILE_KEY: AttributeKey[File] = + AttributeKey.valueOf("airframe.tempFile") + + /** + * Delete the temp file associated with the channel, if any. + */ + def cleanupTempFile(ctx: ChannelHandlerContext): Unit = { + val tmpFile = ctx.channel().attr(TEMP_FILE_KEY).getAndSet(null) + if (tmpFile != null) { + tmpFile.delete() + } + } + + private[netty] class RequestState( + val request: wvlet.airframe.http.HttpMessage.Request, + var bodyBuf: ByteArrayOutputStream, + var fileBuf: FileOutputStream, + var tmpFile: Option[File] + ) +} diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala index 837a7ab432..b685345bde 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyRequestHandler.scala @@ -14,36 +14,25 @@ package wvlet.airframe.http.netty import io.netty.buffer.Unpooled -import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, SimpleChannelInboundHandler} +import io.netty.channel.{ChannelFutureListener, ChannelHandlerContext, ChannelInboundHandlerAdapter} import io.netty.handler.codec.http.* import wvlet.airframe.http.HttpMessage.{Request, Response} -import wvlet.airframe.http.internal.{HttpLogs, RPCResponseFilter} -import wvlet.airframe.http.{ - Http, - HttpHeader, - HttpLogger, - HttpMethod, - HttpServerException, - HttpStatus, - RPCException, - RPCStatus, - ServerAddress, - ServerSentEvent -} -import wvlet.airframe.rx.{Cancelable, OnCompletion, OnError, OnNext, Rx, RxRunner} +import wvlet.airframe.http.internal.HttpLogs +import wvlet.airframe.http.{Http, HttpHeader, HttpLogger, HttpStatus, RPCException, RPCStatus, ServerSentEvent} +import wvlet.airframe.rx.{OnCompletion, OnError, OnNext, Rx, RxRunner} import wvlet.log.LogSupport -import java.net.InetSocketAddress import java.util.concurrent.{SynchronousQueue, ThreadFactory, ThreadPoolExecutor, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.collection.immutable.ListMap -import scala.jdk.CollectionConverters.* import NettyRequestHandler.toNettyResponse -import java.io.ByteArrayOutputStream - +/** + * Handles fully-assembled HttpMessage.Request objects from NettyBodyHandler. The request body is already available as + * a Message (either ByteArrayMessage for small bodies or InputStreamMessage for large bodies backed by a temp file). + */ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Filter, httpStreamLogger: HttpLogger) - extends SimpleChannelInboundHandler[FullHttpRequest] + extends ChannelInboundHandlerAdapter with LogSupport { override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = { @@ -60,49 +49,17 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi ctx.flush() } - override def channelRead0(ctx: ChannelHandlerContext, msg: FullHttpRequest): Unit = { - try { - var req: wvlet.airframe.http.HttpMessage.Request = msg.method().name().toUpperCase match { - case HttpMethod.GET => Http.GET(msg.uri()) - case HttpMethod.POST => Http.POST(msg.uri()) - case HttpMethod.PUT => Http.PUT(msg.uri()) - case HttpMethod.DELETE => Http.DELETE(msg.uri()) - case HttpMethod.PATCH => Http.PATCH(msg.uri()) - case HttpMethod.TRACE => Http.request(wvlet.airframe.http.HttpMethod.TRACE, msg.uri()) - case HttpMethod.OPTIONS => Http.request(wvlet.airframe.http.HttpMethod.OPTIONS, msg.uri()) - case HttpMethod.HEAD => Http.request(wvlet.airframe.http.HttpMethod.HEAD, msg.uri()) - case _ => - throw RPCStatus.INVALID_REQUEST_U1.newException(s"Unsupported HTTP method: ${msg.method()}") - } - - // Set remote address for logging purpose - ctx.channel().remoteAddress() match { - case x: InetSocketAddress => - // TODO This address might be IPv6 - req = req.withRemoteAddress(ServerAddress(s"${x.getHostString}:${x.getPort}")) - case _ => - } - - // Read request headers - msg.headers().names().asScala.map { x => - req = req.withHeader(x, msg.headers().get(x)) - } - - // Read request body - var bodyBuf: ByteArrayOutputStream = null - val requestBody = msg.content() - while (requestBody.isReadable) { - // the returned size is greater than 0 when isReadable = true - val size = requestBody.readableBytes() - if (bodyBuf == null) { - bodyBuf = new ByteArrayOutputStream(size) - } - requestBody.readBytes(bodyBuf, size) - } - if (bodyBuf != null && bodyBuf.size() > 0) { - req = req.withContent(bodyBuf.toByteArray) - } + override def channelRead(ctx: ChannelHandlerContext, msg: Any): Unit = { + msg match { + case req: Request => + handleRequest(ctx, req) + case _ => + ctx.fireChannelRead(msg) + } + } + private def handleRequest(ctx: ChannelHandlerContext, req: Request): Unit = { + try { // Dispatch the request and get an async response, Rx[Response] val rxResponse: Rx[Response] = dispatcher.apply( req, @@ -111,11 +68,16 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi } ) + // HTTP/1.1 defaults to keep-alive; only close if client explicitly sends "Connection: close" + val requestKeepAlive = !req.header + .get(HttpHeader.Connection) + .exists(_.toLowerCase().contains("close")) + RxRunner.run(rxResponse) { case OnNext(v) => val resp = v.asInstanceOf[Response] val nettyResponse = toNettyResponse(resp) - writeResponse(msg, ctx, nettyResponse) + writeResponse(req, requestKeepAlive, ctx, nettyResponse) if (resp.isContentTypeEventStream && resp.message.isEmpty) { // Capture request context and timing before handing off to SSE executor thread @@ -182,32 +144,36 @@ class NettyRequestHandler(config: NettyServerConfig, dispatcher: NettyBackend.Fi // This path manages unhandled exceptions val resp = RPCStatus.INTERNAL_ERROR_I0.newException(ex.getMessage, ex).toResponse val nettyResponse = toNettyResponse(resp) - writeResponse(msg, ctx, nettyResponse) + writeResponse(req, requestKeepAlive, ctx, nettyResponse) case OnCompletion => } } catch { case e: RPCException => - writeResponse(msg, ctx, toNettyResponse(e.toResponse)) + writeResponse(req, requestKeepAlive = false, ctx, toNettyResponse(e.toResponse)) } finally { + // Clean up temp file if body was file-backed + NettyBodyHandler.cleanupTempFile(ctx) // Need to clean up the TLS in case the same thread is reused for the next request NettyBackend.clearThreadLocal() } } - private def writeResponse(req: HttpRequest, ctx: ChannelHandlerContext, resp: DefaultHttpResponse): Unit = { + private def writeResponse( + req: Request, + requestKeepAlive: Boolean, + ctx: ChannelHandlerContext, + resp: DefaultHttpResponse + ): Unit = { val isEventStream = Option(resp.headers()) .flatMap(h => Option(h.get(HttpHeader.ContentType))) .exists(_.contains("text/event-stream")) + // Respect client's Connection header. HTTP/1.1 defaults to keep-alive; HTTP/1.0 defaults to close. val keepAlive: Boolean = - HttpStatus.ofCode(resp.status().code()).isSuccessful && (HttpUtil.isKeepAlive(req) || isEventStream) + HttpStatus.ofCode(resp.status().code()).isSuccessful && (requestKeepAlive || isEventStream) - if (keepAlive) { - if (!req.protocolVersion().isKeepAliveDefault) { - resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE) - } - } else { + if (!keepAlive) { resp.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE) } val f = ctx.writeAndFlush(resp) diff --git a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala index 4bc6dadffd..eb09ccea14 100644 --- a/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala +++ b/airframe-http-netty/src/main/scala/wvlet/airframe/http/netty/NettyServer.scala @@ -79,7 +79,10 @@ case class NettyServerConfig( // slow or blocking request handlers from starving the event loop and blocking other // connections (e.g., health check endpoints). // None (default) = handlers run on Netty worker threads (event loop). - handlerExecutorThreads: Option[Int] = None + handlerExecutorThreads: Option[Int] = None, + // Threshold in bytes for file-backed body buffering. Request bodies larger than this + // are written to a temp file instead of held in memory. Default: 8MB. + bodyBufferThresholdBytes: Long = 8L * 1024 * 1024 ) { lazy val port = serverPort.getOrElse(IOUtil.unusedPort) @@ -201,6 +204,18 @@ case class NettyServerConfig( this.copy(handlerExecutorThreads = Some(threads)) } + /** + * Set the threshold in bytes for file-backed body buffering. Request bodies larger than this are written to a temp + * file instead of held in memory. + * + * @param bytes + * threshold in bytes (default: 8MB) + */ + def withBodyBufferThreshold(bytes: Long): NettyServerConfig = { + require(bytes > 0, "bodyBufferThresholdBytes must be positive") + this.copy(bodyBufferThresholdBytes = bytes) + } + def newServer(session: Session): NettyServer = { val s = new NettyServer(this, session) s.start @@ -408,9 +423,9 @@ class NettyServer(config: NettyServerConfig, session: Session) extends HttpServe ) ) pipeline.addLast(new HttpServerKeepAliveHandler()) - pipeline.addLast(new HttpObjectAggregator(Int.MaxValue)) pipeline.addLast(new HttpContentCompressor()) pipeline.addLast(new HttpServerExpectContinueHandler) + pipeline.addLast(new NettyBodyHandler(config.bodyBufferThresholdBytes)) pipeline.addLast(new ChunkedWriteHandler()) val handler = new NettyRequestHandler(config, dispatcher, httpStreamLogger) handlerExecutorGroup match { diff --git a/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/InputStreamEndpointTest.scala b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/InputStreamEndpointTest.scala new file mode 100644 index 0000000000..d27807f368 --- /dev/null +++ b/airframe-http-netty/src/test/scala/wvlet/airframe/http/netty/InputStreamEndpointTest.scala @@ -0,0 +1,79 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http.netty + +import wvlet.airframe.http.* +import wvlet.airframe.http.client.SyncClient +import wvlet.airspec.AirSpec + +import java.io.InputStream + +object InputStreamEndpointTest { + + class InputStreamApi { + @Endpoint(method = HttpMethod.POST, path = "/upload") + def upload(body: InputStream): String = { + val bytes = body.readAllBytes() + s"received ${bytes.length} bytes" + } + + @Endpoint(method = HttpMethod.POST, path = "/echo") + def echo(body: InputStream): HttpMessage.Response = { + val bytes = body.readAllBytes() + Http.response(HttpStatus.Ok_200).withContent(bytes) + } + } +} + +class InputStreamEndpointTest extends AirSpec { + import InputStreamEndpointTest.* + + initDesign { + _.add( + Netty.server + .withRouter(RxRouter.of[InputStreamApi]) + .designWithSyncClient + ) + } + + test("receive small binary body via InputStream") { (client: SyncClient) => + val data = "hello world".getBytes("UTF-8") + val request = Http.POST("/upload").withContent(data) + val resp = client.send(request) + resp.contentString shouldBe s"received ${data.length} bytes" + } + + test("echo binary content via InputStream") { (client: SyncClient) => + val data = new Array[Byte](1024) + scala.util.Random.nextBytes(data) + val request = Http.POST("/echo").withContent(data) + val resp = client.send(request) + resp.contentBytes shouldBe data + } + + test("handle large body via InputStream") { (client: SyncClient) => + // Create a body larger than default 8MB threshold to trigger file-backed buffering + val data = new Array[Byte](9 * 1024 * 1024) + scala.util.Random.nextBytes(data) + val request = Http.POST("/upload").withContent(data) + val resp = client.send(request) + resp.contentString shouldBe s"received ${data.length} bytes" + } + + test("handle empty body via InputStream") { (client: SyncClient) => + val request = Http.POST("/upload") + val resp = client.send(request) + resp.contentString shouldBe "received 0 bytes" + } +} diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/HttpAccessLogWriter.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/HttpAccessLogWriter.scala index c970178e7b..32ca64a81a 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/HttpAccessLogWriter.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/HttpAccessLogWriter.scala @@ -180,6 +180,9 @@ object HttpAccessLogWriter { ListMap.empty case c: HttpContext[_, _, _] => ListMap.empty + case _: java.io.InputStream => + // InputStream parameters are not serializable for logging + ListMap.empty case _ if p.isSecret => ListMap.empty case u: ULID => diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/InputStreamMessage.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/InputStreamMessage.scala new file mode 100644 index 0000000000..805ad1c7d4 --- /dev/null +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/InputStreamMessage.scala @@ -0,0 +1,99 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.nio.charset.StandardCharsets + +/** + * A Message implementation backed by an InputStream. This allows handlers to read large request bodies without holding + * the entire content in memory. + * + * The InputStream can only be consumed once via getInputStream. If toContentBytes or toContentString is called, the + * stream is fully read and cached. If getInputStream is called first, a tee-reading wrapper is returned that caches + * bytes as they are read, so toContentBytes can still work afterwards. + */ +class InputStreamMessage(inputStream: InputStream) extends HttpMessage.Message { + private var cachedBytes: Array[Byte] = null + + private def ensureCached(): Array[Byte] = synchronized { + if (cachedBytes == null) { + cachedBytes = inputStream.readAllBytes() + } + cachedBytes + } + + /** + * Returns an InputStream for reading the body content. If the content has already been materialized via + * toContentBytes/toContentString, returns a ByteArrayInputStream over the cached data. Otherwise, returns a wrapper + * that caches bytes as they are read. + */ + def getInputStream: InputStream = synchronized { + if (cachedBytes != null) { + new ByteArrayInputStream(cachedBytes) + } else { + new TeeInputStream(inputStream, this) + } + } + + private[http] def setCachedBytes(bytes: Array[Byte]): Unit = synchronized { + if (cachedBytes == null) { + cachedBytes = bytes + } + } + + override def isEmpty: Boolean = false + override def toContentBytes: Array[Byte] = ensureCached() + override def toContentString: String = new String(ensureCached(), StandardCharsets.UTF_8) +} + +/** + * An InputStream wrapper that caches all bytes read from the underlying stream into a buffer. When the stream is fully + * consumed, the cached bytes are stored back into the parent InputStreamMessage for subsequent access via + * toContentBytes. + */ +private[http] class TeeInputStream(underlying: InputStream, parent: InputStreamMessage) extends InputStream { + private val buffer = new ByteArrayOutputStream() + + override def read(): Int = { + val b = underlying.read() + if (b >= 0) { + buffer.write(b) + } else { + parent.setCachedBytes(buffer.toByteArray) + } + b + } + + override def read(b: Array[Byte], off: Int, len: Int): Int = { + val n = underlying.read(b, off, len) + if (n > 0) { + buffer.write(b, off, n) + } else if (n < 0) { + parent.setCachedBytes(buffer.toByteArray) + } + n + } + + override def available(): Int = underlying.available() + + override def close(): Unit = { + try { + // Cache whatever has been read so far (do not drain remaining bytes to avoid memory issues) + parent.setCachedBytes(buffer.toByteArray) + } finally { + underlying.close() + } + } +} diff --git a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpRequestMapper.scala b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpRequestMapper.scala index c2e292c91a..2d0bf2f728 100644 --- a/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpRequestMapper.scala +++ b/airframe-http/.jvm/src/main/scala/wvlet/airframe/http/router/HttpRequestMapper.scala @@ -76,6 +76,15 @@ object HttpRequestMapper extends LogSupport { case cl if classOf[HttpContext[Req, Resp, F]].isAssignableFrom(cl) => // Bind HttpContext Some(context) + case cl if classOf[java.io.InputStream].isAssignableFrom(cl) => + // Bind the request body as an InputStream for streaming large bodies + val msg = adapter.messageOf(request) + msg match { + case ism: InputStreamMessage => + Some(ism.getInputStream) + case _ => + Some(new java.io.ByteArrayInputStream(msg.toContentBytes)) + } case _ => // Pass the String parameter to the method argument val argCodec = codecFactory.of(argSurface) diff --git a/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/InputStreamMessageTest.scala b/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/InputStreamMessageTest.scala new file mode 100644 index 0000000000..454d91a405 --- /dev/null +++ b/airframe-http/.jvm/src/test/scala/wvlet/airframe/http/InputStreamMessageTest.scala @@ -0,0 +1,64 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package wvlet.airframe.http + +import wvlet.airspec.AirSpec + +import java.io.ByteArrayInputStream + +object InputStreamMessageTest extends AirSpec { + + test("read content bytes") { + val data = "hello world".getBytes("UTF-8") + val msg = new InputStreamMessage(new ByteArrayInputStream(data)) + msg.toContentBytes shouldBe data + } + + test("read content string") { + val msg = new InputStreamMessage(new ByteArrayInputStream("test string".getBytes("UTF-8"))) + msg.toContentString shouldBe "test string" + } + + test("getInputStream returns readable stream") { + val data = "stream data".getBytes("UTF-8") + val msg = new InputStreamMessage(new ByteArrayInputStream(data)) + val is = msg.getInputStream + is.readAllBytes() shouldBe data + } + + test("getInputStream after toContentBytes returns cached data") { + val data = "cached".getBytes("UTF-8") + val msg = new InputStreamMessage(new ByteArrayInputStream(data)) + val bytes = msg.toContentBytes + bytes shouldBe data + // After toContentBytes, getInputStream should return a ByteArrayInputStream over cached data + val is = msg.getInputStream + is.readAllBytes() shouldBe data + } + + test("toContentBytes after getInputStream returns cached data") { + val data = "reverse order".getBytes("UTF-8") + val msg = new InputStreamMessage(new ByteArrayInputStream(data)) + // Consume the stream first + val is = msg.getInputStream + is.readAllBytes() shouldBe data + // toContentBytes should still work via cache + msg.toContentBytes shouldBe data + } + + test("isEmpty returns false") { + val msg = new InputStreamMessage(new ByteArrayInputStream(Array.empty[Byte])) + msg.isEmpty shouldBe false + } +} diff --git a/airframe-http/src/main/scala/wvlet/airframe/http/internal/HttpLogs.scala b/airframe-http/src/main/scala/wvlet/airframe/http/internal/HttpLogs.scala index 50e7e4d594..5e197f8c7d 100644 --- a/airframe-http/src/main/scala/wvlet/airframe/http/internal/HttpLogs.scala +++ b/airframe-http/src/main/scala/wvlet/airframe/http/internal/HttpLogs.scala @@ -230,6 +230,9 @@ object HttpLogs extends LogSupport { ListMap.empty case c: HttpContext[_, _, _] => ListMap.empty + case _ if p.surface.fullName == "java.io.InputStream" => + // InputStream parameters are not serializable for logging + ListMap.empty case _ if p.isSecret => ListMap.empty case u: ULID =>