diff --git a/changelog.d/19041.feature b/changelog.d/19041.feature new file mode 100644 index 00000000000..e04b0189f8c --- /dev/null +++ b/changelog.d/19041.feature @@ -0,0 +1 @@ +Add companion endpoint for MSC4360: Sliding Sync Threads Extension. diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index 8513e897115..3cd964886b2 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -20,6 +20,7 @@ # import enum import logging +from collections import defaultdict from typing import ( TYPE_CHECKING, Collection, @@ -30,24 +31,59 @@ import attr -from synapse.api.constants import Direction, EventTypes, RelationTypes +from synapse.api.constants import Direction, EventTypes, Membership, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase, relation_from_event from synapse.events.utils import SerializeEventConfig from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import trace -from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent +from synapse.storage.databases.main.relations import ( + ThreadsNextBatch, + ThreadUpdateInfo, + _RelatedEvent, +) from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict, Requester, UserID +from synapse.types import ( + JsonDict, + Requester, + RoomStreamToken, + StreamKeyType, + StreamToken, + UserID, +) from synapse.util.async_helpers import gather_results from synapse.visibility import filter_events_for_client if TYPE_CHECKING: + from synapse.events.utils import EventClientSerializer + from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +# Type aliases for thread update processing +ThreadUpdatesMap = dict[str, list[ThreadUpdateInfo]] +ThreadRootsMap = dict[str, EventBase] +AggregationsMap = dict[str, "BundledAggregations"] + + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadUpdate: + """ + Data for a single thread update. + + Attributes: + thread_root: The thread root event, or None if not requested/not visible + prev_batch: Per-thread pagination token for fetching older events in this thread + bundled_aggregations: Bundled aggregations for the thread root event + """ + + thread_root: EventBase | None + prev_batch: StreamToken | None + bundled_aggregations: "BundledAggregations | None" = None + class ThreadsListInclude(str, enum.Enum): """Valid values for the 'include' flag of /threads.""" @@ -544,6 +580,367 @@ async def _fetch_edits() -> None: return results + async def _filter_thread_updates_for_user( + self, + all_thread_updates: ThreadUpdatesMap, + user_id: str, + ) -> ThreadUpdatesMap: + """Process thread updates by filtering for visibility. + + Takes raw thread updates from storage and filters them based on whether the + user can see the events. Preserves the ordering of updates within each thread. + + Args: + all_thread_updates: Map of thread_id to list of ThreadUpdateInfo objects + user_id: The user ID to filter events for + + Returns: + Filtered map of thread_id to list of ThreadUpdateInfo objects, containing + only updates for events the user can see. + """ + # Build a mapping of event_id -> (thread_id, update) for efficient lookup + # during visibility filtering. + event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {} + for thread_id, updates in all_thread_updates.items(): + for update in updates: + event_to_thread_map[update.event_id] = (thread_id, update) + + # Fetch and filter events for visibility + all_events = await self._main_store.get_events_as_list( + event_to_thread_map.keys() + ) + filtered_events = await filter_events_for_client( + self._storage_controllers, user_id, all_events + ) + + # Rebuild thread updates from filtered events + filtered_updates: ThreadUpdatesMap = defaultdict(list) + for event in filtered_events: + if event.event_id in event_to_thread_map: + thread_id, update = event_to_thread_map[event.event_id] + filtered_updates[thread_id].append(update) + + return filtered_updates + + def _build_thread_updates_response( + self, + filtered_updates: ThreadUpdatesMap, + thread_root_event_map: ThreadRootsMap, + aggregations_map: AggregationsMap, + global_prev_batch_token: StreamToken | None, + ) -> dict[str, dict[str, ThreadUpdate]]: + """Build thread update response structure with per-thread prev_batch tokens. + + Args: + filtered_updates: Map of thread_root_id to list of ThreadUpdateInfo + thread_root_event_map: Map of thread_root_id to EventBase + aggregations_map: Map of thread_root_id to BundledAggregations + global_prev_batch_token: Global pagination token, or None if no more results + + Returns: + Map of room_id to thread_root_id to ThreadUpdate + """ + thread_updates: dict[str, dict[str, ThreadUpdate]] = {} + + for thread_root_id, updates in filtered_updates.items(): + # We only care about the latest update for the thread + # Updates are already sorted by stream_ordering DESC from the database query, + # and filter_events_for_client preserves order, so updates[0] is guaranteed to be + # the latest event for each thread. + latest_update = updates[0] + room_id = latest_update.room_id + + # Generate per-thread prev_batch token if this thread has multiple visible updates + # or if we hit the global limit. + # When we hit the global limit, we generate prev_batch tokens for all threads, even if + # we only saw 1 update for them. This is to cover the case where we only saw + # a single update for a given thread, but the global limit prevents us from + # obtaining other updates which would have otherwise been included in the range. + per_thread_prev_batch = None + if len(updates) > 1 or global_prev_batch_token is not None: + # Create a token pointing to one position before the latest event's stream position. + # This makes it exclusive - /relations with dir=b won't return the latest event again. + # Use StreamToken.START as base (all other streams at 0) since only room position matters. + per_thread_prev_batch = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, + RoomStreamToken(stream=latest_update.stream_ordering - 1), + ) + + if room_id not in thread_updates: + thread_updates[room_id] = {} + + thread_updates[room_id][thread_root_id] = ThreadUpdate( + thread_root=thread_root_event_map.get(thread_root_id), + prev_batch=per_thread_prev_batch, + bundled_aggregations=aggregations_map.get(thread_root_id), + ) + + return thread_updates + + async def _fetch_thread_updates( + self, + room_ids: frozenset[str], + room_membership_map: Mapping[str, "RoomsForUserType"], + from_token: StreamToken | None, + to_token: StreamToken, + limit: int, + exclude_thread_ids: set[str] | None = None, + ) -> tuple[ThreadUpdatesMap, StreamToken | None]: + """Fetch thread updates across multiple rooms, handling membership states properly. + + This method separates rooms based on membership status (LEAVE/BAN vs others) + and queries them appropriately to prevent data leaks. For rooms where the user + has left or been banned, we bound the query to their leave/ban event position. + + Args: + room_ids: The set of room IDs to fetch thread updates for + room_membership_map: Map of room_id to RoomsForUserType containing membership info + from_token: Lower bound (exclusive) for the query, or None for no lower bound + to_token: Upper bound for the query (for joined/invited/knocking rooms) + limit: Maximum number of thread updates to return across all rooms + exclude_thread_ids: Optional set of thread IDs to exclude from results + + Returns: + A tuple of: + - Map of thread_id to list of ThreadUpdateInfo objects + - Global prev_batch token if there are more results, None otherwise + """ + # Separate rooms based on membership to handle LEAVE/BAN rooms specially + leave_ban_rooms: set[str] = set() + other_rooms: set[str] = set() + + for room_id in room_ids: + membership_info = room_membership_map.get(room_id) + if membership_info and membership_info.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + leave_ban_rooms.add(room_id) + else: + other_rooms.add(room_id) + + # Fetch thread updates from storage, handling LEAVE/BAN rooms separately + all_thread_updates: ThreadUpdatesMap = {} + prev_batch_token: StreamToken | None = None + remaining_limit = limit + + # Query LEAVE/BAN rooms with bounded to_token to prevent data leaks + if leave_ban_rooms: + for room_id in leave_ban_rooms: + if remaining_limit <= 0: + # We've hit the limit, set prev_batch to indicate more results + prev_batch_token = to_token + break + + membership_info = room_membership_map[room_id] + bounded_to_token = membership_info.event_pos.to_room_stream_token() + + ( + room_thread_updates, + room_prev_batch, + ) = await self._main_store.get_thread_updates_for_rooms( + room_ids={room_id}, + from_token=from_token.room_key if from_token else None, + to_token=bounded_to_token, + limit=remaining_limit, + exclude_thread_ids=exclude_thread_ids, + ) + + # Count updates and reduce remaining limit + num_updates = sum( + len(updates) for updates in room_thread_updates.values() + ) + remaining_limit -= num_updates + + # Merge updates + for thread_id, updates in room_thread_updates.items(): + all_thread_updates.setdefault(thread_id, []).extend(updates) + + # Merge prev_batch tokens (take the maximum for backward pagination) + if room_prev_batch is not None: + if prev_batch_token is None: + prev_batch_token = room_prev_batch + elif ( + room_prev_batch.room_key.stream + > prev_batch_token.room_key.stream + ): + prev_batch_token = room_prev_batch + + # Query other rooms (joined/invited/knocking) with normal to_token + if other_rooms and remaining_limit > 0: + ( + other_thread_updates, + other_prev_batch, + ) = await self._main_store.get_thread_updates_for_rooms( + room_ids=other_rooms, + from_token=from_token.room_key if from_token else None, + to_token=to_token.room_key, + limit=remaining_limit, + exclude_thread_ids=exclude_thread_ids, + ) + + # Merge updates + for thread_id, updates in other_thread_updates.items(): + all_thread_updates.setdefault(thread_id, []).extend(updates) + + # Merge prev_batch tokens + if other_prev_batch is not None: + if prev_batch_token is None: + prev_batch_token = other_prev_batch + elif ( + other_prev_batch.room_key.stream > prev_batch_token.room_key.stream + ): + prev_batch_token = other_prev_batch + + return all_thread_updates, prev_batch_token + + async def get_thread_updates_for_rooms( + self, + room_ids: frozenset[str], + room_membership_map: Mapping[str, "RoomsForUserType"], + user_id: str, + from_token: StreamToken | None, + to_token: StreamToken, + limit: int, + include_roots: bool = False, + exclude_thread_ids: set[str] | None = None, + ) -> tuple[dict[str, dict[str, ThreadUpdate]], StreamToken | None]: + """Get thread updates across multiple rooms with full processing pipeline. + + This is the main entry point for fetching thread updates. It handles: + - Fetching updates with membership-based security + - Filtering for visibility + - Optionally fetching thread roots and aggregations + - Building the response structure + + Args: + room_ids: The set of room IDs to fetch updates for + room_membership_map: Map of room_id to RoomsForUserType for membership info + user_id: The user requesting the updates + from_token: Lower bound (exclusive) for the query + to_token: Upper bound for the query + limit: Maximum number of updates to return + include_roots: Whether to fetch and include thread root events (default: False) + exclude_thread_ids: Optional set of thread IDs to exclude + + Returns: + A tuple of: + - Map of room_id to thread_root_id to ThreadUpdate + - Global prev_batch token if there are more results, None otherwise + """ + # Fetch thread updates with membership handling + all_thread_updates, prev_batch_token = await self._fetch_thread_updates( + room_ids=room_ids, + room_membership_map=room_membership_map, + from_token=from_token, + to_token=to_token, + limit=limit, + exclude_thread_ids=exclude_thread_ids, + ) + + if not all_thread_updates: + return {}, prev_batch_token + + # Filter thread updates for visibility + filtered_updates = await self._filter_thread_updates_for_user( + all_thread_updates, user_id + ) + + if not filtered_updates: + return {}, prev_batch_token + + # Optionally fetch thread root events and their bundled aggregations + thread_root_event_map: ThreadRootsMap = {} + aggregations_map: AggregationsMap = {} + if include_roots: + # Fetch thread root events + thread_root_events = await self._main_store.get_events_as_list( + filtered_updates.keys() + ) + thread_root_event_map = {e.event_id: e for e in thread_root_events} + + # Fetch bundled aggregations for the thread roots + if thread_root_event_map: + aggregations_map = await self.get_bundled_aggregations( + thread_root_event_map.values(), + user_id, + ) + + # Build response structure with per-thread prev_batch tokens + thread_updates = self._build_thread_updates_response( + filtered_updates=filtered_updates, + thread_root_event_map=thread_root_event_map, + aggregations_map=aggregations_map, + global_prev_batch_token=prev_batch_token, + ) + + return thread_updates, prev_batch_token + + @staticmethod + async def serialize_thread_updates( + thread_updates: Mapping[str, Mapping[str, ThreadUpdate]], + prev_batch_token: StreamToken | None, + event_serializer: "EventClientSerializer", + time_now: int, + store: "DataStore", + serialize_options: SerializeEventConfig, + ) -> JsonDict: + """ + Serialize thread updates to JSON format. + + This helper handles serialization of ThreadUpdate objects for both the + companion endpoint and the sliding sync extension. + + Args: + thread_updates: Map of room_id to thread_root_id to ThreadUpdate + prev_batch_token: Global pagination token for fetching more updates + event_serializer: The event serializer to use + time_now: Current time in milliseconds for event serialization + store: Datastore for serializing stream tokens + serialize_options: Serialization config + + Returns: + JSON-serializable dict with "updates" and optionally "prev_batch" + """ + updates_dict: JsonDict = {} + + for room_id, room_threads in thread_updates.items(): + room_updates: JsonDict = {} + for thread_root_id, update in room_threads.items(): + update_dict: JsonDict = {} + + # Serialize thread_root event if present + if update.thread_root is not None: + bundle_aggs_map = ( + {thread_root_id: update.bundled_aggregations} + if update.bundled_aggregations is not None + else None + ) + serialized_events = await event_serializer.serialize_events( + [update.thread_root], + time_now, + config=serialize_options, + bundle_aggregations=bundle_aggs_map, + ) + if serialized_events: + update_dict["thread_root"] = serialized_events[0] + + # Add per-thread prev_batch if present + if update.prev_batch is not None: + update_dict["prev_batch"] = await update.prev_batch.to_string(store) + + room_updates[thread_root_id] = update_dict + + updates_dict[room_id] = room_updates + + result: JsonDict = {"updates": updates_dict} + + # Add global prev_batch token if present + if prev_batch_token is not None: + result["prev_batch"] = await prev_batch_token.to_string(store) + + return result + async def get_threads( self, requester: Requester, diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index d62f2d675f6..914ef1f8cc9 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -14,7 +14,6 @@ import itertools import logging -from collections import defaultdict from typing import ( TYPE_CHECKING, AbstractSet, @@ -31,7 +30,6 @@ AccountDataTypes, EduTypes, EventContentFields, - Membership, MRelatesToFields, RelationTypes, ) @@ -40,15 +38,12 @@ from synapse.handlers.sliding_sync.room_lists import RoomsForUserType from synapse.logging.opentracing import trace from synapse.storage.databases.main.receipts import ReceiptInRoom -from synapse.storage.databases.main.relations import ThreadUpdateInfo from synapse.types import ( DeviceListUpdates, JsonMapping, MultiWriterStreamToken, - RoomStreamToken, SlidingSyncStreamToken, StrCollection, - StreamKeyType, StreamToken, ThreadSubscriptionsToken, ) @@ -64,7 +59,6 @@ concurrently_execute, gather_optional_coroutines, ) -from synapse.visibility import filter_events_for_client _ThreadSubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription @@ -72,7 +66,6 @@ _ThreadUnsubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription ) -_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate if TYPE_CHECKING: from synapse.server import HomeServer @@ -1040,42 +1033,6 @@ def _find_threads_in_timeline( threads_in_timeline.add(thread_id) return threads_in_timeline - def _merge_prev_batch_token( - self, - current_token: StreamToken | None, - new_token: StreamToken | None, - ) -> StreamToken | None: - """Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination. - - Args: - current_token: The current prev_batch token (may be None) - new_token: The new prev_batch token to merge (may be None) - - Returns: - The merged token (maximum of the two, or None if both are None) - """ - if new_token is None: - return current_token - if current_token is None: - return new_token - if new_token.room_key.stream > current_token.room_key.stream: - return new_token - return current_token - - def _merge_thread_updates( - self, - target: dict[str, list[ThreadUpdateInfo]], - source: dict[str, list[ThreadUpdateInfo]], - ) -> None: - """Merge thread updates from source into target. - - Args: - target: The target dict to merge into (modified in place) - source: The source dict to merge from - """ - for thread_id, updates in source.items(): - target.setdefault(thread_id, []).extend(updates) - async def get_threads_extension_response( self, sync_config: SlidingSyncConfig, @@ -1117,159 +1074,26 @@ async def get_threads_extension_response( actual_room_response_map ) - # Separate rooms into groups based on membership status. - # For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events - # that occurred after the user left/was banned. - leave_ban_rooms: set[str] = set() - other_rooms: set[str] = set() - - for room_id in actual_room_ids: - membership_info = room_membership_for_user_at_to_token_map.get(room_id) - if membership_info and membership_info.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - leave_ban_rooms.add(room_id) - else: - other_rooms.add(room_id) - - # Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks. - all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {} - prev_batch_token: StreamToken | None = None - remaining_limit = threads_request.limit - - # Query for rooms where the user has left or been banned, using their leave/ban - # event position as the upper bound to prevent seeing events after they left. - if leave_ban_rooms: - for room_id in leave_ban_rooms: - if remaining_limit <= 0: - # We've already fetched enough updates, but we still need to set - # prev_batch to indicate there are more results. - prev_batch_token = to_token - break - - membership_info = room_membership_for_user_at_to_token_map[room_id] - bounded_to_token = membership_info.event_pos.to_room_stream_token() - - ( - room_thread_updates, - room_prev_batch, - ) = await self.store.get_thread_updates_for_rooms( - room_ids={room_id}, - from_token=from_token.stream_token.room_key if from_token else None, - to_token=bounded_to_token, - limit=remaining_limit, - exclude_thread_ids=threads_to_exclude, - ) - - # Count how many updates we fetched and reduce the remaining limit - num_updates = sum( - len(updates) for updates in room_thread_updates.values() - ) - remaining_limit -= num_updates - - self._merge_thread_updates(all_thread_updates, room_thread_updates) - prev_batch_token = self._merge_prev_batch_token( - prev_batch_token, room_prev_batch - ) - - # Query for rooms where the user is joined, invited, or knocking, using the - # normal to_token as the upper bound. - if other_rooms and remaining_limit > 0: - ( - other_thread_updates, - other_prev_batch, - ) = await self.store.get_thread_updates_for_rooms( - room_ids=other_rooms, - from_token=from_token.stream_token.room_key if from_token else None, - to_token=to_token.room_key, - limit=remaining_limit, - exclude_thread_ids=threads_to_exclude, - ) - - self._merge_thread_updates(all_thread_updates, other_thread_updates) - prev_batch_token = self._merge_prev_batch_token( - prev_batch_token, other_prev_batch - ) - - if len(all_thread_updates) == 0: - return None - - # Build a mapping of event_id -> (thread_id, update) for efficient lookup - # during visibility filtering. - event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {} - for thread_id, updates in all_thread_updates.items(): - for update in updates: - event_to_thread_map[update.event_id] = (thread_id, update) - - # Fetch and filter events for visibility - all_events = await self.store.get_events_as_list(event_to_thread_map.keys()) - filtered_events = await filter_events_for_client( - self._storage_controllers, sync_config.user.to_string(), all_events + # Get thread updates using unified helper + user_id = sync_config.user.to_string() + ( + thread_updates_response, + prev_batch_token, + ) = await self.relations_handler.get_thread_updates_for_rooms( + room_ids=frozenset(actual_room_ids), + room_membership_map=room_membership_for_user_at_to_token_map, + user_id=user_id, + from_token=from_token.stream_token if from_token else None, + to_token=to_token, + limit=threads_request.limit, + include_roots=threads_request.include_roots, + exclude_thread_ids=threads_to_exclude, ) - # Rebuild thread updates from filtered events - filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list) - for event in filtered_events: - if event.event_id in event_to_thread_map: - thread_id, update = event_to_thread_map[event.event_id] - filtered_updates[thread_id].append(update) - - if not filtered_updates: + if not thread_updates_response: return None - # Note: Updates are already sorted by stream_ordering DESC from the database query, - # and filter_events_for_client preserves order, so updates[0] is guaranteed to be - # the latest event for each thread. - - # Optionally fetch thread root events and their bundled aggregations - thread_root_event_map = {} - aggregations_map = {} - if threads_request.include_roots: - thread_root_events = await self.store.get_events_as_list( - filtered_updates.keys() - ) - thread_root_event_map = {e.event_id: e for e in thread_root_events} - - if thread_root_event_map: - aggregations_map = ( - await self.relations_handler.get_bundled_aggregations( - thread_root_event_map.values(), - sync_config.user.to_string(), - ) - ) - - thread_updates: dict[str, dict[str, _ThreadUpdate]] = {} - for thread_root, updates in filtered_updates.items(): - # We only care about the latest update for the thread. - # After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering). - latest_update = updates[0] - - # Generate per-thread prev_batch token if this thread has multiple visible updates. - # When we hit the global limit, we generate prev_batch tokens for all threads, even if - # we only saw 1 update for them. This is to cover the case where we only saw - # a single update for a given thread, but the global limit prevents us from - # obtaining other updates which would have otherwise been included in the - # range. - per_thread_prev_batch = None - if len(updates) > 1 or prev_batch_token is not None: - # Create a token pointing to one position before the latest event's stream position. - # This makes it exclusive - /relations with dir=b won't return the latest event again. - # Use StreamToken.START as base (all other streams at 0) since only room position matters. - per_thread_prev_batch = StreamToken.START.copy_and_replace( - StreamKeyType.ROOM, - RoomStreamToken(stream=latest_update.stream_ordering - 1), - ) - - thread_updates.setdefault(latest_update.room_id, {})[thread_root] = ( - _ThreadUpdate( - thread_root=thread_root_event_map.get(thread_root), - prev_batch=per_thread_prev_batch, - bundled_aggregations=aggregations_map.get(thread_root), - ) - ) - return SlidingSyncResult.Extensions.ThreadsExtension( - updates=thread_updates, + updates=thread_updates_response, prev_batch=prev_batch_token, ) diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index c913bc69707..bab109f21e4 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -20,17 +20,33 @@ import logging import re -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Annotated -from synapse.api.constants import Direction +from pydantic import StrictBool, StrictStr +from pydantic.types import StringConstraints + +from synapse.api.constants import Direction, Membership +from synapse.api.errors import SynapseError +from synapse.events.utils import SerializeEventConfig from synapse.handlers.relations import ThreadsListInclude from synapse.http.server import HttpServer -from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string +from synapse.http.servlet import ( + RestServlet, + parse_and_validate_json_object_from_request, + parse_boolean, + parse_integer, + parse_string, +) from synapse.http.site import SynapseRequest from synapse.rest.client._base import client_patterns from synapse.storage.databases.main.relations import ThreadsNextBatch -from synapse.streams.config import PaginationConfig -from synapse.types import JsonDict +from synapse.streams.config import ( + PaginationConfig, + extract_stream_token_from_pagination_token, +) +from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID +from synapse.types.handlers.sliding_sync import PerConnectionState, SlidingSyncConfig +from synapse.types.rest.client import RequestBodyModel, SlidingSyncBody if TYPE_CHECKING: from synapse.server import HomeServer @@ -38,6 +54,39 @@ logger = logging.getLogger(__name__) +class ThreadUpdatesBody(RequestBodyModel): + """ + Thread updates companion endpoint request body (MSC4360). + + Allows paginating thread updates using the same room selection as a sliding sync + request. This enables clients to fetch thread updates for the same set of rooms + that were included in their sliding sync response. + + Attributes: + lists: Sliding window API lists, using the same structure as SlidingSyncBody.lists. + If provided along with room_subscriptions, the union of rooms from both will + be used. + room_subscriptions: Room subscription API rooms, using the same structure as + SlidingSyncBody.room_subscriptions. If provided along with lists, the union + of rooms from both will be used. + include_roots: Whether to include the thread root events in the response. + Defaults to False. + + If neither lists nor room_subscriptions are provided, thread updates from all + joined rooms are returned. + """ + + lists: ( + dict[ + Annotated[str, StringConstraints(max_length=64, strict=True)], + SlidingSyncBody.SlidingSyncList, + ] + | None + ) = None + room_subscriptions: dict[StrictStr, SlidingSyncBody.RoomSubscription] | None = None + include_roots: StrictBool = False + + class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -133,6 +182,167 @@ async def on_GET( return 200, result +class ThreadUpdatesServlet(RestServlet): + """ + Companion endpoint to the Sliding Sync threads extension (MSC4360). + Allows clients to bulk fetch thread updates across all joined rooms. + """ + + PATTERNS = client_patterns( + "/io.element.msc4360/thread_updates$", + unstable=True, + releases=(), + ) + CATEGORY = "Client API requests" + + def __init__(self, hs: "HomeServer"): + super().__init__() + self.clock = hs.get_clock() + self.auth = hs.get_auth() + self.store = hs.get_datastores().main + self.relations_handler = hs.get_relations_handler() + self.event_serializer = hs.get_event_client_serializer() + self._storage_controllers = hs.get_storage_controllers() + self.sliding_sync_handler = hs.get_sliding_sync_handler() + + async def on_POST(self, request: SynapseRequest) -> tuple[int, JsonDict]: + requester = await self.auth.get_user_by_req(request) + + # Parse request body + body = parse_and_validate_json_object_from_request(request, ThreadUpdatesBody) + + # Parse query parameters + dir_str = parse_string(request, "dir", default="b") + if dir_str != "b": + raise SynapseError( + 400, + "The 'dir' parameter must be 'b' (backward). Forward pagination is not supported.", + ) + + limit = parse_integer(request, "limit", default=100) + if limit <= 0: + raise SynapseError(400, "The 'limit' parameter must be positive.") + + from_token_str = parse_string(request, "from") + to_token_str = parse_string(request, "to") + + # Parse pagination tokens + from_token: StreamToken | None = None + to_token: StreamToken | None = None + + if from_token_str: + try: + stream_token_str = extract_stream_token_from_pagination_token( + from_token_str + ) + from_token = await StreamToken.from_string(self.store, stream_token_str) + except Exception as e: + logger.exception("Error parsing 'from' token: %s", from_token_str) + raise SynapseError(400, "'from' parameter is invalid") from e + + if to_token_str: + try: + stream_token_str = extract_stream_token_from_pagination_token( + to_token_str + ) + to_token = await StreamToken.from_string(self.store, stream_token_str) + except Exception: + raise SynapseError(400, "'to' parameter is invalid") + + # Get the list of rooms to fetch thread updates for + user_id = requester.user.to_string() + user = UserID.from_string(user_id) + + # Get the current stream token for membership lookup + if from_token is None: + max_stream_ordering = self.store.get_room_max_stream_ordering() + current_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, RoomStreamToken(stream=max_stream_ordering) + ) + else: + current_token = from_token + + # Get room membership information to properly handle LEAVE/BAN rooms + ( + room_membership_for_user_at_to_token_map, + _, + _, + ) = await self.sliding_sync_handler.room_lists.get_room_membership_for_user_at_to_token( + user=user, + to_token=current_token, + from_token=None, + ) + + # Determine which rooms to fetch updates for based on lists/room_subscriptions + if body.lists is not None or body.room_subscriptions is not None: + # Use sliding sync room selection logic + sync_config = SlidingSyncConfig( + user=user, + requester=requester, + lists=body.lists, + room_subscriptions=body.room_subscriptions, + ) + + # Use the sliding sync room list handler to get the same set of rooms + interested_rooms = ( + await self.sliding_sync_handler.room_lists.compute_interested_rooms( + sync_config=sync_config, + previous_connection_state=PerConnectionState(), + to_token=current_token, + from_token=None, + ) + ) + + room_ids = frozenset(interested_rooms.relevant_room_map.keys()) + else: + # No lists or room_subscriptions, use only joined rooms + room_ids = frozenset( + room_id + for room_id, membership_info in room_membership_for_user_at_to_token_map.items() + if membership_info.membership == Membership.JOIN + ) + + # Get thread updates using unified helper + ( + thread_updates, + prev_batch_token, + ) = await self.relations_handler.get_thread_updates_for_rooms( + room_ids=room_ids, + room_membership_map=room_membership_for_user_at_to_token_map, + user_id=user_id, + from_token=to_token, + to_token=from_token if from_token else current_token, + limit=limit, + include_roots=body.include_roots, + ) + + if not thread_updates: + return 200, {"chunk": {}} + + # Serialize thread updates using shared helper + time_now = self.clock.time_msec() + serialize_options = SerializeEventConfig(requester=requester) + + serialized = await self.relations_handler.serialize_thread_updates( + thread_updates=thread_updates, + prev_batch_token=prev_batch_token, + event_serializer=self.event_serializer, + time_now=time_now, + store=self.store, + serialize_options=serialize_options, + ) + + # Build response with "chunk" wrapper and "next_batch" key + # (companion endpoint uses different key names than sliding sync) + response: JsonDict = {"chunk": serialized["updates"]} + if "prev_batch" in serialized: + response["next_batch"] = serialized["prev_batch"] + + return 200, response + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: RelationPaginationServlet(hs).register(http_server) ThreadsServlet(hs).register(http_server) + if hs.config.experimental.msc4360_enabled: + ThreadUpdatesServlet(hs).register(http_server) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b02ac8e4e18..e8f7417f508 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -37,6 +37,7 @@ format_event_raw, ) from synapse.handlers.presence import format_user_presence_state +from synapse.handlers.relations import RelationsHandler from synapse.handlers.sliding_sync import SlidingSyncConfig, SlidingSyncResult from synapse.handlers.sync import ( ArchivedSyncResult, @@ -1107,6 +1108,7 @@ async def encode_extensions( time_now, extensions.threads, self.store, + requester, ) return serialized_extensions @@ -1150,6 +1152,7 @@ async def _serialise_threads( time_now: int, threads: SlidingSyncResult.Extensions.ThreadsExtension, store: "DataStore", + requester: Requester, ) -> JsonDict: """ Serialize the threads extension response for sliding sync. @@ -1159,6 +1162,7 @@ async def _serialise_threads( time_now: The current time in milliseconds, used for event serialization. threads: The threads extension data containing thread updates and pagination tokens. store: The datastore, needed for serializing stream tokens. + requester: The user making the request, used for transaction_id inclusion. Returns: A JSON-serializable dict containing: @@ -1169,46 +1173,24 @@ async def _serialise_threads( - "prev_batch": A pagination token for fetching older events in the thread. - "prev_batch": A pagination token for fetching older thread updates (if available). """ - out: JsonDict = {} - - if threads.updates: - updates_dict: JsonDict = {} - for room_id, thread_updates in threads.updates.items(): - room_updates: JsonDict = {} - for thread_root_id, update in thread_updates.items(): - # Serialize the update - update_dict: JsonDict = {} - - # Serialize the thread_root event if present - if update.thread_root is not None: - # Create a mapping of event_id to bundled_aggregations - bundle_aggs_map = ( - {thread_root_id: update.bundled_aggregations} - if update.bundled_aggregations - else None - ) - serialized_events = await event_serializer.serialize_events( - [update.thread_root], - time_now, - bundle_aggregations=bundle_aggs_map, - ) - if serialized_events: - update_dict["thread_root"] = serialized_events[0] - - # Add prev_batch if present - if update.prev_batch is not None: - update_dict["prev_batch"] = await update.prev_batch.to_string(store) - - room_updates[thread_root_id] = update_dict - - updates_dict[room_id] = room_updates - - out["updates"] = updates_dict - - if threads.prev_batch: - out["prev_batch"] = await threads.prev_batch.to_string(store) - - return out + if not threads.updates: + out: JsonDict = {} + if threads.prev_batch: + out["prev_batch"] = await threads.prev_batch.to_string(store) + return out + + # Create serialization config to include transaction_id for requester's events + serialize_options = SerializeEventConfig(requester=requester) + + # Use shared serialization helper (static method) + return await RelationsHandler.serialize_thread_updates( + thread_updates=threads.updates, + prev_batch_token=threads.prev_batch, + event_serializer=event_serializer, + time_now=time_now, + store=store, + serialize_options=serialize_options, + ) def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 5b59dd6145d..4f7f44ac65e 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -35,6 +35,32 @@ MAX_LIMIT = 1000 +def extract_stream_token_from_pagination_token(token_str: str) -> str: + """ + Extract the StreamToken portion from a pagination token string. + + Handles both: + - StreamToken format: "s123_456_..." + - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) + + This allows clients using sliding sync to use their pos tokens + with endpoints like /relations and /messages. + + Args: + token_str: The token string to parse + + Returns: + The StreamToken portion of the token + """ + if "/" in token_str: + # SlidingSyncStreamToken format: "connection_position/stream_token" + # Split and return just the stream_token part + parts = token_str.split("/", 1) + if len(parts) == 2: + return parts[1] + return token_str + + @attr.s(slots=True, auto_attribs=True) class PaginationConfig: """A configuration object which stores pagination parameters.""" @@ -57,32 +83,14 @@ async def from_request( from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") - # Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format - def extract_stream_token(token_str: str) -> str: - """ - Extract the StreamToken portion from a token string. - - Handles both: - - StreamToken format: "s123_456_..." - - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) - - This allows clients using sliding sync to use their pos tokens - with endpoints like /relations and /messages. - """ - if "/" in token_str: - # SlidingSyncStreamToken format: "connection_position/stream_token" - # Split and return just the stream_token part - parts = token_str.split("/", 1) - if len(parts) == 2: - return parts[1] - return token_str - try: from_tok = None if from_tok_str == "END": from_tok = None # For backwards compat. elif from_tok_str: - stream_token_str = extract_stream_token(from_tok_str) + stream_token_str = extract_stream_token_from_pagination_token( + from_tok_str + ) from_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") @@ -90,7 +98,9 @@ def extract_stream_token(token_str: str) -> str: try: to_tok = None if to_tok_str: - stream_token_str = extract_stream_token(to_tok_str) + stream_token_str = extract_stream_token_from_pagination_token( + to_tok_str + ) to_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index a5d90252b76..def7a709fa5 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -37,7 +37,8 @@ from synapse.events import EventBase if TYPE_CHECKING: - from synapse.handlers.relations import BundledAggregations + from synapse.handlers.relations import BundledAggregations, ThreadUpdate + from synapse.types import ( DeviceListUpdates, JsonDict, @@ -409,30 +410,7 @@ class ThreadsExtension: to paginate through older thread updates. """ - @attr.s(slots=True, frozen=True, auto_attribs=True) - class ThreadUpdate: - """Information about a single thread that has new activity. - - Attributes: - thread_root: The thread root event, if requested via include_roots in the - request. This is the event that started the thread. - prev_batch: A pagination token (exclusive) for fetching older events in this - specific thread. Only present if the thread has multiple updates in the - sync window. This token can be used with the /relations endpoint with - dir=b to paginate backwards through the thread's history. - bundled_aggregations: Bundled aggregations for the thread root event, - including the latest_event in the thread (found in - unsigned.m.relations.m.thread). Only present if thread_root is included. - """ - - thread_root: EventBase | None - prev_batch: StreamToken | None - bundled_aggregations: "BundledAggregations | None" = None - - def __bool__(self) -> bool: - return bool(self.thread_root) or bool(self.prev_batch) - - updates: Mapping[str, Mapping[str, ThreadUpdate]] | None + updates: Mapping[str, Mapping[str, "ThreadUpdate"]] | None prev_batch: StreamToken | None def __bool__(self) -> bool: diff --git a/tests/rest/client/sliding_sync/test_extension_threads.py b/tests/rest/client/sliding_sync/test_extension_threads.py index cfbc3a2155b..f4d2204870a 100644 --- a/tests/rest/client/sliding_sync/test_extension_threads.py +++ b/tests/rest/client/sliding_sync/test_extension_threads.py @@ -924,6 +924,250 @@ def test_thread_in_timeline_included_with_include_roots(self) -> None: # Verify the thread root event is present self.assertIn("thread_root", thread_updates[thread_root_id]) + def test_thread_updates_initial_sync(self) -> None: + """ + Test that prev_batch from the threads extension response can be used + with the /thread_updates endpoint to get additional thread updates during + initial sync. This verifies: + 1. The from parameter boundary is exclusive (no duplicates) + 2. Using prev_batch as 'from' provides complete coverage (no gaps) + 3. Works correctly with different numbers of threads + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 5 thread roots + thread_ids = [] + for i in range(5): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + thread_ids.append(thread_root_id) + + # Add reply to each thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Do initial sync with threads extension enabled and limit=2 + sync_body = { + "lists": { + "all-rooms": { + "ranges": [[0, 10]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "limit": 2, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Should get 2 thread updates + thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id] + self.assertEqual(len(thread_updates), 2) + first_sync_threads = set(thread_updates.keys()) + + # Get the top-level prev_batch token from the extension + self.assertIn("prev_batch", response_body["extensions"][EXT_NAME]) + prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"] + + # Use prev_batch with /thread_updates endpoint to get remaining updates + # Note: prev_batch should be used as 'from' parameter (upper bound for backward pagination) + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}", + access_token=user1_tok, + content={}, + ) + self.assertEqual(channel.code, 200) + + # Should get the remaining 3 thread updates + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertEqual(len(chunk[room_id]), 3) + + thread_updates_response_threads = set(chunk[room_id].keys()) + + # Verify no overlap - the from parameter boundary should be exclusive + self.assertEqual( + len(first_sync_threads & thread_updates_response_threads), + 0, + "from parameter boundary should be exclusive - no thread should appear in both responses", + ) + + # Verify no gaps - all threads should be accounted for + all_threads = set(thread_ids) + combined_threads = first_sync_threads | thread_updates_response_threads + self.assertEqual( + combined_threads, + all_threads, + "Combined responses should include all thread updates with no gaps", + ) + + def test_thread_updates_incremental_sync(self) -> None: + """ + Test the intended usage pattern from MSC4360: using prev_batch as 'from' + and a previous sync pos as 'to' with /thread_updates to fill gaps between + syncs. This verifies that using both bounds together provides complete + coverage with no gaps or duplicates. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create 3 threads initially + initial_thread_ids = [] + for i in range(3): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + initial_thread_ids.append(thread_root_id) + + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # First sync + sync_body = { + "lists": { + "all-rooms": { + "ranges": [[0, 10]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, pos1 = self.do_sync(sync_body, tok=user1_tok) + + # Should get 3 thread updates + first_sync_threads = set( + response_body["extensions"][EXT_NAME]["updates"][room_id].keys() + ) + self.assertEqual(len(first_sync_threads), 3) + + # Create 3 more threads after the first sync + new_thread_ids = [] + for i in range(3, 6): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + new_thread_ids.append(thread_root_id) + + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Second sync with limit=1 to get only some of the new threads + sync_body_with_limit = { + "lists": { + "all-rooms": { + "ranges": [[0, 10]], + "required_state": [], + "timeline_limit": 0, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "limit": 1, + } + }, + } + response_body, pos2 = self.do_sync( + sync_body_with_limit, tok=user1_tok, since=pos1 + ) + + # Should get 1 thread update + second_sync_threads = set( + response_body["extensions"][EXT_NAME]["updates"][room_id].keys() + ) + self.assertEqual(len(second_sync_threads), 1) + + # Get prev_batch from the extension + self.assertIn("prev_batch", response_body["extensions"][EXT_NAME]) + prev_batch = response_body["extensions"][EXT_NAME]["prev_batch"] + + # Now use /thread_updates with from=prev_batch and to=pos1 + # This should get the 2 remaining new threads (created after pos1, not returned in second sync) + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from={prev_batch}&to={pos1}", + access_token=user1_tok, + content={}, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + thread_updates_threads = set(chunk[room_id].keys()) + + # Should get exactly 2 threads + self.assertEqual(len(thread_updates_threads), 2) + + # Verify no overlap with second sync + self.assertEqual( + len(second_sync_threads & thread_updates_threads), + 0, + "No thread should appear in both second sync and thread_updates responses", + ) + + # Verify no overlap with first sync (to=pos1 should exclude those) + self.assertEqual( + len(first_sync_threads & thread_updates_threads), + 0, + "Threads from first sync should not appear in thread_updates (to=pos1 excludes them)", + ) + + # Verify no gaps - all new threads should be accounted for + all_new_threads = set(new_thread_ids) + combined_new_threads = second_sync_threads | thread_updates_threads + self.assertEqual( + combined_new_threads, + all_new_threads, + "Combined responses should include all new thread updates with no gaps", + ) + def test_threads_only_from_rooms_in_list(self) -> None: """ Test that thread updates are only returned for rooms that are in the diff --git a/tests/rest/client/test_thread_updates.py b/tests/rest/client/test_thread_updates.py new file mode 100644 index 00000000000..1418a35dd5e --- /dev/null +++ b/tests/rest/client/test_thread_updates.py @@ -0,0 +1,957 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.constants import RelationTypes +from synapse.rest.client import login, relations, room +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock + +from tests import unittest + +logger = logging.getLogger(__name__) + + +class ThreadUpdatesTestCase(unittest.HomeserverTestCase): + """ + Test the /thread_updates companion endpoint (MSC4360). + """ + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + relations.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4360_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + + def test_no_updates_for_new_user(self) -> None: + """ + Test that a user with no thread updates gets an empty response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Request thread updates + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + content={"include_roots": True}, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert empty chunk and no next_batch + self.assertEqual(channel.json_body["chunk"], {}) + self.assertNotIn("next_batch", channel.json_body) + + def test_single_thread_update(self) -> None: + """ + Test that a single thread with one reply appears in the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply to thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert thread is present + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertIn(thread_root_id, chunk[room_id]) + + # Assert thread root is included + thread_update = chunk[room_id][thread_root_id] + self.assertIn("thread_root", thread_update) + self.assertEqual(thread_update["thread_root"]["event_id"], thread_root_id) + + # Assert prev_batch is NOT present (only 1 update - the reply) + self.assertNotIn("prev_batch", thread_update) + + def test_multiple_threads_single_room(self) -> None: + """ + Test that multiple threads in the same room are grouped correctly. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create two threads + thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert both threads are in the same room + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + self.assertEqual(len(chunk), 1, "Should only have one room") + self.assertEqual(len(chunk[room_id]), 2, "Should have two threads") + self.assertIn(thread1_root_id, chunk[room_id]) + self.assertIn(thread2_root_id, chunk[room_id]) + + def test_threads_across_multiple_rooms(self) -> None: + """ + Test that threads from different rooms are grouped by room_id. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in both rooms + thread_a_root_id = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[ + "event_id" + ] + thread_b_root_id = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[ + "event_id" + ] + + # Add replies + self.helper.send_event( + room_a_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to A", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_a_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_b_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to B", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_b_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert both rooms are present with their threads + chunk = channel.json_body["chunk"] + self.assertEqual(len(chunk), 2, "Should have two rooms") + self.assertIn(room_a_id, chunk) + self.assertIn(room_b_id, chunk) + self.assertIn(thread_a_root_id, chunk[room_a_id]) + self.assertIn(thread_b_root_id, chunk[room_b_id]) + + def test_pagination_with_from_token(self) -> None: + """ + Test that pagination works using the next_batch token. + This verifies that multiple calls to /thread_updates return all thread + updates with no duplicates and no gaps. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create many threads (more than default limit) + thread_ids = [] + for i in range(5): + thread_root_id = self.helper.send( + room_id, body=f"Thread {i}", tok=user1_tok + )["event_id"] + thread_ids.append(thread_root_id) + + # Add reply + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply to thread {i}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request first page with small limit + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Should have 2 threads and a next_batch token + first_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(first_page_threads), 2) + self.assertIn("next_batch", channel.json_body) + + next_batch = channel.json_body["next_batch"] + + # Request second page + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch}", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + second_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(second_page_threads), 2) + + # Verify no overlap + self.assertEqual( + len(first_page_threads & second_page_threads), + 0, + "Pages should not have overlapping threads", + ) + + # Request third page to get the remaining thread + self.assertIn("next_batch", channel.json_body) + next_batch_2 = channel.json_body["next_batch"] + + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=2&from={next_batch_2}", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + third_page_threads = set(channel.json_body["chunk"][room_id].keys()) + self.assertEqual(len(third_page_threads), 1) + + # Verify no overlap between any pages + self.assertEqual(len(first_page_threads & third_page_threads), 0) + self.assertEqual(len(second_page_threads & third_page_threads), 0) + + # Verify no gaps - all threads should be accounted for across all pages + all_threads = set(thread_ids) + combined_threads = first_page_threads | second_page_threads | third_page_threads + self.assertEqual( + combined_threads, + all_threads, + "Combined pages should include all thread updates with no gaps", + ) + + def test_invalid_dir_parameter(self) -> None: + """ + Test that forward pagination (dir=f) is rejected with an error. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Request with forward direction should fail + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=f", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 400) + + def test_invalid_limit_parameter(self) -> None: + """ + Test that invalid limit values are rejected. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Zero limit should fail + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=0", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 400) + + # Negative limit should fail + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=-5", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 400) + + def test_invalid_pagination_tokens(self) -> None: + """ + Test that invalid from/to tokens are rejected with appropriate errors. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Invalid from token + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&from=invalid_token", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 400) + + # Invalid to token + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to=invalid_token", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 400) + + def test_to_token_filtering(self) -> None: + """ + Test that the to_token parameter correctly limits pagination to updates + newer than the to_token (since we paginate backwards from newest to oldest). + This also verifies the to_token boundary is exclusive - updates at exactly + the to_token position should not be included (as they were already returned + in a previous response that synced up to that position). + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create two thread roots + thread1_root_id = self.helper.send(room_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + + # Send replies to both threads + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user1_tok, + ) + + # Request with limit=1 to get only the latest thread update + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=1", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200) + self.assertIn("next_batch", channel.json_body) + + # next_batch points to before the update we just received + next_batch = channel.json_body["next_batch"] + first_response_threads = set(channel.json_body["chunk"][room_id].keys()) + + # Request again with to=next_batch (lower bound for backward pagination) and no + # limit. + # This should get only the same thread updates as before, not the additional + # update. + channel = self.make_request( + "POST", + f"/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&to={next_batch}", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + self.assertIn(room_id, chunk) + # Should have exactly one thread update + self.assertEqual(len(chunk[room_id]), 1) + + second_response_threads = set(chunk[room_id].keys()) + + # Verify no overlap - the from parameter boundary should be exclusive + self.assertEqual( + first_response_threads, + second_response_threads, + "to parameter boundary should be exclusive - both responses should be identical", + ) + + def test_bundled_aggregations_on_thread_roots(self) -> None: + """ + Test that thread root events include bundled aggregations with latest thread event. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_id = self.helper.send(room_id, body="Thread root", tok=user1_tok)[ + "event_id" + ] + + # Send replies to create bundled aggregation data + for i in range(2): + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": f"Reply {i + 1}", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200) + + # Check that thread root has bundled aggregations with latest event + chunk = channel.json_body["chunk"] + thread_update = chunk[room_id][thread_root_id] + thread_root_event = thread_update["thread_root"] + + # Should have unsigned data with latest thread event content + self.assertIn("unsigned", thread_root_event) + self.assertIn("m.relations", thread_root_event["unsigned"]) + relations = thread_root_event["unsigned"]["m.relations"] + self.assertIn(RelationTypes.THREAD, relations) + + # Check latest event is present in bundled aggregations + thread_summary = relations[RelationTypes.THREAD] + self.assertIn("latest_event", thread_summary) + latest_event = thread_summary["latest_event"] + self.assertEqual(latest_event["content"]["body"], "Reply 2") + + def test_only_joined_rooms(self) -> None: + """ + Test that thread updates only include rooms where the user is currently joined. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create two rooms, user1 joins both + room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room2_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room2_id, user1_id, tok=user1_tok) + + # Create threads in both rooms + thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user2_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room1_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread1_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room2_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to thread 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread2_root_id, + }, + }, + tok=user2_tok, + ) + + # User1 leaves room2 + self.helper.leave(room2_id, user1_id, tok=user1_tok) + + # Request thread updates for user1 - should only get room1 + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={"include_roots": True}, + ) + self.assertEqual(channel.code, 200) + + chunk = channel.json_body["chunk"] + # Should only have room1, not room2 + self.assertIn(room1_id, chunk) + self.assertNotIn(room2_id, chunk) + self.assertIn(thread1_root_id, chunk[room1_id]) + + def test_room_filtering_with_lists(self) -> None: + """ + Test that room filtering works correctly using the lists parameter. + This verifies that thread updates are only returned for rooms matching + the provided filters. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create an encrypted room and an unencrypted room + encrypted_room_id = self.helper.create_room_as( + user1_id, + tok=user1_tok, + extra_content={ + "initial_state": [ + { + "type": "m.room.encryption", + "state_key": "", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + } + ] + }, + ) + unencrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in both rooms + enc_thread_root_id = self.helper.send( + encrypted_room_id, body="Encrypted thread", tok=user1_tok + )["event_id"] + unenc_thread_root_id = self.helper.send( + unencrypted_room_id, body="Unencrypted thread", tok=user1_tok + )["event_id"] + + # Add replies to both threads + self.helper.send_event( + encrypted_room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply in encrypted room", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": enc_thread_root_id, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + unencrypted_room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply in unencrypted room", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": unenc_thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with filter for encrypted rooms only + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "lists": { + "encrypted_list": { + "ranges": [[0, 99]], + "required_state": [["m.room.encryption", ""]], + "timeline_limit": 10, + "filters": {"is_encrypted": True}, + } + } + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should only include the encrypted room + self.assertIn(encrypted_room_id, chunk) + self.assertNotIn(unencrypted_room_id, chunk) + self.assertIn(enc_thread_root_id, chunk[encrypted_room_id]) + + def test_room_filtering_with_room_subscriptions(self) -> None: + """ + Test that room filtering works correctly using the room_subscriptions parameter. + This verifies that thread updates are only returned for explicitly subscribed rooms. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create three rooms + room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room2_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room3_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in all three rooms + thread1_root_id = self.helper.send(room1_id, body="Thread 1", tok=user1_tok)[ + "event_id" + ] + thread2_root_id = self.helper.send(room2_id, body="Thread 2", tok=user1_tok)[ + "event_id" + ] + thread3_root_id = self.helper.send(room3_id, body="Thread 3", tok=user1_tok)[ + "event_id" + ] + + # Add replies to all threads + for room_id, thread_root_id in [ + (room1_id, thread1_root_id), + (room2_id, thread2_root_id), + (room3_id, thread3_root_id), + ]: + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with subscription to only room1 and room2 + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "room_subscriptions": { + room1_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + }, + room2_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + }, + } + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should only include room1 and room2, not room3 + self.assertIn(room1_id, chunk) + self.assertIn(room2_id, chunk) + self.assertNotIn(room3_id, chunk) + self.assertIn(thread1_root_id, chunk[room1_id]) + self.assertIn(thread2_root_id, chunk[room2_id]) + + def test_room_filtering_with_lists_and_room_subscriptions(self) -> None: + """ + Test that room filtering works correctly when both lists and room_subscriptions + are provided. The union of rooms from both should be included. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create an encrypted room and two unencrypted rooms + encrypted_room_id = self.helper.create_room_as( + user1_id, + tok=user1_tok, + extra_content={ + "initial_state": [ + { + "type": "m.room.encryption", + "state_key": "", + "content": {"algorithm": "m.megolm.v1.aes-sha2"}, + } + ] + }, + ) + unencrypted_room1_id = self.helper.create_room_as(user1_id, tok=user1_tok) + unencrypted_room2_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create threads in all three rooms + enc_thread_root_id = self.helper.send( + encrypted_room_id, body="Encrypted thread", tok=user1_tok + )["event_id"] + unenc1_thread_root_id = self.helper.send( + unencrypted_room1_id, body="Unencrypted thread 1", tok=user1_tok + )["event_id"] + unenc2_thread_root_id = self.helper.send( + unencrypted_room2_id, body="Unencrypted thread 2", tok=user1_tok + )["event_id"] + + # Add replies to all threads + for room_id, thread_root_id in [ + (encrypted_room_id, enc_thread_root_id), + (unencrypted_room1_id, unenc1_thread_root_id), + (unencrypted_room2_id, unenc2_thread_root_id), + ]: + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Request thread updates with: + # - lists: filter for encrypted rooms + # - room_subscriptions: explicitly subscribe to unencrypted_room1_id + # Expected: should get both encrypted_room_id (from list) and unencrypted_room1_id + # (from subscription), but NOT unencrypted_room2_id + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b", + access_token=user1_tok, + content={ + "lists": { + "encrypted_list": { + "ranges": [[0, 99]], + "required_state": [["m.room.encryption", ""]], + "timeline_limit": 10, + "filters": {"is_encrypted": True}, + } + }, + "room_subscriptions": { + unencrypted_room1_id: { + "required_state": [["m.room.name", ""]], + "timeline_limit": 10, + } + }, + }, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + chunk = channel.json_body["chunk"] + # Should include encrypted_room_id (from list filter) and unencrypted_room1_id + # (from subscription), but not unencrypted_room2_id + self.assertIn(encrypted_room_id, chunk) + self.assertIn(unencrypted_room1_id, chunk) + self.assertNotIn(unencrypted_room2_id, chunk) + self.assertIn(enc_thread_root_id, chunk[encrypted_room_id]) + self.assertIn(unenc1_thread_root_id, chunk[unencrypted_room1_id]) + + def test_threads_not_returned_after_leaving_room(self) -> None: + """ + Test that thread updates are properly bounded when a user leaves a room. + + Users should see thread updates that occurred up to the point they left, + but NOT updates that occurred after they left. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create room and both users join + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + self.helper.join(room_id, user2_id, tok=user2_tok) + + # Create thread + res = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root = res["event_id"] + + # Reply in thread while user2 is joined + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1 while user2 joined", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 leaves the room + self.helper.leave(room_id, user2_id, tok=user2_tok) + + # Another reply after user2 left + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2 after user2 left", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 gets thread updates with an explicit room subscription + # (We need to explicitly subscribe to the room to include it after leaving; + # otherwise only joined rooms are returned) + channel = self.make_request( + "POST", + "/_matrix/client/unstable/io.element.msc4360/thread_updates?dir=b&limit=100", + { + "room_subscriptions": { + room_id: { + "required_state": [], + "timeline_limit": 0, + } + } + }, + access_token=user2_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Assert: User2 SHOULD see Reply 1 (happened while joined) but NOT Reply 2 (after leaving) + chunk = channel.json_body["chunk"] + self.assertIn( + room_id, + chunk, + "Thread updates should include the room user2 left", + ) + self.assertIn( + thread_root, + chunk[room_id], + "Thread root should be in the updates", + ) + + # Verify that only a single update was seen (Reply 1) by checking that there's + # no prev_batch token. If Reply 2 was also included, there would be multiple + # updates and a prev_batch token would be present. + thread_update = chunk[room_id][thread_root] + self.assertNotIn( + "prev_batch", + thread_update, + "No prev_batch should be present since only one update (Reply 1) is visible", + )