From 84eb709b99401a174fef3388f6bff9cced654484 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Thu, 7 Aug 2025 15:40:15 +0200 Subject: [PATCH 1/9] grpc-pb: Skip unknown fields Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 11 +++++++++++ .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 18 ++++++++++++++++++ .../protobuf/ModelToKotlinCommonGenerator.kt | 5 ++++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 168017cc7..7ad0be745 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -88,6 +88,17 @@ public interface WireDecoder : AutoCloseable { decoder(msg, this) popLimit(limit) } + + public fun skipValue(writeType: WireType) { + when (writeType) { + WireType.VARINT -> readInt64() + WireType.FIXED32 -> readFixed32() + WireType.FIXED64 -> readFixed64() + WireType.LENGTH_DELIMITED -> readBytes() + WireType.START_GROUP -> error("Unexpected START_GROUP wire type") + WireType.END_GROUP -> {} // nothing to do + } + } } /** diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 001e0a2a8..905794814 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -34,6 +34,24 @@ class ProtosTest { return codec.decode(source) } + @Test + fun testUnknownFieldsDontCrash() { + val buffer = Buffer() + val encoder = WireEncoder(buffer) + // optional sint32 sint32 = 7 + encoder.writeSInt32(7, 12) + // optional sint64 sint64 = 8; (unknown as wrong wire-type) + encoder.writeFloat(8, 2f) + // optional fixed32 fixed32 = 9; + encoder.writeFixed32(9, 1234u) + encoder.flush() + + val decoded = AllPrimitivesInternal.CODEC.decode(buffer) + assertEquals(12, decoded.sint32) + assertNull(decoded.sint64) + assertEquals(1234u, decoded.fixed32) + } + @Test fun testAllPrimitiveProto() { val msg = AllPrimitives { diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 598701b7e..a1afdf082 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -288,7 +288,10 @@ class ModelToKotlinCommonGenerator( code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { declaration.fields().forEach { (_, field) -> readMatchCase(field) } - whenCase("else") { code("TODO(\"Handle unknown fields: \$tag\")") } + whenCase("else") { + code("// we are currently just skipping unknown fields (KRPC-191)") + code("decoder.skipValue(tag.wireType)") + } } } ifBranch( From cc8505881992f2d690ee3de07b8bf9b0c9d37f2d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Thu, 7 Aug 2025 16:06:48 +0200 Subject: [PATCH 2/9] grpc-pb: Add InvalidProtobufError class Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/GrpcError.kt | 19 +++++++++++++++++++ .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 13 +++++++------ .../protobuf/ModelToKotlinCommonGenerator.kt | 2 +- 3 files changed, 27 insertions(+), 7 deletions(-) create mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt new file mode 100644 index 000000000..ceed22130 --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt @@ -0,0 +1,19 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc + +public sealed class GrpcError : RuntimeException { + protected constructor(message: String) : super(message) + protected constructor(message: String, cause: Throwable) : super(message, cause) +} + + +public class InvalidProtobufError(message: String) : GrpcError(message) { + public companion object { + internal fun missingRequiredField(messageName: String, fieldName: String) = + InvalidProtobufError("Message '$messageName' is missing a required field: $fieldName") + } +} + diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 905794814..cafcf7f45 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -13,6 +13,7 @@ import asInternal import encodeWith import invoke import kotlinx.io.Buffer +import kotlinx.rpc.grpc.InvalidProtobufError import kotlinx.rpc.grpc.codec.MessageCodec import kotlinx.rpc.grpc.test.* import kotlinx.rpc.grpc.test.common.* @@ -104,7 +105,7 @@ class ProtosTest { @Test fun testRepeatedWithRequiredSubField() { - assertFailsWith { + assertFailsWith { RepeatedWithRequired { // we construct the message using the internal class, // so it is not invoking the checkRequired method on construction @@ -116,7 +117,7 @@ class ProtosTest { @Test fun testPresenceCheckProto() { // Check a missing required field in a user-constructed message - assertFailsWith { + assertFailsWith { PresenceCheck {} } @@ -126,7 +127,7 @@ class ProtosTest { encoder.writeFloat(2, 1f) encoder.flush() - assertFailsWith { + assertFailsWith { PresenceCheckInternal.CODEC.decode(buffer) } } @@ -245,7 +246,7 @@ class ProtosTest { @Test fun testOneOfRequiredSubField() { - assertFailsWith { + assertFailsWith { OneOfWithRequired { // we construct the message using the internal class, // so it is not invoking the checkRequired method on construction @@ -276,7 +277,7 @@ class ProtosTest { @Test fun testRecursiveReqNotSet() { - assertFailsWith { + assertFailsWith { val msg = RecursiveReq { rec = RecursiveReq { rec = RecursiveReq { @@ -397,7 +398,7 @@ class ProtosTest { // we use the internal constructor to avoid a "missing required field" error during object construction val missingRequiredMessage = PresenceCheckInternal() - assertFailsWith { + assertFailsWith { val msg = TestMap { messages = mapOf( 2 to missingRequiredMessage diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index a1afdf082..ddb9aa655 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -586,7 +586,7 @@ class ModelToKotlinCommonGenerator( requiredFields.forEach { field -> ifBranch(condition = "!presenceMask[${field.presenceIdx}]", ifBlock = { - code("error(\"${declaration.name.simpleName} is missing required field: ${field.name}\")") + code("throw kotlinx.rpc.grpc.InvalidProtobufError.missingRequiredField(\"${declaration.name.simpleName}\", \"${field.name}\")") }) } From 35b4e944fd25e43d892f62291333d7158a913f2d Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 10:21:13 +0200 Subject: [PATCH 3/9] grpc-pb: Throw error in decoder instead of hadError() check Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/GrpcError.kt | 19 ------ .../kotlin/kotlinx/rpc/grpc/GrpcException.kt | 40 +++++++++++++ .../kotlinx/rpc/grpc/internal/readPacked.kt | 13 +---- .../kotlin/kotlinx/rpc/grpc/pb/KTag.kt | 6 +- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 2 +- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 14 ++--- .../kotlinx/rpc/grpc/pb/WireCodecTest.kt | 12 ++++ .../kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt | 58 ++++++++++++------- .../kotlinx/rpc/grpc/pb/WireDecoder.native.kt | 44 ++++++-------- grpc/grpcpp-c/include/protowire.h | 1 + grpc/grpcpp-c/src/protowire.cpp | 4 ++ .../protobuf/ModelToKotlinCommonGenerator.kt | 2 +- 12 files changed, 127 insertions(+), 88 deletions(-) delete mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt create mode 100644 grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt deleted file mode 100644 index ceed22130..000000000 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcError.kt +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. - */ - -package kotlinx.rpc.grpc - -public sealed class GrpcError : RuntimeException { - protected constructor(message: String) : super(message) - protected constructor(message: String, cause: Throwable) : super(message, cause) -} - - -public class InvalidProtobufError(message: String) : GrpcError(message) { - public companion object { - internal fun missingRequiredField(messageName: String, fieldName: String) = - InvalidProtobufError("Message '$messageName' is missing a required field: $fieldName") - } -} - diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt new file mode 100644 index 000000000..a3fabfc5b --- /dev/null +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt @@ -0,0 +1,40 @@ +/* + * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. + */ + +package kotlinx.rpc.grpc + +public sealed class GrpcException : RuntimeException { + protected constructor(message: String, cause: Throwable? = null) : super(message, cause) +} + + +public class ProtobufDecodingException : GrpcException { + internal constructor(message: String, cause: Throwable? = null) : super(message, cause) + + public companion object Companion { + internal fun missingRequiredField(messageName: String, fieldName: String) = + ProtobufDecodingException("Message '$messageName' is missing a required field: $fieldName") + + internal fun negativeSize() = ProtobufDecodingException( + "CodedInputStream encountered an embedded string or message which claimed to have negative size." + ) + + internal fun invalidTag() = ProtobufDecodingException( + "Protocol message contained an invalid tag (zero)." + ) + + internal fun truncatedMessage() = ProtobufDecodingException( + ("While parsing a protocol message, the input ended unexpectedly " + + "in the middle of a field. This could mean either that the " + + "input has been truncated or that an embedded message " + + "misreported its own length.") + ) + + internal fun genericParsingError() = ProtobufDecodingException("Failed to parse the message.") + } +} + +public class ProtobufEncodingException : GrpcException { + internal constructor(message: String, cause: Throwable? = null) : super(message, cause) +} \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt index a7c62e0a1..e43f6b20a 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt @@ -4,6 +4,7 @@ package kotlinx.rpc.grpc.internal +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.pb.WireDecoder internal expect fun WireDecoder.pushLimit(byteLen: Int): Int @@ -13,20 +14,15 @@ internal expect fun WireDecoder.bytesUntilLimit(): Int internal inline fun WireDecoder.readPackedVarInternal( crossinline size: () -> Long, crossinline readFn: () -> T, - crossinline withError: () -> Unit, - crossinline hadError: () -> Boolean, ): List { val byteLen = readInt32() - if (hadError()) { - return emptyList() - } if (byteLen < 0) { - return emptyList().apply { withError() } + throw ProtobufDecodingException.negativeSize() } val size = size() // no size check on jvm if (size != -1L && size < byteLen) { - return emptyList().apply { withError() } + throw ProtobufDecodingException.truncatedMessage() } if (byteLen == 0) { return emptyList() // actually an empty list (no error) @@ -38,9 +34,6 @@ internal inline fun WireDecoder.readPackedVarInternal( while (bytesUntilLimit() > 0) { val elem = readFn() - if (hadError()) { - break - } result.add(elem) } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt index d60f0f6a8..cc9134288 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt @@ -36,14 +36,14 @@ internal fun KTag.toRawKTag(): UInt { return (fieldNr.toUInt() shl KTag.Companion.K_TAG_TYPE_BITS) or wireType.ordinal.toUInt() } -internal fun KTag.Companion.fromOrNull(rawKTag: UInt): KTag? { +internal fun KTag.Companion.from(rawKTag: UInt): KTag { val type = (rawKTag and K_TAG_TYPE_MASK).toInt() val field = (rawKTag shr K_TAG_TYPE_BITS).toInt() if (!isValidFieldNr(field)) { - return null + error("Invalid field number: $field") } if (type >= WireType.entries.size) { - return null + error("Invalid wire type: $type") } return KTag(field, WireType.entries[type]) } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 7ad0be745..b8b12e296 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -95,7 +95,7 @@ public interface WireDecoder : AutoCloseable { WireType.FIXED32 -> readFixed32() WireType.FIXED64 -> readFixed64() WireType.LENGTH_DELIMITED -> readBytes() - WireType.START_GROUP -> error("Unexpected START_GROUP wire type") + WireType.START_GROUP -> error("Unexpected START_GROUP wire type (KRPC-193)") WireType.END_GROUP -> {} // nothing to do } } diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index cafcf7f45..9032f9f44 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -13,7 +13,7 @@ import asInternal import encodeWith import invoke import kotlinx.io.Buffer -import kotlinx.rpc.grpc.InvalidProtobufError +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.codec.MessageCodec import kotlinx.rpc.grpc.test.* import kotlinx.rpc.grpc.test.common.* @@ -105,7 +105,7 @@ class ProtosTest { @Test fun testRepeatedWithRequiredSubField() { - assertFailsWith { + assertFailsWith { RepeatedWithRequired { // we construct the message using the internal class, // so it is not invoking the checkRequired method on construction @@ -117,7 +117,7 @@ class ProtosTest { @Test fun testPresenceCheckProto() { // Check a missing required field in a user-constructed message - assertFailsWith { + assertFailsWith { PresenceCheck {} } @@ -127,7 +127,7 @@ class ProtosTest { encoder.writeFloat(2, 1f) encoder.flush() - assertFailsWith { + assertFailsWith { PresenceCheckInternal.CODEC.decode(buffer) } } @@ -246,7 +246,7 @@ class ProtosTest { @Test fun testOneOfRequiredSubField() { - assertFailsWith { + assertFailsWith { OneOfWithRequired { // we construct the message using the internal class, // so it is not invoking the checkRequired method on construction @@ -277,7 +277,7 @@ class ProtosTest { @Test fun testRecursiveReqNotSet() { - assertFailsWith { + assertFailsWith { val msg = RecursiveReq { rec = RecursiveReq { rec = RecursiveReq { @@ -398,7 +398,7 @@ class ProtosTest { // we use the internal constructor to avoid a "missing required field" error during object construction val missingRequiredMessage = PresenceCheckInternal() - assertFailsWith { + assertFailsWith { val msg = TestMap { messages = mapOf( 2 to missingRequiredMessage diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt index 16f56dea7..194fcada4 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt @@ -5,6 +5,7 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlin.test.* enum class TestPlatform { @@ -808,4 +809,15 @@ class WireCodecTest { WireDecoder::readPackedEnum ) + + @Test + fun testInvalidTag() { + val buffer = Buffer() + buffer.writeByte(0) + + assertFailsWith { + WireDecoder(buffer).readTag() + } + } + } diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt index 1d27ead8c..8c1ebae1d 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt @@ -5,8 +5,10 @@ package kotlinx.rpc.grpc.pb import com.google.protobuf.CodedInputStream +import com.google.protobuf.InvalidProtocolBufferException import kotlinx.io.Buffer import kotlinx.io.asInputStream +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.readPackedVarInternal internal class WireDecoderJvm(source: Buffer) : WireDecoder { @@ -18,80 +20,79 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { return false } - override fun readTag(): KTag? { + override fun readTag(): KTag? = checked { val tag = codedInputStream.readTag().toUInt() if (tag == 0u) { return null } - - return KTag.fromOrNull(tag) + return KTag.from(tag) } - override fun readBool(): Boolean { + override fun readBool(): Boolean = checked { return codedInputStream.readBool() } - override fun readInt32(): Int { + override fun readInt32(): Int = checked { return codedInputStream.readInt32() } - override fun readInt64(): Long { + override fun readInt64(): Long = checked { return codedInputStream.readInt64() } - override fun readUInt32(): UInt { + override fun readUInt32(): UInt = checked { // todo check java unsigned types return codedInputStream.readUInt32().toUInt() } - override fun readUInt64(): ULong { + override fun readUInt64(): ULong = checked { // todo check java unsigned types return codedInputStream.readUInt64().toULong() } - override fun readSInt32(): Int { + override fun readSInt32(): Int = checked { return codedInputStream.readSInt32() } - override fun readSInt64(): Long { + override fun readSInt64(): Long = checked { return codedInputStream.readSInt64() } - override fun readFixed32(): UInt { + override fun readFixed32(): UInt = checked { // todo check java unsigned types return codedInputStream.readFixed32().toUInt() } - override fun readFixed64(): ULong { + override fun readFixed64(): ULong = checked { // todo check java unsigned types return codedInputStream.readFixed64().toULong() } - override fun readSFixed32(): Int { + override fun readSFixed32(): Int = checked { return codedInputStream.readSFixed32() } - override fun readSFixed64(): Long { + override fun readSFixed64(): Long = checked { return codedInputStream.readSFixed64() } - override fun readFloat(): Float { + override fun readFloat(): Float = checked { return codedInputStream.readFloat() } - override fun readDouble(): Double { + override fun readDouble(): Double = checked { return codedInputStream.readDouble() } - override fun readEnum(): Int { + override fun readEnum(): Int = checked { return codedInputStream.readEnum() } - override fun readString(): String { + override fun readString(): String = checked { return codedInputStream.readStringRequireUtf8() } - override fun readBytes(): ByteArray { + override fun readBytes(): ByteArray = checked { return codedInputStream.readByteArray() } @@ -114,12 +115,25 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { private fun readPackedInternal(read: () -> T) = readPackedVarInternal( size = { -1 }, - readFn = read, - withError = { }, - hadError = { false }, + readFn = read ) } internal actual fun WireDecoder(source: Buffer): WireDecoder { return WireDecoderJvm(source) } + +/** + * Turns a [InvalidProtocolBufferException] into our own [ProtobufDecodingException]. + */ +private inline fun checked(block: () -> T): T { + try { + return block() + } catch (e: InvalidProtocolBufferException) { + throw e.toDecodingException() + } +} + +private fun InvalidProtocolBufferException.toDecodingException(): ProtobufDecodingException { + return ProtobufDecodingException(message ?: "Failed to decode protobuf message.", cause) +} diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt index c5b22c918..e24729aee 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt @@ -7,6 +7,7 @@ package kotlinx.rpc.grpc.pb import kotlinx.cinterop.* import kotlinx.collections.immutable.persistentListOf import kotlinx.io.Buffer +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.ZeroCopyInputSource import kotlinx.rpc.grpc.internal.readPackedVarInternal import libprotowire.* @@ -17,8 +18,6 @@ import kotlin.native.ref.createCleaner @OptIn(ExperimentalForeignApi::class, ExperimentalNativeApi::class) internal class WireDecoderNative(private val source: Buffer) : WireDecoder { - private var hadError = false; - // wraps the source in a class that allows to pass data from the source buffer to the C++ encoder // without copying it to an intermediate byte array. private val zeroCopyInput = StableRef.create(ZeroCopyInputSource(source)) @@ -62,12 +61,18 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { } override fun hadError(): Boolean { - return hadError; + return false } override fun readTag(): KTag? { val tag = pw_decoder_read_tag(raw) - return KTag.fromOrNull(tag) + if (tag == 0u) { + if (!pw_decoder_consumed_entire_msg(raw)) { + throw ProtobufDecodingException.invalidTag() + } + return null + } + return KTag.from(tag) } override fun readBool(): Boolean = memScoped { @@ -159,8 +164,8 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { val str = alloc>() pw_decoder_read_string(raw, str.ptr).checkError() try { - if (hadError) return "" - return pw_string_c_str(str.value)?.toKString() ?: "".also { hadError = true } + return pw_string_c_str(str.value)?.toKString() + ?: throw ProtobufDecodingException.genericParsingError() } finally { pw_string_delete(str.value) } @@ -169,17 +174,15 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { // TODO: Should readBytes return a buffer, to prevent allocation of large contiguous memory blocks ? KRPC-182 override fun readBytes(): ByteArray { val length = readInt32() - if (hadError) return ByteArray(0) - if (length < 0) return ByteArray(0).withError() + if (length < 0) throw ProtobufDecodingException.negativeSize() // check if the remaining buffer size is less than the set length, // we can early abort, without allocating unnecessary memory - if (source.size < length) return ByteArray(0).withError() + if (source.size < length) throw ProtobufDecodingException.truncatedMessage() if (length == 0) return ByteArray(0) // actually an empty array (no error) val bytes = ByteArray(length) bytes.usePinned { pw_decoder_read_raw_bytes(raw, it.addressOf(0), length).checkError() } - if (hadError) return ByteArray(0) return bytes } @@ -236,9 +239,7 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { private fun readPackedVarInternal(read: () -> T) = readPackedVarInternal( size = { source.size }, - readFn = read, - withError = { hadError = true }, - hadError = { hadError }, + readFn = read ) /* @@ -254,14 +255,13 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { sizeBytes: Int, crossinline createArray: (Int) -> R, crossinline getAddress: Pinned.(Int) -> COpaquePointer, - crossinline asList: (R) -> List + crossinline asList: (R) -> List, ): List { // fetch the size of the packed repeated field var byteLen = readInt32() - if (hadError) return emptyList() - if (byteLen < 0) return emptyList().withError() - if (source.size < byteLen) return emptyList().withError() - if (byteLen % sizeBytes != 0) return emptyList().withError() + if (byteLen < 0) throw ProtobufDecodingException.negativeSize() + if (source.size < byteLen) throw ProtobufDecodingException.truncatedMessage() + if (byteLen % sizeBytes != 0) throw ProtobufDecodingException.truncatedMessage() if (byteLen == 0) return emptyList() // actually an empty list (no error) // allocate the buffer array (has at most MAX_PACKED_BULK_SIZE bytes) @@ -284,7 +284,6 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { // copy data into the buffer. val copySize = min(bufByteLen, byteLen) pw_decoder_read_raw_bytes(raw, bufAddr, copySize).checkError() - if (hadError) return emptyList() // add buffer to the chunked list chunkedList = if (copySize == bufByteLen) { @@ -302,12 +301,7 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { } private fun Boolean.checkError() { - hadError = !this || hadError; - } - - private fun T.withError(): T { - hadError = true - return this + if (!this) throw ProtobufDecodingException.genericParsingError() } } diff --git a/grpc/grpcpp-c/include/protowire.h b/grpc/grpcpp-c/include/protowire.h index e53d3f41a..87435e488 100644 --- a/grpc/grpcpp-c/include/protowire.h +++ b/grpc/grpcpp-c/include/protowire.h @@ -95,6 +95,7 @@ extern "C" { void pw_decoder_close(pw_decoder_t *self); uint32_t pw_decoder_read_tag(pw_decoder_t *self); + bool pw_decoder_consumed_entire_msg(pw_decoder_t *self); bool pw_decoder_read_bool(pw_decoder_t *self, bool *value); bool pw_decoder_read_int32(pw_decoder_t *self, int32_t *value); bool pw_decoder_read_int64(pw_decoder_t *self, int64_t *value); diff --git a/grpc/grpcpp-c/src/protowire.cpp b/grpc/grpcpp-c/src/protowire.cpp index 50700b47d..1a9f35ffa 100644 --- a/grpc/grpcpp-c/src/protowire.cpp +++ b/grpc/grpcpp-c/src/protowire.cpp @@ -230,6 +230,10 @@ extern "C" { return self->codedInputStream.ReadTag(); } + bool pw_decoder_consumed_entire_msg(pw_decoder_t *self) { + return self->codedInputStream.ConsumedEntireMessage(); + } + #define READ_VAL_FUNC( funcSuffix, wireTy, cTy) \ bool pw_decoder_read_##funcSuffix(pw_decoder_t *self, cTy *value_ref) { \ return WireFormatLite::ReadPrimitive(&self->codedInputStream, value_ref); \ diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index ddb9aa655..b1e0744f3 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -586,7 +586,7 @@ class ModelToKotlinCommonGenerator( requiredFields.forEach { field -> ifBranch(condition = "!presenceMask[${field.presenceIdx}]", ifBlock = { - code("throw kotlinx.rpc.grpc.InvalidProtobufError.missingRequiredField(\"${declaration.name.simpleName}\", \"${field.name}\")") + code("throw kotlinx.rpc.grpc.ProtobufDecodingException.missingRequiredField(\"${declaration.name.simpleName}\", \"${field.name}\")") }) } From b7113d4f6608e9fe42e466b6db25aaa2f8969e86 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 11:15:11 +0200 Subject: [PATCH 4/9] grpc-pb: Fix string encoding bug Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index d66cef946..6e3c56d0e 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -112,7 +112,8 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { return pw_encoder_write_string(raw, fieldNr, null, 0) } val cStr = value.cstr - return pw_encoder_write_string(raw, fieldNr, cStr.ptr, cStr.size) + val len = cStr.size - 1 // minus 1 as it also counts the null terminator + return pw_encoder_write_string(raw, fieldNr, cStr.ptr, len) } override fun writeBytes(fieldNr: Int, value: ByteArray): Boolean { @@ -166,7 +167,7 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { override fun writeMessage( fieldNr: Int, value: T, - encode: T.(WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit, ) { pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) pw_encoder_write_int32_no_tag(raw, value._size) @@ -184,7 +185,7 @@ private inline fun WireEncoderNative.writePackedInternal( fieldNr: Int, value: List, fieldSize: Int, - crossinline writer: (CValuesRef?, T) -> Boolean + crossinline writer: (CValuesRef?, T) -> Boolean, ): Boolean { pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) // write the field size of the packed field From 634774f3a2f374905009c27caa17bb7f943aeea7 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 11:21:38 +0200 Subject: [PATCH 5/9] grpc-pb: Remove hadError() method Signed-off-by: Johannes Zottele --- .../commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt | 2 +- .../commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 7 ++----- .../commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt | 1 - .../jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt | 5 ----- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt | 4 ---- .../kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt | 6 +----- 6 files changed, 4 insertions(+), 21 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt index a3fabfc5b..19f5b8a9c 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt @@ -17,7 +17,7 @@ public class ProtobufDecodingException : GrpcException { ProtobufDecodingException("Message '$messageName' is missing a required field: $fieldName") internal fun negativeSize() = ProtobufDecodingException( - "CodedInputStream encountered an embedded string or message which claimed to have negative size." + "Decoder encountered an embedded string or message which claimed to have negative size." ) internal fun invalidTag() = ProtobufDecodingException( diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index b8b12e296..44fe5b45c 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -5,6 +5,7 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer +import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.popLimit import kotlinx.rpc.grpc.internal.pushLimit import kotlinx.rpc.internal.utils.InternalRpcApi @@ -41,8 +42,6 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000 */ @InternalRpcApi public interface WireDecoder : AutoCloseable { - public fun hadError(): Boolean - /** * When the read tag is null, it indicates EOF and the parser may stop at this point. */ @@ -79,11 +78,9 @@ public interface WireDecoder : AutoCloseable { public fun readPackedDouble(): List public fun readPackedEnum(): List - // TODO: Throw error instead of just returning public fun readMessage(msg: T, decoder: (T, WireDecoder) -> Unit) { val len = readInt32() - if (hadError()) return - if (len <= 0) return + if (len < 0) throw ProtobufDecodingException.negativeSize() val limit = pushLimit(len) decoder(msg, this) popLimit(limit) diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt index 194fcada4..b3d5b9991 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt @@ -28,7 +28,6 @@ class WireCodecTest { val decoder = WireDecoder(buffer) val tag = decoder.readTag() - assertFalse(decoder.hadError()) assertNotNull(tag) assertEquals(WireType.VARINT, tag.wireType) assertEquals(fieldNr, tag.fieldNr) diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt index 8c1ebae1d..23d5f0653 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt @@ -15,11 +15,6 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { // there is no way to omit coping here internal val codedInputStream: CodedInputStream = CodedInputStream.newInstance(source.asInputStream()) - // errors in jvm are exceptions - override fun hadError(): Boolean { - return false - } - override fun readTag(): KTag? = checked { val tag = codedInputStream.readTag().toUInt() if (tag == 0u) { diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt index e24729aee..432cd3d76 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt @@ -60,10 +60,6 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { zeroCopyInput.dispose() } - override fun hadError(): Boolean { - return false - } - override fun readTag(): KTag? { val tag = pw_decoder_read_tag(raw) if (tag == 0u) { diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index b1e0744f3..08e06f44d 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -284,7 +284,7 @@ class ModelToKotlinCommonGenerator( annotations = listOf("@$INTERNAL_RPC_API_ANNO"), contextReceiver = "${declaration.internalClassFullName()}.Companion" ) { - whileBlock("!decoder.hadError()") { + whileBlock("true") { code("val tag = decoder.readTag() ?: break // EOF, we read the whole message") whenBlock { declaration.fields().forEach { (_, field) -> readMatchCase(field) } @@ -294,10 +294,6 @@ class ModelToKotlinCommonGenerator( } } } - ifBranch( - condition = "decoder.hadError()", - ifBlock = { code("error(\"Error during decoding of ${declaration.name.simpleName}\")") } - ) // TODO: Make lists and maps immutable (KRPC-190) } From 7c1f4c0fdaf326ea8b7a10f9bc6911d515e8dc78 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 11:42:12 +0200 Subject: [PATCH 6/9] grpc-pb: Throw exception instead of Boolean when encoding value Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 62 +++++------ .../kotlinx/rpc/grpc/pb/WireCodecTest.kt | 68 ++++++------ .../kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 105 ++++++++---------- .../kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 92 ++++++++------- 4 files changed, 161 insertions(+), 166 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index d5fd1374c..6a7db1507 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -19,42 +19,42 @@ import kotlinx.rpc.internal.utils.InternalRpcApi @OptIn(ExperimentalUnsignedTypes::class) public interface WireEncoder { public fun flush() - public fun writeBool(fieldNr: Int, value: Boolean): Boolean - public fun writeInt32(fieldNr: Int, value: Int): Boolean - public fun writeInt64(fieldNr: Int, value: Long): Boolean - public fun writeUInt32(fieldNr: Int, value: UInt): Boolean - public fun writeUInt64(fieldNr: Int, value: ULong): Boolean - public fun writeSInt32(fieldNr: Int, value: Int): Boolean - public fun writeSInt64(fieldNr: Int, value: Long): Boolean - public fun writeFixed32(fieldNr: Int, value: UInt): Boolean - public fun writeFixed64(fieldNr: Int, value: ULong): Boolean - public fun writeSFixed32(fieldNr: Int, value: Int): Boolean - public fun writeSFixed64(fieldNr: Int, value: Long): Boolean - public fun writeFloat(fieldNr: Int, value: Float): Boolean - public fun writeDouble(fieldNr: Int, value: Double): Boolean - public fun writeEnum(fieldNr: Int, value: Int): Boolean - public fun writeBytes(fieldNr: Int, value: ByteArray): Boolean - public fun writeString(fieldNr: Int, value: String): Boolean - public fun writePackedBool(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedUInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedUInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedSInt32(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedSInt64(fieldNr: Int, value: List, fieldSize: Int): Boolean - public fun writePackedFixed32(fieldNr: Int, value: List): Boolean - public fun writePackedFixed64(fieldNr: Int, value: List): Boolean - public fun writePackedSFixed32(fieldNr: Int, value: List): Boolean - public fun writePackedSFixed64(fieldNr: Int, value: List): Boolean - public fun writePackedFloat(fieldNr: Int, value: List): Boolean - public fun writePackedDouble(fieldNr: Int, value: List): Boolean - public fun writePackedEnum(fieldNr: Int, value: List, fieldSize: Int): Boolean = + public fun writeBool(fieldNr: Int, value: Boolean) + public fun writeInt32(fieldNr: Int, value: Int) + public fun writeInt64(fieldNr: Int, value: Long) + public fun writeUInt32(fieldNr: Int, value: UInt) + public fun writeUInt64(fieldNr: Int, value: ULong) + public fun writeSInt32(fieldNr: Int, value: Int) + public fun writeSInt64(fieldNr: Int, value: Long) + public fun writeFixed32(fieldNr: Int, value: UInt) + public fun writeFixed64(fieldNr: Int, value: ULong) + public fun writeSFixed32(fieldNr: Int, value: Int) + public fun writeSFixed64(fieldNr: Int, value: Long) + public fun writeFloat(fieldNr: Int, value: Float) + public fun writeDouble(fieldNr: Int, value: Double) + public fun writeEnum(fieldNr: Int, value: Int) + public fun writeBytes(fieldNr: Int, value: ByteArray) + public fun writeString(fieldNr: Int, value: String) + public fun writePackedBool(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedInt32(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedInt64(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedUInt32(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedUInt64(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedSInt32(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedSInt64(fieldNr: Int, value: List, fieldSize: Int) + public fun writePackedFixed32(fieldNr: Int, value: List) + public fun writePackedFixed64(fieldNr: Int, value: List) + public fun writePackedSFixed32(fieldNr: Int, value: List) + public fun writePackedSFixed64(fieldNr: Int, value: List) + public fun writePackedFloat(fieldNr: Int, value: List) + public fun writePackedDouble(fieldNr: Int, value: List) + public fun writePackedEnum(fieldNr: Int, value: List, fieldSize: Int): Unit = writePackedInt32(fieldNr, value, fieldSize) public fun writeMessage( fieldNr: Int, value: T, - encode: T.(WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit, ) } diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt index b3d5b9991..5170167a3 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt @@ -22,7 +22,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeBool(fieldNr, true)) + encoder.writeBool(fieldNr, true) encoder.flush() val decoder = WireDecoder(buffer) @@ -47,7 +47,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeInt32(fieldNr, testValue)) + encoder.writeInt32(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -72,7 +72,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeInt64(fieldNr, testValue)) + encoder.writeInt64(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -97,7 +97,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeUInt32(fieldNr, testValue)) + encoder.writeUInt32(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -122,7 +122,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeUInt64(fieldNr, testValue)) + encoder.writeUInt64(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -147,7 +147,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeSInt32(fieldNr, testValue)) + encoder.writeSInt32(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -172,7 +172,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeSInt64(fieldNr, testValue)) + encoder.writeSInt64(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -197,7 +197,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeFixed32(fieldNr, testValue)) + encoder.writeFixed32(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -222,7 +222,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeFixed64(fieldNr, testValue)) + encoder.writeFixed64(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -247,7 +247,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeSFixed32(fieldNr, testValue)) + encoder.writeSFixed32(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -272,7 +272,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeSFixed64(fieldNr, testValue)) + encoder.writeSFixed64(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -297,7 +297,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeEnum(fieldNr, testValue)) + encoder.writeEnum(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -322,7 +322,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeString(fieldNr, testValue)) + encoder.writeString(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -369,10 +369,10 @@ class WireCodecTest { val encoder = WireEncoder(buffer) // Write multiple fields of different types - assertTrue(encoder.writeBool(1, true)) - assertTrue(encoder.writeInt32(2, 42)) - assertTrue(encoder.writeString(3, "Hello")) - assertTrue(encoder.writeFixed64(4, 123456789uL)) + encoder.writeBool(1, true) + encoder.writeInt32(2, 42) + encoder.writeString(3, "Hello") + encoder.writeFixed64(4, 123456789uL) encoder.flush() val decoder = WireDecoder(buffer) @@ -428,7 +428,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeBool(fieldNr, true)) + encoder.writeBool(fieldNr, true) encoder.flush() val decoder = WireDecoder(buffer) @@ -448,11 +448,11 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeBool(1, true)) + encoder.writeBool(1, true) encoder.flush() // Writing after flush should still work - assertTrue(encoder.writeInt32(2, 42)) + encoder.writeInt32(2, 42) encoder.flush() val decoder = WireDecoder(buffer) @@ -483,7 +483,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeString(fieldNr, testValue)) + encoder.writeString(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -512,8 +512,8 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeBool(fieldNr, true)) - assertTrue(encoder.writeBool(fieldNr + 1, true)) + encoder.writeBool(fieldNr, true) + encoder.writeBool(fieldNr + 1, true) encoder.flush() WireDecoder(buffer).use { decoder -> @@ -533,8 +533,8 @@ class WireCodecTest { val field2Str = "b".repeat(1000000) val encoder = WireEncoder(buffer) - assertTrue(encoder.writeString(field1Nr, field1Str)) - assertTrue(encoder.writeString(field2Nr, field2Str)) + encoder.writeString(field1Nr, field1Str) + encoder.writeString(field2Nr, field2Str) encoder.flush() WireDecoder(buffer).use { decoder -> @@ -557,7 +557,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeString(1, "")) + encoder.writeString(1, "") encoder.flush() val decoder = WireDecoder(buffer) @@ -577,7 +577,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeBytes(1, ByteArray(0))) + encoder.writeBytes(1, ByteArray(0)) encoder.flush() val decoder = WireDecoder(buffer) @@ -599,7 +599,7 @@ class WireCodecTest { val bytes = ByteArray(1000000) { it.toByte() } - assertTrue(encoder.writeBytes(1, bytes)) + encoder.writeBytes(1, bytes) encoder.flush() val decoder = WireDecoder(buffer) @@ -624,7 +624,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeDouble(fieldNr, testValue)) + encoder.writeDouble(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -649,7 +649,7 @@ class WireCodecTest { val buffer = Buffer() val encoder = WireEncoder(buffer) - assertTrue(encoder.writeFloat(fieldNr, testValue)) + encoder.writeFloat(fieldNr, testValue) encoder.flush() val decoder = WireDecoder(buffer) @@ -669,12 +669,12 @@ class WireCodecTest { private fun runPackedFixedTest( list: List, - write: WireEncoder.(Int, List) -> Boolean, + write: WireEncoder.(Int, List) -> Unit, read: WireDecoder.() -> List?, ) { val buf = Buffer() with(WireEncoder(buf)) { - assertTrue(write(1, list)) + write(1, list) flush() } WireDecoder(buf).use { dec -> @@ -733,12 +733,12 @@ class WireCodecTest { private fun runPackedVarTest( list: List, sizeFn: (List) -> Int, - write: WireEncoder.(Int, List, Int) -> Boolean, + write: WireEncoder.(Int, List, Int) -> Unit, read: WireDecoder.() -> List?, ) { val buf = Buffer() with(WireEncoder(buf)) { - assertTrue(write(1, list, sizeFn(list))) + write(1, list, sizeFn(list)) flush() } WireDecoder(buf).use { dec -> diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 3c40a2263..041a467b0 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -5,8 +5,10 @@ package kotlinx.rpc.grpc.pb import com.google.protobuf.CodedOutputStream +import kotlinx.io.IOException import kotlinx.io.Sink import kotlinx.io.asOutputStream +import kotlinx.rpc.grpc.ProtobufEncodingException private class WireEncoderJvm(sink: Sink) : WireEncoder { private val codedOutputStream = CodedOutputStream.newInstance(sink.asOutputStream()) @@ -15,195 +17,166 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { codedOutputStream.flush() } - override fun writeBool(fieldNr: Int, value: Boolean): Boolean { + override fun writeBool(fieldNr: Int, value: Boolean) = checked { codedOutputStream.writeBool(fieldNr, value) - return true } - override fun writeInt32(fieldNr: Int, value: Int): Boolean { + override fun writeInt32(fieldNr: Int, value: Int) = checked { codedOutputStream.writeInt32(fieldNr, value) - return true } - override fun writeInt64(fieldNr: Int, value: Long): Boolean { + override fun writeInt64(fieldNr: Int, value: Long) = checked { codedOutputStream.writeInt64(fieldNr, value) - return true } - override fun writeUInt32(fieldNr: Int, value: UInt): Boolean { + override fun writeUInt32(fieldNr: Int, value: UInt) = checked { // todo check java unsigned types codedOutputStream.writeUInt32(fieldNr, value.toInt()) - return true } - override fun writeUInt64(fieldNr: Int, value: ULong): Boolean { + override fun writeUInt64(fieldNr: Int, value: ULong) = checked { // todo check java unsigned types codedOutputStream.writeUInt64(fieldNr, value.toLong()) - return true } - override fun writeSInt32(fieldNr: Int, value: Int): Boolean { + override fun writeSInt32(fieldNr: Int, value: Int) = checked { codedOutputStream.writeSInt32(fieldNr, value) - return true } - override fun writeSInt64(fieldNr: Int, value: Long): Boolean { + override fun writeSInt64(fieldNr: Int, value: Long) = checked { codedOutputStream.writeSInt64(fieldNr, value) - return true } - override fun writeFixed32(fieldNr: Int, value: UInt): Boolean { + override fun writeFixed32(fieldNr: Int, value: UInt) = checked { // todo check java unsigned types codedOutputStream.writeFixed32(fieldNr, value.toInt()) - return true } - override fun writeFixed64(fieldNr: Int, value: ULong): Boolean { + override fun writeFixed64(fieldNr: Int, value: ULong) = checked { // todo check java unsigned types codedOutputStream.writeFixed64(fieldNr, value.toLong()) - return true } - override fun writeSFixed32(fieldNr: Int, value: Int): Boolean { + override fun writeSFixed32(fieldNr: Int, value: Int) = checked { codedOutputStream.writeSFixed32(fieldNr, value) - return true } - override fun writeSFixed64(fieldNr: Int, value: Long): Boolean { + override fun writeSFixed64(fieldNr: Int, value: Long) = checked { codedOutputStream.writeSFixed64(fieldNr, value) - return true } - override fun writeFloat(fieldNr: Int, value: Float): Boolean { + override fun writeFloat(fieldNr: Int, value: Float) = checked { codedOutputStream.writeFloat(fieldNr, value) - return true } - override fun writeDouble(fieldNr: Int, value: Double): Boolean { + override fun writeDouble(fieldNr: Int, value: Double) = checked { codedOutputStream.writeDouble(fieldNr, value) - return true } - override fun writeEnum(fieldNr: Int, value: Int): Boolean { + override fun writeEnum(fieldNr: Int, value: Int) = checked { codedOutputStream.writeEnum(fieldNr, value) - return true } - override fun writeBytes(fieldNr: Int, value: ByteArray): Boolean { + override fun writeBytes(fieldNr: Int, value: ByteArray) = checked { codedOutputStream.writeByteArray(fieldNr, value) - return true } - override fun writeString(fieldNr: Int, value: String): Boolean { + override fun writeString(fieldNr: Int, value: String) = checked { codedOutputStream.writeString(fieldNr, value) - return true } override fun writePackedBool( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeBoolNoTag) - return true } override fun writePackedInt32( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeInt32NoTag) - return true } override fun writePackedInt64( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeInt64NoTag) - return true } override fun writePackedUInt32( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize) { codedOutputStream, v -> codedOutputStream.writeUInt32NoTag(v.toInt()) } - return true } override fun writePackedUInt64( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize) { codedOutputStream, v -> codedOutputStream.writeUInt64NoTag(v.toLong()) } - return true } override fun writePackedSInt32( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeSInt32NoTag) - return true } override fun writePackedSInt64( fieldNr: Int, value: List, fieldSize: Int, - ): Boolean { + ) = checked { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeSInt64NoTag) - return true } - override fun writePackedFixed32(fieldNr: Int, value: List): Boolean { + override fun writePackedFixed32(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * UInt.SIZE_BYTES) { codedOutputStream, v -> codedOutputStream.writeFixed32NoTag(v.toInt()) } - return true } - override fun writePackedFixed64(fieldNr: Int, value: List): Boolean { + override fun writePackedFixed64(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * ULong.SIZE_BYTES) { codedOutputStream, v -> codedOutputStream.writeFixed64NoTag(v.toLong()) } - return true } - override fun writePackedSFixed32(fieldNr: Int, value: List): Boolean { + override fun writePackedSFixed32(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * Int.SIZE_BYTES, CodedOutputStream::writeSFixed32NoTag) - return true } - override fun writePackedSFixed64(fieldNr: Int, value: List): Boolean { + override fun writePackedSFixed64(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * Long.SIZE_BYTES, CodedOutputStream::writeSFixed64NoTag) - return true } - override fun writePackedFloat(fieldNr: Int, value: List): Boolean { + override fun writePackedFloat(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * Float.SIZE_BYTES, CodedOutputStream::writeFloatNoTag) - return true } - override fun writePackedDouble(fieldNr: Int, value: List): Boolean { + override fun writePackedDouble(fieldNr: Int, value: List) = checked { writePackedInternal(fieldNr, value, value.size * Double.SIZE_BYTES, CodedOutputStream::writeDoubleNoTag) - return true } override fun writeMessage( fieldNr: Int, value: T, - encode: T.(WireEncoder) -> Unit + encode: T.(WireEncoder) -> Unit, ) { codedOutputStream.writeTag(fieldNr, WireType.LENGTH_DELIMITED.ordinal) codedOutputStream.writeInt32NoTag(value._size) @@ -215,17 +188,27 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { value: List, fieldSize: Int, crossinline writer: (CodedOutputStream, T) -> Unit, - ): Boolean { + ) = checked { codedOutputStream.writeTag(fieldNr, WireType.LENGTH_DELIMITED.ordinal) // write the field size of the packed field codedOutputStream.writeInt32NoTag(fieldSize) for (v in value) { writer(codedOutputStream, v) } - return true } } internal actual fun WireEncoder(sink: Sink): WireEncoder { return WireEncoderJvm(sink) } + +/** + * Wraps a [kotlinx.io.IOException] in our own [kotlinx.rpc.grpc.ProtobufEncodingException]. + */ +private inline fun checked(crossinline block: () -> Unit) { + try { + return block() + } catch (e: IOException) { + throw ProtobufEncodingException("Failed to encode protobuf message.", e) + } +} diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 6e3c56d0e..6f0efe797 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -6,6 +6,7 @@ package kotlinx.rpc.grpc.pb import kotlinx.cinterop.* import kotlinx.io.Sink +import kotlinx.rpc.grpc.ProtobufEncodingException import kotlinx.rpc.grpc.internal.writeFully import libprotowire.* import kotlin.experimental.ExperimentalNativeApi @@ -51,76 +52,78 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { pw_encoder_flush(raw) } - override fun writeBool(fieldNr: Int, value: Boolean): Boolean { - return pw_encoder_write_bool(raw, fieldNr, value) + override fun writeBool(fieldNr: Int, value: Boolean) = checked { + pw_encoder_write_bool(raw, fieldNr, value) } - override fun writeInt32(fieldNr: Int, value: Int): Boolean { - return pw_encoder_write_int32(raw, fieldNr, value) + override fun writeInt32(fieldNr: Int, value: Int) = checked { + pw_encoder_write_int32(raw, fieldNr, value) } - override fun writeInt64(fieldNr: Int, value: Long): Boolean { - return pw_encoder_write_int64(raw, fieldNr, value) + override fun writeInt64(fieldNr: Int, value: Long) = checked { + pw_encoder_write_int64(raw, fieldNr, value) } - override fun writeUInt32(fieldNr: Int, value: UInt): Boolean { - return pw_encoder_write_uint32(raw, fieldNr, value) + override fun writeUInt32(fieldNr: Int, value: UInt) = checked { + pw_encoder_write_uint32(raw, fieldNr, value) } - override fun writeUInt64(fieldNr: Int, value: ULong): Boolean { - return pw_encoder_write_uint64(raw, fieldNr, value) + override fun writeUInt64(fieldNr: Int, value: ULong) = checked { + pw_encoder_write_uint64(raw, fieldNr, value) } - override fun writeSInt32(fieldNr: Int, value: Int): Boolean { - return pw_encoder_write_sint32(raw, fieldNr, value) + override fun writeSInt32(fieldNr: Int, value: Int) = checked { + pw_encoder_write_sint32(raw, fieldNr, value) } - override fun writeSInt64(fieldNr: Int, value: Long): Boolean { - return pw_encoder_write_sint64(raw, fieldNr, value) + override fun writeSInt64(fieldNr: Int, value: Long) = checked { + pw_encoder_write_sint64(raw, fieldNr, value) } - override fun writeFixed32(fieldNr: Int, value: UInt): Boolean { - return pw_encoder_write_fixed32(raw, fieldNr, value) + override fun writeFixed32(fieldNr: Int, value: UInt) = checked { + pw_encoder_write_fixed32(raw, fieldNr, value) } - override fun writeFixed64(fieldNr: Int, value: ULong): Boolean { - return pw_encoder_write_fixed64(raw, fieldNr, value) + override fun writeFixed64(fieldNr: Int, value: ULong) = checked { + pw_encoder_write_fixed64(raw, fieldNr, value) } - override fun writeSFixed32(fieldNr: Int, value: Int): Boolean { - return pw_encoder_write_sfixed32(raw, fieldNr, value) + override fun writeSFixed32(fieldNr: Int, value: Int) = checked { + pw_encoder_write_sfixed32(raw, fieldNr, value) } - override fun writeSFixed64(fieldNr: Int, value: Long): Boolean { - return pw_encoder_write_sfixed64(raw, fieldNr, value) + override fun writeSFixed64(fieldNr: Int, value: Long) = checked { + pw_encoder_write_sfixed64(raw, fieldNr, value) } - override fun writeFloat(fieldNr: Int, value: Float): Boolean { - return pw_encoder_write_float(raw, fieldNr, value) + override fun writeFloat(fieldNr: Int, value: Float) = checked { + pw_encoder_write_float(raw, fieldNr, value) } - override fun writeDouble(fieldNr: Int, value: Double): Boolean { - return pw_encoder_write_double(raw, fieldNr, value) + override fun writeDouble(fieldNr: Int, value: Double) = checked { + pw_encoder_write_double(raw, fieldNr, value) } - override fun writeEnum(fieldNr: Int, value: Int): Boolean { - return pw_encoder_write_enum(raw, fieldNr, value) + override fun writeEnum(fieldNr: Int, value: Int) = checked { + pw_encoder_write_enum(raw, fieldNr, value) } - override fun writeString(fieldNr: Int, value: String): Boolean = memScoped { - if (value.isEmpty()) { - return pw_encoder_write_string(raw, fieldNr, null, 0) + override fun writeString(fieldNr: Int, value: String) = checked { + memScoped { + if (value.isEmpty()) { + return@checked pw_encoder_write_string(raw, fieldNr, null, 0) + } + val cStr = value.cstr + val len = cStr.size - 1 // minus 1 as it also counts the null terminator + return@checked pw_encoder_write_string(raw, fieldNr, cStr.ptr, len) } - val cStr = value.cstr - val len = cStr.size - 1 // minus 1 as it also counts the null terminator - return pw_encoder_write_string(raw, fieldNr, cStr.ptr, len) } - override fun writeBytes(fieldNr: Int, value: ByteArray): Boolean { + override fun writeBytes(fieldNr: Int, value: ByteArray) = checked { if (value.isEmpty()) { - return pw_encoder_write_bytes(raw, fieldNr, null, 0) + return@checked pw_encoder_write_bytes(raw, fieldNr, null, 0) } - return value.usePinned { + return@checked value.usePinned { pw_encoder_write_bytes(raw, fieldNr, it.addressOf(0), value.size) } } @@ -186,14 +189,23 @@ private inline fun WireEncoderNative.writePackedInternal( value: List, fieldSize: Int, crossinline writer: (CValuesRef?, T) -> Boolean, -): Boolean { +) = checked { pw_encoder_write_tag(raw, fieldNr, WireType.LENGTH_DELIMITED.ordinal) // write the field size of the packed field pw_encoder_write_int32_no_tag(raw, fieldSize) for (v in value) { if (!writer(raw, v)) { - return false + return@checked false } } - return true + return@checked true +} + +/** + * Checks the block's return value and throws an [ProtobufEncodingException] if its `false`. + */ +private inline fun checked(crossinline block: () -> Boolean) { + if (!block()) { + throw ProtobufEncodingException("Failed to encode protobuf message.") + } } From 09a09a6255c7cf213fd409802cfbde86f00a7129 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 12:31:31 +0200 Subject: [PATCH 7/9] grpc-pb: Move JVM exception check to CODEC to avoid performance overhead Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 4 +- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 4 +- .../kotlinx/rpc/grpc/pb/WireCodecTest.kt | 4 +- .../kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt | 46 ++++++------- .../kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 69 +++++++++---------- .../kotlinx/rpc/grpc/pb/WireDecoder.native.kt | 6 +- .../kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 6 +- .../protobuf/ModelToKotlinCommonGenerator.kt | 12 ++-- 8 files changed, 81 insertions(+), 70 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 44fe5b45c..1aecb80e7 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -98,6 +98,8 @@ public interface WireDecoder : AutoCloseable { } } +public expect fun checkForPlatformDecodeException(block: () -> Unit) + /** * Creates a platform-specific [WireDecoder]. * @@ -108,4 +110,4 @@ public interface WireDecoder : AutoCloseable { * * @param source The buffer containing the encoded wire-format data. */ -internal expect fun WireDecoder(source: Buffer): WireDecoder \ No newline at end of file +public expect fun WireDecoder(source: Buffer): WireDecoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index 6a7db1507..f1e4c7957 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -60,4 +60,6 @@ public interface WireEncoder { } -internal expect fun WireEncoder(sink: Sink): WireEncoder \ No newline at end of file +public expect fun checkForPlatformEncodeException(block: () -> Unit) + +public expect fun WireEncoder(sink: Sink): WireEncoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt index 5170167a3..bbf253b5f 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt @@ -815,7 +815,9 @@ class WireCodecTest { buffer.writeByte(0) assertFailsWith { - WireDecoder(buffer).readTag() + checkForPlatformDecodeException { + WireDecoder(buffer).readTag() + } } } diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt index 23d5f0653..2ee7097e9 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt @@ -15,7 +15,7 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { // there is no way to omit coping here internal val codedInputStream: CodedInputStream = CodedInputStream.newInstance(source.asInputStream()) - override fun readTag(): KTag? = checked { + override fun readTag(): KTag? { val tag = codedInputStream.readTag().toUInt() if (tag == 0u) { return null @@ -23,71 +23,71 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { return KTag.from(tag) } - override fun readBool(): Boolean = checked { + override fun readBool(): Boolean { return codedInputStream.readBool() } - override fun readInt32(): Int = checked { + override fun readInt32(): Int { return codedInputStream.readInt32() } - override fun readInt64(): Long = checked { + override fun readInt64(): Long { return codedInputStream.readInt64() } - override fun readUInt32(): UInt = checked { + override fun readUInt32(): UInt { // todo check java unsigned types return codedInputStream.readUInt32().toUInt() } - override fun readUInt64(): ULong = checked { + override fun readUInt64(): ULong { // todo check java unsigned types return codedInputStream.readUInt64().toULong() } - override fun readSInt32(): Int = checked { + override fun readSInt32(): Int { return codedInputStream.readSInt32() } - override fun readSInt64(): Long = checked { + override fun readSInt64(): Long { return codedInputStream.readSInt64() } - override fun readFixed32(): UInt = checked { + override fun readFixed32(): UInt { // todo check java unsigned types return codedInputStream.readFixed32().toUInt() } - override fun readFixed64(): ULong = checked { + override fun readFixed64(): ULong { // todo check java unsigned types return codedInputStream.readFixed64().toULong() } - override fun readSFixed32(): Int = checked { + override fun readSFixed32(): Int { return codedInputStream.readSFixed32() } - override fun readSFixed64(): Long = checked { + override fun readSFixed64(): Long { return codedInputStream.readSFixed64() } - override fun readFloat(): Float = checked { + override fun readFloat(): Float { return codedInputStream.readFloat() } - override fun readDouble(): Double = checked { + override fun readDouble(): Double { return codedInputStream.readDouble() } - override fun readEnum(): Int = checked { + override fun readEnum(): Int { return codedInputStream.readEnum() } - override fun readString(): String = checked { + override fun readString(): String { return codedInputStream.readStringRequireUtf8() } - override fun readBytes(): ByteArray = checked { + override fun readBytes(): ByteArray { return codedInputStream.readByteArray() } @@ -114,14 +114,8 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { ) } -internal actual fun WireDecoder(source: Buffer): WireDecoder { - return WireDecoderJvm(source) -} -/** - * Turns a [InvalidProtocolBufferException] into our own [ProtobufDecodingException]. - */ -private inline fun checked(block: () -> T): T { +public actual fun checkForPlatformDecodeException(block: () -> Unit) { try { return block() } catch (e: InvalidProtocolBufferException) { @@ -129,6 +123,10 @@ private inline fun checked(block: () -> T): T { } } +public actual fun WireDecoder(source: Buffer): WireDecoder { + return WireDecoderJvm(source) +} + private fun InvalidProtocolBufferException.toDecodingException(): ProtobufDecodingException { return ProtobufDecodingException(message ?: "Failed to decode protobuf message.", cause) } diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 041a467b0..1dbc66205 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -17,71 +17,71 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { codedOutputStream.flush() } - override fun writeBool(fieldNr: Int, value: Boolean) = checked { + override fun writeBool(fieldNr: Int, value: Boolean) { codedOutputStream.writeBool(fieldNr, value) } - override fun writeInt32(fieldNr: Int, value: Int) = checked { + override fun writeInt32(fieldNr: Int, value: Int) { codedOutputStream.writeInt32(fieldNr, value) } - override fun writeInt64(fieldNr: Int, value: Long) = checked { + override fun writeInt64(fieldNr: Int, value: Long) { codedOutputStream.writeInt64(fieldNr, value) } - override fun writeUInt32(fieldNr: Int, value: UInt) = checked { + override fun writeUInt32(fieldNr: Int, value: UInt) { // todo check java unsigned types codedOutputStream.writeUInt32(fieldNr, value.toInt()) } - override fun writeUInt64(fieldNr: Int, value: ULong) = checked { + override fun writeUInt64(fieldNr: Int, value: ULong) { // todo check java unsigned types codedOutputStream.writeUInt64(fieldNr, value.toLong()) } - override fun writeSInt32(fieldNr: Int, value: Int) = checked { + override fun writeSInt32(fieldNr: Int, value: Int) { codedOutputStream.writeSInt32(fieldNr, value) } - override fun writeSInt64(fieldNr: Int, value: Long) = checked { + override fun writeSInt64(fieldNr: Int, value: Long) { codedOutputStream.writeSInt64(fieldNr, value) } - override fun writeFixed32(fieldNr: Int, value: UInt) = checked { + override fun writeFixed32(fieldNr: Int, value: UInt) { // todo check java unsigned types codedOutputStream.writeFixed32(fieldNr, value.toInt()) } - override fun writeFixed64(fieldNr: Int, value: ULong) = checked { + override fun writeFixed64(fieldNr: Int, value: ULong) { // todo check java unsigned types codedOutputStream.writeFixed64(fieldNr, value.toLong()) } - override fun writeSFixed32(fieldNr: Int, value: Int) = checked { + override fun writeSFixed32(fieldNr: Int, value: Int) { codedOutputStream.writeSFixed32(fieldNr, value) } - override fun writeSFixed64(fieldNr: Int, value: Long) = checked { + override fun writeSFixed64(fieldNr: Int, value: Long) { codedOutputStream.writeSFixed64(fieldNr, value) } - override fun writeFloat(fieldNr: Int, value: Float) = checked { + override fun writeFloat(fieldNr: Int, value: Float) { codedOutputStream.writeFloat(fieldNr, value) } - override fun writeDouble(fieldNr: Int, value: Double) = checked { + override fun writeDouble(fieldNr: Int, value: Double) { codedOutputStream.writeDouble(fieldNr, value) } - override fun writeEnum(fieldNr: Int, value: Int) = checked { + override fun writeEnum(fieldNr: Int, value: Int) { codedOutputStream.writeEnum(fieldNr, value) } - override fun writeBytes(fieldNr: Int, value: ByteArray) = checked { + override fun writeBytes(fieldNr: Int, value: ByteArray) { codedOutputStream.writeByteArray(fieldNr, value) } - override fun writeString(fieldNr: Int, value: String) = checked { + override fun writeString(fieldNr: Int, value: String) { codedOutputStream.writeString(fieldNr, value) } @@ -89,7 +89,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeBoolNoTag) } @@ -97,7 +97,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeInt32NoTag) } @@ -105,7 +105,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeInt64NoTag) } @@ -113,7 +113,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize) { codedOutputStream, v -> codedOutputStream.writeUInt32NoTag(v.toInt()) } @@ -123,7 +123,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize) { codedOutputStream, v -> codedOutputStream.writeUInt64NoTag(v.toLong()) } @@ -133,7 +133,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeSInt32NoTag) } @@ -141,35 +141,35 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { fieldNr: Int, value: List, fieldSize: Int, - ) = checked { + ) { writePackedInternal(fieldNr, value, fieldSize, CodedOutputStream::writeSInt64NoTag) } - override fun writePackedFixed32(fieldNr: Int, value: List) = checked { + override fun writePackedFixed32(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * UInt.SIZE_BYTES) { codedOutputStream, v -> codedOutputStream.writeFixed32NoTag(v.toInt()) } } - override fun writePackedFixed64(fieldNr: Int, value: List) = checked { + override fun writePackedFixed64(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * ULong.SIZE_BYTES) { codedOutputStream, v -> codedOutputStream.writeFixed64NoTag(v.toLong()) } } - override fun writePackedSFixed32(fieldNr: Int, value: List) = checked { + override fun writePackedSFixed32(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * Int.SIZE_BYTES, CodedOutputStream::writeSFixed32NoTag) } - override fun writePackedSFixed64(fieldNr: Int, value: List) = checked { + override fun writePackedSFixed64(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * Long.SIZE_BYTES, CodedOutputStream::writeSFixed64NoTag) } - override fun writePackedFloat(fieldNr: Int, value: List) = checked { + override fun writePackedFloat(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * Float.SIZE_BYTES, CodedOutputStream::writeFloatNoTag) } - override fun writePackedDouble(fieldNr: Int, value: List) = checked { + override fun writePackedDouble(fieldNr: Int, value: List) { writePackedInternal(fieldNr, value, value.size * Double.SIZE_BYTES, CodedOutputStream::writeDoubleNoTag) } @@ -188,7 +188,7 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { value: List, fieldSize: Int, crossinline writer: (CodedOutputStream, T) -> Unit, - ) = checked { + ) { codedOutputStream.writeTag(fieldNr, WireType.LENGTH_DELIMITED.ordinal) // write the field size of the packed field codedOutputStream.writeInt32NoTag(fieldSize) @@ -198,17 +198,14 @@ private class WireEncoderJvm(sink: Sink) : WireEncoder { } } -internal actual fun WireEncoder(sink: Sink): WireEncoder { +public actual fun WireEncoder(sink: Sink): WireEncoder { return WireEncoderJvm(sink) } -/** - * Wraps a [kotlinx.io.IOException] in our own [kotlinx.rpc.grpc.ProtobufEncodingException]. - */ -private inline fun checked(crossinline block: () -> Unit) { +public actual fun checkForPlatformEncodeException(block: () -> Unit) { try { return block() } catch (e: IOException) { throw ProtobufEncodingException("Failed to encode protobuf message.", e) } -} +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt index 432cd3d76..023910755 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt @@ -301,4 +301,8 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { } } -internal actual fun WireDecoder(source: Buffer): WireDecoder = WireDecoderNative(source) +public actual fun WireDecoder(source: Buffer): WireDecoder = WireDecoderNative(source) + +public actual fun checkForPlatformDecodeException(block: () -> Unit) { + block() +} \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 6f0efe797..0e9063b0d 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -178,7 +178,7 @@ internal class WireEncoderNative(private val sink: Sink) : WireEncoder { } } -internal actual fun WireEncoder(sink: Sink): WireEncoder = WireEncoderNative(sink) +public actual fun WireEncoder(sink: Sink): WireEncoder = WireEncoderNative(sink) // the current implementation is slow, as it iterates through the list, to write each element individually, @@ -209,3 +209,7 @@ private inline fun checked(crossinline block: () -> Boolean) { throw ProtobufEncodingException("Failed to encode protobuf message.") } } + +public actual fun checkForPlatformEncodeException(block: () -> Unit) { + block() // nothing to check for on native +} \ No newline at end of file diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index 08e06f44d..d7df29067 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -137,6 +137,7 @@ class ModelToKotlinCommonGenerator( clazz( name = declaration.name.simpleName, declarationType = DeclarationType.Interface, + annotations = listOf("@$WITH_CODEC_ANNO(${declaration.internalClassFullName()}.CODEC::class)") ) { declaration.fields().forEach { (fieldDeclaration, _) -> code("val $fieldDeclaration") @@ -167,9 +168,6 @@ class ModelToKotlinCommonGenerator( val annotations = buildList { add("@$INTERNAL_RPC_API_ANNO") - if (declaration.isUserFacing) { - add("@$WITH_CODEC_ANNO($internalClassName.CODEC::class)") - } } val superTypes = buildList { if (declaration.isUserFacing) { @@ -244,7 +242,9 @@ class ModelToKotlinCommonGenerator( function("encode", modifiers = "override", args = "value: $msgFqName", returnType = sourceFqName) { code("val buffer = $bufferFqName()") code("val encoder = $PB_PKG.WireEncoder(buffer)") - code("value.asInternal().encodeWith(encoder)") + scope("$PB_PKG.checkForPlatformEncodeException", nlAfterClosed = false) { + code("value.asInternal().encodeWith(encoder)") + } code("encoder.flush()") code("return buffer") } @@ -252,7 +252,9 @@ class ModelToKotlinCommonGenerator( function("decode", modifiers = "override", args = "stream: $sourceFqName", returnType = msgFqName) { scope("$PB_PKG.WireDecoder(stream as $bufferFqName).use") { code("val msg = ${declaration.internalClassFullName()}()") - code("${declaration.internalClassFullName()}.decodeWith(msg, it)") + scope("$PB_PKG.checkForPlatformDecodeException", nlAfterClosed = false) { + code("${declaration.internalClassFullName()}.decodeWith(msg, it)") + } code("msg.checkRequiredFields()") code("return msg") } From 9e890cdba6805f8ffa5f1d256e7c0ed2127b0318 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Fri, 8 Aug 2025 13:07:05 +0200 Subject: [PATCH 8/9] grpc-pb: Address PR comments Signed-off-by: Johannes Zottele --- .../kotlinx/rpc/grpc/internal/readPacked.kt | 2 +- .../kotlin/kotlinx/rpc/grpc/pb/KTag.kt | 4 +-- .../ProtobufException.kt} | 8 +++--- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 26 ++++++++++++------- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 13 +++++++--- .../kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt | 1 - .../kotlinx/rpc/grpc/pb/WireCodecTest.kt | 1 - .../kotlin/kotlinx/rpc/grpc/Server.jvm.kt | 2 +- .../kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt | 7 +---- .../kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 1 - .../kotlinx/rpc/grpc/pb/WireDecoder.native.kt | 1 - .../kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 1 - .../protobuf/ModelToKotlinCommonGenerator.kt | 2 +- 13 files changed, 37 insertions(+), 32 deletions(-) rename grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/{GrpcException.kt => pb/ProtobufException.kt} (87%) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt index e43f6b20a..c20bd0e79 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/internal/readPacked.kt @@ -4,7 +4,7 @@ package kotlinx.rpc.grpc.internal -import kotlinx.rpc.grpc.ProtobufDecodingException +import kotlinx.rpc.grpc.pb.ProtobufDecodingException import kotlinx.rpc.grpc.pb.WireDecoder internal expect fun WireDecoder.pushLimit(byteLen: Int): Int diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt index cc9134288..c6b9ffe2e 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/KTag.kt @@ -40,10 +40,10 @@ internal fun KTag.Companion.from(rawKTag: UInt): KTag { val type = (rawKTag and K_TAG_TYPE_MASK).toInt() val field = (rawKTag shr K_TAG_TYPE_BITS).toInt() if (!isValidFieldNr(field)) { - error("Invalid field number: $field") + throw ProtobufDecodingException("Invalid field number: $field") } if (type >= WireType.entries.size) { - error("Invalid wire type: $type") + throw ProtobufDecodingException("Invalid wire type: $type") } return KTag(field, WireType.entries[type]) } diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt similarity index 87% rename from grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt rename to grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt index 19f5b8a9c..8e10402e8 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/GrpcException.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt @@ -2,14 +2,14 @@ * Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license. */ -package kotlinx.rpc.grpc +package kotlinx.rpc.grpc.pb -public sealed class GrpcException : RuntimeException { +public sealed class ProtobufException : RuntimeException { protected constructor(message: String, cause: Throwable? = null) : super(message, cause) } -public class ProtobufDecodingException : GrpcException { +public class ProtobufDecodingException : ProtobufException { internal constructor(message: String, cause: Throwable? = null) : super(message, cause) public companion object Companion { @@ -35,6 +35,6 @@ public class ProtobufDecodingException : GrpcException { } } -public class ProtobufEncodingException : GrpcException { +public class ProtobufEncodingException : ProtobufException { internal constructor(message: String, cause: Throwable? = null) : super(message, cause) } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 1aecb80e7..50541ce9d 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -5,7 +5,6 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer -import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.popLimit import kotlinx.rpc.grpc.internal.pushLimit import kotlinx.rpc.internal.utils.InternalRpcApi @@ -20,10 +19,12 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000 * This decoder is used by first calling [readTag], than looking up the field based on the field number in the returned, * tag and then calling the actual `read*()` method to read the value to the corresponding field. * - * [hadError] indicates an error during decoding. While calling `read*()` is safe, the returned values - * are meaningless if [hadError] returns `true`. + * All `read*()` methods will throw an exception if the expected value couldn't be decoded. + * Because of optimization reasons, the exception is platform-dependent. To unify them + * wrap the decoding in a [checkForPlatformDecodeException] call, which turn platform-specific exceptions + * into a [ProtobufDecodingException]. * - * NOTE: If the [hadError] after a call to `read*()` returns `false`, it doesn't mean that the + * NOTE: If a call to `read*()` doesn't throw an error, it doesn't mean that the * value is correctly decoded. E.g., the following test will pass: * ```kt * val fieldNr = 1 @@ -33,10 +34,12 @@ internal const val MAX_PACKED_BULK_SIZE: Int = 1_000_000 * assertTrue(encoder.writeInt32(fieldNr, 12312)) * encoder.flush() * - * WireDecoder(buffer).use { decoder -> - * decoder.readTag() - * decoder.readBool() - * assertFalse(decoder.hasError()) + * checkForPlatformDecodeException { + * WireDecoder(buffer).use { decoder -> + * decoder.readTag() + * decoder.readBool() + * assertFalse(decoder.hasError()) + * } * } * ``` */ @@ -92,12 +95,16 @@ public interface WireDecoder : AutoCloseable { WireType.FIXED32 -> readFixed32() WireType.FIXED64 -> readFixed64() WireType.LENGTH_DELIMITED -> readBytes() - WireType.START_GROUP -> error("Unexpected START_GROUP wire type (KRPC-193)") + WireType.START_GROUP -> throw ProtobufDecodingException("Unexpected START_GROUP wire type (KRPC-193)") WireType.END_GROUP -> {} // nothing to do } } } +/** + * Turns exceptions thrown by different platforms during decoding into [ProtobufDecodingException]. + */ +@InternalRpcApi public expect fun checkForPlatformDecodeException(block: () -> Unit) /** @@ -110,4 +117,5 @@ public expect fun checkForPlatformDecodeException(block: () -> Unit) * * @param source The buffer containing the encoded wire-format data. */ +@InternalRpcApi public expect fun WireDecoder(source: Buffer): WireDecoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index f1e4c7957..132bc4c23 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -10,8 +10,11 @@ import kotlinx.rpc.internal.utils.InternalRpcApi /** * A platform-specific class that encodes values into protobuf's wire format. * - * If one `write*()` method returns false, the encoding of the value failed - * and no further encodings can be performed on this [WireEncoder]. + * If one `write*()` method fails to encode the value in the buffer, + * it will throw a platform-specific exception. + * + * Wrap the encoding of a message with [checkForPlatformEncodeException] to + * turn all thrown platform-specific exceptions into [ProtobufEncodingException]s. * * [flush] must be called to ensure that all data is written to the [Sink]. */ @@ -59,7 +62,11 @@ public interface WireEncoder { } - +/** + * Turns exceptions thrown by different platforms during encoding into [ProtobufEncodingException]. + */ +@InternalRpcApi public expect fun checkForPlatformEncodeException(block: () -> Unit) +@InternalRpcApi public expect fun WireEncoder(sink: Sink): WireEncoder \ No newline at end of file diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt index 9032f9f44..49c349ebc 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/ProtosTest.kt @@ -13,7 +13,6 @@ import asInternal import encodeWith import invoke import kotlinx.io.Buffer -import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.codec.MessageCodec import kotlinx.rpc.grpc.test.* import kotlinx.rpc.grpc.test.common.* diff --git a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt index bbf253b5f..e3a41b112 100644 --- a/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt +++ b/grpc/grpc-core/src/commonTest/kotlin/kotlinx/rpc/grpc/pb/WireCodecTest.kt @@ -5,7 +5,6 @@ package kotlinx.rpc.grpc.pb import kotlinx.io.Buffer -import kotlinx.rpc.grpc.ProtobufDecodingException import kotlin.test.* enum class TestPlatform { diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/Server.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/Server.jvm.kt index 4a01ca51b..cf6fce68a 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/Server.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/Server.jvm.kt @@ -33,7 +33,7 @@ private fun io.grpc.Server.toKotlin(): Server { override val isTerminated: Boolean get() = this@toKotlin.isTerminated - override fun start() : Server { + override fun start(): Server { this@toKotlin.start() return this } diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt index 2ee7097e9..50e9cd916 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt @@ -8,7 +8,6 @@ import com.google.protobuf.CodedInputStream import com.google.protobuf.InvalidProtocolBufferException import kotlinx.io.Buffer import kotlinx.io.asInputStream -import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.readPackedVarInternal internal class WireDecoderJvm(source: Buffer) : WireDecoder { @@ -119,14 +118,10 @@ public actual fun checkForPlatformDecodeException(block: () -> Unit) { try { return block() } catch (e: InvalidProtocolBufferException) { - throw e.toDecodingException() + throw ProtobufDecodingException(e.message ?: "Failed to decode protobuf message.", e) } } public actual fun WireDecoder(source: Buffer): WireDecoder { return WireDecoderJvm(source) } - -private fun InvalidProtocolBufferException.toDecodingException(): ProtobufDecodingException { - return ProtobufDecodingException(message ?: "Failed to decode protobuf message.", cause) -} diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 1dbc66205..3cb7ff17b 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -8,7 +8,6 @@ import com.google.protobuf.CodedOutputStream import kotlinx.io.IOException import kotlinx.io.Sink import kotlinx.io.asOutputStream -import kotlinx.rpc.grpc.ProtobufEncodingException private class WireEncoderJvm(sink: Sink) : WireEncoder { private val codedOutputStream = CodedOutputStream.newInstance(sink.asOutputStream()) diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt index 023910755..a46099528 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt @@ -7,7 +7,6 @@ package kotlinx.rpc.grpc.pb import kotlinx.cinterop.* import kotlinx.collections.immutable.persistentListOf import kotlinx.io.Buffer -import kotlinx.rpc.grpc.ProtobufDecodingException import kotlinx.rpc.grpc.internal.ZeroCopyInputSource import kotlinx.rpc.grpc.internal.readPackedVarInternal import libprotowire.* diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 0e9063b0d..618a0a5df 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -6,7 +6,6 @@ package kotlinx.rpc.grpc.pb import kotlinx.cinterop.* import kotlinx.io.Sink -import kotlinx.rpc.grpc.ProtobufEncodingException import kotlinx.rpc.grpc.internal.writeFully import libprotowire.* import kotlin.experimental.ExperimentalNativeApi diff --git a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt index d7df29067..fb13b0873 100644 --- a/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt +++ b/protoc-gen/src/main/kotlin/kotlinx/rpc/protobuf/ModelToKotlinCommonGenerator.kt @@ -584,7 +584,7 @@ class ModelToKotlinCommonGenerator( requiredFields.forEach { field -> ifBranch(condition = "!presenceMask[${field.presenceIdx}]", ifBlock = { - code("throw kotlinx.rpc.grpc.ProtobufDecodingException.missingRequiredField(\"${declaration.name.simpleName}\", \"${field.name}\")") + code("throw $PB_PKG.ProtobufDecodingException.missingRequiredField(\"${declaration.name.simpleName}\", \"${field.name}\")") }) } From a8629928f5f352d4ffb58f490e974551275031f0 Mon Sep 17 00:00:00 2001 From: Johannes Zottele Date: Mon, 11 Aug 2025 18:47:47 +0200 Subject: [PATCH 9/9] grpc-pb: Address PR Comments Signed-off-by: Johannes Zottele --- .../kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt | 4 ++-- .../src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt | 2 +- .../src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt | 2 +- .../src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt | 2 +- .../src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt | 2 +- .../kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt | 2 +- .../kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt | 2 +- 7 files changed, 8 insertions(+), 8 deletions(-) diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt index 8e10402e8..3830c6f66 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/ProtobufException.kt @@ -10,7 +10,7 @@ public sealed class ProtobufException : RuntimeException { public class ProtobufDecodingException : ProtobufException { - internal constructor(message: String, cause: Throwable? = null) : super(message, cause) + public constructor(message: String, cause: Throwable? = null) : super(message, cause) public companion object Companion { internal fun missingRequiredField(messageName: String, fieldName: String) = @@ -36,5 +36,5 @@ public class ProtobufDecodingException : ProtobufException { } public class ProtobufEncodingException : ProtobufException { - internal constructor(message: String, cause: Throwable? = null) : super(message, cause) + public constructor(message: String, cause: Throwable? = null) : super(message, cause) } \ No newline at end of file diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt index 50541ce9d..2961558c9 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.kt @@ -105,7 +105,7 @@ public interface WireDecoder : AutoCloseable { * Turns exceptions thrown by different platforms during decoding into [ProtobufDecodingException]. */ @InternalRpcApi -public expect fun checkForPlatformDecodeException(block: () -> Unit) +public expect inline fun checkForPlatformDecodeException(block: () -> Unit) /** * Creates a platform-specific [WireDecoder]. diff --git a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt index 132bc4c23..dc6177d8a 100644 --- a/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt +++ b/grpc/grpc-core/src/commonMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.kt @@ -66,7 +66,7 @@ public interface WireEncoder { * Turns exceptions thrown by different platforms during encoding into [ProtobufEncodingException]. */ @InternalRpcApi -public expect fun checkForPlatformEncodeException(block: () -> Unit) +public expect inline fun checkForPlatformEncodeException(block: () -> Unit) @InternalRpcApi public expect fun WireEncoder(sink: Sink): WireEncoder \ No newline at end of file diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt index 50e9cd916..77095c84d 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.jvm.kt @@ -114,7 +114,7 @@ internal class WireDecoderJvm(source: Buffer) : WireDecoder { } -public actual fun checkForPlatformDecodeException(block: () -> Unit) { +public actual inline fun checkForPlatformDecodeException(block: () -> Unit) { try { return block() } catch (e: InvalidProtocolBufferException) { diff --git a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt index 3cb7ff17b..ffa4b2f3c 100644 --- a/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt +++ b/grpc/grpc-core/src/jvmMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.jvm.kt @@ -201,7 +201,7 @@ public actual fun WireEncoder(sink: Sink): WireEncoder { return WireEncoderJvm(sink) } -public actual fun checkForPlatformEncodeException(block: () -> Unit) { +public actual inline fun checkForPlatformEncodeException(block: () -> Unit) { try { return block() } catch (e: IOException) { diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt index a46099528..f26b9da9a 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireDecoder.native.kt @@ -302,6 +302,6 @@ internal class WireDecoderNative(private val source: Buffer) : WireDecoder { public actual fun WireDecoder(source: Buffer): WireDecoder = WireDecoderNative(source) -public actual fun checkForPlatformDecodeException(block: () -> Unit) { +public actual inline fun checkForPlatformDecodeException(block: () -> Unit) { block() } \ No newline at end of file diff --git a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt index 618a0a5df..de5266f30 100644 --- a/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt +++ b/grpc/grpc-core/src/nativeMain/kotlin/kotlinx/rpc/grpc/pb/WireEncoder.native.kt @@ -209,6 +209,6 @@ private inline fun checked(crossinline block: () -> Boolean) { } } -public actual fun checkForPlatformEncodeException(block: () -> Unit) { +public actual inline fun checkForPlatformEncodeException(block: () -> Unit) { block() // nothing to check for on native } \ No newline at end of file