From 1ae852b181d14251d752680fe8c041e0e7778570 Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Mon, 14 Apr 2025 17:08:26 +0200 Subject: [PATCH 1/2] Optimize groupb y for single partition topics Group by operations on topics with a single partition are now optimized to avoid creating a repartition topic. Instead, the messages are directly transformed to use the new key, as all messages go to the same partition. --- quixstreams/dataframe/dataframe.py | 41 ++- quixstreams/dataframe/registry.py | 21 +- tests/test_quixstreams/test_app.py | 47 ++- .../test_dataframe/test_dataframe.py | 278 ++++++++++++------ 4 files changed, 281 insertions(+), 106 deletions(-) diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 31b7aaf8e..e15d9b0d8 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -92,7 +92,7 @@ class StreamingDataFrame: What it Does: - Builds a data processing pipeline, declaratively (not executed immediately) - - Executes this pipeline on inputs at runtime (Kafka message values) + - Executes this pipeline on inputs at runtime (Kafka message values) - Provides functions/interface similar to Pandas Dataframes/Series - Enables stateful processing (and manages everything related to it) @@ -135,15 +135,20 @@ def __init__( registry: DataFrameRegistry, processing_context: ProcessingContext, stream: Optional[Stream] = None, + stream_id: Optional[str] = None, ): if not topics: raise ValueError("At least one Topic must be passed") - self._stream: Stream = stream or Stream() # Implicitly deduplicate Topic objects into a tuple and sort them by name self._topics: tuple[Topic, ...] = tuple( sorted({t.name: t for t in topics}.values(), key=attrgetter("name")) ) + + self._stream: Stream = stream or Stream() + self._stream_id: str = stream_id or topic_manager.stream_id_from_topics( + self.topics + ) self._topic_manager = topic_manager self._registry = registry self._processing_context = processing_context @@ -174,7 +179,7 @@ def stream_id(self) -> str: By default, a topic name or a combination of topic names are used as `stream_id`. """ - return self._topic_manager.stream_id_from_topics(self.topics) + return self._stream_id @property def topics(self) -> tuple[Topic, ...]: @@ -591,6 +596,11 @@ def func(d: dict, state: State): # Generate a config for the new repartition topic based on the underlying topics repartition_config = self._topic_manager.derive_topic_config(self._topics) + # If the topic has only one partition, we don't need a repartition topic + # we can directly change the messages key as they all go to the same partition. + if repartition_config.num_partitions == 1: + return self._single_partition_groupby(operation, key) + groupby_topic = self._topic_manager.repartition_topic( operation=operation, stream_id=self.stream_id, @@ -606,6 +616,29 @@ def func(d: dict, state: State): self._registry.register_groupby(source_sdf=self, new_sdf=groupby_sdf) return groupby_sdf + def _single_partition_groupby( + self, operation: str, key: Union[str, Callable[[Any], Any]] + ) -> "StreamingDataFrame": + if isinstance(key, str): + + def _callback(value, _, timestamp, headers): + return value, value[key], timestamp, headers + else: + + def _callback(value, _, timestamp, headers): + return value, key(value), timestamp, headers + + stream = self.stream.add_transform(_callback, expand=False) + + groupby_sdf = self.__dataframe_clone__( + stream=stream, stream_id=f"{self.stream_id}--groupby--{operation}" + ) + self._registry.register_groupby( + source_sdf=self, new_sdf=groupby_sdf, register_new_root=False + ) + + return groupby_sdf + def contains(self, keys: Union[str, list[str]]) -> StreamingSeries: """ Check if keys are present in the Row value. @@ -1679,6 +1712,7 @@ def __dataframe_clone__( self, *topics: Topic, stream: Optional[Stream] = None, + stream_id: Optional[str] = None, ) -> "StreamingDataFrame": """ Clone the StreamingDataFrame with a new `stream`, `topics`, @@ -1692,6 +1726,7 @@ def __dataframe_clone__( clone = self.__class__( *(topics or self._topics), stream=stream, + stream_id=stream_id, processing_context=self._processing_context, topic_manager=self._topic_manager, registry=self._registry, diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index faf8750e2..1d534336c 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -70,7 +70,10 @@ def register_root( self._registry[topic.name] = dataframe.stream def register_groupby( - self, source_sdf: "StreamingDataFrame", new_sdf: "StreamingDataFrame" + self, + source_sdf: "StreamingDataFrame", + new_sdf: "StreamingDataFrame", + register_new_root: bool = True, ): """ Register a "groupby" SDF, which is one generated with `SDF.group_by()`. @@ -81,16 +84,26 @@ def register_groupby( raise GroupByNestingLimit( "Subsequent (nested) `SDF.group_by()` operations are not allowed." ) - try: - self.register_root(new_sdf) - except StreamingDataFrameDuplicate: + + if new_sdf.stream_id in self._repartition_origins: raise GroupByDuplicate( "A `SDF.group_by()` operation appears to be the same as another, " "either from using the same column or name parameter; " "adjust by setting a unique name with `SDF.group_by(name=)` " ) + self._repartition_origins.add(new_sdf.stream_id) + if register_new_root: + try: + self.register_root(new_sdf) + except StreamingDataFrameDuplicate: + raise GroupByDuplicate( + "A `SDF.group_by()` operation appears to be the same as another, " + "either from using the same column or name parameter; " + "adjust by setting a unique name with `SDF.group_by(name=)` " + ) + def compose_all( self, sink: Optional[VoidExecutor] = None ) -> dict[str, VoidExecutor]: diff --git a/tests/test_quixstreams/test_app.py b/tests/test_quixstreams/test_app.py index d5a33533b..329d4367f 100644 --- a/tests/test_quixstreams/test_app.py +++ b/tests/test_quixstreams/test_app.py @@ -513,12 +513,14 @@ def test_state_dir_env(self): assert app.config.state_dir == Path("/path/to/other") +@pytest.mark.parametrize("number_of_partitions", [1, 2]) class TestAppGroupBy: def test_group_by( self, app_factory, internal_consumer_factory, executor, + number_of_partitions, ): """ Test that StreamingDataFrame processes 6 messages from Kafka and groups them @@ -540,8 +542,12 @@ def on_message_processed(*_): timestamp_ms = int(time.time() * 1000) user_id = "abc123" value_in = {"user": user_id} - expected_message_count = 1 - total_messages = expected_message_count * 2 # groupby reproduces each message + + if number_of_partitions == 1: + total_messages = 1 # groupby optimisation for 1 partition + else: + total_messages = 2 # groupby reproduces each message + app = app_factory( auto_offset_reset="earliest", on_message_processed=on_message_processed, @@ -551,6 +557,9 @@ def on_message_processed(*_): str(uuid.uuid4()), value_deserializer="json", value_serializer="json", + config=TopicConfig( + num_partitions=number_of_partitions, replication_factor=1 + ), ) app_topic_out = app.topic( str(uuid.uuid4()), @@ -607,6 +616,7 @@ def test_group_by_with_window( internal_consumer_factory, executor, processing_guarantee, + number_of_partitions, ): """ Test that StreamingDataFrame processes 6 messages from Kafka and groups them @@ -631,8 +641,12 @@ def on_message_processed(*_): timestamp_ms = timestamp_ms - (timestamp_ms % window_duration_ms) user_id = "abc123" value_in = {"user": user_id} - expected_message_count = 1 - total_messages = expected_message_count * 2 # groupby reproduces each message + + if number_of_partitions == 1: + total_messages = 1 # groupby optimisation for 1 partition + else: + total_messages = 2 # groupby reproduces each message + app = app_factory( auto_offset_reset="earliest", on_message_processed=on_message_processed, @@ -643,6 +657,9 @@ def on_message_processed(*_): str(uuid.uuid4()), value_deserializer="json", value_serializer="json", + config=TopicConfig( + num_partitions=number_of_partitions, replication_factor=1 + ), ) app_topic_out = app.topic( str(uuid.uuid4()), @@ -2380,11 +2397,9 @@ def on_message_processed(topic_, partition, offset): assert row.timestamp == timestamp_ms assert row.headers == headers + @pytest.mark.parametrize("number_of_partitions", [1, 2]) def test_group_by( - self, - app_factory, - internal_consumer_factory, - executor, + self, app_factory, internal_consumer_factory, executor, number_of_partitions ): """ Test that StreamingDataFrame processes 6 messages from Kafka and groups them @@ -2411,11 +2426,17 @@ def on_message_processed(*_): str(uuid.uuid4()), value_deserializer="json", value_serializer="json", + config=TopicConfig( + num_partitions=number_of_partitions, replication_factor=1 + ), ) input_topic_b = app.topic( str(uuid.uuid4()), value_deserializer="json", value_serializer="json", + config=TopicConfig( + num_partitions=number_of_partitions, replication_factor=1 + ), ) input_topics = [input_topic_a, input_topic_b] output_topic_user = app.topic( @@ -2433,8 +2454,14 @@ def on_message_processed(*_): user_id = "abc123" account_id = "def456" value_in = {"user": user_id, "account": account_id} - # expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs - expected_processed = 6 + + if number_of_partitions == 1: + # expected_processed = 1 (input msg per SDF) * 1 (2 optimized groupbys that don't reprocesses input) * 2 SDFs + expected_processed = 2 + else: + # expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs + expected_processed = 6 + expected_output_topic_count = 2 sdf_a = app.dataframe(topic=input_topic_a) diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index 168f8b79d..7041dca1a 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -1646,6 +1646,7 @@ def on_late( ] +@pytest.mark.parametrize("num_partitions", [1, 2]) class TestStreamingDataFrameGroupBy: def test_group_by_column( self, @@ -1654,10 +1655,16 @@ def test_group_by_column( internal_producer_factory, internal_consumer_factory, message_context_factory, + num_partitions, ): """GroupBy can accept a string (column name) as its grouping method.""" topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) producer = internal_producer_factory() col = "column_A" @@ -1676,8 +1683,11 @@ def test_group_by_column( sdf[col] = col_update groupby_topic = sdf.topics[0] - assert sdf_registry.consumer_topics == [topic, groupby_topic] - assert groupby_topic.name.startswith("repartition__") + if num_partitions == 1: + assert sdf_registry.consumer_topics == [topic] + else: + assert sdf_registry.consumer_topics == [topic, groupby_topic] + assert groupby_topic.name.startswith("repartition__") with producer: pre_groupby_branch_result = sdf.test( @@ -1689,31 +1699,35 @@ def test_group_by_column( ctx=message_context_factory(topic=topic.name), ) - with internal_consumer_factory(auto_offset_reset="earliest") as consumer: - consumer.subscribe([groupby_topic]) - consumed_row = consumer.poll_row(timeout=5.0) + if num_partitions == 1: + post_groupby_branch_result = pre_groupby_branch_result + else: + with internal_producer_factory(auto_offset_reset="earliest") as consumer: + consumer.subscribe([groupby_topic]) + consumed_row = consumer.poll_row(timeout=5.0) + + assert consumed_row + assert consumed_row.topic == groupby_topic.name + assert consumed_row.key == new_key + assert consumed_row.timestamp == orig_timestamp_ms + assert consumed_row.value == value + assert consumed_row.headers == headers + assert pre_groupby_branch_result[0] == ( + value, + orig_key, + orig_timestamp_ms, + headers, + ) - assert consumed_row - assert consumed_row.topic == groupby_topic.name - assert consumed_row.key == new_key - assert consumed_row.timestamp == orig_timestamp_ms - assert consumed_row.value == value - assert consumed_row.headers == headers - assert pre_groupby_branch_result[0] == ( - value, - orig_key, - orig_timestamp_ms, - headers, - ) + # Check that the value is updated after record passed the groupby + post_groupby_branch_result = sdf.test( + value=value, + key=new_key, + timestamp=orig_timestamp_ms, + headers=headers, + ctx=message_context_factory(topic=groupby_topic.name), + ) - # Check that the value is updated after record passed the groupby - post_groupby_branch_result = sdf.test( - value=value, - key=new_key, - timestamp=orig_timestamp_ms, - headers=headers, - ctx=message_context_factory(topic=groupby_topic.name), - ) assert post_groupby_branch_result[0] == ( {col: col_update}, new_key, @@ -1728,13 +1742,19 @@ def test_group_by_column_with_name( internal_producer_factory, internal_consumer_factory, message_context_factory, + num_partitions, ): """ GroupBy can accept a string (column name) as its grouping method and use a custom name for it (instead of the column name) """ topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) producer = internal_producer_factory() col = "column_A" @@ -1754,8 +1774,11 @@ def test_group_by_column_with_name( sdf[col] = col_update groupby_topic = sdf.topics[0] - assert sdf_registry.consumer_topics == [topic, groupby_topic] - assert groupby_topic.name.startswith("repartition__") + if num_partitions == 1: + assert sdf_registry.consumer_topics == [topic] + else: + assert sdf_registry.consumer_topics == [topic, groupby_topic] + assert groupby_topic.name.startswith("repartition__") with producer: pre_groupby_branch_result = sdf.test( @@ -1767,31 +1790,35 @@ def test_group_by_column_with_name( ctx=message_context_factory(topic=topic.name), ) - with internal_consumer_factory(auto_offset_reset="earliest") as consumer: - consumer.subscribe([groupby_topic]) - consumed_row = consumer.poll_row(timeout=5.0) + if num_partitions == 1: + post_groupby_branch_result = pre_groupby_branch_result + else: + with internal_consumer_factory(auto_offset_reset="earliest") as consumer: + consumer.subscribe([groupby_topic]) + consumed_row = consumer.poll_row(timeout=5.0) + + assert consumed_row + assert consumed_row.topic == groupby_topic.name + assert consumed_row.key == new_key + assert consumed_row.timestamp == orig_timestamp_ms + assert consumed_row.value == value + assert consumed_row.headers == headers + assert pre_groupby_branch_result[0] == ( + value, + orig_key, + orig_timestamp_ms, + headers, + ) - assert consumed_row - assert consumed_row.topic == groupby_topic.name - assert consumed_row.key == new_key - assert consumed_row.timestamp == orig_timestamp_ms - assert consumed_row.value == value - assert consumed_row.headers == headers - assert pre_groupby_branch_result[0] == ( - value, - orig_key, - orig_timestamp_ms, - headers, - ) + # Check that the value is updated after record passed the groupby + post_groupby_branch_result = sdf.test( + value=value, + key=new_key, + timestamp=orig_timestamp_ms, + headers=headers, + ctx=message_context_factory(topic=groupby_topic.name), + ) - # Check that the value is updated after record passed the groupby - post_groupby_branch_result = sdf.test( - value=value, - key=new_key, - timestamp=orig_timestamp_ms, - headers=headers, - ctx=message_context_factory(topic=groupby_topic.name), - ) assert post_groupby_branch_result[0] == ( {col: col_update}, new_key, @@ -1806,12 +1833,18 @@ def test_group_by_func( internal_producer_factory, internal_consumer_factory, message_context_factory, + num_partitions, ): """ GroupBy can accept a Callable as its grouping method (requires a name too). """ topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) producer = internal_producer_factory() col = "column_A" @@ -1832,7 +1865,11 @@ def test_group_by_func( sdf[col] = col_update groupby_topic = sdf.topics[0] - assert sdf_registry.consumer_topics == [topic, groupby_topic] + if num_partitions == 1: + assert sdf_registry.consumer_topics == [topic] + else: + assert sdf_registry.consumer_topics == [topic, groupby_topic] + assert groupby_topic.name.startswith("repartition__") with producer: pre_groupby_branch_result = sdf.test( @@ -1844,31 +1881,35 @@ def test_group_by_func( ctx=message_context_factory(topic=topic.name), ) - with internal_consumer_factory(auto_offset_reset="earliest") as consumer: - consumer.subscribe([groupby_topic]) - consumed_row = consumer.poll_row(timeout=5.0) + if num_partitions == 1: + post_groupby_branch_result = pre_groupby_branch_result + else: + with internal_consumer_factory(auto_offset_reset="earliest") as consumer: + consumer.subscribe([groupby_topic]) + consumed_row = consumer.poll_row(timeout=5.0) + + assert consumed_row + assert consumed_row.topic == groupby_topic.name + assert consumed_row.key == new_key + assert consumed_row.timestamp == orig_timestamp_ms + assert consumed_row.value == value + assert consumed_row.headers == headers + assert pre_groupby_branch_result[0] == ( + value, + orig_key, + orig_timestamp_ms, + headers, + ) - assert consumed_row - assert consumed_row.topic == groupby_topic.name - assert consumed_row.key == new_key - assert consumed_row.timestamp == orig_timestamp_ms - assert consumed_row.value == value - assert consumed_row.headers == headers - assert pre_groupby_branch_result[0] == ( - value, - orig_key, - orig_timestamp_ms, - headers, - ) + # Check that the value is updated after record passed the groupby + post_groupby_branch_result = sdf.test( + value=value, + key=new_key, + timestamp=orig_timestamp_ms, + headers=headers, + ctx=message_context_factory(topic=groupby_topic.name), + ) - # Check that the value is updated after record passed the groupby - post_groupby_branch_result = sdf.test( - value=value, - key=new_key, - timestamp=orig_timestamp_ms, - headers=headers, - ctx=message_context_factory(topic=groupby_topic.name), - ) assert post_groupby_branch_result[0] == ( {col: col_update}, new_key, @@ -1876,46 +1917,93 @@ def test_group_by_func( headers, ) - def test_group_by_func_name_missing(self, dataframe_factory, topic_manager_factory): + def test_group_by_func_name_missing( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """Using a Callable for groupby requires giving a name""" topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) with pytest.raises(ValueError): sdf.group_by(lambda v: "do_stuff") - def test_group_by_key_empty_fails(self, dataframe_factory, topic_manager_factory): + def test_group_by_key_empty_fails( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """Using a Callable for groupby requires giving a name""" topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) with pytest.raises(ValueError, match='Parameter "key" cannot be empty'): sdf.group_by(key="") - def test_group_by_invalid_key_func(self, dataframe_factory, topic_manager_factory): + def test_group_by_invalid_key_func( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """GroupBy can only use a string (column name) or Callable to group with""" topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) with pytest.raises(ValueError): sdf.group_by({"um": "what is this"}) - def test_group_by_limit_exceeded(self, dataframe_factory, topic_manager_factory): + def test_group_by_limit_exceeded( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """ Only 1 GroupBy depth per SDF (no nesting of them). """ topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) sdf = sdf.group_by("col_a") with pytest.raises(GroupByNestingLimit): sdf.group_by("col_b") - def test_group_by_name_clash(self, dataframe_factory, topic_manager_factory): + def test_group_by_branching( + self, dataframe_factory, topic_manager_factory, num_partitions + ): + """ + GroupBy can be branched. + """ + topic_manager = topic_manager_factory() + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) + sdf = dataframe_factory(topic, topic_manager=topic_manager) + sdf.group_by("col_a") + sdf.group_by("col_b") + + def test_group_by_name_clash( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """ Each groupby operation per SDF instance (or, what appears to end users as a "single" SDF instance) should be uniquely named. @@ -1923,19 +2011,31 @@ def test_group_by_name_clash(self, dataframe_factory, topic_manager_factory): Most likely to encounter this if group by is used with the same column name. """ topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) sdf.group_by("col_a") with pytest.raises(GroupByDuplicate): sdf.group_by("col_a") - def test_sink_cannot_be_added_to(self, dataframe_factory, topic_manager_factory): + def test_sink_cannot_be_added_to( + self, dataframe_factory, topic_manager_factory, num_partitions + ): """ A sink cannot be added to or branched. """ topic_manager = topic_manager_factory() - topic = topic_manager.topic(str(uuid.uuid4())) + topic = topic_manager.topic( + str(uuid.uuid4()), + create_config=TopicConfig( + num_partitions=num_partitions, replication_factor=1 + ), + ) sdf = dataframe_factory(topic, topic_manager=topic_manager) assert len(sdf.stream.children) == 0 sdf.sink(DummySink()) From dd1a36b1211c417c83f6d1ef16cf2e23fddc103a Mon Sep 17 00:00:00 2001 From: Quentin Dawans Date: Thu, 17 Apr 2025 12:33:41 +0200 Subject: [PATCH 2/2] requested by review --- quixstreams/dataframe/registry.py | 5 +++-- .../test_quixstreams/test_dataframe/test_dataframe.py | 11 +++++++---- 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/quixstreams/dataframe/registry.py b/quixstreams/dataframe/registry.py index 1d534336c..dd7138e0b 100644 --- a/quixstreams/dataframe/registry.py +++ b/quixstreams/dataframe/registry.py @@ -79,6 +79,7 @@ def register_groupby( Register a "groupby" SDF, which is one generated with `SDF.group_by()`. :param source_sdf: the SDF used by `sdf.group_by()` :param new_sdf: the SDF generated by `sdf.group_by()`. + :param register_new_root: whether to register the new SDF as a root SDF. """ if source_sdf.stream_id in self._repartition_origins: raise GroupByNestingLimit( @@ -87,7 +88,7 @@ def register_groupby( if new_sdf.stream_id in self._repartition_origins: raise GroupByDuplicate( - "A `SDF.group_by()` operation appears to be the same as another, " + "An `SDF.group_by()` operation appears to be the same as another, " "either from using the same column or name parameter; " "adjust by setting a unique name with `SDF.group_by(name=)` " ) @@ -99,7 +100,7 @@ def register_groupby( self.register_root(new_sdf) except StreamingDataFrameDuplicate: raise GroupByDuplicate( - "A `SDF.group_by()` operation appears to be the same as another, " + "An `SDF.group_by()` operation appears to be the same as another, " "either from using the same column or name parameter; " "adjust by setting a unique name with `SDF.group_by(name=)` " ) diff --git a/tests/test_quixstreams/test_dataframe/test_dataframe.py b/tests/test_quixstreams/test_dataframe/test_dataframe.py index 7041dca1a..709cbe610 100644 --- a/tests/test_quixstreams/test_dataframe/test_dataframe.py +++ b/tests/test_quixstreams/test_dataframe/test_dataframe.py @@ -1682,10 +1682,11 @@ def test_group_by_column( sdf = sdf.group_by(col) sdf[col] = col_update - groupby_topic = sdf.topics[0] if num_partitions == 1: + groupby_topic = topic assert sdf_registry.consumer_topics == [topic] else: + groupby_topic = sdf.topics[0] assert sdf_registry.consumer_topics == [topic, groupby_topic] assert groupby_topic.name.startswith("repartition__") @@ -1702,7 +1703,7 @@ def test_group_by_column( if num_partitions == 1: post_groupby_branch_result = pre_groupby_branch_result else: - with internal_producer_factory(auto_offset_reset="earliest") as consumer: + with internal_consumer_factory(auto_offset_reset="earliest") as consumer: consumer.subscribe([groupby_topic]) consumed_row = consumer.poll_row(timeout=5.0) @@ -1773,10 +1774,11 @@ def test_group_by_column_with_name( sdf = sdf.group_by(col, name=op_name) sdf[col] = col_update - groupby_topic = sdf.topics[0] if num_partitions == 1: + groupby_topic = topic assert sdf_registry.consumer_topics == [topic] else: + groupby_topic = sdf.topics[0] assert sdf_registry.consumer_topics == [topic, groupby_topic] assert groupby_topic.name.startswith("repartition__") @@ -1864,10 +1866,11 @@ def test_group_by_func( sdf = sdf.group_by(lambda v: v[col], name=op_name) sdf[col] = col_update - groupby_topic = sdf.topics[0] if num_partitions == 1: + groupby_topic = topic assert sdf_registry.consumer_topics == [topic] else: + groupby_topic = sdf.topics[0] assert sdf_registry.consumer_topics == [topic, groupby_topic] assert groupby_topic.name.startswith("repartition__")