Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog.d/19041.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add companion endpoint for MSC4360: Sliding Sync Threads Extension.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

403 changes: 400 additions & 3 deletions synapse/handlers/relations.py

Large diffs are not rendered by default.

208 changes: 16 additions & 192 deletions synapse/handlers/sliding_sync/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import itertools
import logging
from collections import defaultdict
from typing import (
TYPE_CHECKING,
AbstractSet,
Expand All @@ -31,7 +30,6 @@
AccountDataTypes,
EduTypes,
EventContentFields,
Membership,
MRelatesToFields,
RelationTypes,
)
Expand All @@ -40,15 +38,12 @@
from synapse.handlers.sliding_sync.room_lists import RoomsForUserType
from synapse.logging.opentracing import trace
from synapse.storage.databases.main.receipts import ReceiptInRoom
from synapse.storage.databases.main.relations import ThreadUpdateInfo
from synapse.types import (
DeviceListUpdates,
JsonMapping,
MultiWriterStreamToken,
RoomStreamToken,
SlidingSyncStreamToken,
StrCollection,
StreamKeyType,
StreamToken,
ThreadSubscriptionsToken,
)
Expand All @@ -64,15 +59,13 @@
concurrently_execute,
gather_optional_coroutines,
)
from synapse.visibility import filter_events_for_client

_ThreadSubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadSubscription
)
_ThreadUnsubscription: TypeAlias = (
SlidingSyncResult.Extensions.ThreadSubscriptionsExtension.ThreadUnsubscription
)
_ThreadUpdate: TypeAlias = SlidingSyncResult.Extensions.ThreadsExtension.ThreadUpdate

if TYPE_CHECKING:
from synapse.server import HomeServer
Expand Down Expand Up @@ -1040,42 +1033,6 @@ def _find_threads_in_timeline(
threads_in_timeline.add(thread_id)
return threads_in_timeline

def _merge_prev_batch_token(
self,
current_token: StreamToken | None,
new_token: StreamToken | None,
) -> StreamToken | None:
"""Merge two prev_batch tokens, taking the maximum (latest) for backwards pagination.

Args:
current_token: The current prev_batch token (may be None)
new_token: The new prev_batch token to merge (may be None)

Returns:
The merged token (maximum of the two, or None if both are None)
"""
if new_token is None:
return current_token
if current_token is None:
return new_token
if new_token.room_key.stream > current_token.room_key.stream:
return new_token
return current_token

def _merge_thread_updates(
self,
target: dict[str, list[ThreadUpdateInfo]],
source: dict[str, list[ThreadUpdateInfo]],
) -> None:
"""Merge thread updates from source into target.

Args:
target: The target dict to merge into (modified in place)
source: The source dict to merge from
"""
for thread_id, updates in source.items():
target.setdefault(thread_id, []).extend(updates)

async def get_threads_extension_response(
self,
sync_config: SlidingSyncConfig,
Expand Down Expand Up @@ -1117,159 +1074,26 @@ async def get_threads_extension_response(
actual_room_response_map
)

# Separate rooms into groups based on membership status.
# For LEAVE/BAN rooms, we need to bound the to_token to prevent leaking events
# that occurred after the user left/was banned.
leave_ban_rooms: set[str] = set()
other_rooms: set[str] = set()

for room_id in actual_room_ids:
membership_info = room_membership_for_user_at_to_token_map.get(room_id)
if membership_info and membership_info.membership in (
Membership.LEAVE,
Membership.BAN,
):
leave_ban_rooms.add(room_id)
else:
other_rooms.add(room_id)

# Fetch thread updates, handling LEAVE/BAN rooms separately to avoid data leaks.
all_thread_updates: dict[str, list[ThreadUpdateInfo]] = {}
prev_batch_token: StreamToken | None = None
remaining_limit = threads_request.limit

# Query for rooms where the user has left or been banned, using their leave/ban
# event position as the upper bound to prevent seeing events after they left.
if leave_ban_rooms:
for room_id in leave_ban_rooms:
if remaining_limit <= 0:
# We've already fetched enough updates, but we still need to set
# prev_batch to indicate there are more results.
prev_batch_token = to_token
break

membership_info = room_membership_for_user_at_to_token_map[room_id]
bounded_to_token = membership_info.event_pos.to_room_stream_token()

(
room_thread_updates,
room_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids={room_id},
from_token=from_token.stream_token.room_key if from_token else None,
to_token=bounded_to_token,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)

# Count how many updates we fetched and reduce the remaining limit
num_updates = sum(
len(updates) for updates in room_thread_updates.values()
)
remaining_limit -= num_updates

self._merge_thread_updates(all_thread_updates, room_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, room_prev_batch
)

# Query for rooms where the user is joined, invited, or knocking, using the
# normal to_token as the upper bound.
if other_rooms and remaining_limit > 0:
(
other_thread_updates,
other_prev_batch,
) = await self.store.get_thread_updates_for_rooms(
room_ids=other_rooms,
from_token=from_token.stream_token.room_key if from_token else None,
to_token=to_token.room_key,
limit=remaining_limit,
exclude_thread_ids=threads_to_exclude,
)

self._merge_thread_updates(all_thread_updates, other_thread_updates)
prev_batch_token = self._merge_prev_batch_token(
prev_batch_token, other_prev_batch
)

if len(all_thread_updates) == 0:
return None

# Build a mapping of event_id -> (thread_id, update) for efficient lookup
# during visibility filtering.
event_to_thread_map: dict[str, tuple[str, ThreadUpdateInfo]] = {}
for thread_id, updates in all_thread_updates.items():
for update in updates:
event_to_thread_map[update.event_id] = (thread_id, update)

# Fetch and filter events for visibility
all_events = await self.store.get_events_as_list(event_to_thread_map.keys())
filtered_events = await filter_events_for_client(
self._storage_controllers, sync_config.user.to_string(), all_events
# Get thread updates using unified helper
user_id = sync_config.user.to_string()
(
thread_updates_response,
prev_batch_token,
) = await self.relations_handler.get_thread_updates_for_rooms(
room_ids=frozenset(actual_room_ids),
room_membership_map=room_membership_for_user_at_to_token_map,
user_id=user_id,
from_token=from_token.stream_token if from_token else None,
to_token=to_token,
limit=threads_request.limit,
include_roots=threads_request.include_roots,
exclude_thread_ids=threads_to_exclude,
)

# Rebuild thread updates from filtered events
filtered_updates: dict[str, list[ThreadUpdateInfo]] = defaultdict(list)
for event in filtered_events:
if event.event_id in event_to_thread_map:
thread_id, update = event_to_thread_map[event.event_id]
filtered_updates[thread_id].append(update)

if not filtered_updates:
if not thread_updates_response:
return None

# Note: Updates are already sorted by stream_ordering DESC from the database query,
# and filter_events_for_client preserves order, so updates[0] is guaranteed to be
# the latest event for each thread.

# Optionally fetch thread root events and their bundled aggregations
thread_root_event_map = {}
aggregations_map = {}
if threads_request.include_roots:
thread_root_events = await self.store.get_events_as_list(
filtered_updates.keys()
)
thread_root_event_map = {e.event_id: e for e in thread_root_events}

if thread_root_event_map:
aggregations_map = (
await self.relations_handler.get_bundled_aggregations(
thread_root_event_map.values(),
sync_config.user.to_string(),
)
)

thread_updates: dict[str, dict[str, _ThreadUpdate]] = {}
for thread_root, updates in filtered_updates.items():
# We only care about the latest update for the thread.
# After sorting above, updates[0] is guaranteed to be the latest (highest stream_ordering).
latest_update = updates[0]

# Generate per-thread prev_batch token if this thread has multiple visible updates.
# When we hit the global limit, we generate prev_batch tokens for all threads, even if
# we only saw 1 update for them. This is to cover the case where we only saw
# a single update for a given thread, but the global limit prevents us from
# obtaining other updates which would have otherwise been included in the
# range.
per_thread_prev_batch = None
if len(updates) > 1 or prev_batch_token is not None:
# Create a token pointing to one position before the latest event's stream position.
# This makes it exclusive - /relations with dir=b won't return the latest event again.
# Use StreamToken.START as base (all other streams at 0) since only room position matters.
per_thread_prev_batch = StreamToken.START.copy_and_replace(
StreamKeyType.ROOM,
RoomStreamToken(stream=latest_update.stream_ordering - 1),
)

thread_updates.setdefault(latest_update.room_id, {})[thread_root] = (
_ThreadUpdate(
thread_root=thread_root_event_map.get(thread_root),
prev_batch=per_thread_prev_batch,
bundled_aggregations=aggregations_map.get(thread_root),
)
)

return SlidingSyncResult.Extensions.ThreadsExtension(
updates=thread_updates,
updates=thread_updates_response,
prev_batch=prev_batch_token,
)
Loading
Loading