From 54553cb82b1fb479d178cf89499f17a8cdc8ebe1 Mon Sep 17 00:00:00 2001 From: Charmander <~@charmander.me> Date: Thu, 21 Aug 2025 19:27:29 -0700 Subject: [PATCH 1/2] test: Add failing test for parser reader cleanup --- packages/pg-protocol/src/inbound-parser.test.ts | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/packages/pg-protocol/src/inbound-parser.test.ts b/packages/pg-protocol/src/inbound-parser.test.ts index 0575993df..285f4bf2b 100644 --- a/packages/pg-protocol/src/inbound-parser.test.ts +++ b/packages/pg-protocol/src/inbound-parser.test.ts @@ -4,6 +4,7 @@ import { parse } from '.' import assert from 'assert' import { PassThrough } from 'stream' import { BackendMessage } from './messages' +import { Parser } from './parser' const authOkBuffer = buffers.authenticationOk() const paramStatusBuffer = buffers.parameterStatus('client_encoding', 'UTF8') @@ -565,4 +566,10 @@ describe('PgPacketStream', function () { }) }) }) + + it('cleans up the reader after handling a packet', function () { + const parser = new Parser() + parser.parse(oneFieldBuf, () => {}) + assert.strictEqual((parser as any).reader.buffer.byteLength, 0) + }) }) From 715c9dcb09f2e0e93c1fb08cbbec038da813f801 Mon Sep 17 00:00:00 2001 From: Charmander <~@charmander.me> Date: Thu, 21 Aug 2025 19:33:00 -0700 Subject: [PATCH 2/2] fix: Avoid retaining buffer for latest parse in reader MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The buffer can be arbitrarily large, and the parser shouldn’t keep it around while waiting on (and potentially also buffering) the next complete packet. --- packages/pg-protocol/src/parser.ts | 392 +++++++++++++++-------------- 1 file changed, 208 insertions(+), 184 deletions(-) diff --git a/packages/pg-protocol/src/parser.ts b/packages/pg-protocol/src/parser.ts index f7313f235..998077a00 100644 --- a/packages/pg-protocol/src/parser.ts +++ b/packages/pg-protocol/src/parser.ts @@ -36,6 +36,9 @@ const LEN_LENGTH = 4 const HEADER_LENGTH = CODE_LENGTH + LEN_LENGTH +// A placeholder for a `BackendMessage`’s length value that will be set after construction. +const LATEINIT_LENGTH = -1 + export type Packet = { code: number packet: Buffer @@ -152,238 +155,259 @@ export class Parser { } private handlePacket(offset: number, code: number, length: number, bytes: Buffer): BackendMessage { + const { reader } = this + + // NOTE: This undesirably retains the buffer in `this.reader` if the `parse*Message` calls below throw. However, those should only throw in the case of a protocol error, which normally results in the reader being discarded. + reader.setBuffer(offset, bytes) + + let message: BackendMessage + switch (code) { case MessageCodes.BindComplete: - return bindComplete + message = bindComplete + break case MessageCodes.ParseComplete: - return parseComplete + message = parseComplete + break case MessageCodes.CloseComplete: - return closeComplete + message = closeComplete + break case MessageCodes.NoData: - return noData + message = noData + break case MessageCodes.PortalSuspended: - return portalSuspended + message = portalSuspended + break case MessageCodes.CopyDone: - return copyDone + message = copyDone + break case MessageCodes.ReplicationStart: - return replicationStart + message = replicationStart + break case MessageCodes.EmptyQuery: - return emptyQuery + message = emptyQuery + break case MessageCodes.DataRow: - return this.parseDataRowMessage(offset, length, bytes) + message = parseDataRowMessage(reader) + break case MessageCodes.CommandComplete: - return this.parseCommandCompleteMessage(offset, length, bytes) + message = parseCommandCompleteMessage(reader) + break case MessageCodes.ReadyForQuery: - return this.parseReadyForQueryMessage(offset, length, bytes) + message = parseReadyForQueryMessage(reader) + break case MessageCodes.NotificationResponse: - return this.parseNotificationMessage(offset, length, bytes) + message = parseNotificationMessage(reader) + break case MessageCodes.AuthenticationResponse: - return this.parseAuthenticationResponse(offset, length, bytes) + message = parseAuthenticationResponse(reader, length) + break case MessageCodes.ParameterStatus: - return this.parseParameterStatusMessage(offset, length, bytes) + message = parseParameterStatusMessage(reader) + break case MessageCodes.BackendKeyData: - return this.parseBackendKeyData(offset, length, bytes) + message = parseBackendKeyData(reader) + break case MessageCodes.ErrorMessage: - return this.parseErrorMessage(offset, length, bytes, 'error') + message = parseErrorMessage(reader, 'error') + break case MessageCodes.NoticeMessage: - return this.parseErrorMessage(offset, length, bytes, 'notice') + message = parseErrorMessage(reader, 'notice') + break case MessageCodes.RowDescriptionMessage: - return this.parseRowDescriptionMessage(offset, length, bytes) + message = parseRowDescriptionMessage(reader) + break case MessageCodes.ParameterDescriptionMessage: - return this.parseParameterDescriptionMessage(offset, length, bytes) + message = parseParameterDescriptionMessage(reader) + break case MessageCodes.CopyIn: - return this.parseCopyInMessage(offset, length, bytes) + message = parseCopyInMessage(reader) + break case MessageCodes.CopyOut: - return this.parseCopyOutMessage(offset, length, bytes) + message = parseCopyOutMessage(reader) + break case MessageCodes.CopyData: - return this.parseCopyData(offset, length, bytes) + message = parseCopyData(reader, length) + break default: return new DatabaseError('received invalid response: ' + code.toString(16), length, 'error') } - } - private parseReadyForQueryMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const status = this.reader.string(1) - return new ReadyForQueryMessage(length, status) - } + reader.setBuffer(0, emptyBuffer) - private parseCommandCompleteMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const text = this.reader.cstring() - return new CommandCompleteMessage(length, text) + message.length = length + return message } +} - private parseCopyData(offset: number, length: number, bytes: Buffer) { - const chunk = bytes.slice(offset, offset + (length - 4)) - return new CopyDataMessage(length, chunk) - } +const parseReadyForQueryMessage = (reader: BufferReader) => { + const status = reader.string(1) + return new ReadyForQueryMessage(LATEINIT_LENGTH, status) +} - private parseCopyInMessage(offset: number, length: number, bytes: Buffer) { - return this.parseCopyMessage(offset, length, bytes, 'copyInResponse') - } +const parseCommandCompleteMessage = (reader: BufferReader) => { + const text = reader.cstring() + return new CommandCompleteMessage(LATEINIT_LENGTH, text) +} - private parseCopyOutMessage(offset: number, length: number, bytes: Buffer) { - return this.parseCopyMessage(offset, length, bytes, 'copyOutResponse') - } +const parseCopyData = (reader: BufferReader, length: number) => { + const chunk = reader.bytes(length - 4) + return new CopyDataMessage(LATEINIT_LENGTH, chunk) +} - private parseCopyMessage(offset: number, length: number, bytes: Buffer, messageName: MessageName) { - this.reader.setBuffer(offset, bytes) - const isBinary = this.reader.byte() !== 0 - const columnCount = this.reader.int16() - const message = new CopyResponse(length, messageName, isBinary, columnCount) - for (let i = 0; i < columnCount; i++) { - message.columnTypes[i] = this.reader.int16() - } - return message - } +const parseCopyInMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyInResponse') - private parseNotificationMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const processId = this.reader.int32() - const channel = this.reader.cstring() - const payload = this.reader.cstring() - return new NotificationResponseMessage(length, processId, channel, payload) - } +const parseCopyOutMessage = (reader: BufferReader) => parseCopyMessage(reader, 'copyOutResponse') - private parseRowDescriptionMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const fieldCount = this.reader.int16() - const message = new RowDescriptionMessage(length, fieldCount) - for (let i = 0; i < fieldCount; i++) { - message.fields[i] = this.parseField() - } - return message +const parseCopyMessage = (reader: BufferReader, messageName: MessageName) => { + const isBinary = reader.byte() !== 0 + const columnCount = reader.int16() + const message = new CopyResponse(LATEINIT_LENGTH, messageName, isBinary, columnCount) + for (let i = 0; i < columnCount; i++) { + message.columnTypes[i] = reader.int16() } + return message +} - private parseField(): Field { - const name = this.reader.cstring() - const tableID = this.reader.uint32() - const columnID = this.reader.int16() - const dataTypeID = this.reader.uint32() - const dataTypeSize = this.reader.int16() - const dataTypeModifier = this.reader.int32() - const mode = this.reader.int16() === 0 ? 'text' : 'binary' - return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode) - } +const parseNotificationMessage = (reader: BufferReader) => { + const processId = reader.int32() + const channel = reader.cstring() + const payload = reader.cstring() + return new NotificationResponseMessage(LATEINIT_LENGTH, processId, channel, payload) +} - private parseParameterDescriptionMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const parameterCount = this.reader.int16() - const message = new ParameterDescriptionMessage(length, parameterCount) - for (let i = 0; i < parameterCount; i++) { - message.dataTypeIDs[i] = this.reader.int32() - } - return message +const parseRowDescriptionMessage = (reader: BufferReader) => { + const fieldCount = reader.int16() + const message = new RowDescriptionMessage(LATEINIT_LENGTH, fieldCount) + for (let i = 0; i < fieldCount; i++) { + message.fields[i] = parseField(reader) } + return message +} - private parseDataRowMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const fieldCount = this.reader.int16() - const fields: any[] = new Array(fieldCount) - for (let i = 0; i < fieldCount; i++) { - const len = this.reader.int32() - // a -1 for length means the value of the field is null - fields[i] = len === -1 ? null : this.reader.string(len) - } - return new DataRowMessage(length, fields) - } +const parseField = (reader: BufferReader) => { + const name = reader.cstring() + const tableID = reader.uint32() + const columnID = reader.int16() + const dataTypeID = reader.uint32() + const dataTypeSize = reader.int16() + const dataTypeModifier = reader.int32() + const mode = reader.int16() === 0 ? 'text' : 'binary' + return new Field(name, tableID, columnID, dataTypeID, dataTypeSize, dataTypeModifier, mode) +} - private parseParameterStatusMessage(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const name = this.reader.cstring() - const value = this.reader.cstring() - return new ParameterStatusMessage(length, name, value) +const parseParameterDescriptionMessage = (reader: BufferReader) => { + const parameterCount = reader.int16() + const message = new ParameterDescriptionMessage(LATEINIT_LENGTH, parameterCount) + for (let i = 0; i < parameterCount; i++) { + message.dataTypeIDs[i] = reader.int32() } + return message +} - private parseBackendKeyData(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const processID = this.reader.int32() - const secretKey = this.reader.int32() - return new BackendKeyDataMessage(length, processID, secretKey) +const parseDataRowMessage = (reader: BufferReader) => { + const fieldCount = reader.int16() + const fields: any[] = new Array(fieldCount) + for (let i = 0; i < fieldCount; i++) { + const len = reader.int32() + // a -1 for length means the value of the field is null + fields[i] = len === -1 ? null : reader.string(len) } + return new DataRowMessage(LATEINIT_LENGTH, fields) +} - public parseAuthenticationResponse(offset: number, length: number, bytes: Buffer) { - this.reader.setBuffer(offset, bytes) - const code = this.reader.int32() - // TODO(bmc): maybe better types here - const message: BackendMessage & any = { - name: 'authenticationOk', - length, - } +const parseParameterStatusMessage = (reader: BufferReader) => { + const name = reader.cstring() + const value = reader.cstring() + return new ParameterStatusMessage(LATEINIT_LENGTH, name, value) +} - switch (code) { - case 0: // AuthenticationOk - break - case 3: // AuthenticationCleartextPassword - if (message.length === 8) { - message.name = 'authenticationCleartextPassword' - } - break - case 5: // AuthenticationMD5Password - if (message.length === 12) { - message.name = 'authenticationMD5Password' - const salt = this.reader.bytes(4) - return new AuthenticationMD5Password(length, salt) - } - break - case 10: // AuthenticationSASL - { - message.name = 'authenticationSASL' - message.mechanisms = [] - let mechanism: string - do { - mechanism = this.reader.cstring() - if (mechanism) { - message.mechanisms.push(mechanism) - } - } while (mechanism) - } - break - case 11: // AuthenticationSASLContinue - message.name = 'authenticationSASLContinue' - message.data = this.reader.string(length - 8) - break - case 12: // AuthenticationSASLFinal - message.name = 'authenticationSASLFinal' - message.data = this.reader.string(length - 8) - break - default: - throw new Error('Unknown authenticationOk message type ' + code) - } - return message +const parseBackendKeyData = (reader: BufferReader) => { + const processID = reader.int32() + const secretKey = reader.int32() + return new BackendKeyDataMessage(LATEINIT_LENGTH, processID, secretKey) +} + +const parseAuthenticationResponse = (reader: BufferReader, length: number) => { + const code = reader.int32() + // TODO(bmc): maybe better types here + const message: BackendMessage & any = { + name: 'authenticationOk', + length, } - private parseErrorMessage(offset: number, length: number, bytes: Buffer, name: MessageName) { - this.reader.setBuffer(offset, bytes) - const fields: Record = {} - let fieldType = this.reader.string(1) - while (fieldType !== '\0') { - fields[fieldType] = this.reader.cstring() - fieldType = this.reader.string(1) - } + switch (code) { + case 0: // AuthenticationOk + break + case 3: // AuthenticationCleartextPassword + if (message.length === 8) { + message.name = 'authenticationCleartextPassword' + } + break + case 5: // AuthenticationMD5Password + if (message.length === 12) { + message.name = 'authenticationMD5Password' + const salt = reader.bytes(4) + return new AuthenticationMD5Password(LATEINIT_LENGTH, salt) + } + break + case 10: // AuthenticationSASL + { + message.name = 'authenticationSASL' + message.mechanisms = [] + let mechanism: string + do { + mechanism = reader.cstring() + if (mechanism) { + message.mechanisms.push(mechanism) + } + } while (mechanism) + } + break + case 11: // AuthenticationSASLContinue + message.name = 'authenticationSASLContinue' + message.data = reader.string(length - 8) + break + case 12: // AuthenticationSASLFinal + message.name = 'authenticationSASLFinal' + message.data = reader.string(length - 8) + break + default: + throw new Error('Unknown authenticationOk message type ' + code) + } + return message +} - const messageValue = fields.M - - const message = - name === 'notice' ? new NoticeMessage(length, messageValue) : new DatabaseError(messageValue, length, name) - - message.severity = fields.S - message.code = fields.C - message.detail = fields.D - message.hint = fields.H - message.position = fields.P - message.internalPosition = fields.p - message.internalQuery = fields.q - message.where = fields.W - message.schema = fields.s - message.table = fields.t - message.column = fields.c - message.dataType = fields.d - message.constraint = fields.n - message.file = fields.F - message.line = fields.L - message.routine = fields.R - return message +const parseErrorMessage = (reader: BufferReader, name: MessageName) => { + const fields: Record = {} + let fieldType = reader.string(1) + while (fieldType !== '\0') { + fields[fieldType] = reader.cstring() + fieldType = reader.string(1) } + + const messageValue = fields.M + + const message = + name === 'notice' + ? new NoticeMessage(LATEINIT_LENGTH, messageValue) + : new DatabaseError(messageValue, LATEINIT_LENGTH, name) + + message.severity = fields.S + message.code = fields.C + message.detail = fields.D + message.hint = fields.H + message.position = fields.P + message.internalPosition = fields.p + message.internalQuery = fields.q + message.where = fields.W + message.schema = fields.s + message.table = fields.t + message.column = fields.c + message.dataType = fields.d + message.constraint = fields.n + message.file = fields.F + message.line = fields.L + message.routine = fields.R + return message }