Skip to content

Commit eeb039c

Browse files
Fix RequestCaptureProxy usage of coroutines
1 parent b8efcba commit eeb039c

File tree

6 files changed

+43
-32
lines changed

6 files changed

+43
-32
lines changed

client-kotlin/src/main/kotlin/dev/restate/client/kotlin/ingress.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ internal constructor(
429429
* @return a [KClientRequest] with the correct response type
430430
*/
431431
@Suppress("UNCHECKED_CAST")
432-
fun <Res> request(block: suspend SVC.() -> Res): KClientRequest<Any?, Res> {
432+
suspend fun <Res> request(block: suspend SVC.() -> Res): KClientRequest<Any?, Res> {
433433
return KClientRequestImpl(
434434
client,
435435
RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest(),

common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/RequestCaptureProxy.kt

Lines changed: 10 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,6 @@ package dev.restate.common.reflection.kotlin
1010

1111
import dev.restate.common.reflections.ProxySupport
1212
import dev.restate.common.reflections.ReflectionUtils
13-
import kotlin.coroutines.Continuation
14-
import kotlin.coroutines.EmptyCoroutineContext
15-
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
16-
import kotlin.coroutines.startCoroutine
1713

1814
/**
1915
* Captures method invocations on a proxy to extract invocation information.
@@ -37,35 +33,20 @@ class RequestCaptureProxy<SVC : Any>(private val clazz: Class<SVC>, private val
3733
* @param block the suspend lambda that invokes a method on the service proxy
3834
* @return the captured invocation information
3935
*/
40-
fun capture(block: suspend SVC.() -> Any?): CapturedInvocation {
41-
var capturedInvocation: CapturedInvocation? = null
42-
36+
suspend fun capture(block: suspend SVC.() -> Any?): CapturedInvocation {
4337
val proxy =
4438
ProxySupport.createProxy(clazz) { invocation ->
45-
capturedInvocation = invocation.captureInvocation(serviceName, key)
46-
47-
// Return COROUTINE_SUSPENDED to prevent actual execution
48-
COROUTINE_SUSPENDED
49-
}
50-
51-
// Invoke the block with the proxy to capture the method call.
52-
// Since the proxy returns COROUTINE_SUSPENDED, we use startCoroutine
53-
// which starts but doesn't block waiting for completion.
54-
val capturingContinuation =
55-
object : Continuation<Any?> {
56-
override val context = EmptyCoroutineContext
57-
58-
override fun resumeWith(result: Result<Any?>) {
59-
// Do nothing - we're just capturing, the coroutine suspends immediately
60-
}
39+
throw invocation.captureInvocation(serviceName, key)
6140
}
6241

63-
val suspendBlock: suspend () -> Any? = { proxy.block() }
64-
suspendBlock.startCoroutine(capturingContinuation)
42+
try {
43+
proxy.block()
44+
} catch (e: CapturedInvocation) {
45+
return e
46+
}
6547

66-
return capturedInvocation
67-
?: error(
68-
"Method invocation was not captured. Make sure to call ONLY a method of the service proxy."
69-
)
48+
error(
49+
"Method invocation was not captured. Make sure to call ONLY a method of the service proxy."
50+
)
7051
}
7152
}

common-kotlin/src/main/kotlin/dev/restate/common/reflection/kotlin/reflections.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ data class CapturedInvocation(
3636
val inputTypeTag: TypeTag<*>,
3737
val outputTypeTag: TypeTag<*>,
3838
val input: Any?,
39-
) {
39+
) : Exception("CapturedInvocation message should not be used", null, false, false) {
4040
@Suppress("UNCHECKED_CAST")
4141
fun toRequest(): Request<*, *> {
4242
return Request.of(target, inputTypeTag as TypeTag<Any?>, outputTypeTag as TypeTag<Any?>, input)

sdk-api-kotlin/src/main/kotlin/dev/restate/sdk/kotlin/api.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1293,7 +1293,7 @@ internal constructor(
12931293
* @return a [KRequest] with the correct response type
12941294
*/
12951295
@Suppress("UNCHECKED_CAST")
1296-
fun <Res> request(block: suspend SVC.() -> Res): KRequest<Any?, Res> {
1296+
suspend fun <Res> request(block: suspend SVC.() -> Res): KRequest<Any?, Res> {
12971297
return KRequestImpl(
12981298
RequestCaptureProxy(clazz, key).capture(block as suspend SVC.() -> Any?).toRequest()
12991299
)

sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/ReflectionTest.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,24 @@ class ReflectionTest : TestDefinitions.TestSuite {
220220
outputCmd(),
221221
END_MESSAGE,
222222
),
223+
testInvocation({ CornerCases() }, "callSuspendWithinProxy")
224+
.withInput(startMessage(1, "mykey"), inputCmd())
225+
.onlyBidiStream()
226+
.expectingOutput(
227+
oneWayCallCmd(
228+
1,
229+
Target.virtualObject(
230+
"CornerCases",
231+
"mykey",
232+
"callSuspendWithinProxy",
233+
),
234+
null,
235+
null,
236+
Slice.EMPTY,
237+
),
238+
outputCmd(),
239+
END_MESSAGE,
240+
),
223241
testInvocation({ CustomSerdeService() }, "echo")
224242
.withInput(startMessage(1), inputCmd(byteArrayOf(1)))
225243
.onlyBidiStream()

sdk-core/src/test/kotlin/dev/restate/sdk/core/kotlinapi/reflections/testClasses.kt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ import dev.restate.serde.SerdeFactory
1515
import dev.restate.serde.TypeRef
1616
import dev.restate.serde.TypeTag
1717
import dev.restate.serde.kotlinx.KotlinSerializationSerdeFactory
18+
import kotlinx.coroutines.delay
1819
import kotlinx.serialization.Serializable
1920

2021
@Service
@@ -112,6 +113,17 @@ open class CornerCases {
112113
open suspend fun badReturnTypeInferred(): Unit {
113114
toVirtualObject<CornerCases>(objectKey()).request { badReturnTypeInferred() }.send()
114115
}
116+
117+
@Exclusive
118+
open suspend fun callSuspendWithinProxy() {
119+
toVirtualObject<CornerCases>(objectKey())
120+
.request {
121+
// Doing a suspend call within the proxy
122+
delay(1)
123+
callSuspendWithinProxy()
124+
}
125+
.send()
126+
}
115127
}
116128

117129
@Service

0 commit comments

Comments
 (0)