Skip to content

Commit

Permalink
chore: Tweak withAttribuets in Flow (#1658)
Browse files Browse the repository at this point in the history
  • Loading branch information
He-Pin authored Feb 16, 2025
1 parent 96f70c4 commit 4252382
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ import scala.util.control.NoStackTrace

import org.apache.pekko
import pekko.Done
import pekko.stream.AbruptStageTerminationException
import pekko.stream.ActorAttributes
import pekko.stream.ActorMaterializer
import pekko.stream.Supervision
import pekko.stream.{ AbruptStageTerminationException, ActorAttributes, ActorMaterializer, ClosedShape, Supervision }
import pekko.stream.testkit.StreamSpec
import pekko.stream.testkit.TestSubscriber
import pekko.stream.testkit.Utils.TE
Expand Down Expand Up @@ -434,4 +431,28 @@ class FlowStatefulMapSpec extends StreamSpec {
closedCounter.get() shouldBe 1
}
}

"support junction output ports" in {
val source = Source(List((1, 1), (2, 2)))
val g = RunnableGraph.fromGraph(GraphDSL.createGraph(TestSink.probe[(Int, Int)]) { implicit b => sink =>
import GraphDSL.Implicits._
val unzip = b.add(Unzip[Int, Int]())
val zip = b.add(Zip[Int, Int]())
val s = b.add(source)
// format: OFF
s ~> unzip.in
unzip.out0 ~> zip.in0
unzip.out1 ~> zip.in1
zip.out.statefulMap(() => None)((_, elem) => (None, elem), _ => None) ~> sink.in
// format: ON

ClosedShape
})
g.run()
.request(2)
.expectNext((1, 1))
.expectNext((2, 2))
.expectComplete()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,23 @@

package org.apache.pekko.stream

import org.apache.pekko.actor.typed.ActorSystem
import org.apache.pekko.actor.typed.scaladsl.Behaviors
import org.apache.pekko.stream.scaladsl.{ Flow, FlowWithContext, Keep, Sink, Source, SourceWithContext }
import org.apache.pekko
import pekko.actor.typed.ActorSystem
import pekko.actor.typed.scaladsl.Behaviors
import pekko.stream.scaladsl.{
Flow,
FlowWithContext,
GraphDSL,
Keep,
RunnableGraph,
Sink,
Source,
SourceWithContext,
Unzip,
Zip
}
import pekko.stream.testkit.scaladsl.TestSink

import org.scalacheck.{ Arbitrary, Gen }
import org.scalatest.BeforeAndAfterAll
import org.scalatest.concurrent.ScalaFutures
Expand All @@ -29,6 +43,7 @@ import org.scalatestplus.scalacheck.ScalaCheckDrivenPropertyChecks

import java.time.Instant
import java.util.concurrent.Executors

import scala.annotation.nowarn
import scala.concurrent.duration.{ DurationInt, FiniteDuration }
import scala.concurrent.{ blocking, ExecutionContext, Future }
Expand Down Expand Up @@ -439,6 +454,52 @@ class MapAsyncPartitionedSpec
.futureValue shouldBe Seq(1 -> "A")
}

it should "support junction output ports with mapAsyncPartitioned" in {
val source = Source(List((1, 1), (2, 2)))
val g = RunnableGraph.fromGraph(GraphDSL.createGraph(TestSink.probe[(Int, Int)](system.classicSystem)) {
implicit b => sink =>
import GraphDSL.Implicits._
val unzip = b.add(Unzip[Int, Int]())
val zip = b.add(Zip[Int, Int]())
val s = b.add(source)
// format: OFF
s ~> unzip.in
unzip.out0 ~> zip.in0
unzip.out1 ~> zip.in1
zip.out.mapAsyncPartitioned(1)(_ => 1)((elem, _) => Future.successful(elem)) ~> sink.in
// format: ON
ClosedShape
})
g.run()
.request(2)
.expectNext((1, 1))
.expectNext((2, 2))
.expectComplete()
}

it should "support junction output ports with mapAsyncPartitionedUnordered" in {
val source = Source(List((1, 1), (2, 2)))
val g = RunnableGraph.fromGraph(GraphDSL.createGraph(TestSink.probe[(Int, Int)](system.classicSystem)) {
implicit b => sink =>
import GraphDSL.Implicits._
val unzip = b.add(Unzip[Int, Int]())
val zip = b.add(Zip[Int, Int]())
val s = b.add(source)
// format: OFF
s ~> unzip.in
unzip.out0 ~> zip.in0
unzip.out1 ~> zip.in1
zip.out.mapAsyncPartitionedUnordered(1)(_ => 1)((elem, _) => Future.successful(elem)) ~> sink.in
// format: ON
ClosedShape
})
g.run()
.request(2)
.expectNext((1, 1))
.expectNext((2, 2))
.expectComplete()
}

private implicit class MapWrapper[K, V](map: Map[K, V]) {
@nowarn("msg=deprecated")
def mapValues2[W](f: V => W) = map.mapValues(f)
Expand Down
26 changes: 14 additions & 12 deletions stream/src/main/scala/org/apache/pekko/stream/scaladsl/Flow.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,8 @@ trait FlowOps[+Out, +Mat] {
* @param onComplete a function that transforms the ongoing state into an optional output element
*/
def statefulMap[S, T](create: () => S)(f: (S, Out) => (S, T), onComplete: S => Option[T]): Repr[T] =
via(new StatefulMap[S, Out, T](create, f, onComplete).withAttributes(DefaultAttributes.statefulMap))
via(new StatefulMap[S, Out, T](create, f, onComplete)
.withAttributes(DefaultAttributes.statefulMap and SourceLocation.forLambda(f)))

/**
* Transform each stream element with the help of a resource.
Expand Down Expand Up @@ -1358,12 +1359,12 @@ trait FlowOps[+Out, +Mat] {
def mapAsyncPartitioned[T, P](parallelism: Int)(
partitioner: Out => P)(
f: (Out, P) => Future[T]): Repr[T] = {
(if (parallelism == 1) {
via(MapAsyncUnordered(1, elem => f(elem, partitioner(elem))))
} else {
via(new MapAsyncPartitioned(parallelism, orderedOutput = true, partitioner, f))
})
.withAttributes(DefaultAttributes.mapAsyncPartition and SourceLocation.forLambda(f))
val graph: Graph[FlowShape[Out, T], _] = if (parallelism == 1) {
MapAsyncUnordered(1, elem => f(elem, partitioner(elem)))
} else {
new MapAsyncPartitioned(parallelism, orderedOutput = true, partitioner, f)
}
via(graph.withAttributes(DefaultAttributes.mapAsyncPartition and SourceLocation.forLambda(f)))
}

/**
Expand Down Expand Up @@ -1396,11 +1397,12 @@ trait FlowOps[+Out, +Mat] {
def mapAsyncPartitionedUnordered[T, P](parallelism: Int)(
partitioner: Out => P)(
f: (Out, P) => Future[T]): Repr[T] = {
(if (parallelism == 1) {
via(MapAsyncUnordered(1, elem => f(elem, partitioner(elem))))
} else {
via(new MapAsyncPartitioned(parallelism, orderedOutput = false, partitioner, f))
}).withAttributes(DefaultAttributes.mapAsyncPartitionUnordered and SourceLocation.forLambda(f))
val graph: Graph[FlowShape[Out, T], _] = if (parallelism == 1) {
MapAsyncUnordered(1, elem => f(elem, partitioner(elem)))
} else {
new MapAsyncPartitioned(parallelism, orderedOutput = false, partitioner, f)
}
via(graph.withAttributes(DefaultAttributes.mapAsyncPartition and SourceLocation.forLambda(f)))
}

/**
Expand Down

0 comments on commit 4252382

Please sign in to comment.