From 4e91230f055587f0d5b768a96b426e06defaee6d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 23 Sep 2025 11:38:53 +0100 Subject: [PATCH 1/6] Clip if the total is 100 rows, not only >100 This fixes the case where every state delta in `current_state_delta_stream` has a count of 1, meaning `total` will be 100. Before this change, that would result in `clipped_stream_id = max_stream_id`, meaning we'd potentially pull out millions of rows. --- synapse/storage/databases/main/state_deltas.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 303b232d7b5..0c04527f4bf 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -133,7 +133,8 @@ def get_current_state_deltas_txn( for stream_id, count in txn: total += count - if total > 100: + + if total >= 100: # We arbitrarily limit to 100 entries to ensure we don't # select toooo many. logger.debug( From 6557703762768b2d0520ffe5c46b1a1618ed0f0d Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 23 Sep 2025 11:45:10 +0100 Subject: [PATCH 2/6] newsfile --- changelog.d/18960.bugfix | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/18960.bugfix diff --git a/changelog.d/18960.bugfix b/changelog.d/18960.bugfix new file mode 100644 index 00000000000..909089f8092 --- /dev/null +++ b/changelog.d/18960.bugfix @@ -0,0 +1 @@ +Fix a bug in the database function for fetching state deltas that could result in unnecessarily long query times. \ No newline at end of file From e77135130cea7d0b03c1b078d12646d06bd2e090 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Tue, 23 Sep 2025 12:55:59 +0100 Subject: [PATCH 3/6] Add regression unit test --- tests/storage/test_state.py | 98 +++++++++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index bf6da715493..52f26e653e5 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -643,3 +643,101 @@ def test_batched_state_group_storing(self) -> None: ), ) self.assertEqual(context.state_group_before_event, groups[0][0]) + + +class CurrentStateDeltaStreamTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.store = hs.get_datastores().main + self.storage = hs.get_storage_controllers() + self.state_datastore = self.storage.state.stores.state + self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() + + # Create a made-up room and a user. + self.alice_user_id = UserID.from_string("@alice:test") + self.room = RoomID.from_string("!abc1234:test") + + self.get_success( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) + ) + + def inject_state_event( + self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict + ) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": typ, + "sender": sender.to_string(), + "state_key": state_key, + "room_id": room.to_string(), + "content": content, + }, + ) + + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + context = self.get_success(unpersisted_context.persist(event)) + + assert self.storage.persistence is not None + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event + + def test_current_state_delta_stream_is_limited_to_100(self) -> None: + """ + Tests that `get_partial_current_state_deltas` returns max 100 rows. + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Fetch the current state group to pass into store_state_deltas_for_batched + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + # Make >100 state changes in the room. + for i in range(101): + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": f"rename #{i}"}, + ) + + # Sanity check: count how many rows exist between prev=0 and max, it should be >100. + max_stream_id = self.store.get_room_max_stream_ordering() + + # Call the function under test. With the >= 100 clipping fix, this must return <= 100 rows. + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, max_stream_id=max_stream_id + ) + ) + + self.assertLessEqual( + len(deltas), 100, f"Returned {len(deltas)} rows, expected at most 100" + ) + + # Advancing from the clipped point should eventually drain the remainder. + # Make sure we make progress and don’t get stuck. + if deltas: + next_prev = clipped_stream_id + next_clipped, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=next_prev, max_stream_id=max_stream_id + ) + ) + self.assertNotEqual( + next_clipped, clipped_stream_id, "Did not advance clipped_stream_id" + ) + # Still should respect the 100-row limit. + self.assertLessEqual(len(next_deltas), 100) From 24ff9f1b512492a32b12505193d80f305de8b0bf Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 25 Sep 2025 16:41:49 +0100 Subject: [PATCH 4/6] Add a `limit` parameter to speed up unit tests Credit to @reivilibre for the idea. --- synapse/storage/controllers/state.py | 2 +- .../storage/databases/main/state_deltas.py | 17 ++++++------ tests/storage/test_state.py | 27 ++++++++++--------- 3 files changed, 24 insertions(+), 22 deletions(-) diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index ad90a1be13a..c67921ca484 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -689,7 +689,7 @@ async def get_current_state_deltas( # https://github.com/matrix-org/synapse/issues/13008 return await self.stores.main.get_partial_current_state_deltas( - prev_stream_id, max_stream_id + prev_stream_id, max_stream_id, limit=100 ) @trace diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 0c04527f4bf..ef32b962726 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -78,7 +78,7 @@ def __init__( ) async def get_partial_current_state_deltas( - self, prev_stream_id: int, max_stream_id: int + self, prev_stream_id: int, max_stream_id: int, limit: int = 100 ) -> Tuple[int, List[StateDelta]]: """Fetch a list of room state changes since the given stream id @@ -88,14 +88,13 @@ async def get_partial_current_state_deltas( prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. + limit: the maximum number of rows to return. Returns: A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. - - A maximum of 100 rows will be returned. """ prev_stream_id = int(prev_stream_id) @@ -117,25 +116,25 @@ def get_current_state_deltas_txn( ) -> Tuple[int, List[StateDelta]]: # First we calculate the max stream id that will give us less than # N results. - # We arbitrarily limit to 100 stream_id entries to ensure we don't - # select toooo many. + # We limit the number of returned stream_id entries to ensure we + # don't select toooo many. sql = """ SELECT stream_id, count(*) FROM current_state_delta_stream WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id ORDER BY stream_id ASC - LIMIT 100 + LIMIT ? """ - txn.execute(sql, (prev_stream_id, max_stream_id)) + txn.execute(sql, (prev_stream_id, max_stream_id, limit)) total = 0 for stream_id, count in txn: total += count - if total >= 100: - # We arbitrarily limit to 100 entries to ensure we don't + if total >= limit: + # We limit the number of returned entries to ensure we don't # select toooo many. logger.debug( "Clipping current_state_delta_stream rows to stream_id %i", diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 52f26e653e5..737118d482b 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -694,17 +694,20 @@ def inject_state_event( def test_current_state_delta_stream_is_limited_to_100(self) -> None: """ - Tests that `get_partial_current_state_deltas` returns max 100 rows. + Tests that `get_partial_current_state_deltas` actually returns `limit` rows. Regression test for https://github.com/element-hq/synapse/pull/18960. """ - # Fetch the current state group to pass into store_state_deltas_for_batched + # Inject a create event which other events can auth with. self.inject_state_event( self.room, self.alice_user_id, EventTypes.Create, "", {} ) - # Make >100 state changes in the room. - for i in range(101): + limit = 2 + + # Make N*2 state changes in the room, resulting in 2N+1 total state + # events (including the create event) in the room. + for i in range(limit * 2): self.inject_state_event( self.room, self.alice_user_id, @@ -713,18 +716,18 @@ def test_current_state_delta_stream_is_limited_to_100(self) -> None: {"name": f"rename #{i}"}, ) - # Sanity check: count how many rows exist between prev=0 and max, it should be >100. + # Call the function under test. This must return <= `limit` rows. max_stream_id = self.store.get_room_max_stream_ordering() - - # Call the function under test. With the >= 100 clipping fix, this must return <= 100 rows. clipped_stream_id, deltas = self.get_success( self.store.get_partial_current_state_deltas( - prev_stream_id=0, max_stream_id=max_stream_id + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=limit, ) ) self.assertLessEqual( - len(deltas), 100, f"Returned {len(deltas)} rows, expected at most 100" + len(deltas), limit, f"Returned {len(deltas)} rows, expected at most {limit}" ) # Advancing from the clipped point should eventually drain the remainder. @@ -733,11 +736,11 @@ def test_current_state_delta_stream_is_limited_to_100(self) -> None: next_prev = clipped_stream_id next_clipped, next_deltas = self.get_success( self.store.get_partial_current_state_deltas( - prev_stream_id=next_prev, max_stream_id=max_stream_id + prev_stream_id=next_prev, max_stream_id=max_stream_id, limit=limit ) ) self.assertNotEqual( next_clipped, clipped_stream_id, "Did not advance clipped_stream_id" ) - # Still should respect the 100-row limit. - self.assertLessEqual(len(next_deltas), 100) + # Still should respect the limit. + self.assertLessEqual(len(next_deltas), limit) From 0e40bf4589c90294d45489bef9d72744bcbe6c7f Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Thu, 25 Sep 2025 17:11:38 +0100 Subject: [PATCH 5/6] Just use `LIMIT` on the DB Simplify the queries to just a single one. --- .../storage/databases/main/state_deltas.py | 50 ++++++------------- 1 file changed, 16 insertions(+), 34 deletions(-) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index ef32b962726..8dcb0ca49e5 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -98,6 +98,11 @@ async def get_partial_current_state_deltas( """ prev_stream_id = int(prev_stream_id) + if limit <= 0: + raise ValueError( + "Invalid `limit` passed to `get_partial_current_state_deltas" + ) + # check we're not going backwards assert prev_stream_id <= max_stream_id, ( f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" @@ -114,46 +119,23 @@ async def get_partial_current_state_deltas( def get_current_state_deltas_txn( txn: LoggingTransaction, ) -> Tuple[int, List[StateDelta]]: - # First we calculate the max stream id that will give us less than - # N results. - # We limit the number of returned stream_id entries to ensure we - # don't select toooo many. sql = """ - SELECT stream_id, count(*) + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id FROM current_state_delta_stream - WHERE stream_id > ? AND stream_id <= ? - GROUP BY stream_id + WHERE ? < stream_id AND stream_id <= ? ORDER BY stream_id ASC LIMIT ? """ txn.execute(sql, (prev_stream_id, max_stream_id, limit)) + rows = txn.fetchall() + + # In the case that we hit the given `limit` rather than fetching the + # most recent rows, return the `stream_id` of the last row. + # + # With this, the caller knows from what stream_id to call this + # function again with. + clipped_stream_id = rows[-1][0] - total = 0 - - for stream_id, count in txn: - total += count - - if total >= limit: - # We limit the number of returned entries to ensure we don't - # select toooo many. - logger.debug( - "Clipping current_state_delta_stream rows to stream_id %i", - stream_id, - ) - clipped_stream_id = stream_id - break - else: - # if there's no problem, we may as well go right up to the max_stream_id - clipped_stream_id = max_stream_id - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC - """ - txn.execute(sql, (prev_stream_id, clipped_stream_id)) return clipped_stream_id, [ StateDelta( stream_id=row[0], @@ -163,7 +145,7 @@ def get_current_state_deltas_txn( event_id=row[4], prev_event_id=row[5], ) - for row in txn.fetchall() + for row in rows ] return await self.db_pool.runInteraction( From db124034f02fe8240768b7352409ee14d6a11ee0 Mon Sep 17 00:00:00 2001 From: Andrew Morgan Date: Fri, 26 Sep 2025 18:58:32 +0100 Subject: [PATCH 6/6] Rework query to get max stream ID to avoid using a python loop We now: 1. Group state deltas by `stream_id` and get their count 2. Count (in the DB) until we go over our limit. We then get the `stream_id` we could naively go up to, as well as the clamped `stream_id` which would keep us under our limit. The second query fetches rows up to and including the clamped `stream_id`. We also add a unit test that injects multiple state deltas with the same `stream_id`, which correctly failed when tested against the previous implementation in this PR. --- .../storage/databases/main/state_deltas.py | 83 +++++++++++-- tests/storage/test_state.py | 117 +++++++++++++++++- 2 files changed, 186 insertions(+), 14 deletions(-) diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 8dcb0ca49e5..13acbeeebb6 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -80,10 +80,19 @@ def __init__( async def get_partial_current_state_deltas( self, prev_stream_id: int, max_stream_id: int, limit: int = 100 ) -> Tuple[int, List[StateDelta]]: - """Fetch a list of room state changes since the given stream id + """Fetch a list of room state changes since the given stream id. This may be the partial state if we're lazy joining the room. + This method takes care to handle state deltas that share the same + `stream_id`. That can happen when persisting state in a batch, + potentially as the result of state resolution (both adding new state and + undo'ing previous state). + + State deltas are grouped by `stream_id`. When hitting the given `limit` + would return only part of a "group" of state deltas, that entire group + is omitted. Thus, this function may return *up to* `limit` state deltas. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted @@ -119,23 +128,71 @@ async def get_partial_current_state_deltas( def get_current_state_deltas_txn( txn: LoggingTransaction, ) -> Tuple[int, List[StateDelta]]: - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + # First we group state deltas by `stream_id` and calculate the + # stream id that will give us the most amount of state deltas (under + # the provided `limit`) without splitting any group up. + + # 1) Figure out which stream_id groups fit within `limit` + # and whether we consumed everything up to max_stream_id. + sql_meta = """ + WITH grouped AS ( + SELECT stream_id, COUNT(*) AS c FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC + WHERE stream_id > ? AND stream_id <= ? + GROUP BY stream_id + ORDER BY stream_id LIMIT ? + ), + accum AS ( + SELECT + stream_id, + c, + SUM(c) OVER (ORDER BY stream_id) AS running + FROM grouped + ), + included AS ( + SELECT stream_id, running + FROM accum + WHERE running <= ? + ) + SELECT + COALESCE((SELECT SUM(c) FROM grouped), 0) AS total_rows, + COALESCE((SELECT MAX(running) FROM included), 0) AS included_rows, + COALESCE((SELECT MAX(stream_id) FROM included), ?) AS last_included_sid """ - txn.execute(sql, (prev_stream_id, max_stream_id, limit)) + txn.execute( + sql_meta, (prev_stream_id, max_stream_id, limit, limit, prev_stream_id) + ) + total_rows, included_rows, last_included_sid = txn.fetchone() # type: ignore + + if total_rows == 0: + # Nothing to return in the range; we are up to date through max_stream_id. + return max_stream_id, [] + + if included_rows == 0: + # The first group itself would exceed the limit. Return nothing + # and do not advance beyond prev_stream_id. + # + # TODO: In this case, we should return *more* than the given `limit`. + # Otherwise we'll either deadlock the caller (they'll keep calling us + # with the same prev_stream_id) or make the caller think there's no + # more rows to consume (when there are). + return prev_stream_id, [] + + # If we included every row up to max_stream_id, we can safely report progress to max_stream_id. + consumed_all = included_rows == total_rows + clipped_stream_id = max_stream_id if consumed_all else last_included_sid + + # 2) Fetch the actual rows for only the included stream_id groups. + sql_rows = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute(sql_rows, (prev_stream_id, clipped_stream_id)) rows = txn.fetchall() - # In the case that we hit the given `limit` rather than fetching the - # most recent rows, return the `stream_id` of the last row. - # - # With this, the caller knows from what stream_id to call this - # function again with. - clipped_stream_id = rows[-1][0] - return clipped_stream_id, [ StateDelta( stream_id=row[0], diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index 737118d482b..73d7d68a826 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -19,6 +19,7 @@ # # +import json import logging from typing import List, Tuple, cast @@ -33,6 +34,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util.clock import Clock +from synapse.util.stringutils import random_string from tests.unittest import HomeserverTestCase @@ -692,7 +694,7 @@ def inject_state_event( return event - def test_current_state_delta_stream_is_limited_to_100(self) -> None: + def test_get_partial_current_state_deltas_limit(self) -> None: """ Tests that `get_partial_current_state_deltas` actually returns `limit` rows. @@ -744,3 +746,116 @@ def test_current_state_delta_stream_is_limited_to_100(self) -> None: ) # Still should respect the limit. self.assertLessEqual(len(next_deltas), limit) + + def test_non_unique_stream_ids_in_current_state_delta_stream(self) -> None: + """ + Tests that `get_partial_current_state_deltas` always returns entire + groups of state deltas (grouped by `stream_id`), and never part of one. + + We check by passing a `limit` that to the function that, if followed + blindly, would split a group of state deltas that share a `stream_id`. + The test passes if that group is not returned at all (because doing so + would overshoot the limit of returned state deltas). + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Inject a create event to start with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + # Then inject one "real" m.room.name event. This will give us a stream_id that + # we can create some more (fake) events with. + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": "rename #1"}, + ) + + # Get the stream_id of the last-inserted event. + max_stream_id = self.store.get_room_max_stream_ordering() + + # Make 3 more state changes in the room, resulting in 5 total state + # events (including the create event, and the first name update) in + # the room. + # + # All of these state deltas have the same `stream_id`. Do so by editing + # the table directly as that's the simplest way to have all share the + # same `stream_id`. + self.get_success( + self.store.db_pool.simple_insert_many( + "current_state_delta_stream", + keys=( + "stream_id", + "room_id", + "type", + "state_key", + "event_id", + "prev_event_id", + "instance_name", + ), + values=[ + ( + max_stream_id, + self.room.to_string(), + "m.room.name", + "", + f"${random_string(5)}:test", + json.dumps({"name": f"rename #{i}"}), + "master", + ) + for i in range(3) + ], + desc="inject_room_name_state_events", + ) + ) + + # Call the function under test with a limit of 4. Without the limit, we would return + # 5 state deltas: + # + # C T T T T + # 1 2 3 4 5 + # + # C = m.room.create + # T = m.room.topic + # + # With the limit, we should return only the create event, as returning 4 + # state deltas would result in splitting a group: + # + # C T T T T + # 1 2 3 4 X + + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=4, + ) + ) + + # 2 is the stream ID of the m.room.create event. + self.assertEqual(clipped_stream_id, 2) + self.assertEqual( + len(deltas), + 1, + f"Returned {len(deltas)} rows, expected only one (the create event): {deltas}", + ) + + # Advance once more with our limit of 4. We should now get all 4 + # `m.room.name` state deltas as they can fit under the limit. + clipped_stream_id, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=clipped_stream_id, max_stream_id=max_stream_id, limit=4 + ) + ) + self.assertEqual( + clipped_stream_id, 3 + ) # The stream ID of the 4 m.room.name events. + + self.assertEqual( + len(next_deltas), + 4, + f"Returned {len(next_deltas)} rows, expected all 4 m.room.name events: {next_deltas}", + )