diff --git a/changelog.d/18960.bugfix b/changelog.d/18960.bugfix new file mode 100644 index 00000000000..909089f8092 --- /dev/null +++ b/changelog.d/18960.bugfix @@ -0,0 +1 @@ +Fix a bug in the database function for fetching state deltas that could result in unnecessarily long query times. \ No newline at end of file diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index ad90a1be13a..c67921ca484 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -689,7 +689,7 @@ async def get_current_state_deltas( # https://github.com/matrix-org/synapse/issues/13008 return await self.stores.main.get_partial_current_state_deltas( - prev_stream_id, max_stream_id + prev_stream_id, max_stream_id, limit=100 ) @trace diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 303b232d7b5..13acbeeebb6 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -78,27 +78,40 @@ def __init__( ) async def get_partial_current_state_deltas( - self, prev_stream_id: int, max_stream_id: int + self, prev_stream_id: int, max_stream_id: int, limit: int = 100 ) -> Tuple[int, List[StateDelta]]: - """Fetch a list of room state changes since the given stream id + """Fetch a list of room state changes since the given stream id. This may be the partial state if we're lazy joining the room. + This method takes care to handle state deltas that share the same + `stream_id`. That can happen when persisting state in a batch, + potentially as the result of state resolution (both adding new state and + undo'ing previous state). + + State deltas are grouped by `stream_id`. When hitting the given `limit` + would return only part of a "group" of state deltas, that entire group + is omitted. Thus, this function may return *up to* `limit` state deltas. + Args: prev_stream_id: point to get changes since (exclusive) max_stream_id: the point that we know has been correctly persisted - ie, an upper limit to return changes from. + limit: the maximum number of rows to return. Returns: A tuple consisting of: - the stream id which these results go up to - list of current_state_delta_stream rows. If it is empty, we are up to date. - - A maximum of 100 rows will be returned. """ prev_stream_id = int(prev_stream_id) + if limit <= 0: + raise ValueError( + "Invalid `limit` passed to `get_partial_current_state_deltas" + ) + # check we're not going backwards assert prev_stream_id <= max_stream_id, ( f"New stream id {max_stream_id} is smaller than prev stream id {prev_stream_id}" @@ -115,45 +128,71 @@ async def get_partial_current_state_deltas( def get_current_state_deltas_txn( txn: LoggingTransaction, ) -> Tuple[int, List[StateDelta]]: - # First we calculate the max stream id that will give us less than - # N results. - # We arbitrarily limit to 100 stream_id entries to ensure we don't - # select toooo many. - sql = """ - SELECT stream_id, count(*) + # First we group state deltas by `stream_id` and calculate the + # stream id that will give us the most amount of state deltas (under + # the provided `limit`) without splitting any group up. + + # 1) Figure out which stream_id groups fit within `limit` + # and whether we consumed everything up to max_stream_id. + sql_meta = """ + WITH grouped AS ( + SELECT stream_id, COUNT(*) AS c FROM current_state_delta_stream WHERE stream_id > ? AND stream_id <= ? GROUP BY stream_id - ORDER BY stream_id ASC - LIMIT 100 + ORDER BY stream_id + LIMIT ? + ), + accum AS ( + SELECT + stream_id, + c, + SUM(c) OVER (ORDER BY stream_id) AS running + FROM grouped + ), + included AS ( + SELECT stream_id, running + FROM accum + WHERE running <= ? + ) + SELECT + COALESCE((SELECT SUM(c) FROM grouped), 0) AS total_rows, + COALESCE((SELECT MAX(running) FROM included), 0) AS included_rows, + COALESCE((SELECT MAX(stream_id) FROM included), ?) AS last_included_sid """ - txn.execute(sql, (prev_stream_id, max_stream_id)) - - total = 0 - - for stream_id, count in txn: - total += count - if total > 100: - # We arbitrarily limit to 100 entries to ensure we don't - # select toooo many. - logger.debug( - "Clipping current_state_delta_stream rows to stream_id %i", - stream_id, - ) - clipped_stream_id = stream_id - break - else: - # if there's no problem, we may as well go right up to the max_stream_id - clipped_stream_id = max_stream_id - - # Now actually get the deltas - sql = """ - SELECT stream_id, room_id, type, state_key, event_id, prev_event_id - FROM current_state_delta_stream - WHERE ? < stream_id AND stream_id <= ? - ORDER BY stream_id ASC + txn.execute( + sql_meta, (prev_stream_id, max_stream_id, limit, limit, prev_stream_id) + ) + total_rows, included_rows, last_included_sid = txn.fetchone() # type: ignore + + if total_rows == 0: + # Nothing to return in the range; we are up to date through max_stream_id. + return max_stream_id, [] + + if included_rows == 0: + # The first group itself would exceed the limit. Return nothing + # and do not advance beyond prev_stream_id. + # + # TODO: In this case, we should return *more* than the given `limit`. + # Otherwise we'll either deadlock the caller (they'll keep calling us + # with the same prev_stream_id) or make the caller think there's no + # more rows to consume (when there are). + return prev_stream_id, [] + + # If we included every row up to max_stream_id, we can safely report progress to max_stream_id. + consumed_all = included_rows == total_rows + clipped_stream_id = max_stream_id if consumed_all else last_included_sid + + # 2) Fetch the actual rows for only the included stream_id groups. + sql_rows = """ + SELECT stream_id, room_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC """ - txn.execute(sql, (prev_stream_id, clipped_stream_id)) + txn.execute(sql_rows, (prev_stream_id, clipped_stream_id)) + rows = txn.fetchall() + return clipped_stream_id, [ StateDelta( stream_id=row[0], @@ -163,7 +202,7 @@ def get_current_state_deltas_txn( event_id=row[4], prev_event_id=row[5], ) - for row in txn.fetchall() + for row in rows ] return await self.db_pool.runInteraction( diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py index bf6da715493..73d7d68a826 100644 --- a/tests/storage/test_state.py +++ b/tests/storage/test_state.py @@ -19,6 +19,7 @@ # # +import json import logging from typing import List, Tuple, cast @@ -33,6 +34,7 @@ from synapse.types import JsonDict, RoomID, StateMap, UserID from synapse.types.state import StateFilter from synapse.util.clock import Clock +from synapse.util.stringutils import random_string from tests.unittest import HomeserverTestCase @@ -643,3 +645,217 @@ def test_batched_state_group_storing(self) -> None: ), ) self.assertEqual(context.state_group_before_event, groups[0][0]) + + +class CurrentStateDeltaStreamTestCase(HomeserverTestCase): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + super().prepare(reactor, clock, hs) + self.store = hs.get_datastores().main + self.storage = hs.get_storage_controllers() + self.state_datastore = self.storage.state.stores.state + self.event_creation_handler = hs.get_event_creation_handler() + self.event_builder_factory = hs.get_event_builder_factory() + + # Create a made-up room and a user. + self.alice_user_id = UserID.from_string("@alice:test") + self.room = RoomID.from_string("!abc1234:test") + + self.get_success( + self.store.store_room( + self.room.to_string(), + room_creator_user_id="@creator:text", + is_public=True, + room_version=RoomVersions.V1, + ) + ) + + def inject_state_event( + self, room: RoomID, sender: UserID, typ: str, state_key: str, content: JsonDict + ) -> EventBase: + builder = self.event_builder_factory.for_room_version( + RoomVersions.V1, + { + "type": typ, + "sender": sender.to_string(), + "state_key": state_key, + "room_id": room.to_string(), + "content": content, + }, + ) + + event, unpersisted_context = self.get_success( + self.event_creation_handler.create_new_client_event(builder) + ) + + context = self.get_success(unpersisted_context.persist(event)) + + assert self.storage.persistence is not None + self.get_success(self.storage.persistence.persist_event(event, context)) + + return event + + def test_get_partial_current_state_deltas_limit(self) -> None: + """ + Tests that `get_partial_current_state_deltas` actually returns `limit` rows. + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Inject a create event which other events can auth with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + limit = 2 + + # Make N*2 state changes in the room, resulting in 2N+1 total state + # events (including the create event) in the room. + for i in range(limit * 2): + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": f"rename #{i}"}, + ) + + # Call the function under test. This must return <= `limit` rows. + max_stream_id = self.store.get_room_max_stream_ordering() + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=limit, + ) + ) + + self.assertLessEqual( + len(deltas), limit, f"Returned {len(deltas)} rows, expected at most {limit}" + ) + + # Advancing from the clipped point should eventually drain the remainder. + # Make sure we make progress and don’t get stuck. + if deltas: + next_prev = clipped_stream_id + next_clipped, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=next_prev, max_stream_id=max_stream_id, limit=limit + ) + ) + self.assertNotEqual( + next_clipped, clipped_stream_id, "Did not advance clipped_stream_id" + ) + # Still should respect the limit. + self.assertLessEqual(len(next_deltas), limit) + + def test_non_unique_stream_ids_in_current_state_delta_stream(self) -> None: + """ + Tests that `get_partial_current_state_deltas` always returns entire + groups of state deltas (grouped by `stream_id`), and never part of one. + + We check by passing a `limit` that to the function that, if followed + blindly, would split a group of state deltas that share a `stream_id`. + The test passes if that group is not returned at all (because doing so + would overshoot the limit of returned state deltas). + + Regression test for https://github.com/element-hq/synapse/pull/18960. + """ + # Inject a create event to start with. + self.inject_state_event( + self.room, self.alice_user_id, EventTypes.Create, "", {} + ) + + # Then inject one "real" m.room.name event. This will give us a stream_id that + # we can create some more (fake) events with. + self.inject_state_event( + self.room, + self.alice_user_id, + EventTypes.Name, + "", + {"name": "rename #1"}, + ) + + # Get the stream_id of the last-inserted event. + max_stream_id = self.store.get_room_max_stream_ordering() + + # Make 3 more state changes in the room, resulting in 5 total state + # events (including the create event, and the first name update) in + # the room. + # + # All of these state deltas have the same `stream_id`. Do so by editing + # the table directly as that's the simplest way to have all share the + # same `stream_id`. + self.get_success( + self.store.db_pool.simple_insert_many( + "current_state_delta_stream", + keys=( + "stream_id", + "room_id", + "type", + "state_key", + "event_id", + "prev_event_id", + "instance_name", + ), + values=[ + ( + max_stream_id, + self.room.to_string(), + "m.room.name", + "", + f"${random_string(5)}:test", + json.dumps({"name": f"rename #{i}"}), + "master", + ) + for i in range(3) + ], + desc="inject_room_name_state_events", + ) + ) + + # Call the function under test with a limit of 4. Without the limit, we would return + # 5 state deltas: + # + # C T T T T + # 1 2 3 4 5 + # + # C = m.room.create + # T = m.room.topic + # + # With the limit, we should return only the create event, as returning 4 + # state deltas would result in splitting a group: + # + # C T T T T + # 1 2 3 4 X + + clipped_stream_id, deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=0, + max_stream_id=max_stream_id, + limit=4, + ) + ) + + # 2 is the stream ID of the m.room.create event. + self.assertEqual(clipped_stream_id, 2) + self.assertEqual( + len(deltas), + 1, + f"Returned {len(deltas)} rows, expected only one (the create event): {deltas}", + ) + + # Advance once more with our limit of 4. We should now get all 4 + # `m.room.name` state deltas as they can fit under the limit. + clipped_stream_id, next_deltas = self.get_success( + self.store.get_partial_current_state_deltas( + prev_stream_id=clipped_stream_id, max_stream_id=max_stream_id, limit=4 + ) + ) + self.assertEqual( + clipped_stream_id, 3 + ) # The stream ID of the 4 m.room.name events. + + self.assertEqual( + len(next_deltas), + 4, + f"Returned {len(next_deltas)} rows, expected all 4 m.room.name events: {next_deltas}", + )