Skip to content

Optimize group by for single partition topics #836

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 38 additions & 3 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class StreamingDataFrame:
What it Does:

- Builds a data processing pipeline, declaratively (not executed immediately)
- Executes this pipeline on inputs at runtime (Kafka message values)
- Executes this pipeline on inputs at runtime (Kafka message values)
- Provides functions/interface similar to Pandas Dataframes/Series
- Enables stateful processing (and manages everything related to it)

Expand Down Expand Up @@ -135,15 +135,20 @@ def __init__(
registry: DataFrameRegistry,
processing_context: ProcessingContext,
stream: Optional[Stream] = None,
stream_id: Optional[str] = None,
):
if not topics:
raise ValueError("At least one Topic must be passed")

self._stream: Stream = stream or Stream()
# Implicitly deduplicate Topic objects into a tuple and sort them by name
self._topics: tuple[Topic, ...] = tuple(
sorted({t.name: t for t in topics}.values(), key=attrgetter("name"))
)

self._stream: Stream = stream or Stream()
self._stream_id: str = stream_id or topic_manager.stream_id_from_topics(
self.topics
)
self._topic_manager = topic_manager
self._registry = registry
self._processing_context = processing_context
Expand Down Expand Up @@ -174,7 +179,7 @@ def stream_id(self) -> str:

By default, a topic name or a combination of topic names are used as `stream_id`.
"""
return self._topic_manager.stream_id_from_topics(self.topics)
return self._stream_id

@property
def topics(self) -> tuple[Topic, ...]:
Expand Down Expand Up @@ -591,6 +596,11 @@ def func(d: dict, state: State):
# Generate a config for the new repartition topic based on the underlying topics
repartition_config = self._topic_manager.derive_topic_config(self._topics)

# If the topic has only one partition, we don't need a repartition topic
# we can directly change the messages key as they all go to the same partition.
if repartition_config.num_partitions == 1:
return self._single_partition_groupby(operation, key)

groupby_topic = self._topic_manager.repartition_topic(
operation=operation,
stream_id=self.stream_id,
Expand All @@ -606,6 +616,29 @@ def func(d: dict, state: State):
self._registry.register_groupby(source_sdf=self, new_sdf=groupby_sdf)
return groupby_sdf

def _single_partition_groupby(
self, operation: str, key: Union[str, Callable[[Any], Any]]
) -> "StreamingDataFrame":
if isinstance(key, str):

def _callback(value, _, timestamp, headers):
return value, value[key], timestamp, headers
else:

def _callback(value, _, timestamp, headers):
return value, key(value), timestamp, headers

stream = self.stream.add_transform(_callback, expand=False)

groupby_sdf = self.__dataframe_clone__(
stream=stream, stream_id=f"{self.stream_id}--groupby--{operation}"
)
self._registry.register_groupby(
source_sdf=self, new_sdf=groupby_sdf, register_new_root=False
)

return groupby_sdf

def contains(self, keys: Union[str, list[str]]) -> StreamingSeries:
"""
Check if keys are present in the Row value.
Expand Down Expand Up @@ -1679,6 +1712,7 @@ def __dataframe_clone__(
self,
*topics: Topic,
stream: Optional[Stream] = None,
stream_id: Optional[str] = None,
) -> "StreamingDataFrame":
"""
Clone the StreamingDataFrame with a new `stream`, `topics`,
Expand All @@ -1692,6 +1726,7 @@ def __dataframe_clone__(
clone = self.__class__(
*(topics or self._topics),
stream=stream,
stream_id=stream_id,
processing_context=self._processing_context,
topic_manager=self._topic_manager,
registry=self._registry,
Expand Down
24 changes: 19 additions & 5 deletions quixstreams/dataframe/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,41 @@ def register_root(
self._registry[topic.name] = dataframe.stream

def register_groupby(
self, source_sdf: "StreamingDataFrame", new_sdf: "StreamingDataFrame"
self,
source_sdf: "StreamingDataFrame",
new_sdf: "StreamingDataFrame",
register_new_root: bool = True,
):
"""
Register a "groupby" SDF, which is one generated with `SDF.group_by()`.
:param source_sdf: the SDF used by `sdf.group_by()`
:param new_sdf: the SDF generated by `sdf.group_by()`.
:param register_new_root: whether to register the new SDF as a root SDF.
"""
if source_sdf.stream_id in self._repartition_origins:
raise GroupByNestingLimit(
"Subsequent (nested) `SDF.group_by()` operations are not allowed."
)
try:
self.register_root(new_sdf)
except StreamingDataFrameDuplicate:

if new_sdf.stream_id in self._repartition_origins:
raise GroupByDuplicate(
"A `SDF.group_by()` operation appears to be the same as another, "
"An `SDF.group_by()` operation appears to be the same as another, "
"either from using the same column or name parameter; "
"adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
)

self._repartition_origins.add(new_sdf.stream_id)

if register_new_root:
try:
self.register_root(new_sdf)
except StreamingDataFrameDuplicate:
raise GroupByDuplicate(
"An `SDF.group_by()` operation appears to be the same as another, "
"either from using the same column or name parameter; "
"adjust by setting a unique name with `SDF.group_by(name=<NAME>)` "
)

def compose_all(
self, sink: Optional[VoidExecutor] = None
) -> dict[str, VoidExecutor]:
Expand Down
47 changes: 37 additions & 10 deletions tests/test_quixstreams/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,12 +513,14 @@ def test_state_dir_env(self):
assert app.config.state_dir == Path("/path/to/other")


@pytest.mark.parametrize("number_of_partitions", [1, 2])
class TestAppGroupBy:
def test_group_by(
self,
app_factory,
internal_consumer_factory,
executor,
number_of_partitions,
):
"""
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
Expand All @@ -540,8 +542,12 @@ def on_message_processed(*_):
timestamp_ms = int(time.time() * 1000)
user_id = "abc123"
value_in = {"user": user_id}
expected_message_count = 1
total_messages = expected_message_count * 2 # groupby reproduces each message

if number_of_partitions == 1:
total_messages = 1 # groupby optimisation for 1 partition
else:
total_messages = 2 # groupby reproduces each message

app = app_factory(
auto_offset_reset="earliest",
on_message_processed=on_message_processed,
Expand All @@ -551,6 +557,9 @@ def on_message_processed(*_):
str(uuid.uuid4()),
value_deserializer="json",
value_serializer="json",
config=TopicConfig(
num_partitions=number_of_partitions, replication_factor=1
),
)
app_topic_out = app.topic(
str(uuid.uuid4()),
Expand Down Expand Up @@ -607,6 +616,7 @@ def test_group_by_with_window(
internal_consumer_factory,
executor,
processing_guarantee,
number_of_partitions,
):
"""
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
Expand All @@ -631,8 +641,12 @@ def on_message_processed(*_):
timestamp_ms = timestamp_ms - (timestamp_ms % window_duration_ms)
user_id = "abc123"
value_in = {"user": user_id}
expected_message_count = 1
total_messages = expected_message_count * 2 # groupby reproduces each message

if number_of_partitions == 1:
total_messages = 1 # groupby optimisation for 1 partition
else:
total_messages = 2 # groupby reproduces each message

app = app_factory(
auto_offset_reset="earliest",
on_message_processed=on_message_processed,
Expand All @@ -643,6 +657,9 @@ def on_message_processed(*_):
str(uuid.uuid4()),
value_deserializer="json",
value_serializer="json",
config=TopicConfig(
num_partitions=number_of_partitions, replication_factor=1
),
)
app_topic_out = app.topic(
str(uuid.uuid4()),
Expand Down Expand Up @@ -2380,11 +2397,9 @@ def on_message_processed(topic_, partition, offset):
assert row.timestamp == timestamp_ms
assert row.headers == headers

@pytest.mark.parametrize("number_of_partitions", [1, 2])
def test_group_by(
self,
app_factory,
internal_consumer_factory,
executor,
self, app_factory, internal_consumer_factory, executor, number_of_partitions
):
"""
Test that StreamingDataFrame processes 6 messages from Kafka and groups them
Expand All @@ -2411,11 +2426,17 @@ def on_message_processed(*_):
str(uuid.uuid4()),
value_deserializer="json",
value_serializer="json",
config=TopicConfig(
num_partitions=number_of_partitions, replication_factor=1
),
)
input_topic_b = app.topic(
str(uuid.uuid4()),
value_deserializer="json",
value_serializer="json",
config=TopicConfig(
num_partitions=number_of_partitions, replication_factor=1
),
)
input_topics = [input_topic_a, input_topic_b]
output_topic_user = app.topic(
Expand All @@ -2433,8 +2454,14 @@ def on_message_processed(*_):
user_id = "abc123"
account_id = "def456"
value_in = {"user": user_id, "account": account_id}
# expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs
expected_processed = 6

if number_of_partitions == 1:
# expected_processed = 1 (input msg per SDF) * 1 (2 optimized groupbys that don't reprocesses input) * 2 SDFs
expected_processed = 2
else:
# expected_processed = 1 (input msg per SDF) * 3 (2 groupbys, each reprocesses input) * 2 SDFs
expected_processed = 6

expected_output_topic_count = 2

sdf_a = app.dataframe(topic=input_topic_a)
Expand Down
Loading