Skip to content

Commit

Permalink
[SPARK-47553][SS] Add Java support for transformWithState operator APIs
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Add Java support for transformWithState operator APIs

### Why are the changes needed?
To add support for using transformWithState operator in Java

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
Added unit tests

```
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testGroupByKey() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testCollect() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testKryoEncoderErrorMessageForPrivateClass() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testJavaBeanEncoder() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testTupleEncoder() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testPeriodEncoder() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testRowEncoder() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testNestedTupleEncoder() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testTupleEncoderSchema() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testMappingFunctionWithTestGroupState() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testReduce() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testSelect() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testInitialStateFlatMapGroupsWithState() started
[info] Test test.org.apache.spark.sql.JavaDatasetSuite#testJavaEncoderErrorMessageForPrivateClass() started
[info] Test run finished: 0 failed, 0 ignored, 45 total, 14.73s
[info] Passed: Total 45, Failed 0, Errors 0, Passed 45
[success] Total time: 20 s, completed Mar 28, 2024, 12:37:30 PM
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes #45758 from anishshri-db/task/SPARK-47553.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Apr 3, 2024
1 parent a427a45 commit 1515c56
Show file tree
Hide file tree
Showing 7 changed files with 479 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.ProductEncoder
import org.apache.spark.sql.connect.common.UdfUtils
import org.apache.spark.sql.expressions.ScalarUserDefinedFunction
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, TimeoutMode}
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout, OutputMode, StatefulProcessor, StatefulProcessorWithInitialState, TimeoutMode}

/**
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
Expand Down Expand Up @@ -818,8 +818,7 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
/**
* (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
* API v2. We allow the user to act on per-group set of input rows along with keyed state and
* the user can choose to output/return 0 or more rows. For a static/batch dataset, this
* operator is not supported and will throw an exception. For a streaming dataframe, we will
* the user can choose to output/return 0 or more rows. For a streaming dataframe, we will
* repeatedly invoke the interface methods for new rows in each trigger and the user's
* state/state variables will be stored persistently across invocations. Currently this operator
* is not supported with Spark Connect.
Expand All @@ -831,12 +830,103 @@ class KeyValueGroupedDataset[K, V] private[sql] () extends Serializable {
* @param timeoutMode
* The timeout mode of the stateful processor.
* @param outputMode
* The output mode of the stateful processor. Defaults to APPEND mode.
* The output mode of the stateful processor.
*/
def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeoutMode: TimeoutMode,
outputMode: OutputMode = OutputMode.Append()): Dataset[U] = {
outputMode: OutputMode): Dataset[U] = {
throw new UnsupportedOperationException
}

/**
* (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
* v2. We allow the user to act on per-group set of input rows along with keyed state and the
* user can choose to output/return 0 or more rows. For a streaming dataframe, we will
* repeatedly invoke the interface methods for new rows in each trigger and the user's
* state/state variables will be stored persistently across invocations. Currently this operator
* is not supported with Spark Connect.
*
* @tparam U
* The type of the output objects. Must be encodable to Spark SQL types.
* @param statefulProcessor
* Instance of statefulProcessor whose functions will be invoked by the operator.
* @param timeoutMode
* The timeout mode of the stateful processor.
* @param outputMode
* The output mode of the stateful processor.
* @param outputEncoder
* Encoder for the output type.
*/
def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
outputEncoder: Encoder[U]): Dataset[U] = {
throw new UnsupportedOperationException
}

/**
* (Scala-specific) Invokes methods defined in the stateful processor used in arbitrary state
* API v2. Functions as the function above, but with additional initial state. Currently this
* operator is not supported with Spark Connect.
*
* @tparam U
* The type of the output objects. Must be encodable to Spark SQL types.
* @tparam S
* The type of initial state objects. Must be encodable to Spark SQL types.
* @param statefulProcessor
* Instance of statefulProcessor whose functions will be invoked by the operator.
* @param timeoutMode
* The timeout mode of the stateful processor.
* @param outputMode
* The output mode of the stateful processor.
* @param initialState
* User provided initial state that will be used to initiate state for the query in the first
* batch.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S]): Dataset[U] = {
throw new UnsupportedOperationException
}

/**
* (Java-specific) Invokes methods defined in the stateful processor used in arbitrary state API
* v2. Functions as the function above, but with additional initial state. Currently this
* operator is not supported with Spark Connect.
*
* @tparam U
* The type of the output objects. Must be encodable to Spark SQL types.
* @tparam S
* The type of initial state objects. Must be encodable to Spark SQL types.
* @param statefulProcessor
* Instance of statefulProcessor whose functions will be invoked by the operator.
* @param timeoutMode
* The timeout mode of the stateful processor.
* @param outputMode
* The output mode of the stateful processor.
* @param initialState
* User provided initial state that will be used to initiate state for the query in the first
* batch.
* @param outputEncoder
* Encoder for the output type.
* @param initialStateEncoder
* Encoder for the initial state type.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S],
outputEncoder: Encoder[U],
initialStateEncoder: Encoder[S]): Dataset[U] = {
throw new UnsupportedOperationException
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.spark.sql.errors.ExecutionErrors
*/
@Experimental
@Evolving
private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {

/**
* Handle to the stateful processor that provides access to the state store and other
Expand Down Expand Up @@ -100,7 +100,7 @@ private[sql] trait StatefulProcessor[K, I, O] extends Serializable {
*/
@Experimental
@Evolving
private[sql] trait StatefulProcessorWithInitialState[K, I, O, S]
private[sql] abstract class StatefulProcessorWithInitialState[K, I, O, S]
extends StatefulProcessor[K, I, O] {

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ class KeyValueGroupedDataset[K, V] private[sql](
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
* We allow the user to act on per-group set of input rows along with keyed state and the
* user can choose to output/return 0 or more rows.
* For a static/batch dataset, this operator is not supported and will throw an exception.
* For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
* in each trigger and the user's state/state variables will be stored persistently across
* invocations.
Expand All @@ -656,13 +655,14 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
* operator.
* @param timeoutMode The timeout mode of the stateful processor.
* @param outputMode The output mode of the stateful processor. Defaults to APPEND mode.
* @param outputMode The output mode of the stateful processor.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeoutMode: TimeoutMode,
outputMode: OutputMode = OutputMode.Append()): Dataset[U] = {
outputMode: OutputMode): Dataset[U] = {
Dataset[U](
sparkSession,
TransformWithState[K, V, U](
Expand All @@ -676,6 +676,32 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}

/**
* (Java-specific)
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
* We allow the user to act on per-group set of input rows along with keyed state and the
* user can choose to output/return 0 or more rows.
* For a streaming dataframe, we will repeatedly invoke the interface methods for new rows
* in each trigger and the user's state/state variables will be stored persistently across
* invocations.
*
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @param statefulProcessor Instance of statefulProcessor whose functions will be invoked by the
* operator.
* @param timeoutMode The timeout mode of the stateful processor.
* @param outputMode The output mode of the stateful processor.
* @param outputEncoder Encoder for the output type.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
private[sql] def transformWithState[U: Encoder](
statefulProcessor: StatefulProcessor[K, V, U],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
outputEncoder: Encoder[U]): Dataset[U] = {
transformWithState(statefulProcessor, timeoutMode, outputMode)(outputEncoder)
}

/**
* (Scala-specific)
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
Expand All @@ -686,10 +712,11 @@ class KeyValueGroupedDataset[K, V] private[sql](
* @param statefulProcessor Instance of statefulProcessor whose functions will
* be invoked by the operator.
* @param timeoutMode The timeout mode of the stateful processor.
* @param outputMode The output mode of the stateful processor. Defaults to APPEND mode.
* @param outputMode The output mode of the stateful processor.
* @param initialState User provided initial state that will be used to initiate state for
* the query in the first batch.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
Expand All @@ -712,6 +739,35 @@ class KeyValueGroupedDataset[K, V] private[sql](
)
}

/**
* (Java-specific)
* Invokes methods defined in the stateful processor used in arbitrary state API v2.
* Functions as the function above, but with additional initial state.
*
* @tparam U The type of the output objects. Must be encodable to Spark SQL types.
* @tparam S The type of initial state objects. Must be encodable to Spark SQL types.
* @param statefulProcessor Instance of statefulProcessor whose functions will
* be invoked by the operator.
* @param timeoutMode The timeout mode of the stateful processor.
* @param outputMode The output mode of the stateful processor.
* @param initialState User provided initial state that will be used to initiate state for
* the query in the first batch.
* @param outputEncoder Encoder for the output type.
* @param initialStateEncoder Encoder for the initial state type.
*
* See [[Encoder]] for more details on what types are encodable to Spark SQL.
*/
private[sql] def transformWithState[U: Encoder, S: Encoder](
statefulProcessor: StatefulProcessorWithInitialState[K, V, U, S],
timeoutMode: TimeoutMode,
outputMode: OutputMode,
initialState: KeyValueGroupedDataset[K, S],
outputEncoder: Encoder[U],
initialStateEncoder: Encoder[S]): Dataset[U] = {
transformWithState(statefulProcessor, timeoutMode,
outputMode, initialState)(outputEncoder, initialStateEncoder)
}

/**
* (Scala-specific)
* Reduces the elements of each group of data using the specified binary function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@
import javax.annotation.Nonnull;

import org.apache.spark.api.java.Optional;
import org.apache.spark.sql.streaming.GroupStateTimeout;
import org.apache.spark.sql.streaming.OutputMode;
import org.apache.spark.sql.streaming.*;
import scala.Tuple2;
import scala.Tuple3;
import scala.Tuple4;
import scala.Tuple5;

import com.google.common.base.Objects;
import org.apache.spark.sql.streaming.TestGroupState;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeEach;
Expand Down Expand Up @@ -185,6 +183,39 @@ public void testReduce() {
Assertions.assertEquals(6, reduced);
}

@Test
public void testInitialStateForTransformWithState() {
List<String> data = Arrays.asList("a", "xy", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
Dataset<Tuple2<Integer, String>> initialStateDS = spark.createDataset(
Arrays.asList(new Tuple2<Integer, String>(2, "pq")),
Encoders.tuple(Encoders.INT(), Encoders.STRING())
);

KeyValueGroupedDataset<Integer, Tuple2<Integer, String>> kvInitStateDS =
initialStateDS.groupByKey(
(MapFunction<Tuple2<Integer, String>, Integer>) f -> f._1, Encoders.INT());

KeyValueGroupedDataset<Integer, String> kvInitStateMappedDS = kvInitStateDS.mapValues(
(MapFunction<Tuple2<Integer, String>, String>) f -> f._2,
Encoders.STRING()
);

KeyValueGroupedDataset<Integer, String> grouped =
ds.groupByKey((MapFunction<String, Integer>) String::length, Encoders.INT());

Dataset<String> transformWithStateMapped = grouped.transformWithState(
new TestStatefulProcessorWithInitialState(),
TimeoutMode.NoTimeouts(),
OutputMode.Append(),
kvInitStateMappedDS,
Encoders.STRING(),
Encoders.STRING());

Assertions.assertEquals(asSet("1a", "2pqxy", "3foobar"),
toSet(transformWithStateMapped.collectAsList()));
}

@Test
public void testInitialStateFlatMapGroupsWithState() {
List<String> data = Arrays.asList("a", "foo", "bar");
Expand Down Expand Up @@ -320,6 +351,24 @@ public void testMappingFunctionWithTestGroupState() throws Exception {
Assertions.assertFalse(prevState.exists());
}

@Test
public void testTransformWithState() {
List<String> data = Arrays.asList("a", "foo", "bar");
Dataset<String> ds = spark.createDataset(data, Encoders.STRING());
KeyValueGroupedDataset<Integer, String> grouped =
ds.groupByKey((MapFunction<String, Integer>) String::length, Encoders.INT());

StatefulProcessor<Integer, String, String> testStatefulProcessor = new TestStatefulProcessor();
Dataset<String> transformWithStateMapped = grouped.transformWithState(
testStatefulProcessor,
TimeoutMode.NoTimeouts(),
OutputMode.Append(),
Encoders.STRING());

Assertions.assertEquals(asSet("1a", "3foobar"),
toSet(transformWithStateMapped.collectAsList()));
}

@Test
public void testGroupByKey() {
List<String> data = Arrays.asList("a", "foo", "bar");
Expand Down
Loading

0 comments on commit 1515c56

Please sign in to comment.