Skip to content

Commit a2bcdc7

Browse files
authored
Grpc services followup (#435)
* Fix KRPC-173 for gRPC * Added stream tests * Support zero-parameter functions * Make MessageCodecResolver able to return null and add operator plus * Add checkers for Grpc and WithCodec annotations * Added box tests for Grpc
1 parent 49ff159 commit a2bcdc7

File tree

36 files changed

+2288
-292
lines changed

36 files changed

+2288
-292
lines changed

compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcIrContext.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,8 @@ internal class RpcIrContext(
255255
grpcServiceDescriptor.namedFunction("delegate")
256256
}
257257

258-
val grpcMessageCodecResolverResolve by lazy {
259-
grpcMessageCodecResolver.namedFunction("resolve")
258+
val grpcMessageCodecResolverResolveOrNull by lazy {
259+
grpcMessageCodecResolver.namedFunction("resolveOrNull")
260260
}
261261

262262
private fun IrClassSymbol.namedFunction(name: String): IrSimpleFunction {

compiler-plugin/compiler-plugin-backend/src/main/kotlin/kotlinx/rpc/codegen/extension/RpcStubGenerator.kt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,11 +1155,11 @@ internal class RpcStubGenerator(
11551155
"Only methods are allowed here"
11561156
}
11571157

1158-
check(callable.arguments.size == 1) {
1159-
"Only single argument methods are allowed here"
1158+
check(callable.arguments.size <= 1) {
1159+
"Only single or none argument methods are allowed here"
11601160
}
11611161

1162-
val requestParameterType = callable.arguments[0].type
1162+
val requestParameterType = callable.arguments.getOrNull(0)?.type ?: ctx.irBuiltIns.unitType
11631163
val responseParameterType = callable.function.returnType
11641164

11651165
val requestType: IrType = requestParameterType.unwrapFlow()
@@ -1262,7 +1262,7 @@ internal class RpcStubGenerator(
12621262
startOffset = UNDEFINED_OFFSET,
12631263
endOffset = UNDEFINED_OFFSET,
12641264
type = ctx.grpcMessageCodec.typeWith(type),
1265-
symbol = ctx.functions.grpcMessageCodecResolverResolve.symbol,
1265+
symbol = ctx.functions.grpcMessageCodecResolverResolveOrNull.symbol,
12661266
typeArgumentsCount = 0,
12671267
valueArgumentsCount = 1,
12681268
)

compiler-plugin/compiler-plugin-common/src/main/kotlin/kotlinx/rpc/codegen/common/Names.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import org.jetbrains.kotlin.name.Name
1111
object RpcClassId {
1212
val rpcAnnotation = ClassId(FqName("kotlinx.rpc.annotations"), Name.identifier("Rpc"))
1313
val grpcAnnotation = ClassId(FqName("kotlinx.rpc.grpc.annotations"), Name.identifier("Grpc"))
14+
val withCodecAnnotation = ClassId(FqName("kotlinx.rpc.grpc.codec"), Name.identifier("WithCodec"))
1415
val checkedTypeAnnotation = ClassId(FqName("kotlinx.rpc.annotations"), Name.identifier("CheckedTypeAnnotation"))
1516

1617
val serializableAnnotation = ClassId(FqName("kotlinx.serialization"), Name.identifier("Serializable"))
@@ -19,6 +20,8 @@ object RpcClassId {
1920
val flow = ClassId(FqName("kotlinx.coroutines.flow"), Name.identifier("Flow"))
2021
val sharedFlow = ClassId(FqName("kotlinx.coroutines.flow"), Name.identifier("SharedFlow"))
2122
val stateFlow = ClassId(FqName("kotlinx.coroutines.flow"), Name.identifier("StateFlow"))
23+
24+
val messageCodec = ClassId(FqName("kotlinx.rpc.grpc.codec"), Name.identifier("MessageCodec"))
2225
}
2326

2427
object RpcNames {

compiler-plugin/compiler-plugin-k2/src/main/kotlin/kotlinx/rpc/codegen/FirRpcAdditionalCheckers.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ class FirRpcAdditionalCheckers(
2323
) : FirAdditionalCheckersExtension(session) {
2424
override fun FirDeclarationPredicateRegistrar.registerPredicates() {
2525
register(FirRpcPredicates.rpc)
26+
register(FirRpcPredicates.grpc)
27+
register(FirRpcPredicates.withCodec)
2628
register(FirRpcPredicates.checkedAnnotationMeta)
2729
}
2830

compiler-plugin/compiler-plugin-k2/src/main/kotlin/kotlinx/rpc/codegen/FirRpcPredicates.kt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,14 @@ object FirRpcPredicates {
1212
metaAnnotated(RpcClassId.rpcAnnotation.asSingleFqName(), includeItself = true) // @Rpc
1313
}
1414

15+
internal val grpc = DeclarationPredicate.create {
16+
metaAnnotated(RpcClassId.grpcAnnotation.asSingleFqName(), includeItself = true)
17+
}
18+
19+
internal val withCodec = DeclarationPredicate.create {
20+
annotated(RpcClassId.withCodecAnnotation.asSingleFqName())
21+
}
22+
1523
internal val checkedAnnotationMeta = DeclarationPredicate.create {
1624
metaAnnotated(RpcClassId.checkedTypeAnnotation.asSingleFqName(), includeItself = false)
1725
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright 2023-2025 JetBrains s.r.o and contributors. Use of this source code is governed by the Apache 2.0 license.
3+
*/
4+
5+
package kotlinx.rpc.codegen.checkers
6+
7+
import kotlinx.rpc.codegen.FirRpcPredicates
8+
import kotlinx.rpc.codegen.checkers.diagnostics.FirGrpcDiagnostics
9+
import kotlinx.rpc.codegen.checkers.util.functionParametersRecursionCheck
10+
import kotlinx.rpc.codegen.common.RpcClassId
11+
import kotlinx.rpc.codegen.vsApi
12+
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
13+
import org.jetbrains.kotlin.diagnostics.reportOn
14+
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
15+
import org.jetbrains.kotlin.fir.declarations.FirRegularClass
16+
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
17+
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
18+
import org.jetbrains.kotlin.fir.types.isMarkedNullable
19+
20+
object FirGrpcServiceDeclarationChecker {
21+
fun check(
22+
declaration: FirRegularClass,
23+
context: CheckerContext,
24+
reporter: DiagnosticReporter,
25+
) {
26+
if (!context.session.predicateBasedProvider.matches(FirRpcPredicates.grpc, declaration)) {
27+
return
28+
}
29+
30+
vsApi {
31+
declaration
32+
.declarationsVS(context.session)
33+
.filterIsInstance<FirNamedFunctionSymbol>()
34+
}.onEach { function ->
35+
if (function.valueParameterSymbols.size > 1) {
36+
reporter.reportOn(
37+
source = function.source,
38+
factory = FirGrpcDiagnostics.MULTIPLE_PARAMETERS_IN_GRPC_SERVICE,
39+
context = context,
40+
)
41+
}
42+
43+
if (function.valueParameterSymbols.size == 1) {
44+
val parameterSymbol = function.valueParameterSymbols[0]
45+
if (parameterSymbol.resolvedReturnType.isMarkedNullable) {
46+
reporter.reportOn(
47+
source = parameterSymbol.source,
48+
factory = FirGrpcDiagnostics.NULLABLE_PARAMETER_IN_GRPC_SERVICE,
49+
context = context,
50+
)
51+
}
52+
53+
functionParametersRecursionCheck(
54+
function = function,
55+
context = context,
56+
) { source, symbol, parents ->
57+
if (symbol.classId == RpcClassId.flow && parents.isNotEmpty()) {
58+
reporter.reportOn(
59+
source = source,
60+
factory = FirGrpcDiagnostics.NON_TOP_LEVEL_CLIENT_STREAMING_IN_RPC_SERVICE,
61+
context = context,
62+
)
63+
}
64+
}
65+
}
66+
67+
if (function.resolvedReturnType.isMarkedNullable) {
68+
reporter.reportOn(
69+
source = function.resolvedReturnTypeRef.source,
70+
factory = FirGrpcDiagnostics.NULLABLE_RETURN_TYPE_IN_GRPC_SERVICE,
71+
context = context,
72+
)
73+
}
74+
}
75+
}
76+
}

compiler-plugin/compiler-plugin-k2/src/main/kotlin/kotlinx/rpc/codegen/checkers/FirRpcCheckers.kt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class FirRpcDeclarationCheckers(ctx: FirCheckersContext) : DeclarationCheckers()
1818
FirRpcAnnotationCheckerVS(),
1919
FirRpcStrictModeClassCheckerVS(),
2020
FirRpcServiceDeclarationCheckerVS(ctx),
21+
FirGrpcServiceDeclarationCheckerVS(),
22+
FirWithCodecDeclarationCheckerVS(),
2123
)
2224

2325
override val classCheckers: Set<FirClassChecker> = setOf(

compiler-plugin/compiler-plugin-k2/src/main/kotlin/kotlinx/rpc/codegen/checkers/FirRpcStrictModeClassChecker.kt

Lines changed: 29 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -6,26 +6,19 @@ package kotlinx.rpc.codegen.checkers
66

77
import kotlinx.rpc.codegen.FirRpcPredicates
88
import kotlinx.rpc.codegen.checkers.diagnostics.FirRpcStrictModeDiagnostics
9+
import kotlinx.rpc.codegen.checkers.util.functionParametersRecursionCheck
910
import kotlinx.rpc.codegen.common.RpcClassId
1011
import kotlinx.rpc.codegen.vsApi
1112
import org.jetbrains.kotlin.KtSourceElement
1213
import org.jetbrains.kotlin.diagnostics.DiagnosticReporter
1314
import org.jetbrains.kotlin.diagnostics.KtDiagnosticFactory0
1415
import org.jetbrains.kotlin.diagnostics.reportOn
1516
import org.jetbrains.kotlin.fir.analysis.checkers.context.CheckerContext
16-
import org.jetbrains.kotlin.fir.analysis.checkers.extractArgumentsTypeRefAndSource
17-
import org.jetbrains.kotlin.fir.analysis.checkers.toClassLikeSymbol
1817
import org.jetbrains.kotlin.fir.declarations.FirRegularClass
1918
import org.jetbrains.kotlin.fir.declarations.utils.isSuspend
2019
import org.jetbrains.kotlin.fir.extensions.predicateBasedProvider
21-
import org.jetbrains.kotlin.fir.scopes.impl.toConeType
22-
import org.jetbrains.kotlin.fir.symbols.impl.FirClassLikeSymbol
23-
import org.jetbrains.kotlin.fir.symbols.impl.FirClassSymbol
2420
import org.jetbrains.kotlin.fir.symbols.impl.FirNamedFunctionSymbol
2521
import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol
26-
import org.jetbrains.kotlin.fir.types.FirTypeRef
27-
import org.jetbrains.kotlin.utils.memoryOptimizedMap
28-
import org.jetbrains.kotlin.utils.memoryOptimizedPlus
2922

3023
object FirRpcStrictModeClassChecker {
3124
fun check(
@@ -37,124 +30,51 @@ object FirRpcStrictModeClassChecker {
3730
return
3831
}
3932

40-
val serializablePropertiesProvider = context.session.serializablePropertiesProvider
4133
vsApi { declaration.declarationsVS(context.session) }.forEach { declaration ->
4234
when (declaration) {
4335
is FirPropertySymbol -> {
4436
reporter.reportOn(declaration.source, FirRpcStrictModeDiagnostics.FIELD_IN_RPC_SERVICE, context)
4537
}
4638

4739
is FirNamedFunctionSymbol -> {
48-
checkFunction(declaration, context, reporter, serializablePropertiesProvider)
49-
}
50-
51-
else -> {}
52-
}
53-
}
54-
}
55-
56-
private fun checkFunction(
57-
function: FirNamedFunctionSymbol,
58-
context: CheckerContext,
59-
reporter: DiagnosticReporter,
60-
serializablePropertiesProvider: FirSerializablePropertiesProvider,
61-
) {
62-
fun reportOn(element: KtSourceElement?, checker: FirRpcStrictModeDiagnostics.() -> KtDiagnosticFactory0?) {
63-
reporter.reportOn(element, FirRpcStrictModeDiagnostics.checker() ?: return, context)
64-
}
65-
66-
val returnClassSymbol = vsApi {
67-
function.resolvedReturnTypeRef.coneTypeVS.toClassSymbolVS(context.session)
68-
}
69-
70-
val types = function.valueParameterSymbols.memoryOptimizedMap { parameter ->
71-
parameter.source to vsApi {
72-
parameter.resolvedReturnTypeRef
73-
}
74-
} memoryOptimizedPlus (function.resolvedReturnTypeRef.source to function.resolvedReturnTypeRef)
40+
fun reportOn(element: KtSourceElement?, checker: FirRpcStrictModeDiagnostics.() -> KtDiagnosticFactory0) {
41+
reporter.reportOn(element, FirRpcStrictModeDiagnostics.checker(), context)
42+
}
7543

76-
types.forEach { (source, symbol) ->
77-
checkSerializableTypes(
78-
context = context,
79-
typeRef = symbol,
80-
serializablePropertiesProvider = serializablePropertiesProvider,
81-
) { symbol, parents ->
82-
when (symbol.classId) {
83-
RpcClassId.stateFlow -> {
84-
reportOn(source) { STATE_FLOW_IN_RPC_SERVICE }
44+
val returnClassSymbol = vsApi {
45+
declaration.resolvedReturnTypeRef.coneTypeVS.toClassSymbolVS(context.session)
8546
}
8647

87-
RpcClassId.sharedFlow -> {
88-
reportOn(source) { SHARED_FLOW_IN_RPC_SERVICE }
48+
if (returnClassSymbol?.classId == RpcClassId.flow && declaration.isSuspend) {
49+
reportOn(declaration.source) { SUSPENDING_SERVER_STREAMING_IN_RPC_SERVICE }
8950
}
9051

91-
RpcClassId.flow -> {
92-
if (parents.any { it.classId == RpcClassId.flow }) {
93-
reportOn(source) { NESTED_STREAMING_IN_RPC_SERVICE }
94-
} else if (parents.isNotEmpty() && parents[0] == returnClassSymbol) {
95-
reportOn(source) { NON_TOP_LEVEL_SERVER_STREAMING_IN_RPC_SERVICE }
52+
functionParametersRecursionCheck(
53+
function = declaration,
54+
context = context,
55+
) { source, symbol, parents ->
56+
when (symbol.classId) {
57+
RpcClassId.stateFlow -> {
58+
reportOn(source) { STATE_FLOW_IN_RPC_SERVICE }
59+
}
60+
61+
RpcClassId.sharedFlow -> {
62+
reportOn(source) { SHARED_FLOW_IN_RPC_SERVICE }
63+
}
64+
65+
RpcClassId.flow -> {
66+
if (parents.any { it.classId == RpcClassId.flow }) {
67+
reportOn(source) { NESTED_STREAMING_IN_RPC_SERVICE }
68+
} else if (parents.isNotEmpty() && parents[0] == returnClassSymbol) {
69+
reportOn(source) { NON_TOP_LEVEL_SERVER_STREAMING_IN_RPC_SERVICE }
70+
}
71+
}
9672
}
9773
}
9874
}
99-
}
100-
}
10175

102-
if (returnClassSymbol?.classId == RpcClassId.flow && function.isSuspend) {
103-
reportOn(function.source) { SUSPENDING_SERVER_STREAMING_IN_RPC_SERVICE }
104-
}
105-
}
106-
107-
private fun checkSerializableTypes(
108-
context: CheckerContext,
109-
typeRef: FirTypeRef,
110-
serializablePropertiesProvider: FirSerializablePropertiesProvider,
111-
parentContext: List<FirClassLikeSymbol<*>> = emptyList(),
112-
checker: (FirClassLikeSymbol<*>, List<FirClassLikeSymbol<*>>) -> Unit,
113-
) {
114-
val symbol = typeRef.toClassLikeSymbol(context.session) ?: return
115-
116-
checker(symbol, parentContext)
117-
118-
if (symbol !is FirClassSymbol<*>) {
119-
return
120-
}
121-
122-
val nextContext = parentContext memoryOptimizedPlus symbol
123-
124-
if (symbol in parentContext && symbol.typeParameterSymbols.isEmpty()) {
125-
return
126-
}
127-
128-
val typeParameters = extractArgumentsTypeRefAndSource(typeRef)
129-
.orEmpty()
130-
.withIndex()
131-
.associate { (i, refSource) ->
132-
symbol.typeParameterSymbols[i].toConeType() to refSource.typeRef
76+
else -> {}
13377
}
134-
135-
val flowProps: List<FirTypeRef> = if (symbol.classId == RpcClassId.flow) {
136-
listOf(typeParameters.values.toList()[0]!!)
137-
} else {
138-
emptyList()
13978
}
140-
141-
serializablePropertiesProvider.getSerializablePropertiesForClass(symbol)
142-
.mapNotNull { property ->
143-
val resolvedTypeRef = property.resolvedReturnTypeRef
144-
if (resolvedTypeRef.toClassLikeSymbol(context.session) != null) {
145-
resolvedTypeRef
146-
} else {
147-
typeParameters[property.resolvedReturnType]
148-
}
149-
}.memoryOptimizedPlus(flowProps)
150-
.forEach { symbol ->
151-
checkSerializableTypes(
152-
context = context,
153-
typeRef = symbol,
154-
serializablePropertiesProvider = serializablePropertiesProvider,
155-
parentContext = nextContext,
156-
checker = checker,
157-
)
158-
}
15979
}
16080
}

0 commit comments

Comments
 (0)