Skip to content

grpc-pb: Add gRPC Exceptions and skip unknown fields #437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

package kotlinx.rpc.grpc.internal

import kotlinx.rpc.grpc.pb.ProtobufDecodingException
import kotlinx.rpc.grpc.pb.WireDecoder

internal expect fun WireDecoder.pushLimit(byteLen: Int): Int
Expand All @@ -13,20 +14,15 @@ internal expect fun WireDecoder.bytesUntilLimit(): Int
internal inline fun <T : Any> WireDecoder.readPackedVarInternal(
crossinline size: () -> Long,
crossinline readFn: () -> T,
crossinline withError: () -> Unit,
crossinline hadError: () -> Boolean,
): List<T> {
val byteLen = readInt32()
if (hadError()) {
return emptyList()
}
if (byteLen < 0) {
return emptyList<T>().apply { withError() }
throw ProtobufDecodingException.negativeSize()
}
val size = size()
// no size check on jvm
if (size != -1L && size < byteLen) {
return emptyList<T>().apply { withError() }
throw ProtobufDecodingException.truncatedMessage()
}
if (byteLen == 0) {
return emptyList() // actually an empty list (no error)
Expand All @@ -38,9 +34,6 @@ internal inline fun <T : Any> WireDecoder.readPackedVarInternal(

while (bytesUntilLimit() > 0) {
val elem = readFn()
if (hadError()) {
break
}
result.add(elem)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,14 @@ internal fun KTag.toRawKTag(): UInt {
return (fieldNr.toUInt() shl KTag.Companion.K_TAG_TYPE_BITS) or wireType.ordinal.toUInt()
}

internal fun KTag.Companion.fromOrNull(rawKTag: UInt): KTag? {
internal fun KTag.Companion.from(rawKTag: UInt): KTag {
val type = (rawKTag and K_TAG_TYPE_MASK).toInt()
val field = (rawKTag shr K_TAG_TYPE_BITS).toInt()
if (!isValidFieldNr(field)) {
return null
throw ProtobufDecodingException("Invalid field number: $field")
}
if (type >= WireType.entries.size) {
return null
throw ProtobufDecodingException("Invalid wire type: $type")
}
return KTag(field, WireType.entries[type])
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
*/

package kotlinx.rpc.grpc.pb

public sealed class ProtobufException : RuntimeException {
protected constructor(message: String, cause: Throwable? = null) : super(message, cause)
}


public class ProtobufDecodingException : ProtobufException {
public constructor(message: String, cause: Throwable? = null) : super(message, cause)

public companion object Companion {
internal fun missingRequiredField(messageName: String, fieldName: String) =
ProtobufDecodingException("Message '$messageName' is missing a required field: $fieldName")

internal fun negativeSize() = ProtobufDecodingException(
"Decoder encountered an embedded string or message which claimed to have negative size."
)

internal fun invalidTag() = ProtobufDecodingException(
"Protocol message contained an invalid tag (zero)."
)

internal fun truncatedMessage() = ProtobufDecodingException(
("While parsing a protocol message, the input ended unexpectedly "
+ "in the middle of a field. This could mean either that the "
+ "input has been truncated or that an embedded message "
+ "misreported its own length.")
)

internal fun genericParsingError() = ProtobufDecodingException("Failed to parse the message.")
}
}

public class ProtobufEncodingException : ProtobufException {
public constructor(message: String, cause: Throwable? = null) : super(message, cause)
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000
* This decoder is used by first calling [readTag], than looking up the field based on the field number in the returned,
* tag and then calling the actual `read*()` method to read the value to the corresponding field.
*
* [hadError] indicates an error during decoding. While calling `read*()` is safe, the returned values
* are meaningless if [hadError] returns `true`.
* All `read*()` methods will throw an exception if the expected value couldn't be decoded.
* Because of optimization reasons, the exception is platform-dependent. To unify them
* wrap the decoding in a [checkForPlatformDecodeException] call, which turn platform-specific exceptions
* into a [ProtobufDecodingException].
*
* NOTE: If the [hadError] after a call to `read*()` returns `false`, it doesn't mean that the
* NOTE: If a call to `read*()` doesn't throw an error, it doesn't mean that the
* value is correctly decoded. E.g., the following test will pass:
* ```kt
* val fieldNr = 1
Expand All @@ -32,17 +34,17 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000
* assertTrue(encoder.writeInt32(fieldNr, 12312))
* encoder.flush()
*
* WireDecoder(buffer).use { decoder ->
* decoder.readTag()
* decoder.readBool()
* assertFalse(decoder.hasError())
* checkForPlatformDecodeException {
* WireDecoder(buffer).use { decoder ->
* decoder.readTag()
* decoder.readBool()
* assertFalse(decoder.hasError())
* }
* }
* ```
*/
@InternalRpcApi
public interface WireDecoder : AutoCloseable {
public fun hadError(): Boolean

/**
* When the read tag is null, it indicates EOF and the parser may stop at this point.
*/
Expand Down Expand Up @@ -79,17 +81,32 @@ public interface WireDecoder : AutoCloseable {
public fun readPackedDouble(): List<Double>
public fun readPackedEnum(): List<Int>

// TODO: Throw error instead of just returning
public fun <T : InternalMessage> readMessage(msg: T, decoder: (T, WireDecoder) -> Unit) {
val len = readInt32()
if (hadError()) return
if (len <= 0) return
if (len < 0) throw ProtobufDecodingException.negativeSize()
val limit = pushLimit(len)
decoder(msg, this)
popLimit(limit)
}

public fun skipValue(writeType: WireType) {
when (writeType) {
WireType.VARINT -> readInt64()
WireType.FIXED32 -> readFixed32()
WireType.FIXED64 -> readFixed64()
WireType.LENGTH_DELIMITED -> readBytes()
WireType.START_GROUP -> throw ProtobufDecodingException("Unexpected START_GROUP wire type (KRPC-193)")
WireType.END_GROUP -> {} // nothing to do
}
}
}

/**
* Turns exceptions thrown by different platforms during decoding into [ProtobufDecodingException].
*/
@InternalRpcApi
public expect inline fun checkForPlatformDecodeException(block: () -> Unit)

/**
* Creates a platform-specific [WireDecoder].
*
Expand All @@ -100,4 +117,5 @@ public interface WireDecoder : AutoCloseable {
*
* @param source The buffer containing the encoded wire-format data.
*/
internal expect fun WireDecoder(source: Buffer): WireDecoder
@InternalRpcApi
public expect fun WireDecoder(source: Buffer): WireDecoder
Original file line number Diff line number Diff line change
Expand Up @@ -10,54 +10,63 @@ import kotlinx.rpc.internal.utils.InternalRpcApi
/**
* A platform-specific class that encodes values into protobuf's wire format.
*
* If one `write*()` method returns false, the encoding of the value failed
* and no further encodings can be performed on this [WireEncoder].
* If one `write*()` method fails to encode the value in the buffer,
* it will throw a platform-specific exception.
*
* Wrap the encoding of a message with [checkForPlatformEncodeException] to
* turn all thrown platform-specific exceptions into [ProtobufEncodingException]s.
*
* [flush] must be called to ensure that all data is written to the [Sink].
*/
@InternalRpcApi
@OptIn(ExperimentalUnsignedTypes::class)
public interface WireEncoder {
public fun flush()
public fun writeBool(fieldNr: Int, value: Boolean): Boolean
public fun writeInt32(fieldNr: Int, value: Int): Boolean
public fun writeInt64(fieldNr: Int, value: Long): Boolean
public fun writeUInt32(fieldNr: Int, value: UInt): Boolean
public fun writeUInt64(fieldNr: Int, value: ULong): Boolean
public fun writeSInt32(fieldNr: Int, value: Int): Boolean
public fun writeSInt64(fieldNr: Int, value: Long): Boolean
public fun writeFixed32(fieldNr: Int, value: UInt): Boolean
public fun writeFixed64(fieldNr: Int, value: ULong): Boolean
public fun writeSFixed32(fieldNr: Int, value: Int): Boolean
public fun writeSFixed64(fieldNr: Int, value: Long): Boolean
public fun writeFloat(fieldNr: Int, value: Float): Boolean
public fun writeDouble(fieldNr: Int, value: Double): Boolean
public fun writeEnum(fieldNr: Int, value: Int): Boolean
public fun writeBytes(fieldNr: Int, value: ByteArray): Boolean
public fun writeString(fieldNr: Int, value: String): Boolean
public fun writePackedBool(fieldNr: Int, value: List<Boolean>, fieldSize: Int): Boolean
public fun writePackedInt32(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean
public fun writePackedInt64(fieldNr: Int, value: List<Long>, fieldSize: Int): Boolean
public fun writePackedUInt32(fieldNr: Int, value: List<UInt>, fieldSize: Int): Boolean
public fun writePackedUInt64(fieldNr: Int, value: List<ULong>, fieldSize: Int): Boolean
public fun writePackedSInt32(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean
public fun writePackedSInt64(fieldNr: Int, value: List<Long>, fieldSize: Int): Boolean
public fun writePackedFixed32(fieldNr: Int, value: List<UInt>): Boolean
public fun writePackedFixed64(fieldNr: Int, value: List<ULong>): Boolean
public fun writePackedSFixed32(fieldNr: Int, value: List<Int>): Boolean
public fun writePackedSFixed64(fieldNr: Int, value: List<Long>): Boolean
public fun writePackedFloat(fieldNr: Int, value: List<Float>): Boolean
public fun writePackedDouble(fieldNr: Int, value: List<Double>): Boolean
public fun writePackedEnum(fieldNr: Int, value: List<Int>, fieldSize: Int): Boolean =
public fun writeBool(fieldNr: Int, value: Boolean)
public fun writeInt32(fieldNr: Int, value: Int)
public fun writeInt64(fieldNr: Int, value: Long)
public fun writeUInt32(fieldNr: Int, value: UInt)
public fun writeUInt64(fieldNr: Int, value: ULong)
public fun writeSInt32(fieldNr: Int, value: Int)
public fun writeSInt64(fieldNr: Int, value: Long)
public fun writeFixed32(fieldNr: Int, value: UInt)
public fun writeFixed64(fieldNr: Int, value: ULong)
public fun writeSFixed32(fieldNr: Int, value: Int)
public fun writeSFixed64(fieldNr: Int, value: Long)
public fun writeFloat(fieldNr: Int, value: Float)
public fun writeDouble(fieldNr: Int, value: Double)
public fun writeEnum(fieldNr: Int, value: Int)
public fun writeBytes(fieldNr: Int, value: ByteArray)
public fun writeString(fieldNr: Int, value: String)
public fun writePackedBool(fieldNr: Int, value: List<Boolean>, fieldSize: Int)
public fun writePackedInt32(fieldNr: Int, value: List<Int>, fieldSize: Int)
public fun writePackedInt64(fieldNr: Int, value: List<Long>, fieldSize: Int)
public fun writePackedUInt32(fieldNr: Int, value: List<UInt>, fieldSize: Int)
public fun writePackedUInt64(fieldNr: Int, value: List<ULong>, fieldSize: Int)
public fun writePackedSInt32(fieldNr: Int, value: List<Int>, fieldSize: Int)
public fun writePackedSInt64(fieldNr: Int, value: List<Long>, fieldSize: Int)
public fun writePackedFixed32(fieldNr: Int, value: List<UInt>)
public fun writePackedFixed64(fieldNr: Int, value: List<ULong>)
public fun writePackedSFixed32(fieldNr: Int, value: List<Int>)
public fun writePackedSFixed64(fieldNr: Int, value: List<Long>)
public fun writePackedFloat(fieldNr: Int, value: List<Float>)
public fun writePackedDouble(fieldNr: Int, value: List<Double>)
public fun writePackedEnum(fieldNr: Int, value: List<Int>, fieldSize: Int): Unit =
writePackedInt32(fieldNr, value, fieldSize)

public fun <T : InternalMessage> writeMessage(
fieldNr: Int,
value: T,
encode: T.(WireEncoder) -> Unit
encode: T.(WireEncoder) -> Unit,
)

}

/**
* Turns exceptions thrown by different platforms during encoding into [ProtobufEncodingException].
*/
@InternalRpcApi
public expect inline fun checkForPlatformEncodeException(block: () -> Unit)

internal expect fun WireEncoder(sink: Sink): WireEncoder
@InternalRpcApi
public expect fun WireEncoder(sink: Sink): WireEncoder
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,24 @@ class ProtosTest {
return codec.decode(source)
}

@Test
fun testUnknownFieldsDontCrash() {
val buffer = Buffer()
val encoder = WireEncoder(buffer)
// optional sint32 sint32 = 7
encoder.writeSInt32(7, 12)
// optional sint64 sint64 = 8; (unknown as wrong wire-type)
encoder.writeFloat(8, 2f)
// optional fixed32 fixed32 = 9;
encoder.writeFixed32(9, 1234u)
encoder.flush()

val decoded = AllPrimitivesInternal.CODEC.decode(buffer)
assertEquals(12, decoded.sint32)
assertNull(decoded.sint64)
assertEquals(1234u, decoded.fixed32)
}

@Test
fun testAllPrimitiveProto() {
val msg = AllPrimitives {
Expand Down Expand Up @@ -86,7 +104,7 @@ class ProtosTest {

@Test
fun testRepeatedWithRequiredSubField() {
assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
RepeatedWithRequired {
// we construct the message using the internal class,
// so it is not invoking the checkRequired method on construction
Expand All @@ -98,7 +116,7 @@ class ProtosTest {
@Test
fun testPresenceCheckProto() {
// Check a missing required field in a user-constructed message
assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
PresenceCheck {}
}

Expand All @@ -108,7 +126,7 @@ class ProtosTest {
encoder.writeFloat(2, 1f)
encoder.flush()

assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
PresenceCheckInternal.CODEC.decode(buffer)
}
}
Expand Down Expand Up @@ -227,7 +245,7 @@ class ProtosTest {

@Test
fun testOneOfRequiredSubField() {
assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
OneOfWithRequired {
// we construct the message using the internal class,
// so it is not invoking the checkRequired method on construction
Expand Down Expand Up @@ -258,7 +276,7 @@ class ProtosTest {

@Test
fun testRecursiveReqNotSet() {
assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
val msg = RecursiveReq {
rec = RecursiveReq {
rec = RecursiveReq {
Expand Down Expand Up @@ -379,7 +397,7 @@ class ProtosTest {
// we use the internal constructor to avoid a "missing required field" error during object construction
val missingRequiredMessage = PresenceCheckInternal()

assertFailsWith<IllegalStateException> {
assertFailsWith<ProtobufDecodingException> {
val msg = TestMap {
messages = mapOf(
2 to missingRequiredMessage
Expand Down
Loading
Loading