diff --git a/changelog.d/19005.feature b/changelog.d/19005.feature new file mode 100644 index 00000000000..811d2e31af8 --- /dev/null +++ b/changelog.d/19005.feature @@ -0,0 +1 @@ +Add experimental support for MSC4360: Sliding Sync Threads Extension. diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index d7a3d675583..8546e2fc400 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -595,3 +595,6 @@ def read_config( # MSC4306: Thread Subscriptions # (and MSC4308: Thread Subscriptions extension to Sliding Sync) self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False) + + # MSC4360: Threads Extension to Sliding Sync + self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False) diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py index b1158ee77d5..54901b56452 100644 --- a/synapse/handlers/relations.py +++ b/synapse/handlers/relations.py @@ -109,8 +109,6 @@ async def get_relations( ) -> JsonDict: """Get related events of a event, ordered by topological ordering. - TODO Accept a PaginationConfig instead of individual pagination parameters. - Args: requester: The user requesting the relations. event_id: Fetch events that relate to this event ID. diff --git a/synapse/handlers/sliding_sync/extensions.py b/synapse/handlers/sliding_sync/extensions.py index 25ee954b7fd..2c39838fa8b 100644 --- a/synapse/handlers/sliding_sync/extensions.py +++ b/synapse/handlers/sliding_sync/extensions.py @@ -24,12 +24,13 @@ Optional, Sequence, Set, + Tuple, cast, ) from typing_extensions import TypeAlias, assert_never -from synapse.api.constants import AccountDataTypes, EduTypes +from synapse.api.constants import AccountDataTypes, EduTypes, RelationTypes from synapse.handlers.receipts import ReceiptEventSource from synapse.logging.opentracing import trace from synapse.storage.databases.main.receipts import ReceiptInRoom @@ -61,6 +62,7 @@ _ThreadUnsubscription: TypeAlias = ( SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription ) +_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate if TYPE_CHECKING: from synapse.server import HomeServer @@ -76,7 +78,9 @@ def __init__(self, hs: "HomeServer"): self.event_sources = hs.get_event_sources() self.device_handler = hs.get_device_handler() self.push_rules_handler = hs.get_push_rules_handler() + self.relations_handler = hs.get_relations_handler() self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled + self._enable_threads_ext = hs.config.experimental.msc4360_enabled @trace async def get_extensions_response( @@ -177,6 +181,16 @@ async def get_extensions_response( from_token=from_token, ) + threads_coro = None + if sync_config.extensions.threads is not None and self._enable_threads_ext: + threads_coro = self.get_threads_extension_response( + sync_config=sync_config, + threads_request=sync_config.extensions.threads, + actual_room_response_map=actual_room_response_map, + to_token=to_token, + from_token=from_token, + ) + ( to_device_response, e2ee_response, @@ -184,6 +198,7 @@ async def get_extensions_response( receipts_response, typing_response, thread_subs_response, + threads_response, ) = await gather_optional_coroutines( to_device_coro, e2ee_coro, @@ -191,6 +206,7 @@ async def get_extensions_response( receipts_coro, typing_coro, thread_subs_coro, + threads_coro, ) return SlidingSyncResult.Extensions( @@ -200,6 +216,7 @@ async def get_extensions_response( receipts=receipts_response, typing=typing_response, thread_subscriptions=thread_subs_response, + threads=threads_response, ) def find_relevant_room_ids_for_extension( @@ -970,3 +987,113 @@ async def get_thread_subscriptions_extension_response( unsubscribed=unsubscribed_threads, prev_batch=prev_batch, ) + + async def get_threads_extension_response( + self, + sync_config: SlidingSyncConfig, + threads_request: SlidingSyncConfig.Extensions.ThreadsExtension, + actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult], + to_token: StreamToken, + from_token: Optional[SlidingSyncStreamToken], + ) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]: + """Handle Threads extension (MSC4360) + + Args: + sync_config: Sync configuration. + threads_request: The threads extension from the request. + actual_room_response_map: A map of room ID to room results in the + sliding sync response. Used to determine which threads already have + events in the room timeline. + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + + Returns: + the response (None if empty or threads extension is disabled) + """ + if not threads_request.enabled: + return None + + # Fetch thread updates globally across all joined rooms. + # The database layer returns a StreamToken (exclusive) for prev_batch if there + # are more results. + ( + all_thread_updates, + prev_batch_token, + ) = await self.store.get_thread_updates_for_user( + user_id=sync_config.user.to_string(), + from_token=from_token.stream_token.room_key if from_token else None, + to_token=to_token.room_key, + limit=threads_request.limit, + include_thread_roots=threads_request.include_roots, + ) + + if len(all_thread_updates) == 0: + return None + + # Identify which threads already have events in the room timelines. + # If include_roots=False, we'll omit these threads from the extension response + # since the client already sees the thread activity in the timeline. + # If include_roots=True, we include all threads regardless, because the client + # wants the thread root events. + threads_in_timeline: Set[Tuple[str, str]] = set() # (room_id, thread_id) + if not threads_request.include_roots: + for room_id, room_result in actual_room_response_map.items(): + if room_result.timeline_events: + for event in room_result.timeline_events: + # Check if this event is part of a thread + relates_to = event.content.get("m.relates_to") + if not isinstance(relates_to, dict): + continue + + rel_type = relates_to.get("rel_type") + + # If this is a thread reply, track the thread + if rel_type == RelationTypes.THREAD: + thread_id = relates_to.get("event_id") + if thread_id: + threads_in_timeline.add((room_id, thread_id)) + + # Collect thread root events and get bundled aggregations. + # Only fetch bundled aggregations if we have thread root events to attach them to. + thread_root_events = [ + update.thread_root_event + for update in all_thread_updates + if update.thread_root_event + ] + aggregations_map = {} + if thread_root_events: + aggregations_map = await self.relations_handler.get_bundled_aggregations( + thread_root_events, + sync_config.user.to_string(), + ) + + thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {} + for update in all_thread_updates: + # Skip this thread if it already has events in the room timeline + # (unless include_roots=True, in which case we always include it) + if (update.room_id, update.thread_id) in threads_in_timeline: + continue + + # Only look up bundled aggregations if we have a thread root event + bundled_aggs = ( + aggregations_map.get(update.thread_id) + if update.thread_root_event + else None + ) + + thread_updates.setdefault(update.room_id, {})[update.thread_id] = ( + _ThreadUpdate( + thread_root=update.thread_root_event, + prev_batch=update.prev_batch, + bundled_aggregations=bundled_aggs, + ) + ) + + # If after filtering we have no thread updates, return None to omit the extension + if not thread_updates: + return None + + return SlidingSyncResult.Extensions.ThreadsExtension( + updates=thread_updates, + prev_batch=prev_batch_token, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index bb63b51599b..0ddd82d8779 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -31,6 +31,7 @@ from synapse.api.presence import UserPresenceState from synapse.api.ratelimiting import Ratelimiter from synapse.events.utils import ( + EventClientSerializer, SerializeEventConfig, format_event_for_client_v2_without_room_id, format_event_raw, @@ -56,6 +57,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname from synapse.rest.admin.experimental_features import ExperimentalFeature +from synapse.storage.databases.main import DataStore from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken from synapse.types.rest.client import SlidingSyncBody from synapse.util.caches.lrucache import LruCache @@ -648,6 +650,7 @@ class SlidingSyncRestServlet(RestServlet): - receipts (MSC3960) - account data (MSC3959) - thread subscriptions (MSC4308) + - threads (MSC4360) Request query parameters: timeout: How long to wait for new events in milliseconds. @@ -851,7 +854,10 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: logger.info("Client has disconnected; not serializing response.") return 200, {} - response_content = await self.encode_response(requester, sliding_sync_results) + time_now = self.clock.time_msec() + response_content = await self.encode_response( + requester, sliding_sync_results, time_now + ) return 200, response_content @@ -860,6 +866,7 @@ async def encode_response( self, requester: Requester, sliding_sync_result: SlidingSyncResult, + time_now: int, ) -> JsonDict: response: JsonDict = defaultdict(dict) @@ -868,10 +875,10 @@ async def encode_response( if serialized_lists: response["lists"] = serialized_lists response["rooms"] = await self.encode_rooms( - requester, sliding_sync_result.rooms + requester, sliding_sync_result.rooms, time_now ) response["extensions"] = await self.encode_extensions( - requester, sliding_sync_result.extensions + requester, sliding_sync_result.extensions, time_now ) return response @@ -903,9 +910,8 @@ async def encode_rooms( self, requester: Requester, rooms: Dict[str, SlidingSyncResult.RoomResult], + time_now: int, ) -> JsonDict: - time_now = self.clock.time_msec() - serialize_options = SerializeEventConfig( event_format=format_event_for_client_v2_without_room_id, requester=requester, @@ -1021,7 +1027,10 @@ async def encode_rooms( @trace_with_opname("sliding_sync.encode_extensions") async def encode_extensions( - self, requester: Requester, extensions: SlidingSyncResult.Extensions + self, + requester: Requester, + extensions: SlidingSyncResult.Extensions, + time_now: int, ) -> JsonDict: serialized_extensions: JsonDict = {} @@ -1091,6 +1100,17 @@ async def encode_extensions( _serialise_thread_subscriptions(extensions.thread_subscriptions) ) + # excludes both None and falsy `threads` + if extensions.threads: + serialized_extensions[ + "io.element.msc4360.threads" + ] = await _serialise_threads( + self.event_serializer, + time_now, + extensions.threads, + self.store, + ) + return serialized_extensions @@ -1127,6 +1147,72 @@ def _serialise_thread_subscriptions( return out +async def _serialise_threads( + event_serializer: EventClientSerializer, + time_now: int, + threads: SlidingSyncResult.Extensions.ThreadsExtension, + store: "DataStore", +) -> JsonDict: + """ + Serialize the threads extension response for sliding sync. + + Args: + event_serializer: The event serializer to use for serializing thread root events. + time_now: The current time in milliseconds, used for event serialization. + threads: The threads extension data containing thread updates and pagination tokens. + store: The datastore, needed for serializing stream tokens. + + Returns: + A JSON-serializable dict containing: + - "updates": A nested dict mapping room_id -> thread_root_id -> thread update. + Each thread update may contain: + - "thread_root": The serialized thread root event (if include_roots was True), + with bundled aggregations including the latest_event in unsigned.m.relations.m.thread. + - "prev_batch": A pagination token for fetching older events in the thread. + - "prev_batch": A pagination token for fetching older thread updates (if available). + """ + out: JsonDict = {} + + if threads.updates: + updates_dict: JsonDict = {} + for room_id, thread_updates in threads.updates.items(): + room_updates: JsonDict = {} + for thread_root_id, update in thread_updates.items(): + # Serialize the update + update_dict: JsonDict = {} + + # Serialize the thread_root event if present + if update.thread_root is not None: + # Create a mapping of event_id to bundled_aggregations + bundle_aggs_map = ( + {thread_root_id: update.bundled_aggregations} + if update.bundled_aggregations + else None + ) + serialized_events = await event_serializer.serialize_events( + [update.thread_root], + time_now, + bundle_aggregations=bundle_aggs_map, + ) + if serialized_events: + update_dict["thread_root"] = serialized_events[0] + + # Add prev_batch if present + if update.prev_batch is not None: + update_dict["prev_batch"] = await update.prev_batch.to_string(store) + + room_updates[thread_root_id] = update_dict + + updates_dict[room_id] = room_updates + + out["updates"] = updates_dict + + if threads.prev_batch: + out["prev_batch"] = await threads.prev_batch.to_string(store) + + return out + + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index ea746e05118..5d412d38b31 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -37,7 +37,7 @@ import attr -from synapse.api.constants import MAIN_TIMELINE, Direction, RelationTypes +from synapse.api.constants import MAIN_TIMELINE, Direction, Membership, RelationTypes from synapse.api.errors import SynapseError from synapse.events import EventBase from synapse.storage._base import SQLBaseStore @@ -47,13 +47,19 @@ LoggingTransaction, make_in_list_sql_clause, ) +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.stream import ( generate_next_token, generate_pagination_bounds, generate_pagination_where_clause, ) from synapse.storage.engines import PostgresEngine -from synapse.types import JsonDict, StreamKeyType, StreamToken +from synapse.types import ( + JsonDict, + RoomStreamToken, + StreamKeyType, + StreamToken, +) from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: @@ -95,7 +101,29 @@ class _RelatedEvent: sender: str -class RelationsWorkerStore(SQLBaseStore): +@attr.s(slots=True, frozen=True, auto_attribs=True) +class ThreadUpdateInfo: + """ + Information about a thread update for the sliding sync threads extension. + + Attributes: + thread_id: The event ID of the thread root event (the event that started the thread). + room_id: The room ID where this thread exists. + thread_root_event: The actual EventBase object for the thread root event, + if include_thread_roots was True in the request. Otherwise None. + prev_batch: A pagination token (exclusive) for fetching older events in this thread. + Only present if update_count > 1. This token can be used with the /relations + endpoint with dir=b to paginate backwards through the thread's history without + re-receiving the latest event that was already included in the sliding sync response. + """ + + thread_id: str + room_id: str + thread_root_event: Optional[EventBase] + prev_batch: Optional[StreamToken] + + +class RelationsWorkerStore(EventsWorkerStore, SQLBaseStore): def __init__( self, database: DatabasePool, @@ -591,14 +619,18 @@ def _get_applicable_edits_txn(txn: LoggingTransaction) -> Dict[str, str]: "get_applicable_edits", _get_applicable_edits_txn ) - edits = await self.get_events(edit_ids.values()) # type: ignore[attr-defined] + edits = await self.get_events(edit_ids.values()) # Map to the original event IDs to the edit events. # # There might not be an edit event due to there being no edits or # due to the event not being known, either case is treated the same. return { - original_event_id: edits.get(edit_ids.get(original_event_id)) + original_event_id: ( + edits.get(edit_id) + if (edit_id := edit_ids.get(original_event_id)) + else None + ) for original_event_id in event_ids } @@ -706,7 +738,7 @@ def _get_thread_summaries_txn( "get_thread_summaries", _get_thread_summaries_txn ) - latest_events = await self.get_events(latest_event_ids.values()) # type: ignore[attr-defined] + latest_events = await self.get_events(latest_event_ids.values()) # Map to the event IDs to the thread summary. # @@ -1118,6 +1150,170 @@ def _get_related_thread_id(txn: LoggingTransaction) -> str: "get_related_thread_id", _get_related_thread_id ) + async def get_thread_updates_for_user( + self, + *, + user_id: str, + from_token: Optional[RoomStreamToken] = None, + to_token: Optional[RoomStreamToken] = None, + limit: int = 5, + include_thread_roots: bool = False, + ) -> Tuple[Sequence[ThreadUpdateInfo], Optional[StreamToken]]: + """Get a list of updated threads, ordered by stream ordering of their + latest reply, filtered to only include threads in rooms where the user + is currently joined. + + Args: + user_id: The user ID to fetch thread updates for. Only threads in rooms + where this user is currently joined will be returned. + from_token: The lower bound (exclusive) for thread updates. If None, + fetch from the start of the room timeline. + to_token: The upper bound (inclusive) for thread updates. If None, + fetch up to the current position in the room timeline. + limit: Maximum number of thread updates to return. + include_thread_roots: If True, fetch and return the thread root EventBase + objects. If False, return None for the thread_root_event field. + + Returns: + A tuple of: + A list of ThreadUpdateInfo objects containing thread update information, + ordered by stream_ordering descending (most recent first). + A prev_batch StreamToken (exclusive) if there are more results available, + None otherwise. + """ + # Ensure bad limits aren't being passed in. + assert limit > 0 + + # Generate the pagination clause, if necessary. + # + # Find any threads where the latest reply is between the stream ordering bounds. + pagination_clause = "" + pagination_args: List[str] = [] + if from_token: + from_bound = from_token.stream + pagination_clause += " AND stream_ordering > ?" + pagination_args.append(str(from_bound)) + + if to_token: + to_bound = to_token.stream + pagination_clause += " AND stream_ordering <= ?" + pagination_args.append(str(to_bound)) + + # Build the update count clause - count events in the thread within the sync window + update_count_clause = "" + update_count_args: List[str] = [] + update_count_clause = f""" + (SELECT COUNT(*) + FROM event_relations AS er + INNER JOIN events AS e ON er.event_id = e.event_id + WHERE er.relates_to_id = threads.thread_id + AND er.relation_type = '{RelationTypes.THREAD}'""" + if from_token: + update_count_clause += " AND e.stream_ordering > ?" + update_count_args.append(str(from_token.stream)) + if to_token: + update_count_clause += " AND e.stream_ordering <= ?" + update_count_args.append(str(to_token.stream)) + update_count_clause += ")" + + # Filter threads to only those in rooms where the user is currently joined. + sql = f""" + SELECT thread_id, room_id, stream_ordering, {update_count_clause} AS update_count + FROM threads + WHERE EXISTS ( + SELECT 1 + FROM local_current_membership AS lcm + WHERE lcm.room_id = threads.room_id + AND lcm.user_id = ? + AND lcm.membership = ? + ) + {pagination_clause} + ORDER BY stream_ordering DESC + LIMIT ? + """ + + def _get_thread_updates_for_user_txn( + txn: LoggingTransaction, + ) -> Tuple[List[Tuple[str, str, int, int]], Optional[int]]: + # Add 1 to the limit as a free way of determining if there are more results + # than the limit amount. If `limit + 1` results are returned, then there are + # more results. Otherwise we would need to do a separate query to determine + # if this was true when exactly `limit` results are returned. + txn.execute( + sql, + ( + *update_count_args, + user_id, + Membership.JOIN, + *pagination_args, + limit + 1, + ), + ) + + # SQL returns: thread_id, room_id, stream_ordering, update_count + rows = cast(List[Tuple[str, str, int, int]], txn.fetchall()) + + # If there are more events, generate the next pagination key from the + # last thread which will be returned. + next_token = None + if len(rows) > limit: + # Set the next_token to be the second last row in the result set since + # that will be the last row we return from this function. + # This works as an exclusive bound that can be backpaginated from. + # Use the stream_ordering field (index 2 in original rows) + next_token = rows[-2][2] + + return rows[:limit], next_token + + thread_infos, next_token_int = await self.db_pool.runInteraction( + "get_thread_updates_for_user", _get_thread_updates_for_user_txn + ) + + # Convert the next_token int (stream ordering) to a StreamToken. + # Use StreamToken.START as base (all other streams at 0) since only room + # position matters. + # Subtract 1 to make it exclusive - the client can paginate from this point without + # receiving the last thread update that was already returned. + next_token = None + if next_token_int is not None: + next_token = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, RoomStreamToken(stream=next_token_int - 1) + ) + + # Optionally fetch thread root events + event_map = {} + if include_thread_roots and thread_infos: + thread_root_ids = [thread_id for thread_id, _, _, _ in thread_infos] + thread_root_events = await self.get_events_as_list(thread_root_ids) + event_map = {e.event_id: e for e in thread_root_events} + + # Build ThreadUpdateInfo objects with per-thread prev_batch tokens. + thread_update_infos = [] + for thread_id, room_id, stream_ordering, update_count in thread_infos: + # Generate prev_batch token if this thread has more than one update. + per_thread_prev_batch = None + if update_count > 1: + # Create a token pointing to one position before the latest event's + # stream position. + # This makes it exclusive - /relations with dir=b won't return the + # latest event again. + # Use StreamToken.START as base (all other streams at 0) since only room + # position matters. + per_thread_prev_batch = StreamToken.START.copy_and_replace( + StreamKeyType.ROOM, RoomStreamToken(stream=stream_ordering - 1) + ) + + thread_update_infos.append( + ThreadUpdateInfo( + thread_id=thread_id, + room_id=room_id, + thread_root_event=event_map.get(thread_id), + prev_batch=per_thread_prev_batch, + ) + ) + + return (thread_update_infos, next_token) + class RelationsStore(RelationsWorkerStore): pass diff --git a/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql b/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql new file mode 100644 index 00000000000..b566c6e5461 --- /dev/null +++ b/synapse/storage/schema/main/delta/92/10_thread_updates_indexes.sql @@ -0,0 +1,33 @@ +-- +-- This file is licensed under the Affero General Public License (AGPL) version 3. +-- +-- Copyright (C) 2025 New Vector, Ltd +-- +-- This program is free software: you can redistribute it and/or modify +-- it under the terms of the GNU Affero General Public License as +-- published by the Free Software Foundation, either version 3 of the +-- License, or (at your option) any later version. +-- +-- See the GNU Affero General Public License for more details: +-- . + +-- Add indexes to improve performance of the thread_updates endpoint and +-- sliding sync threads extension (MSC4360). + +-- Index for efficiently finding all events that relate to a specific event +-- (e.g., all replies to a thread root). This is used by the correlated subquery +-- in get_thread_updates_for_user that counts thread updates. +-- Also useful for other relation queries (edits, reactions, etc.). +CREATE INDEX IF NOT EXISTS event_relations_relates_to_id_type + ON event_relations(relates_to_id, relation_type); + +-- Index for the /thread_updates endpoint's cross-room query. +-- Allows efficient descending ordering and range filtering of threads +-- by stream_ordering across all rooms. +CREATE INDEX IF NOT EXISTS threads_stream_ordering_desc + ON threads(stream_ordering DESC); + +-- Index for the EXISTS clause that filters threads to only joined rooms. +-- Allows efficient lookup of a user's current room memberships. +CREATE INDEX IF NOT EXISTS local_current_membership_user_room + ON local_current_membership(user_id, membership, room_id); diff --git a/synapse/streams/config.py b/synapse/streams/config.py index 9fee5bfb92c..ced17234778 100644 --- a/synapse/streams/config.py +++ b/synapse/streams/config.py @@ -58,19 +58,41 @@ async def from_request( from_tok_str = parse_string(request, "from") to_tok_str = parse_string(request, "to") + # Helper function to extract StreamToken from either StreamToken or SlidingSyncStreamToken format + def extract_stream_token(token_str: str) -> str: + """ + Extract the StreamToken portion from a token string. + + Handles both: + - StreamToken format: "s123_456_..." + - SlidingSyncStreamToken format: "5/s123_456_..." (extracts part after /) + + This allows clients using sliding sync to use their pos tokens + with endpoints like /relations and /messages. + """ + if "/" in token_str: + # SlidingSyncStreamToken format: "connection_position/stream_token" + # Split and return just the stream_token part + parts = token_str.split("/", 1) + if len(parts) == 2: + return parts[1] + return token_str + try: from_tok = None if from_tok_str == "END": from_tok = None # For backwards compat. elif from_tok_str: - from_tok = await StreamToken.from_string(store, from_tok_str) + stream_token_str = extract_stream_token(from_tok_str) + from_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'from' parameter is invalid") try: to_tok = None if to_tok_str: - to_tok = await StreamToken.from_string(store, to_tok_str) + stream_token_str = extract_stream_token(to_tok_str) + to_tok = await StreamToken.from_string(store, stream_token_str) except Exception: raise SynapseError(400, "'to' parameter is invalid") diff --git a/synapse/types/handlers/sliding_sync.py b/synapse/types/handlers/sliding_sync.py index b7bc565464f..3a7a50066db 100644 --- a/synapse/types/handlers/sliding_sync.py +++ b/synapse/types/handlers/sliding_sync.py @@ -40,6 +40,9 @@ from synapse._pydantic_compat import Extra from synapse.api.constants import EventTypes from synapse.events import EventBase + +if TYPE_CHECKING: + from synapse.handlers.relations import BundledAggregations from synapse.types import ( DeviceListUpdates, JsonDict, @@ -396,12 +399,60 @@ def __bool__(self) -> bool: or bool(self.prev_batch) ) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadsExtension: + """The Threads extension (MSC4360) + + Provides thread updates for threads that have new activity across all of the + user's joined rooms within the sync window. + + Attributes: + updates: A nested mapping of room_id -> thread_root_id -> ThreadUpdate. + Each ThreadUpdate contains information about a thread that has new activity, + including the thread root event (if requested) and a pagination token + for fetching older events in that specific thread. + prev_batch: A pagination token for fetching more thread updates across all rooms. + If present, indicates there are more thread updates available beyond what + was returned in this response. This token can be used with a future request + to paginate through older thread updates. + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ThreadUpdate: + """Information about a single thread that has new activity. + + Attributes: + thread_root: The thread root event, if requested via include_roots in the + request. This is the event that started the thread. + prev_batch: A pagination token (exclusive) for fetching older events in this + specific thread. Only present if the thread has multiple updates in the + sync window. This token can be used with the /relations endpoint with + dir=b to paginate backwards through the thread's history. + bundled_aggregations: Bundled aggregations for the thread root event, + including the latest_event in the thread (found in + unsigned.m.relations.m.thread). Only present if thread_root is included. + """ + + thread_root: Optional[EventBase] + prev_batch: Optional[StreamToken] + bundled_aggregations: Optional["BundledAggregations"] = None + + def __bool__(self) -> bool: + return bool(self.thread_root) or bool(self.prev_batch) + + updates: Optional[Mapping[str, Mapping[str, ThreadUpdate]]] + prev_batch: Optional[StreamToken] + + def __bool__(self) -> bool: + return bool(self.updates) or bool(self.prev_batch) + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None receipts: Optional[ReceiptsExtension] = None typing: Optional[TypingExtension] = None thread_subscriptions: Optional[ThreadSubscriptionsExtension] = None + threads: Optional[ThreadsExtension] = None def __bool__(self) -> bool: return bool( @@ -411,6 +462,7 @@ def __bool__(self) -> bool: or self.receipts or self.typing or self.thread_subscriptions + or self.threads ) next_pos: SlidingSyncStreamToken @@ -860,6 +912,7 @@ class PerConnectionState: Attributes: rooms: The status of each room for the events stream. receipts: The status of each room for the receipts stream. + account_data: The status of each room for the account data stream. room_configs: Map from room_id to the `RoomSyncConfig` of all rooms that we have previously sent down. """ diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 11d7e59b43a..3e98fb3def2 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -376,6 +376,19 @@ class ThreadSubscriptionsExtension(RequestBodyModel): enabled: Optional[StrictBool] = False limit: StrictInt = 100 + class ThreadsExtension(RequestBodyModel): + """The Threads extension (MSC4360) + + Attributes: + enabled: Whether the threads extension is enabled. + include_roots: whether to include thread root events in the extension response. + limit: maximum number of thread updates to return across all joined rooms. + """ + + enabled: Optional[StrictBool] = False + include_roots: StrictBool = False + limit: StrictInt = 100 + to_device: Optional[ToDeviceExtension] = None e2ee: Optional[E2eeExtension] = None account_data: Optional[AccountDataExtension] = None @@ -384,6 +397,7 @@ class ThreadSubscriptionsExtension(RequestBodyModel): thread_subscriptions: Optional[ThreadSubscriptionsExtension] = Field( alias="io.element.msc4308.thread_subscriptions" ) + threads: Optional[ThreadsExtension] = Field(alias="io.element.msc4360.threads") conn_id: Optional[StrictStr] diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 1f909885258..5c828420a29 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -347,6 +347,7 @@ async def yieldable_gather_results_delaying_cancellation( T4 = TypeVar("T4") T5 = TypeVar("T5") T6 = TypeVar("T6") +T7 = TypeVar("T7") @overload @@ -478,6 +479,30 @@ async def gather_optional_coroutines( ]: ... +@overload +async def gather_optional_coroutines( + *coroutines: Unpack[ + Tuple[ + Optional[Coroutine[Any, Any, T1]], + Optional[Coroutine[Any, Any, T2]], + Optional[Coroutine[Any, Any, T3]], + Optional[Coroutine[Any, Any, T4]], + Optional[Coroutine[Any, Any, T5]], + Optional[Coroutine[Any, Any, T6]], + Optional[Coroutine[Any, Any, T7]], + ] + ], +) -> Tuple[ + Optional[T1], + Optional[T2], + Optional[T3], + Optional[T4], + Optional[T5], + Optional[T6], + Optional[T7], +]: ... + + async def gather_optional_coroutines( *coroutines: Unpack[Tuple[Optional[Coroutine[Any, Any, T1]], ...]], ) -> Tuple[Optional[T1], ...]: diff --git a/tests/rest/client/sliding_sync/test_extension_threads.py b/tests/rest/client/sliding_sync/test_extension_threads.py new file mode 100644 index 00000000000..3855b036721 --- /dev/null +++ b/tests/rest/client/sliding_sync/test_extension_threads.py @@ -0,0 +1,836 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +import logging + +from twisted.test.proto_helpers import MemoryReactor + +import synapse.rest.admin +from synapse.api.constants import RelationTypes +from synapse.rest.client import login, relations, room, sync +from synapse.server import HomeServer +from synapse.types import JsonDict +from synapse.util.clock import Clock + +from tests.rest.client.sliding_sync.test_sliding_sync import SlidingSyncBase + +logger = logging.getLogger(__name__) + + +# The name of the extension. Currently unstable-prefixed. +EXT_NAME = "io.element.msc4360.threads" + + +class SlidingSyncThreadsExtensionTestCase(SlidingSyncBase): + """ + Test the threads extension in the Sliding Sync API. + """ + + maxDiff = None + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + relations.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + config["experimental_features"] = {"msc4360_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.storage_controllers = hs.get_storage_controllers() + super().prepare(reactor, clock, hs) + + def test_no_data_initial_sync(self) -> None: + """ + Test enabling threads extension during initial sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertNotIn(EXT_NAME, response_body["extensions"]) + + def test_no_data_incremental_sync(self) -> None: + """ + Test enabling threads extension during incremental sync with no data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + initial_sync_body: JsonDict = {} + + # Initial sync + response_body, sync_pos = self.do_sync(initial_sync_body, tok=user1_tok) + + # Incremental sync with extension enabled + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert + self.assertNotIn( + EXT_NAME, + response_body["extensions"], + response_body, + ) + + def test_threads_initial_sync(self) -> None: + """ + Test threads appear in initial sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + _latest_event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": user1_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + )["event_id"] + + # # get the baseline stream_id of the thread_subscriptions stream + # # before we write any data. + # # Required because the initial value differs between SQLite and Postgres. + # base = self.store.get_max_thread_subscriptions_stream_id() + + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + + # Sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_id: {thread_root_id: {}}}}, + ) + + def test_threads_incremental_sync(self) -> None: + """ + Test new thread updates appear in incremental sync response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # get the baseline stream_id of the room events stream + # before we write any data. + # Required because the initial value differs between SQLite and Postgres. + # base = self.store.get_room_max_stream_ordering() + + # Initial sync + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + logger.info("Synced to: %r, now subscribing to thread", sync_pos) + + # Do thing + _latest_event_id = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": user1_id, + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + )["event_id"] + + # Incremental sync + response_body, sync_pos = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + logger.info("Synced to: %r", sync_pos) + + # Assert + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_id: {thread_root_id: {}}}}, + ) + + def test_threads_only_from_joined_rooms(self) -> None: + """ + Test that thread updates are only returned for rooms the user is joined to + at the time of the thread update. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # User1 creates two rooms + room_a_id = self.helper.create_room_as(user1_id, tok=user1_tok) + room_b_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # User2 joins only Room A + self.helper.join(room_a_id, user2_id, tok=user2_tok) + + # Create threads in both rooms + thread_a_root = self.helper.send(room_a_id, body="Thread A", tok=user1_tok)[ + "event_id" + ] + thread_b_root = self.helper.send(room_b_id, body="Thread B", tok=user1_tok)[ + "event_id" + ] + + # Add replies to both threads + self.helper.send_event( + room_a_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to A", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_a_root, + }, + }, + tok=user1_tok, + ) + self.helper.send_event( + room_b_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply to B", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_b_root, + }, + }, + tok=user1_tok, + ) + + # User2 syncs with threads extension enabled + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user2_tok) + + # Assert: User2 should only see thread from Room A (where they are joined) + self.assertEqual( + response_body["extensions"][EXT_NAME], + {"updates": {room_a_id: {thread_a_root: {}}}}, + "User2 should only see threads from Room A where they are joined, not Room B", + ) + + def test_threads_not_returned_after_leaving_room(self) -> None: + """ + Test that thread updates are not returned after a user leaves the room, + even if the thread was updated while they were joined. + + This tests the known limitation: if a thread has multiple updates and the + user leaves between them, they won't see any updates (even earlier ones + while joined). + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create room and both users join + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + self.helper.join(room_id, user2_id, tok=user2_tok) + + # Create thread + thread_root = self.helper.send(room_id, body="Thread root", tok=user1_tok)[ + "event_id" + ] + + # Initial sync for user2 + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user2_tok) + + # Reply in thread while user2 is joined, but after initial sync + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1 while user2 joined", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 leaves the room + self.helper.leave(room_id, user2_id, tok=user2_tok) + + # Another reply after user2 left + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2 after user2 left", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root, + }, + }, + tok=user1_tok, + ) + + # User2 incremental sync + response_body, _ = self.do_sync(sync_body, tok=user2_tok, since=sync_pos) + + # Assert: User2 should NOT see the thread update (they left before latest update) + # Note: This also demonstrates that only currently joined rooms are returned - user2 + # won't see the thread even though there was an update while they were joined (Reply 1) + self.assertNotIn( + EXT_NAME, + response_body["extensions"], + "User2 should not see thread updates after leaving the room", + ) + + def test_threads_with_include_roots_true(self) -> None: + """ + Test that include_roots=True returns thread root events with latest_event + in the unsigned field. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply to thread + latest_event_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Latest reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + latest_event_id = latest_event_resp["event_id"] + + # Sync with include_roots=True + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert thread root is present + thread_root = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["thread_root"] + + # Verify it's the correct event + self.assertEqual(thread_root["event_id"], thread_root_id) + self.assertEqual(thread_root["content"]["body"], "Thread root") + + # Verify latest_event is in unsigned.m.relations.m.thread + latest_event = thread_root["unsigned"]["m.relations"]["m.thread"][ + "latest_event" + ] + self.assertEqual(latest_event["event_id"], latest_event_id) + self.assertEqual(latest_event["content"]["body"], "Latest reply") + + def test_threads_with_include_roots_false(self) -> None: + """ + Test that include_roots=False (or omitted) does not return thread root events. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Add reply + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Sync with include_roots=False (explicitly) + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": False, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert thread update exists but has no thread_root + thread_update = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ] + self.assertNotIn("thread_root", thread_update) + + # Also test with include_roots omitted (should behave the same) + sync_body_no_param = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body_no_param, _ = self.do_sync(sync_body_no_param, tok=user1_tok) + + thread_update_no_param = response_body_no_param["extensions"][EXT_NAME][ + "updates" + ][room_id][thread_root_id] + self.assertNotIn("thread_root", thread_update_no_param) + + def test_per_thread_prev_batch_single_update(self) -> None: + """ + Test that threads with only a single update do NOT get a prev_batch token. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Add ONE reply to thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Single reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Incremental sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Thread update should NOT have prev_batch (only 1 update) + thread_update = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ] + self.assertNotIn( + "prev_batch", + thread_update, + "Threads with single update should not have prev_batch", + ) + + def test_per_thread_prev_batch_multiple_updates(self) -> None: + """ + Test that threads with multiple updates get a prev_batch token that can be + used with /relations endpoint to paginate backwards. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Add MULTIPLE replies to thread + reply1_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "First reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply1_id = reply1_resp["event_id"] + + reply2_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Second reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply2_id = reply2_resp["event_id"] + + reply3_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Third reply", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply3_id = reply3_resp["event_id"] + + # Incremental sync + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Thread update SHOULD have prev_batch (3 updates) + prev_batch = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["prev_batch"] + self.assertIsNotNone(prev_batch, "prev_batch should not be None") + + # Now use the prev_batch token with /relations endpoint to paginate backwards + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{room_id}/relations/{thread_root_id}?from={prev_batch}&to={sync_pos}&dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + relations_response = channel.json_body + returned_event_ids = [ + event["event_id"] for event in relations_response["chunk"] + ] + + # Assert: Only the older replies should be returned (not the latest one we already saw) + # The prev_batch token should be exclusive, pointing just before the latest event + self.assertIn( + reply1_id, + returned_event_ids, + "First reply should be in relations response", + ) + self.assertIn( + reply2_id, + returned_event_ids, + "Second reply should be in relations response", + ) + self.assertNotIn( + reply3_id, + returned_event_ids, + "Third reply (latest) should NOT be in relations response - already returned in sliding sync", + ) + + def test_per_thread_prev_batch_on_initial_sync(self) -> None: + """ + Test that threads with multiple updates get prev_batch tokens on initial sync + so clients can paginate through the full thread history. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread with multiple replies BEFORE any sync + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + reply1_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply1_id = reply1_resp["event_id"] + + reply2_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 2", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply2_id = reply2_resp["event_id"] + + # Initial sync (no from_token) + sync_body = { + "extensions": { + EXT_NAME: { + "enabled": True, + } + }, + } + response_body, _ = self.do_sync(sync_body, tok=user1_tok) + + # Assert: Thread update SHOULD have prev_batch on initial sync (2+ updates exist) + prev_batch = response_body["extensions"][EXT_NAME]["updates"][room_id][ + thread_root_id + ]["prev_batch"] + self.assertIsNotNone(prev_batch) + + # Use prev_batch with /relations to fetch the thread history + channel = self.make_request( + "GET", + f"/_matrix/client/v1/rooms/{room_id}/relations/{thread_root_id}?from={prev_batch}&dir=b", + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + relations_response = channel.json_body + returned_event_ids = [ + event["event_id"] for event in relations_response["chunk"] + ] + + # Assert: Only the older reply should be returned (not the latest one we already saw) + # The prev_batch token should be exclusive, pointing just before the latest event + self.assertIn( + reply1_id, + returned_event_ids, + "First reply should be in relations response", + ) + self.assertNotIn( + reply2_id, + returned_event_ids, + "Second reply (latest) should NOT be in relations response - already returned in sliding sync", + ) + + def test_thread_in_timeline_omitted_without_include_roots(self) -> None: + """ + Test that threads with events in the room timeline are omitted from the + extension response when include_roots=False. When all threads are filtered out, + the entire extension should be omitted from the response. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body: JsonDict = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 5, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": False, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Send a reply to the thread + self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + + # Incremental sync - the reply should be in the timeline + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: Extension should be omitted entirely since the only thread with updates + # is already visible in the timeline (include_roots=False) + self.assertNotIn( + EXT_NAME, + response_body.get("extensions", {}), + "Extension should be omitted when all threads are filtered out (in timeline with include_roots=False)", + ) + + def test_thread_in_timeline_included_with_include_roots(self) -> None: + """ + Test that threads with events in the room timeline are still included in the + extension response when include_roots=True, because the client wants the root event. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + # Create thread root + thread_root_resp = self.helper.send(room_id, body="Thread root", tok=user1_tok) + thread_root_id = thread_root_resp["event_id"] + + # Initial sync to establish baseline + sync_body: JsonDict = { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 5, + } + }, + "extensions": { + EXT_NAME: { + "enabled": True, + "include_roots": True, + } + }, + } + _, sync_pos = self.do_sync(sync_body, tok=user1_tok) + + # Send a reply to the thread + reply_resp = self.helper.send_event( + room_id, + type="m.room.message", + content={ + "msgtype": "m.text", + "body": "Reply 1", + "m.relates_to": { + "rel_type": RelationTypes.THREAD, + "event_id": thread_root_id, + }, + }, + tok=user1_tok, + ) + reply_id = reply_resp["event_id"] + + # Incremental sync - the reply should be in the timeline + response_body, _ = self.do_sync(sync_body, tok=user1_tok, since=sync_pos) + + # Assert: The thread reply should be in the room timeline + room_response = response_body["rooms"][room_id] + timeline_event_ids = [event["event_id"] for event in room_response["timeline"]] + self.assertIn( + reply_id, + timeline_event_ids, + "Thread reply should be in the room timeline", + ) + + # Assert: Thread SHOULD be in extension (include_roots=True) + thread_updates = response_body["extensions"][EXT_NAME]["updates"][room_id] + self.assertIn( + thread_root_id, + thread_updates, + "Thread should be included in extension when include_roots=True, even if in timeline", + ) + # Verify the thread root event is present + self.assertIn("thread_root", thread_updates[thread_root_id])