Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,34 +9,33 @@ class StreamingClientCallListener[Res](
prefetch: Option[Int],
runtime: Runtime[Any],
call: ZClientCall[?, Res],
queue: Queue[ResponseFrame[Res]],
buffered: Ref[Int]
queue: Queue[ResponseFrame[Res]]
) extends ClientCall.Listener[Res] {
private val increment = if (prefetch.isDefined) buffered.update(_ + 1) else ZIO.unit
private val fetchOne = if (prefetch.isDefined) ZIO.unit else call.request(1)
private val fetchMore = prefetch match {
case None => ZIO.unit
case Some(n) => buffered.get.flatMap(b => call.request(n - b).when(n > b))
}
private val fetchOne =
ZIO.whenDiscard(prefetch.isEmpty)(call.request(1))

private def fetchMore(n: Int) =
ZIO.whenDiscard(prefetch.isDefined)(call.request(n))

private def unsafeRun(task: IO[Any, Unit]): Unit =
Unsafe.unsafe(implicit u => runtime.unsafe.run(task).getOrThrowFiberFailure())

private def handle(promise: Promise[StatusException, Unit])(
chunk: Chunk[ResponseFrame[Res]]
) = (chunk.lastOption match {
case Some(ResponseFrame.Trailers(status, trailers)) =>
val exit = if (status.isOk) Exit.unit else Exit.fail(new StatusException(status, trailers))
promise.done(exit) *> queue.shutdown
case _ =>
buffered.update(_ - chunk.size) *> fetchMore
}).as(chunk)
private def handle(promise: Promise[StatusException, Unit])(chunk: Chunk[ResponseFrame[Res]]) =
ZIO.unlessDiscard(chunk.isEmpty)(chunk.last match {
case ResponseFrame.Trailers(status, trailers) =>
val exit =
if (status.isOk) Exit.unit
else Exit.fail(new StatusException(status, trailers))
promise.done(exit) *> queue.shutdown
case _ =>
fetchMore(chunk.size)
})

override def onHeaders(headers: Metadata): Unit =
unsafeRun(queue.offer(ResponseFrame.Headers(headers)) *> increment)
unsafeRun(queue.offer(ResponseFrame.Headers(headers)).unit)

override def onMessage(message: Res): Unit =
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> increment *> fetchOne)
unsafeRun(queue.offer(ResponseFrame.Message(message)) *> fetchOne)

override def onClose(status: Status, trailers: Metadata): Unit =
unsafeRun(queue.offer(ResponseFrame.Trailers(status, trailers)).unit)
Expand All @@ -45,15 +44,14 @@ class StreamingClientCallListener[Res](
ZStream.fromZIO(Promise.make[StatusException, Unit]).flatMap { promise =>
ZStream
.fromQueue(queue, prefetch.getOrElse(ZStream.DefaultChunkSize))
.mapChunksZIO(handle(promise))
.tapChunks(handle(promise))
.concat(ZStream.execute(promise.await))
}
}

object StreamingClientCallListener {
def make[Res](call: ZClientCall[?, Res], prefetch: Option[Int]): UIO[StreamingClientCallListener[Res]] = for {
runtime <- ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
buffered <- Ref.make(0)
} yield new StreamingClientCallListener(prefetch, runtime, call, queue, buffered)
runtime <- ZIO.runtime[Any]
queue <- Queue.unbounded[ResponseFrame[Res]]
} yield new StreamingClientCallListener(prefetch, runtime, call, queue)
}
41 changes: 20 additions & 21 deletions e2e/protos/src/main/protobuf/testservice.proto
Original file line number Diff line number Diff line change
Expand Up @@ -5,38 +5,37 @@ package scalapb.zio_grpc;
import "scalapb/scalapb.proto";

message Request {
enum Scenario {
OK = 0;
ERROR_NOW = 1; // fail with an error
ERROR_AFTER = 2; // for server streaming, error after two responses
DELAY = 3; // do not return a response. for testing cancellations
DIE = 4; // fail
UNAVAILABLE = 5; // fail with UNAVAILABLE, to test client retries
}
Scenario scenario = 1;
int32 in = 2;
enum Scenario {
OK = 0;
ERROR_NOW = 1; // fail with an error
ERROR_AFTER = 2; // for server streaming, error after two responses
DELAY = 3; // do not return a response, to test cancellations
LARGE_STREAM = 4; // stream of large elements, to test backpressure
DIE = 5; // fail
UNAVAILABLE = 6; // fail with UNAVAILABLE, to test client retries
}
Scenario scenario = 1;
int32 in = 2;
}

message Response {
string out = 1;
}
message Response { string out = 1; }

message ResponseTypeMapped {
option (scalapb.message).type = "scalapb.zio_grpc.WrappedString";
option (scalapb.message).type = "scalapb.zio_grpc.WrappedString";

string out = 1;
string out = 1;
}

service TestService {
rpc Unary(Request) returns (Response);
rpc Unary(Request) returns (Response);

rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped);
rpc UnaryTypeMapped(Request) returns (ResponseTypeMapped);

rpc ServerStreaming(Request) returns (stream Response);
rpc ServerStreaming(Request) returns (stream Response);

rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped);
rpc ServerStreamingTypeMapped(Request) returns (stream ResponseTypeMapped);

rpc ClientStreaming(stream Request) returns (Response);
rpc ClientStreaming(stream Request) returns (Response);

rpc BidiStreaming(stream Request) returns (stream Response);
rpc BidiStreaming(stream Request) returns (stream Response);
}
209 changes: 111 additions & 98 deletions e2e/src/main/scalajvm/scalapb/zio_grpc/TestServiceImpl.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
package scalapb.zio_grpc

import scalapb.zio_grpc.testservice.Request
import zio.{Clock, Console, Exit, Promise, ZIO, ZLayer}
import zio._
import scalapb.zio_grpc.testservice.Response
import io.grpc.{Status, StatusException}
import scalapb.zio_grpc.testservice.Request.Scenario
import zio.stream.{Stream, ZStream}
import zio.ZEnvironment

import java.util.concurrent.atomic.AtomicInteger

package object server {

Expand All @@ -19,15 +16,19 @@ package object server {
object TestServiceImpl {

class Service(
requestReceived: zio.Promise[Nothing, Unit],
delayReceived: zio.Promise[Nothing, Unit],
exit: zio.Promise[Nothing, Exit[StatusException, Response]]
requestReceived: Promise[Nothing, Unit],
delayReceived: Promise[Nothing, Unit],
exit: Promise[Nothing, Exit[StatusException, Response]],
responseCounter: Ref[Int],
rpcCounter: Ref[Int]
)(clock: Clock, console: Console)
extends testservice.ZioTestservice.TestService {
val rpcRunsCounter: AtomicInteger = new AtomicInteger(0)

def unary(request: Request): ZIO[Any, StatusException, Response] =
(requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet()) *> (request.scenario match {
// A response of size 100KB to saturate the byte buffers and observe backpressure.
private val largeResponse = Response("*" * 100000)

def unary(request: Request): ZIO[Any, StatusException, Response] = {
def run(request: Request) = request.scenario match {
case Scenario.OK =>
ZIO.succeed(Response(out = "Res" + request.in.toString))
case Scenario.ERROR_NOW =>
Expand All @@ -37,119 +38,128 @@ package object server {
case Scenario.DIE =>
ZIO.die(new RuntimeException("FOO"))
case Scenario.UNAVAILABLE =>
ZIO.fail(Status.UNAVAILABLE.withDescription(rpcRunsCounter.get().toString).asException())
rpcCounter.get.flatMap[Any, StatusException, Nothing] { n =>
ZIO.fail(Status.UNAVAILABLE.withDescription(n.toString).asException())
}
case _ =>
ZIO.fail(Status.UNKNOWN.asException())
})).onExit(exit.succeed(_))
}

(requestReceived.succeed(()) *> rpcCounter.incrementAndGet *> run(request) <* responseCounter.incrementAndGet)
.onExit(exit.succeed(_))
}

def unaryTypeMapped(request: Request): ZIO[Any, StatusException, WrappedString] =
unary(request).map(r => WrappedString(r.out))

def serverStreaming(
request: Request
): ZStream[Any, StatusException, Response] =
ZStream
.acquireReleaseExitWith(requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet())) {
(_, ex) =>
ex.foldExit(
failed =>
if (failed.isInterrupted || failed.isInterruptedOnly)
exit.succeed(Exit.fail(Status.CANCELLED.asException()))
else exit.succeed(Exit.fail(Status.UNKNOWN.asException())),
_ => exit.succeed(Exit.succeed(Response()))
)
}
.flatMap { _ =>
request.scenario match {
case Scenario.OK =>
ZStream(Response(out = "X1"), Response(out = "X2"))
case Scenario.ERROR_NOW =>
ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException())
case Scenario.ERROR_AFTER =>
ZStream(Response(out = "X1"), Response(out = "X2")) ++ ZStream
.fail(
Status.INTERNAL.withDescription("FOO!").asException()
)
case Scenario.DELAY =>
ZStream(
Response(out = "X1"),
Response(out = "X2")
) ++ ZStream.never
case Scenario.DIE => ZStream.die(new RuntimeException("FOO"))
case _ => ZStream.fail(Status.UNKNOWN.asException())
}
}
def serverStreaming(request: Request): ZStream[Any, StatusException, Response] = {
def run(request: Request) = request.scenario match {
case Scenario.OK =>
ZStream(Response(out = "X1"), Response(out = "X2"))
case Scenario.ERROR_NOW =>
ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException())
case Scenario.ERROR_AFTER =>
ZStream(Response(out = "X1"), Response(out = "X2")) ++
ZStream.fail(Status.INTERNAL.withDescription("FOO!").asException())
case Scenario.DELAY =>
ZStream(Response(out = "X1"), Response(out = "X2")) ++ ZStream.never
case Scenario.LARGE_STREAM =>
ZStream.fromIterator(Iterator.fill(100)(largeResponse), 1).orDie
case Scenario.DIE =>
ZStream.die(new RuntimeException("FOO"))
case _ =>
ZStream.fail(Status.UNKNOWN.asException())
}

ZStream.acquireReleaseExitWith(requestReceived.succeed(()) *> rpcCounter.incrementAndGet) { (_, ex) =>
ex.foldExit(
{ failed =>
val status = if (failed.isInterrupted) Status.CANCELLED else Status.UNKNOWN
exit.succeed(Exit.fail(status.asException))
},
_ => exit.succeed(Exit.succeed(Response()))
)
} *> run(request).tapChunks(chunk => responseCounter.update(_ + chunk.size))
}

def serverStreamingTypeMapped(request: Request): ZStream[Any, StatusException, WrappedString] =
serverStreaming(request).map(r => WrappedString(r.out))

def clientStreaming(
request: Stream[StatusException, Request]
): ZIO[Any, StatusException, Response] =
requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet()) *>
request
.runFoldZIO(0)((state, req) =>
req.scenario match {
case Scenario.OK => ZIO.succeed(state + req.in)
case Scenario.DELAY => delayReceived.succeed(()) *> ZIO.never
case Scenario.DIE => ZIO.die(new RuntimeException("foo"))
case Scenario.ERROR_NOW => ZIO.fail((Status.INTERNAL.withDescription("InternalError").asException()))
case _: Scenario => ZIO.fail(Status.UNKNOWN.asException())
}
)
.map(r => Response(r.toString))
.onExit(exit.succeed(_))
def clientStreaming(request: Stream[StatusException, Request]): ZIO[Any, StatusException, Response] = {
def run(state: Int, request: Request) = request.scenario match {
case Scenario.OK =>
ZIO.succeed(state + request.in)
case Scenario.DELAY =>
delayReceived.succeed(()) *> ZIO.never
case Scenario.DIE =>
ZIO.die(new RuntimeException("foo"))
case Scenario.ERROR_NOW =>
ZIO.fail(Status.INTERNAL.withDescription("InternalError").asException())
case _: Scenario =>
ZIO.fail(Status.UNKNOWN.asException())
}

requestReceived.succeed(()) *> rpcCounter.incrementAndGet *> request
.runFoldZIO(0)(run)
.map(r => Response(r.toString))
.zipLeft(responseCounter.incrementAndGet)
.onExit(exit.succeed(_))
}

def bidiStreaming(
request: Stream[StatusException, Request]
): Stream[StatusException, Response] =
(ZStream.fromZIO(requestReceived.succeed(()) *> ZIO.succeed(rpcRunsCounter.incrementAndGet())).drain ++
(request.flatMap { r =>
r.scenario match {
case Scenario.OK =>
ZStream(Response(r.in.toString))
.repeat(Schedule.recurs(r.in - 1))
case Scenario.DELAY => ZStream.never
case Scenario.DIE => ZStream.die(new RuntimeException("FOO"))
case Scenario.ERROR_NOW =>
ZStream.fail(Status.INTERNAL.withDescription("Intentional error").asException())
case _ =>
ZStream.fail(
Status.INVALID_ARGUMENT.withDescription(s"Got request: ${r.toProtoString}").asException()
)
}
} ++ ZStream(Response("DONE")))
.ensuring(exit.succeed(Exit.succeed(Response()))))
.provideEnvironment(ZEnvironment(clock, console))
): Stream[StatusException, Response] = {
def run(request: Request) = request.scenario match {
case Scenario.OK =>
ZStream(Response(request.in.toString)).repeat(Schedule.recurs(request.in - 1))
case Scenario.DELAY =>
ZStream.never
case Scenario.DIE =>
ZStream.die(new RuntimeException("FOO"))
case Scenario.ERROR_NOW =>
ZStream.fail(Status.INTERNAL.withDescription("Intentional error").asException())
case _ =>
ZStream.fail(
Status.INVALID_ARGUMENT.withDescription(s"Got request: ${request.toProtoString}").asException()
)
}

def awaitReceived = requestReceived.await
ZStream.execute(requestReceived.succeed(()) *> rpcCounter.incrementAndGet) ++ request
.flatMap(run)
.concat(ZStream(Response("DONE")))
.tapChunks(chunk => responseCounter.update(_ + chunk.size))
.ensuring(exit.succeed(Exit.succeed(Response())))
.provideEnvironment(ZEnvironment(clock, console))
}

def awaitReceived = requestReceived.await
def awaitDelayReceived = delayReceived.await

def awaitExit = exit.await
def awaitExit = exit.await
def responsesSent = responseCounter.get
}

def make(
clock: Clock,
console: Console
): zio.IO[Nothing, TestServiceImpl.Service] =
for {
p1 <- Promise.make[Nothing, Unit]
p2 <- Promise.make[Nothing, Unit]
p3 <- Promise.make[Nothing, Exit[StatusException, Response]]
} yield new Service(p1, p2, p3)(clock, console)

def makeFromEnv: ZIO[Any, Nothing, Service] =
for {
clock <- ZIO.clock
console <- ZIO.console
service <- make(clock, console)
} yield service
): IO[Nothing, TestServiceImpl.Service] = for {
p1 <- Promise.make[Nothing, Unit]
p2 <- Promise.make[Nothing, Unit]
p3 <- Promise.make[Nothing, Exit[StatusException, Response]]
c1 <- Ref.make(0)
c2 <- Ref.make(0)
} yield new Service(p1, p2, p3, c1, c2)(clock, console)

def makeFromEnv: ZIO[Any, Nothing, Service] = for {
clock <- ZIO.clock
console <- ZIO.console
service <- make(clock, console)
} yield service

val live: ZLayer[Any, Nothing, TestServiceImpl] =
ZLayer.scoped(makeFromEnv)

val any: ZLayer[TestServiceImpl, Nothing, TestServiceImpl] = ZLayer.environment
val any: ZLayer[TestServiceImpl, Nothing, TestServiceImpl] =
ZLayer.environment

def awaitReceived: ZIO[TestServiceImpl, Nothing, Unit] =
ZIO.environmentWithZIO(_.get.awaitReceived)
Expand All @@ -159,5 +169,8 @@ package object server {

def awaitExit: ZIO[TestServiceImpl, Nothing, Exit[StatusException, Response]] =
ZIO.environmentWithZIO(_.get.awaitExit)

def responsesSent: ZIO[TestServiceImpl, Nothing, Int] =
ZIO.environmentWithZIO(_.get.responsesSent)
}
}
Loading
Loading