diff --git a/src/awslambda.ts b/src/awslambda.ts new file mode 100644 index 00000000..23ae6aae --- /dev/null +++ b/src/awslambda.ts @@ -0,0 +1,30 @@ +import type { Writable } from 'stream'; +import type { Handler, Context } from 'aws-lambda'; + +const anyGlobal: any = global; + +export namespace awslambda { + export type HttpMetadata = { + statusCode: number; + headers: Record; + cookies?: string[]; + }; + + export namespace HttpResponseStream { + export function from(writable: Writable, metadata: HttpMetadata): Writable { + return anyGlobal.awslambda.HttpResponseStream.from(writable, metadata); + } + } + + export type StreamHandler = ( + event: Event, + responseStream: Writable, + context: Context, + ) => void; + + export function streamifyResponse( + handler: StreamHandler, + ): Handler { + return anyGlobal.awslambda.streamifyResponse(handler); + } +} diff --git a/src/lambdaHandler.ts b/src/lambdaHandler.ts index 997373e6..ea3e2b40 100644 --- a/src/lambdaHandler.ts +++ b/src/lambdaHandler.ts @@ -5,42 +5,49 @@ import type { } from '@apollo/server'; import type { WithRequired } from '@apollo/utils.withrequired'; import type { Context, Handler } from 'aws-lambda'; -import type { LambdaResponse, MiddlewareFn } from './middleware'; -import type { - RequestHandler, - RequestHandlerEvent, - RequestHandlerResult, +import { + runMiddleware, + type LambdaResponse, + type MiddlewareFn, +} from './middleware'; +import { + isStreamRequestHandler, + type RequestHandler, + type RequestHandlerEvent, + type RequestHandlerResult, + type StreamRequestHandler, } from './request-handlers/_create'; +import { awslambda } from './awslambda'; +import type { Writable } from 'stream'; export interface LambdaContextFunctionArgument< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, > { - event: RH extends RequestHandler ? EventType : never; + event: RequestHandlerEvent; context: Context; } export interface LambdaHandlerOptions< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, > { middleware?: Array>; context?: ContextFunction<[LambdaContextFunctionArgument], TContext>; } -export type LambdaHandler> = Handler< - RequestHandlerEvent, - RequestHandlerResult ->; +export type LambdaHandler< + RH extends RequestHandler | StreamRequestHandler, +> = Handler, RequestHandlerResult>; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, >( server: ApolloServer, handler: RH, options?: LambdaHandlerOptions, ): LambdaHandler; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, >( server: ApolloServer, @@ -48,7 +55,7 @@ export function startServerAndCreateLambdaHandler< options: WithRequired, 'context'>, ): LambdaHandler; export function startServerAndCreateLambdaHandler< - RH extends RequestHandler, + RH extends RequestHandler | StreamRequestHandler, TContext extends BaseContext, >( server: ApolloServer, @@ -70,24 +77,95 @@ export function startServerAndCreateLambdaHandler< TContext > = options?.context ?? defaultContext; + if (isStreamRequestHandler(handler)) { + return awslambda.streamifyResponse>( + async (event, responseStream, context) => { + let resultMiddlewareFns: Array< + LambdaResponse> + > = []; + let httpResponseStream: Writable | undefined; + try { + const middlewareResult = await runMiddleware( + event, + options?.middleware ?? [], + handler, + ); + if (middlewareResult.status === 'result') { + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + middlewareResult.result, + ); + httpResponseStream.end(); + return; + } + resultMiddlewareFns = middlewareResult.middleware; + + const httpGraphQLRequest = handler.fromEvent(event); + + const response = await server.executeHTTPGraphQLRequest({ + httpGraphQLRequest, + context: () => { + return contextFunction({ + event, + context, + }); + }, + }); + + const metadata = await handler.buildHTTPMetadata(response); + + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + metadata, + ); + + if (response.body.kind === 'complete') { + httpResponseStream.write(response.body.string); + httpResponseStream.end(); + return; + } + + for await (const chunk of response.body.asyncIterator) { + httpResponseStream.write(chunk); + } + httpResponseStream.end(); + } catch (e) { + const { metadata, body } = await handler.toErrorResult(e); + + if (httpResponseStream) { + httpResponseStream.write(body); + httpResponseStream.end(); + return; + } + + for (const resultMiddlewareFn of resultMiddlewareFns) { + await resultMiddlewareFn(metadata as any); + } + + httpResponseStream = awslambda.HttpResponseStream.from( + responseStream, + metadata, + ); + httpResponseStream.write(body); + httpResponseStream.end(); + } + }, + ); + } + return async function (event, context) { - const resultMiddlewareFns: Array>> = + let resultMiddlewareFns: Array>> = []; try { - for (const middlewareFn of options?.middleware ?? []) { - const middlewareReturnValue = await middlewareFn(event); - // If the middleware returns an object, we assume it's a LambdaResponse - if ( - typeof middlewareReturnValue === 'object' && - middlewareReturnValue !== null - ) { - return middlewareReturnValue; - } - // If the middleware returns a function, we assume it's a result callback - if (middlewareReturnValue) { - resultMiddlewareFns.push(middlewareReturnValue); - } + const middlewareResult = await runMiddleware( + event, + options?.middleware ?? [], + handler, + ); + if (middlewareResult.status === 'result') { + return middlewareResult.result; } + resultMiddlewareFns = middlewareResult.middleware; const httpGraphQLRequest = handler.fromEvent(event); diff --git a/src/middleware.ts b/src/middleware.ts index 7ba042b7..7bf2697a 100644 --- a/src/middleware.ts +++ b/src/middleware.ts @@ -1,4 +1,9 @@ -import type { RequestHandler } from './request-handlers/_create'; +import type { + RequestHandler, + RequestHandlerEvent, + RequestHandlerResult, + StreamRequestHandler, +} from './request-handlers/_create'; export type LambdaResponse = (result: ResultType) => Promise; @@ -6,7 +11,57 @@ export type LambdaRequest = ( event: EventType, ) => Promise | ResultType | void>; -export type MiddlewareFn> = - RH extends RequestHandler - ? LambdaRequest - : never; +export type MiddlewareFn< + RH extends RequestHandler | StreamRequestHandler, +> = LambdaRequest, RequestHandlerResult>; + +export async function runMiddleware< + RH extends RequestHandler | StreamRequestHandler, +>( + event: RequestHandlerEvent, + middleware: Array>, + handler: RH, +): Promise< + | { + status: 'result'; + result: RequestHandlerResult; + } + | { + status: 'continue'; + middleware: Array>>; + } +> { + const resultMiddlewareFns: Array>> = + []; + try { + for (const middlewareFn of middleware) { + const middlewareReturnValue = await middlewareFn(event); + // If the middleware returns an object, we assume it's a LambdaResponse + if ( + typeof middlewareReturnValue === 'object' && + middlewareReturnValue !== null + ) { + return middlewareReturnValue; + } + // If the middleware returns a function, we assume it's a result callback + if (middlewareReturnValue) { + resultMiddlewareFns.push(middlewareReturnValue); + } + } + return { + status: 'continue', + middleware: resultMiddlewareFns, + }; + } catch (e) { + const result = handler.toErrorResult(e); + + for (const resultMiddlewareFn of resultMiddlewareFns) { + await resultMiddlewareFn(result); + } + + return { + status: 'result', + result, + }; + } +} diff --git a/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts b/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts new file mode 100644 index 00000000..1b2ba83f --- /dev/null +++ b/src/request-handlers/APIGatewayProxyEventV2StreamRequestHandler.ts @@ -0,0 +1,38 @@ +import type { APIGatewayProxyEventV2 } from 'aws-lambda'; +import { createStreamRequestHandler } from './_create'; +import { HeaderMap } from '@apollo/server'; + +export const createAPIGatewayProxyEventV2StreamRequestHandler = < + Event extends APIGatewayProxyEventV2 = APIGatewayProxyEventV2, +>() => { + return createStreamRequestHandler({ + parseHttpMethod(event) { + return event.requestContext.http.method; + }, + parseHeaders(event) { + const headerMap = new HeaderMap(); + for (const [key, value] of Object.entries(event.headers ?? {})) { + headerMap.set(key, value ?? ''); + } + return headerMap; + }, + parseBody(event, headers) { + if (event.body) { + const contentType = headers.get('content-type'); + const parsedBody = event.isBase64Encoded + ? Buffer.from(event.body, 'base64').toString('utf8') + : event.body; + if (contentType?.startsWith('application/json')) { + return JSON.parse(parsedBody); + } + if (contentType?.startsWith('text/plain')) { + return parsedBody; + } + } + return ''; + }, + parseQueryParams(event) { + return event.rawQueryString; + }, + }); +}; diff --git a/src/request-handlers/_create.ts b/src/request-handlers/_create.ts index 945b97e3..64f415d8 100644 --- a/src/request-handlers/_create.ts +++ b/src/request-handlers/_create.ts @@ -3,6 +3,7 @@ import type { HTTPGraphQLRequest, HTTPGraphQLResponse, } from '@apollo/server'; +import type { awslambda } from '../awslambda'; export interface RequestHandler { fromEvent: (event: EventType) => HTTPGraphQLRequest; @@ -10,11 +11,33 @@ export interface RequestHandler { toErrorResult: (error: unknown) => ResultType; } -export type RequestHandlerEvent> = - T extends RequestHandler ? EventType : never; +export interface StreamRequestHandler { + type: 'stream'; + fromEvent: (event: EventType) => HTTPGraphQLRequest; + buildHTTPMetadata: ( + response: HTTPGraphQLResponse, + ) => Promise; + toErrorResult: (error: unknown) => Promise<{ + metadata: awslambda.HttpMetadata; + body: string; + }>; +} -export type RequestHandlerResult> = - T extends RequestHandler ? ResultType : never; +export type RequestHandlerEvent< + T extends RequestHandler | StreamRequestHandler, +> = T extends StreamRequestHandler + ? EventType + : T extends RequestHandler + ? EventType + : never; + +export type RequestHandlerResult< + T extends RequestHandler | StreamRequestHandler, +> = T extends StreamRequestHandler + ? awslambda.HttpMetadata + : T extends RequestHandler + ? ResultType + : never; export type EventParser = | { @@ -51,3 +74,48 @@ export function createRequestHandler( toErrorResult: resultGenerator.error, }; } + +export function createStreamRequestHandler( + eventParser: EventParser, +): StreamRequestHandler { + return { + type: 'stream', + fromEvent(event) { + if (typeof eventParser === 'function') { + return eventParser(event); + } + const headers = eventParser.parseHeaders(event); + return { + method: eventParser.parseHttpMethod(event), + headers, + search: eventParser.parseQueryParams(event), + body: eventParser.parseBody(event, headers), + }; + }, + buildHTTPMetadata: async (response) => { + const { headers, status } = response; + + return { + statusCode: status ?? 200, + headers: { + ...Object.fromEntries(headers), + }, + }; + }, + toErrorResult: async (error) => { + return { + metadata: { + statusCode: 400, + headers: {}, + }, + body: (error as Error).message, + }; + }, + }; +} + +export function isStreamRequestHandler( + handler: RequestHandler | StreamRequestHandler, +): handler is StreamRequestHandler { + return 'type' in handler && handler.type === 'stream'; +} diff --git a/src/request-handlers/_index.ts b/src/request-handlers/_index.ts index 9372f10f..901fb1f1 100644 --- a/src/request-handlers/_index.ts +++ b/src/request-handlers/_index.ts @@ -1,4 +1,5 @@ export { createALBEventRequestHandler } from './ALBEventRequestHandler'; export { createAPIGatewayProxyEventRequestHandler } from './APIGatewayProxyEventRequestHandler'; export { createAPIGatewayProxyEventV2RequestHandler } from './APIGatewayProxyEventV2RequestHandler'; +export { createAPIGatewayProxyEventV2StreamRequestHandler } from './APIGatewayProxyEventV2StreamRequestHandler'; export * from './_create';