Skip to content
1 change: 1 addition & 0 deletions changelog.d/19005.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add experimental support for MSC4360: Sliding Sync Threads Extension.
3 changes: 3 additions & 0 deletions synapse/config/experimental.py
Original file line number Diff line number Diff line change
Expand Up @@ -595,3 +595,6 @@ def read_config(
# MSC4306: Thread Subscriptions
# (and MSC4308: Thread Subscriptions extension to Sliding Sync)
self.msc4306_enabled: bool = experimental.get("msc4306_enabled", False)

# MSC4360: Threads Extension to Sliding Sync
self.msc4360_enabled: bool = experimental.get("msc4360_enabled", False)
2 changes: 0 additions & 2 deletions synapse/handlers/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,6 @@ async def get_relations(
) -> JsonDict:
"""Get related events of a event, ordered by topological ordering.

TODO Accept a PaginationConfig instead of individual pagination parameters.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This has been done alread, the comment just wasn't updated.


Args:
requester: The user requesting the relations.
event_id: Fetch events that relate to this event ID.
Expand Down
129 changes: 128 additions & 1 deletion synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
Optional,
Sequence,
Set,
Tuple,
cast,
)

from typing_extensions import TypeAlias, assert_never

from synapse.api.constants import AccountDataTypes, EduTypes
from synapse.api.constants import AccountDataTypes, EduTypes, RelationTypes
from synapse.handlers.receipts import ReceiptEventSource
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
Expand Down Expand Up @@ -61,6 +62,7 @@
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand All @@ -76,7 +78,9 @@ def __init__(self, hs: "HomeServer"):
self.event_sources = hs.get_event_sources()
self.device_handler = hs.get_device_handler()
self.push_rules_handler = hs.get_push_rules_handler()
self.relations_handler = hs.get_relations_handler()
self._enable_thread_subscriptions = hs.config.experimental.msc4306_enabled
self._enable_threads_ext = hs.config.experimental.msc4360_enabled

@trace
async def get_extensions_response(
Expand Down Expand Up @@ -177,20 +181,32 @@ async def get_extensions_response(
from_token=from_token,
)

threads_coro = None
if sync_config.extensions.threads is not None and self._enable_threads_ext:
threads_coro = self.get_threads_extension_response(
sync_config=sync_config,
threads_request=sync_config.extensions.threads,
actual_room_response_map=actual_room_response_map,
to_token=to_token,
from_token=from_token,
)

(
to_device_response,
e2ee_response,
account_data_response,
receipts_response,
typing_response,
thread_subs_response,
threads_response,
) = await gather_optional_coroutines(
to_device_coro,
e2ee_coro,
account_data_coro,
receipts_coro,
typing_coro,
thread_subs_coro,
threads_coro,
)

return SlidingSyncResult.Extensions(
Expand All @@ -200,6 +216,7 @@ async def get_extensions_response(
receipts=receipts_response,
typing=typing_response,
thread_subscriptions=thread_subs_response,
threads=threads_response,
)

def find_relevant_room_ids_for_extension(
Expand Down Expand Up @@ -970,3 +987,113 @@ async def get_thread_subscriptions_extension_response(
unsubscribed=unsubscribed_threads,
prev_batch=prev_batch,
)

async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
threads_request: SlidingSyncConfig.Extensions.ThreadsExtension,
actual_room_response_map: Mapping[str, SlidingSyncResult.RoomResult],
to_token: StreamToken,
from_token: Optional[SlidingSyncStreamToken],
) -> Optional[SlidingSyncResult.Extensions.ThreadsExtension]:
"""Handle Threads extension (MSC4360)

Args:
sync_config: Sync configuration.
threads_request: The threads extension from the request.
actual_room_response_map: A map of room ID to room results in the
sliding sync response. Used to determine which threads already have
events in the room timeline.
to_token: The point in the stream to sync up to.
from_token: The point in the stream to sync from.

Returns:
the response (None if empty or threads extension is disabled)
"""
if not threads_request.enabled:
return None

# Fetch thread updates globally across all joined rooms.
# The database layer returns a StreamToken (exclusive) for prev_batch if there
# are more results.
(
all_thread_updates,
prev_batch_token,
) = await self.store.get_thread_updates_for_user(
user_id=sync_config.user.to_string(),
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=threads_request.limit,
include_thread_roots=threads_request.include_roots,
)

if len(all_thread_updates) == 0:
return None

# Identify which threads already have events in the room timelines.
# If include_roots=False, we'll omit these threads from the extension response
# since the client already sees the thread activity in the timeline.
# If include_roots=True, we include all threads regardless, because the client
# wants the thread root events.
threads_in_timeline: Set[Tuple[str, str]] = set() # (room_id, thread_id)
if not threads_request.include_roots:
for room_id, room_result in actual_room_response_map.items():
if room_result.timeline_events:
for event in room_result.timeline_events:
# Check if this event is part of a thread
relates_to = event.content.get("m.relates_to")
if not isinstance(relates_to, dict):
continue

rel_type = relates_to.get("rel_type")

# If this is a thread reply, track the thread
if rel_type == RelationTypes.THREAD:
thread_id = relates_to.get("event_id")
if thread_id:
threads_in_timeline.add((room_id, thread_id))

# Collect thread root events and get bundled aggregations.
# Only fetch bundled aggregations if we have thread root events to attach them to.
thread_root_events = [
update.thread_root_event
for update in all_thread_updates
if update.thread_root_event
]
aggregations_map = {}
if thread_root_events:
aggregations_map = await self.relations_handler.get_bundled_aggregations(
thread_root_events,
sync_config.user.to_string(),
)

thread_updates: Dict[str, Dict[str, _ThreadUpdate]] = {}
for update in all_thread_updates:
# Skip this thread if it already has events in the room timeline
# (unless include_roots=True, in which case we always include it)
if (update.room_id, update.thread_id) in threads_in_timeline:
continue

# Only look up bundled aggregations if we have a thread root event
bundled_aggs = (
aggregations_map.get(update.thread_id)
if update.thread_root_event
else None
)

thread_updates.setdefault(update.room_id, {})[update.thread_id] = (
_ThreadUpdate(
thread_root=update.thread_root_event,
prev_batch=update.prev_batch,
bundled_aggregations=bundled_aggs,
)
)

# If after filtering we have no thread updates, return None to omit the extension
if not thread_updates:
return None

return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
prev_batch=prev_batch_token,
)
98 changes: 92 additions & 6 deletions synapse/rest/client/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from synapse.api.presence import UserPresenceState
from synapse.api.ratelimiting import Ratelimiter
from synapse.events.utils import (
EventClientSerializer,
SerializeEventConfig,
format_event_for_client_v2_without_room_id,
format_event_raw,
Expand All @@ -56,6 +57,7 @@
from synapse.http.site import SynapseRequest
from synapse.logging.opentracing import log_kv, set_tag, trace_with_opname
from synapse.rest.admin.experimental_features import ExperimentalFeature
from synapse.storage.databases.main import DataStore
from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken
from synapse.types.rest.client import SlidingSyncBody
from synapse.util.caches.lrucache import LruCache
Expand Down Expand Up @@ -648,6 +650,7 @@ class SlidingSyncRestServlet(RestServlet):
- receipts (MSC3960)
- account data (MSC3959)
- thread subscriptions (MSC4308)
- threads (MSC4360)

Request query parameters:
timeout: How long to wait for new events in milliseconds.
Expand Down Expand Up @@ -851,7 +854,10 @@ async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
logger.info("Client has disconnected; not serializing response.")
return 200, {}

response_content = await self.encode_response(requester, sliding_sync_results)
time_now = self.clock.time_msec()
response_content = await self.encode_response(
requester, sliding_sync_results, time_now
)

return 200, response_content

Expand All @@ -860,6 +866,7 @@ async def encode_response(
self,
requester: Requester,
sliding_sync_result: SlidingSyncResult,
time_now: int,
) -> JsonDict:
response: JsonDict = defaultdict(dict)

Expand All @@ -868,10 +875,10 @@ async def encode_response(
if serialized_lists:
response["lists"] = serialized_lists
response["rooms"] = await self.encode_rooms(
requester, sliding_sync_result.rooms
requester, sliding_sync_result.rooms, time_now
)
response["extensions"] = await self.encode_extensions(
requester, sliding_sync_result.extensions
requester, sliding_sync_result.extensions, time_now
)

return response
Expand Down Expand Up @@ -903,9 +910,8 @@ async def encode_rooms(
self,
requester: Requester,
rooms: Dict[str, SlidingSyncResult.RoomResult],
time_now: int,
) -> JsonDict:
time_now = self.clock.time_msec()

serialize_options = SerializeEventConfig(
event_format=format_event_for_client_v2_without_room_id,
requester=requester,
Expand Down Expand Up @@ -1021,7 +1027,10 @@ async def encode_rooms(

@trace_with_opname("sliding_sync.encode_extensions")
async def encode_extensions(
self, requester: Requester, extensions: SlidingSyncResult.Extensions
self,
requester: Requester,
extensions: SlidingSyncResult.Extensions,
time_now: int,
) -> JsonDict:
serialized_extensions: JsonDict = {}

Expand Down Expand Up @@ -1091,6 +1100,17 @@ async def encode_extensions(
_serialise_thread_subscriptions(extensions.thread_subscriptions)
)

# excludes both None and falsy `threads`
if extensions.threads:
serialized_extensions[
"io.element.msc4360.threads"
] = await _serialise_threads(
self.event_serializer,
time_now,
extensions.threads,
self.store,
)

return serialized_extensions


Expand Down Expand Up @@ -1127,6 +1147,72 @@ def _serialise_thread_subscriptions(
return out


async def _serialise_threads(
event_serializer: EventClientSerializer,
time_now: int,
threads: SlidingSyncResult.Extensions.ThreadsExtension,
store: "DataStore",
) -> JsonDict:
"""
Serialize the threads extension response for sliding sync.

Args:
event_serializer: The event serializer to use for serializing thread root events.
time_now: The current time in milliseconds, used for event serialization.
threads: The threads extension data containing thread updates and pagination tokens.
store: The datastore, needed for serializing stream tokens.

Returns:
A JSON-serializable dict containing:
- "updates": A nested dict mapping room_id -> thread_root_id -> thread update.
Each thread update may contain:
- "thread_root": The serialized thread root event (if include_roots was True),
with bundled aggregations including the latest_event in unsigned.m.relations.m.thread.
- "prev_batch": A pagination token for fetching older events in the thread.
- "prev_batch": A pagination token for fetching older thread updates (if available).
"""
out: JsonDict = {}

if threads.updates:
updates_dict: JsonDict = {}
for room_id, thread_updates in threads.updates.items():
room_updates: JsonDict = {}
for thread_root_id, update in thread_updates.items():
# Serialize the update
update_dict: JsonDict = {}

# Serialize the thread_root event if present
if update.thread_root is not None:
# Create a mapping of event_id to bundled_aggregations
bundle_aggs_map = (
{thread_root_id: update.bundled_aggregations}
if update.bundled_aggregations
else None
)
serialized_events = await event_serializer.serialize_events(
[update.thread_root],
time_now,
bundle_aggregations=bundle_aggs_map,
)
if serialized_events:
update_dict["thread_root"] = serialized_events[0]

# Add prev_batch if present
if update.prev_batch is not None:
update_dict["prev_batch"] = await update.prev_batch.to_string(store)

room_updates[thread_root_id] = update_dict

updates_dict[room_id] = room_updates

out["updates"] = updates_dict

if threads.prev_batch:
out["prev_batch"] = await threads.prev_batch.to_string(store)

return out


def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
SyncRestServlet(hs).register(http_server)

Expand Down
Loading
Loading