diff --git a/changelog.d/18828.feature b/changelog.d/18828.feature new file mode 100644 index 00000000000..e7f3541de43 --- /dev/null +++ b/changelog.d/18828.feature @@ -0,0 +1 @@ +Cleanly shutdown `SynapseHomeServer` object. diff --git a/scripts-dev/mypy_synapse_plugin.py b/scripts-dev/mypy_synapse_plugin.py index e170aabdae3..0b854cdba52 100644 --- a/scripts-dev/mypy_synapse_plugin.py +++ b/scripts-dev/mypy_synapse_plugin.py @@ -68,18 +68,42 @@ category="per-homeserver-tenant-metrics", ) +PREFER_SYNAPSE_CLOCK_CALL_LATER = ErrorCode( + "call-later-not-tracked", + "Prefer using `synapse.util.Clock.call_later` instead of `reactor.callLater`", + category="synapse-reactor-clock", +) + +PREFER_SYNAPSE_CLOCK_LOOPING_CALL = ErrorCode( + "prefer-synapse-clock-looping-call", + "Prefer using `synapse.util.Clock.looping_call` instead of `task.LoopingCall`", + category="synapse-reactor-clock", +) + PREFER_SYNAPSE_CLOCK_CALL_WHEN_RUNNING = ErrorCode( "prefer-synapse-clock-call-when-running", - "`synapse.util.Clock.call_when_running` should be used instead of `reactor.callWhenRunning`", + "Prefer using `synapse.util.Clock.call_when_running` instead of `reactor.callWhenRunning`", category="synapse-reactor-clock", ) PREFER_SYNAPSE_CLOCK_ADD_SYSTEM_EVENT_TRIGGER = ErrorCode( "prefer-synapse-clock-add-system-event-trigger", - "`synapse.util.Clock.add_system_event_trigger` should be used instead of `reactor.addSystemEventTrigger`", + "Prefer using `synapse.util.Clock.add_system_event_trigger` instead of `reactor.addSystemEventTrigger`", category="synapse-reactor-clock", ) +MULTIPLE_INTERNAL_CLOCKS_CREATED = ErrorCode( + "multiple-internal-clocks", + "Only one instance of `clock.Clock` should be created", + category="synapse-reactor-clock", +) + +UNTRACKED_BACKGROUND_PROCESS = ErrorCode( + "untracked-background-process", + "Prefer using `HomeServer.run_as_background_process` method over the bare `run_as_background_process`", + category="synapse-tracked-calls", +) + class Sentinel(enum.Enum): # defining a sentinel in this way allows mypy to correctly handle the @@ -222,6 +246,18 @@ def get_function_signature_hook( # callback, let's just pass it in while we have it. return lambda ctx: check_prometheus_metric_instantiation(ctx, fullname) + if fullname == "twisted.internet.task.LoopingCall": + return check_looping_call + + if fullname == "synapse.util.clock.Clock": + return check_clock_creation + + if ( + fullname + == "synapse.metrics.background_process_metrics.run_as_background_process" + ): + return check_background_process + return None def get_method_signature_hook( @@ -241,6 +277,13 @@ def get_method_signature_hook( ): return check_is_cacheable_wrapper + if fullname in ( + "twisted.internet.interfaces.IReactorTime.callLater", + "synapse.types.ISynapseThreadlessReactor.callLater", + "synapse.types.ISynapseReactor.callLater", + ): + return check_call_later + if fullname in ( "twisted.internet.interfaces.IReactorCore.callWhenRunning", "synapse.types.ISynapseThreadlessReactor.callWhenRunning", @@ -258,6 +301,78 @@ def get_method_signature_hook( return None +def check_clock_creation(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that the only `clock.Clock` instance is the one used by the `HomeServer`. + This is so that the `HomeServer` can cancel any tracked delayed or looping calls + during server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected the only `clock.Clock` instance to be the one used by the `HomeServer`. " + "This is so that the `HomeServer` can cancel any tracked delayed or looping calls " + "during server shutdown", + ctx.context, + code=MULTIPLE_INTERNAL_CLOCKS_CREATED, + ) + + return signature + + +def check_call_later(ctx: MethodSigContext) -> CallableType: + """ + Ensure that the `reactor.callLater` callsites aren't used. + + `synapse.util.Clock.call_later` should always be used instead of `reactor.callLater`. + This is because the `synapse.util.Clock` tracks delayed calls in order to cancel any + outstanding calls during server shutdown. Delayed calls which are either short lived + (<~60s) or frequently called and can be tracked via other means could be candidates for + using `synapse.util.Clock.call_later` with `call_later_cancel_on_shutdown` set to + `False`. There shouldn't be a need to use `reactor.callLater` outside of tests or the + `Clock` class itself. If a need arises, you can use a type ignore comment to disable the + check, e.g. `# type: ignore[call-later-not-tracked]`. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected all `reactor.callLater` calls to use `synapse.util.Clock.call_later` " + "instead. This is so that long lived calls can be tracked for cancellation during " + "server shutdown", + ctx.context, + code=PREFER_SYNAPSE_CLOCK_CALL_LATER, + ) + + return signature + + +def check_looping_call(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that the `task.LoopingCall` callsites aren't used. + + `synapse.util.Clock.looping_call` should always be used instead of `task.LoopingCall`. + `synapse.util.Clock` tracks looping calls in order to cancel any outstanding calls + during server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Expected all `task.LoopingCall` instances to use `synapse.util.Clock.looping_call` " + "instead. This is so that long lived calls can be tracked for cancellation during " + "server shutdown", + ctx.context, + code=PREFER_SYNAPSE_CLOCK_LOOPING_CALL, + ) + + return signature + + def check_call_when_running(ctx: MethodSigContext) -> CallableType: """ Ensure that the `reactor.callWhenRunning` callsites aren't used. @@ -312,6 +427,27 @@ def check_add_system_event_trigger(ctx: MethodSigContext) -> CallableType: return signature +def check_background_process(ctx: FunctionSigContext) -> CallableType: + """ + Ensure that calls to `run_as_background_process` use the `HomeServer` method. + This is so that the `HomeServer` can cancel any running background processes during + server shutdown. + + Args: + ctx: The `FunctionSigContext` from mypy. + """ + signature: CallableType = ctx.default_signature + ctx.api.fail( + "Prefer using `HomeServer.run_as_background_process` method over the bare " + "`run_as_background_process`. This is so that the `HomeServer` can cancel " + "any background processes during server shutdown", + ctx.context, + code=UNTRACKED_BACKGROUND_PROCESS, + ) + + return signature + + def analyze_prometheus_metric_classes(ctx: ClassDefContext) -> None: """ Cross-check the list of Prometheus metric classes against the diff --git a/synapse/_scripts/generate_workers_map.py b/synapse/_scripts/generate_workers_map.py index 8878e364e2e..f66c01040cc 100755 --- a/synapse/_scripts/generate_workers_map.py +++ b/synapse/_scripts/generate_workers_map.py @@ -157,7 +157,12 @@ def get_registered_paths_for_default( # TODO We only do this to avoid an error, but don't need the database etc hs.setup() registered_paths = get_registered_paths_for_hs(hs) - hs.cleanup() + # NOTE: a more robust implementation would properly shutdown/cleanup each server + # to avoid resource buildup. + # However, the call to `shutdown` is `async` so it would require additional complexity here. + # We are intentionally skipping this cleanup because this is a short-lived, one-off + # utility script where the simpler approach is sufficient and we shouldn't run into + # any resource buildup issues. return registered_paths diff --git a/synapse/_scripts/update_synapse_database.py b/synapse/_scripts/update_synapse_database.py index caaecda1617..ad02f0ed887 100644 --- a/synapse/_scripts/update_synapse_database.py +++ b/synapse/_scripts/update_synapse_database.py @@ -28,7 +28,6 @@ from twisted.internet import defer, reactor as reactor_ from synapse.config.homeserver import HomeServerConfig -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.server import HomeServer from synapse.storage import DataStore from synapse.types import ISynapseReactor @@ -53,7 +52,6 @@ def __init__(self, config: HomeServerConfig): def run_background_updates(hs: HomeServer) -> None: - server_name = hs.hostname main = hs.get_datastores().main state = hs.get_datastores().state @@ -67,9 +65,8 @@ async def run_background_updates() -> None: def run() -> None: # Apply all background updates on the database. defer.ensureDeferred( - run_as_background_process( + hs.run_as_background_process( "background_updates", - server_name, run_background_updates, ) ) diff --git a/synapse/app/_base.py b/synapse/app/_base.py index 56387248960..655f684ecf0 100644 --- a/synapse/app/_base.py +++ b/synapse/app/_base.py @@ -28,6 +28,7 @@ import traceback import warnings from textwrap import indent +from threading import Thread from typing import ( TYPE_CHECKING, Any, @@ -40,6 +41,7 @@ Tuple, cast, ) +from wsgiref.simple_server import WSGIServer from cryptography.utils import CryptographyDeprecationWarning from typing_extensions import ParamSpec @@ -97,22 +99,47 @@ logger = logging.getLogger(__name__) -# list of tuples of function, args list, kwargs dict -_sighup_callbacks: List[ - Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]] -] = [] +_instance_id_to_sighup_callbacks_map: Dict[ + str, List[Tuple[Callable[..., None], Tuple[object, ...], Dict[str, object]]] +] = {} +""" +Map from homeserver instance_id to a list of callbacks. + +We use `instance_id` instead of `server_name` because it's possible to have multiple +workers running in the same process with the same `server_name`. +""" P = ParamSpec("P") -def register_sighup(func: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: +def register_sighup( + homeserver_instance_id: str, + func: Callable[P, None], + *args: P.args, + **kwargs: P.kwargs, +) -> None: """ Register a function to be called when a SIGHUP occurs. Args: + homeserver_instance_id: The unique ID for this Synapse process instance + (`hs.get_instance_id()`) that this hook is associated with. func: Function to be called when sent a SIGHUP signal. *args, **kwargs: args and kwargs to be passed to the target function. """ - _sighup_callbacks.append((func, args, kwargs)) + + _instance_id_to_sighup_callbacks_map.setdefault(homeserver_instance_id, []).append( + (func, args, kwargs) + ) + + +def unregister_sighups(instance_id: str) -> None: + """ + Unregister all sighup functions associated with this Synapse instance. + + Args: + instance_id: Unique ID for this Synapse process instance. + """ + _instance_id_to_sighup_callbacks_map.pop(instance_id, []) def start_worker_reactor( @@ -281,7 +308,9 @@ async def wrapper() -> None: clock.call_when_running(lambda: defer.ensureDeferred(wrapper())) -def listen_metrics(bind_addresses: StrCollection, port: int) -> None: +def listen_metrics( + bind_addresses: StrCollection, port: int +) -> List[Tuple[WSGIServer, Thread]]: """ Start Prometheus metrics server. @@ -294,14 +323,22 @@ def listen_metrics(bind_addresses: StrCollection, port: int) -> None: bytecode at a time), this still works because the metrics thread can preempt the Twisted reactor thread between bytecode boundaries and the metrics thread gets scheduled with roughly equal priority to the Twisted reactor thread. + + Returns: + List of WSGIServer with the thread they are running on. """ from prometheus_client import start_http_server as start_http_server_prometheus from synapse.metrics import RegistryProxy + servers: List[Tuple[WSGIServer, Thread]] = [] for host in bind_addresses: logger.info("Starting metrics listener on %s:%d", host, port) - start_http_server_prometheus(port, addr=host, registry=RegistryProxy) + server, thread = start_http_server_prometheus( + port, addr=host, registry=RegistryProxy + ) + servers.append((server, thread)) + return servers def listen_manhole( @@ -309,7 +346,7 @@ def listen_manhole( port: int, manhole_settings: ManholeConfig, manhole_globals: dict, -) -> None: +) -> List[Port]: # twisted.conch.manhole 21.1.0 uses "int_from_bytes", which produces a confusing # warning. It's fixed by https://github.com/twisted/twisted/pull/1522), so # suppress the warning for now. @@ -321,7 +358,7 @@ def listen_manhole( from synapse.util.manhole import manhole - listen_tcp( + return listen_tcp( bind_addresses, port, manhole(settings=manhole_settings, globals=manhole_globals), @@ -498,7 +535,7 @@ def refresh_certificate(hs: "HomeServer") -> None: logger.info("Context factories updated.") -async def start(hs: "HomeServer") -> None: +async def start(hs: "HomeServer", freeze: bool = True) -> None: """ Start a Synapse server or worker. @@ -509,6 +546,11 @@ async def start(hs: "HomeServer") -> None: Args: hs: homeserver instance + freeze: whether to freeze the homeserver base objects in the garbage collector. + May improve garbage collection performance by marking objects with an effectively + static lifetime as frozen so they don't need to be considered for cleanup. + If you ever want to `shutdown` the homeserver, this needs to be + False otherwise the homeserver cannot be garbage collected after `shutdown`. """ server_name = hs.hostname reactor = hs.get_reactor() @@ -541,12 +583,17 @@ async def _handle_sighup(*args: Any, **kwargs: Any) -> None: # we're not using systemd. sdnotify(b"RELOADING=1") - for i, args, kwargs in _sighup_callbacks: - i(*args, **kwargs) + for sighup_callbacks in _instance_id_to_sighup_callbacks_map.values(): + for func, args, kwargs in sighup_callbacks: + func(*args, **kwargs) sdnotify(b"READY=1") - return run_as_background_process( + # It's okay to ignore the linter error here and call + # `run_as_background_process` directly because `_handle_sighup` operates + # outside of the scope of a specific `HomeServer` instance and holds no + # references to it which would prevent a clean shutdown. + return run_as_background_process( # type: ignore[untracked-background-process] "sighup", server_name, _handle_sighup, @@ -564,8 +611,8 @@ def run_sighup(*args: Any, **kwargs: Any) -> None: signal.signal(signal.SIGHUP, run_sighup) - register_sighup(refresh_certificate, hs) - register_sighup(reload_cache_config, hs.config) + register_sighup(hs.get_instance_id(), refresh_certificate, hs) + register_sighup(hs.get_instance_id(), reload_cache_config, hs.config) # Apply the cache config. hs.config.caches.resize_all_caches() @@ -603,7 +650,11 @@ def log_shutdown() -> None: logger.info("Shutting down...") # Log when we start the shut down process. - hs.get_clock().add_system_event_trigger("before", "shutdown", log_shutdown) + hs.register_sync_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=log_shutdown, + ) setup_sentry(hs) setup_sdnotify(hs) @@ -632,18 +683,24 @@ def log_shutdown() -> None: # `REQUIRED_ON_BACKGROUND_TASK_STARTUP` start_phone_stats_home(hs) - # We now freeze all allocated objects in the hopes that (almost) - # everything currently allocated are things that will be used for the - # rest of time. Doing so means less work each GC (hopefully). - # - # PyPy does not (yet?) implement gc.freeze() - if hasattr(gc, "freeze"): - gc.collect() - gc.freeze() - - # Speed up shutdowns by freezing all allocated objects. This moves everything - # into the permanent generation and excludes them from the final GC. - atexit.register(gc.freeze) + if freeze: + # We now freeze all allocated objects in the hopes that (almost) + # everything currently allocated are things that will be used for the + # rest of time. Doing so means less work each GC (hopefully). + # + # Note that freezing the homeserver object means that it won't be able to be + # garbage collected in the case of attempting an in-memory `shutdown`. This only + # needs to be considered if such a case is desirable. Exiting the entire Python + # process will function expectedly either way. + # + # PyPy does not (yet?) implement gc.freeze() + if hasattr(gc, "freeze"): + gc.collect() + gc.freeze() + + # Speed up process exit by freezing all allocated objects. This moves everything + # into the permanent generation and excludes them from the final GC. + atexit.register(gc.freeze) def reload_cache_config(config: HomeServerConfig) -> None: diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py index 51b8adaa278..7e8b47c20a3 100644 --- a/synapse/app/generic_worker.py +++ b/synapse/app/generic_worker.py @@ -278,11 +278,13 @@ def start_listening(self) -> None: self._listen_http(listener) elif listener.type == "manhole": if isinstance(listener, TCPListenerConfig): - _base.listen_manhole( - listener.bind_addresses, - listener.port, - manhole_settings=self.config.server.manhole_settings, - manhole_globals={"hs": self}, + self._listening_services.extend( + _base.listen_manhole( + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, + ) ) else: raise ConfigError( @@ -296,9 +298,11 @@ def start_listening(self) -> None: ) else: if isinstance(listener, TCPListenerConfig): - _base.listen_metrics( - listener.bind_addresses, - listener.port, + self._metrics_listeners.extend( + _base.listen_metrics( + listener.bind_addresses, + listener.port, + ) ) else: raise ConfigError( diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 35d633d5270..3c691906ca1 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -22,7 +22,7 @@ import logging import os import sys -from typing import Dict, Iterable, List +from typing import Dict, Iterable, List, Optional from twisted.internet.tcp import Port from twisted.web.resource import EncodingResourceWrapper, Resource @@ -70,6 +70,7 @@ from synapse.rest.well_known import well_known_resource from synapse.server import HomeServer from synapse.storage import DataStore +from synapse.types import ISynapseReactor from synapse.util.check_dependencies import VERSION, check_requirements from synapse.util.httpresourcetree import create_resource_tree from synapse.util.module_loader import load_module @@ -277,11 +278,13 @@ def start_listening(self) -> None: ) elif listener.type == "manhole": if isinstance(listener, TCPListenerConfig): - _base.listen_manhole( - listener.bind_addresses, - listener.port, - manhole_settings=self.config.server.manhole_settings, - manhole_globals={"hs": self}, + self._listening_services.extend( + _base.listen_manhole( + listener.bind_addresses, + listener.port, + manhole_settings=self.config.server.manhole_settings, + manhole_globals={"hs": self}, + ) ) else: raise ConfigError( @@ -294,9 +297,11 @@ def start_listening(self) -> None: ) else: if isinstance(listener, TCPListenerConfig): - _base.listen_metrics( - listener.bind_addresses, - listener.port, + self._metrics_listeners.extend( + _base.listen_metrics( + listener.bind_addresses, + listener.port, + ) ) else: raise ConfigError( @@ -340,12 +345,23 @@ def load_or_generate_config(argv_options: List[str]) -> HomeServerConfig: return config -def setup(config: HomeServerConfig) -> SynapseHomeServer: +def setup( + config: HomeServerConfig, + reactor: Optional[ISynapseReactor] = None, + freeze: bool = True, +) -> SynapseHomeServer: """ Create and setup a Synapse homeserver instance given a configuration. Args: config: The configuration for the homeserver. + reactor: Optionally provide a reactor to use. Can be useful in different + scenarios that you want control over the reactor, such as tests. + freeze: whether to freeze the homeserver base objects in the garbage collector. + May improve garbage collection performance by marking objects with an effectively + static lifetime as frozen so they don't need to be considered for cleanup. + If you ever want to `shutdown` the homeserver, this needs to be + False otherwise the homeserver cannot be garbage collected after `shutdown`. Returns: A homeserver instance. @@ -384,6 +400,7 @@ def setup(config: HomeServerConfig) -> SynapseHomeServer: config.server.server_name, config=config, version_string=f"Synapse/{VERSION}", + reactor=reactor, ) setup_logging(hs, config, use_worker_options=False) @@ -405,7 +422,7 @@ async def start() -> None: # Loading the provider metadata also ensures the provider config is valid. await oidc.load_metadata() - await _base.start(hs) + await _base.start(hs, freeze) hs.get_datastores().main.db_pool.updates.start_doing_background_updates() diff --git a/synapse/app/phone_stats_home.py b/synapse/app/phone_stats_home.py index 7b8e7fe7001..4bbc33cba28 100644 --- a/synapse/app/phone_stats_home.py +++ b/synapse/app/phone_stats_home.py @@ -29,9 +29,6 @@ from twisted.internet import defer from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.types import JsonDict from synapse.util.constants import ( MILLISECONDS_PER_SECOND, @@ -87,8 +84,6 @@ def phone_stats_home( stats: JsonDict, stats_process: List[Tuple[int, "resource.struct_rusage"]] = _stats_process, ) -> "defer.Deferred[None]": - server_name = hs.hostname - async def _phone_stats_home( hs: "HomeServer", stats: JsonDict, @@ -202,8 +197,8 @@ async def _phone_stats_home( except Exception as e: logger.warning("Error reporting stats: %s", e) - return run_as_background_process( - "phone_stats_home", server_name, _phone_stats_home, hs, stats, stats_process + return hs.run_as_background_process( + "phone_stats_home", _phone_stats_home, hs, stats, stats_process ) @@ -265,9 +260,8 @@ async def _generate_monthly_active_users() -> None: float(hs.config.server.max_mau_value) ) - return run_as_background_process( + return hs.run_as_background_process( "generate_monthly_active_users", - server_name, _generate_monthly_active_users, ) @@ -287,10 +281,16 @@ async def _generate_monthly_active_users() -> None: # We need to defer this init for the cases that we daemonize # otherwise the process ID we get is that of the non-daemon process - clock.call_later(0, performance_stats_init) + clock.call_later( + 0, + performance_stats_init, + ) # We wait 5 minutes to send the first set of stats as the server can # be quite busy the first few minutes clock.call_later( - INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS, phone_stats_home, hs, stats + INITIAL_DELAY_BEFORE_FIRST_PHONE_HOME_SECONDS, + phone_stats_home, + hs, + stats, ) diff --git a/synapse/appservice/__init__.py b/synapse/appservice/__init__.py index 2d8d382e68c..1d0735ca1d6 100644 --- a/synapse/appservice/__init__.py +++ b/synapse/appservice/__init__.py @@ -23,15 +23,33 @@ import logging import re from enum import Enum -from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Pattern, Sequence +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Pattern, + Sequence, + cast, +) import attr from netaddr import IPSet +from twisted.internet import reactor + from synapse.api.constants import EventTypes from synapse.events import EventBase -from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, UserID +from synapse.types import ( + DeviceListUpdates, + ISynapseThreadlessReactor, + JsonDict, + JsonMapping, + UserID, +) from synapse.util.caches.descriptors import _CacheContext, cached +from synapse.util.clock import Clock if TYPE_CHECKING: from synapse.appservice.api import ApplicationServiceApi @@ -98,6 +116,15 @@ def __init__( self.sender = sender # The application service user should be part of the server's domain. self.server_name = sender.domain # nb must be called this for @cached + + # Ideally we would require passing in the `HomeServer` `Clock` instance. + # However this is not currently possible as there are places which use + # `@cached` that aren't aware of the `HomeServer` instance. + # nb must be called this for @cached + self.clock = Clock( + cast(ISynapseThreadlessReactor, reactor), server_name=self.server_name + ) # type: ignore[multiple-internal-clocks] + self.namespaces = self._check_namespaces(namespaces) self.id = id self.ip_range_whitelist = ip_range_whitelist diff --git a/synapse/appservice/scheduler.py b/synapse/appservice/scheduler.py index c8678406a14..b4de759b675 100644 --- a/synapse/appservice/scheduler.py +++ b/synapse/appservice/scheduler.py @@ -81,7 +81,6 @@ from synapse.appservice.api import ApplicationServiceApi from synapse.events import EventBase from synapse.logging.context import run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main import DataStore from synapse.types import DeviceListUpdates, JsonMapping from synapse.util.clock import Clock @@ -200,6 +199,7 @@ def __init__(self, txn_ctrl: "_TransactionController", hs: "HomeServer"): ) self.server_name = hs.hostname self.clock = hs.get_clock() + self.hs = hs self._store = hs.get_datastores().main def start_background_request(self, service: ApplicationService) -> None: @@ -207,9 +207,7 @@ def start_background_request(self, service: ApplicationService) -> None: if service.id in self.requests_in_flight: return - run_as_background_process( - "as-sender", self.server_name, self._send_request, service - ) + self.hs.run_as_background_process("as-sender", self._send_request, service) async def _send_request(self, service: ApplicationService) -> None: # sanity-check: we shouldn't get here if this service already has a sender @@ -361,6 +359,7 @@ class _TransactionController: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self.clock = hs.get_clock() + self.hs = hs self.store = hs.get_datastores().main self.as_api = hs.get_application_service_api() @@ -448,6 +447,7 @@ def start_recoverer(self, service: ApplicationService) -> None: recoverer = self.RECOVERER_CLASS( self.server_name, self.clock, + self.hs, self.store, self.as_api, service, @@ -494,6 +494,7 @@ def __init__( self, server_name: str, clock: Clock, + hs: "HomeServer", store: DataStore, as_api: ApplicationServiceApi, service: ApplicationService, @@ -501,6 +502,7 @@ def __init__( ): self.server_name = server_name self.clock = clock + self.hs = hs self.store = store self.as_api = as_api self.service = service @@ -513,9 +515,8 @@ def recover(self) -> None: logger.info("Scheduling retries on %s in %fs", self.service.id, delay) self.scheduled_recovery = self.clock.call_later( delay, - run_as_background_process, + self.hs.run_as_background_process, "as-recoverer", - self.server_name, self.retry, ) @@ -535,9 +536,8 @@ def force_retry(self) -> None: if self.scheduled_recovery: self.clock.cancel_call_later(self.scheduled_recovery) # Run a retry, which will resechedule a recovery if it fails. - run_as_background_process( + self.hs.run_as_background_process( "retry", - self.server_name, self.retry, ) diff --git a/synapse/config/logger.py b/synapse/config/logger.py index 0531ae78756..9dde4c4003f 100644 --- a/synapse/config/logger.py +++ b/synapse/config/logger.py @@ -345,7 +345,9 @@ def setup_logging( # Add a SIGHUP handler to reload the logging configuration, if one is available. from synapse.app import _base as appbase - appbase.register_sighup(_reload_logging_config, log_config_path) + appbase.register_sighup( + hs.get_instance_id(), _reload_logging_config, log_config_path + ) # Log immediately so we can grep backwards. logger.warning("***** STARTING SERVER *****") diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index eac2d776f92..258bc29357e 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -172,7 +172,7 @@ def __init__( _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]] ] = BatchingQueue( name="keyring_server", - server_name=self.server_name, + hs=hs, clock=hs.get_clock(), # The method called to fetch each key process_batch_callback=self._inner_fetch_key_requests, @@ -194,6 +194,14 @@ def __init__( valid_until_ts=2**63, # fake future timestamp ) + def shutdown(self) -> None: + """ + Prepares the KeyRing for garbage collection by shutting down it's queues. + """ + self._fetch_keys_queue.shutdown() + for key_fetcher in self._key_fetchers: + key_fetcher.shutdown() + async def verify_json_for_server( self, server_name: str, @@ -479,11 +487,17 @@ def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname self._queue = BatchingQueue( name=self.__class__.__name__, - server_name=self.server_name, + hs=hs, clock=hs.get_clock(), process_batch_callback=self._fetch_keys, ) + def shutdown(self) -> None: + """ + Prepares the KeyFetcher for garbage collection by shutting down it's queue. + """ + self._queue.shutdown() + async def get_keys( self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int ) -> Dict[str, FetchKeyResult]: diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 41595043d11..8c91336dbc1 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -148,6 +148,7 @@ def __init__(self, hs: "HomeServer"): self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache( cache_name="get_pdu_cache", server_name=self.server_name, + hs=self.hs, clock=self._clock, max_len=1000, expiry_ms=120 * 1000, @@ -167,6 +168,7 @@ def __init__(self, hs: "HomeServer"): ] = ExpiringCache( cache_name="get_room_hierarchy_cache", server_name=self.server_name, + hs=self.hs, clock=self._clock, max_len=1000, expiry_ms=5 * 60 * 1000, diff --git a/synapse/federation/send_queue.py b/synapse/federation/send_queue.py index 2fdee9ac549..759df9836b9 100644 --- a/synapse/federation/send_queue.py +++ b/synapse/federation/send_queue.py @@ -144,6 +144,9 @@ def register(queue_name: QueueNames, queue: Sized) -> None: self.clock.looping_call(self._clear_queue, 30 * 1000) + def shutdown(self) -> None: + """Stops this federation sender instance from sending further transactions.""" + def _next_pos(self) -> int: pos = self.pos self.pos += 1 diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index 8e3619d1bca..4410ffc5c56 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -168,7 +168,6 @@ events_processed_counter, ) from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.types import ( @@ -232,6 +231,11 @@ class AbstractFederationSender(metaclass=abc.ABCMeta): + @abc.abstractmethod + def shutdown(self) -> None: + """Stops this federation sender instance from sending further transactions.""" + raise NotImplementedError() + @abc.abstractmethod def notify_new_events(self, max_token: RoomStreamToken) -> None: """This gets called when we have some new events we might want to @@ -326,6 +330,7 @@ class _DestinationWakeupQueue: _MAX_TIME_IN_QUEUE = 30.0 sender: "FederationSender" = attr.ib() + hs: "HomeServer" = attr.ib() server_name: str = attr.ib() """ Our homeserver name (used to label metrics) (`hs.hostname`). @@ -453,18 +458,30 @@ def __init__(self, hs: "HomeServer"): 1.0 / hs.config.ratelimiting.federation_rr_transactions_per_room_per_second ) self._destination_wakeup_queue = _DestinationWakeupQueue( - self, self.server_name, self.clock, max_delay_s=rr_txn_interval_per_room_s + self, + hs, + self.server_name, + self.clock, + max_delay_s=rr_txn_interval_per_room_s, ) + # It is important for `_is_shutdown` to be instantiated before the looping call + # for `wake_destinations_needing_catchup`. + self._is_shutdown = False + # Regularly wake up destinations that have outstanding PDUs to be caught up self.clock.looping_call_now( - run_as_background_process, + self.hs.run_as_background_process, WAKEUP_RETRY_PERIOD_SEC * 1000.0, "wake_destinations_needing_catchup", - self.server_name, self._wake_destinations_needing_catchup, ) + def shutdown(self) -> None: + self._is_shutdown = True + for queue in self._per_destination_queues.values(): + queue.shutdown() + def _get_per_destination_queue( self, destination: str ) -> Optional[PerDestinationQueue]: @@ -503,16 +520,15 @@ def notify_new_events(self, max_token: RoomStreamToken) -> None: return # fire off a processing loop in the background - run_as_background_process( + self.hs.run_as_background_process( "process_event_queue_for_federation", - self.server_name, self._process_event_queue_loop, ) async def _process_event_queue_loop(self) -> None: try: self._is_processing = True - while True: + while not self._is_shutdown: last_token = await self.store.get_federation_out_pos("events") ( next_token, @@ -1123,7 +1139,7 @@ async def _wake_destinations_needing_catchup(self) -> None: last_processed: Optional[str] = None - while True: + while not self._is_shutdown: destinations_to_wake = ( await self.store.get_catch_up_outstanding_destinations(last_processed) ) diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py index 4c844d403a2..845af92facf 100644 --- a/synapse/federation/sender/per_destination_queue.py +++ b/synapse/federation/sender/per_destination_queue.py @@ -28,6 +28,8 @@ import attr from prometheus_client import Counter +from twisted.internet import defer + from synapse.api.constants import EduTypes from synapse.api.errors import ( FederationDeniedError, @@ -41,7 +43,6 @@ from synapse.logging import issue9533_logger from synapse.logging.opentracing import SynapseTags, set_tag from synapse.metrics import SERVER_NAME_LABEL, sent_transactions_counter -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, ReadReceipt from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter from synapse.visibility import filter_events_for_server @@ -79,6 +80,7 @@ class PerDestinationQueue: """ Manages the per-destination transmission queues. + Runs until `shutdown()` is called on the queue. Args: hs @@ -94,6 +96,7 @@ def __init__( destination: str, ): self.server_name = hs.hostname + self._hs = hs self._clock = hs.get_clock() self._storage_controllers = hs.get_storage_controllers() self._store = hs.get_datastores().main @@ -117,6 +120,8 @@ def __init__( self._destination = destination self.transmission_loop_running = False + self._transmission_loop_enabled = True + self.active_transmission_loop: Optional[defer.Deferred] = None # Flag to signal to any running transmission loop that there is new data # queued up to be sent. @@ -171,6 +176,20 @@ def __init__( def __str__(self) -> str: return "PerDestinationQueue[%s]" % self._destination + def shutdown(self) -> None: + """Instruct the queue to stop processing any further requests""" + self._transmission_loop_enabled = False + # The transaction manager must be shutdown before cancelling the active + # transmission loop. Otherwise the transmission loop can enter a new cycle of + # sleeping before retrying since the shutdown flag of the _transaction_manager + # hasn't been set yet. + self._transaction_manager.shutdown() + try: + if self.active_transmission_loop is not None: + self.active_transmission_loop.cancel() + except Exception: + pass + def pending_pdu_count(self) -> int: return len(self._pending_pdus) @@ -309,11 +328,14 @@ def attempt_new_transaction(self) -> None: ) return + if not self._transmission_loop_enabled: + logger.warning("Shutdown has been requested. Not sending transaction") + return + logger.debug("TX [%s] Starting transaction loop", self._destination) - run_as_background_process( + self.active_transmission_loop = self._hs.run_as_background_process( "federation_transaction_transmission_loop", - self.server_name, self._transaction_transmission_loop, ) @@ -321,13 +343,13 @@ async def _transaction_transmission_loop(self) -> None: pending_pdus: List[EventBase] = [] try: self.transmission_loop_running = True - # This will throw if we wouldn't retry. We do this here so we fail # quickly, but we will later check this again in the http client, # hence why we throw the result away. await get_retry_limiter( destination=self._destination, our_server_name=self.server_name, + hs=self._hs, clock=self._clock, store=self._store, ) @@ -339,7 +361,7 @@ async def _transaction_transmission_loop(self) -> None: # not caught up yet return - while True: + while self._transmission_loop_enabled: self._new_data_to_send = False async with _TransactionQueueManager(self) as ( @@ -352,8 +374,8 @@ async def _transaction_transmission_loop(self) -> None: # If we've gotten told about new things to send during # checking for things to send, we try looking again. # Otherwise new PDUs or EDUs might arrive in the meantime, - # but not get sent because we hold the - # `transmission_loop_running` flag. + # but not get sent because we currently have an + # `_active_transmission_loop` running. if self._new_data_to_send: continue else: @@ -442,6 +464,7 @@ async def _transaction_transmission_loop(self) -> None: ) finally: # We want to be *very* sure we clear this after we stop processing + self.active_transmission_loop = None self.transmission_loop_running = False async def _catch_up_transmission_loop(self) -> None: @@ -469,7 +492,7 @@ async def _catch_up_transmission_loop(self) -> None: last_successful_stream_ordering: int = _tmp_last_successful_stream_ordering # get at most 50 catchup room/PDUs - while True: + while self._transmission_loop_enabled: event_ids = await self._store.get_catch_up_room_event_ids( self._destination, last_successful_stream_ordering ) diff --git a/synapse/federation/sender/transaction_manager.py b/synapse/federation/sender/transaction_manager.py index b548d9ed70c..f47c0114873 100644 --- a/synapse/federation/sender/transaction_manager.py +++ b/synapse/federation/sender/transaction_manager.py @@ -72,6 +72,12 @@ def __init__(self, hs: "synapse.server.HomeServer"): # HACK to get unique tx id self._next_txn_id = int(self.clock.time_msec()) + self._is_shutdown = False + + def shutdown(self) -> None: + self._is_shutdown = True + self._transport_layer.shutdown() + @measure_func("_send_new_transaction") async def send_new_transaction( self, @@ -86,6 +92,12 @@ async def send_new_transaction( edus: List of EDUs to send """ + if self._is_shutdown: + logger.warning( + "TransactionManager has been shutdown, not sending transaction" + ) + return + # Make a transaction-sending opentracing span. This span follows on from # all the edus in that transaction. This needs to be done since there is # no active span here, so if the edus were not received by the remote the diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index 5a5dc45f108..02e56e8e278 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -70,6 +70,9 @@ def __init__(self, hs: "HomeServer"): self.client = hs.get_federation_http_client() self._is_mine_server_name = hs.is_mine_server_name + def shutdown(self) -> None: + self.client.shutdown() + async def get_room_state_ids( self, destination: str, room_id: str, event_id: str ) -> JsonDict: diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py index 39a22b8cbb1..eed50ef69a7 100644 --- a/synapse/handlers/account_validity.py +++ b/synapse/handlers/account_validity.py @@ -37,10 +37,8 @@ class AccountValidityHandler: def __init__(self, hs: "HomeServer"): - self.hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.config = hs.config self.store = hs.get_datastores().main self.send_email_handler = hs.get_send_email_handler() diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index bf36cf39a19..6536d9fe510 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -47,7 +47,6 @@ event_processing_loop_room_count, ) from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.storage.databases.main.directory import RoomAliasMapping @@ -76,9 +75,8 @@ class ApplicationServicesHandler: def __init__(self, hs: "HomeServer"): - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = hs.get_datastores().main self.is_mine_id = hs.is_mine_id self.appservice_api = hs.get_application_service_api() @@ -171,8 +169,8 @@ async def start_scheduler() -> None: except Exception: logger.error("Application Services Failure") - run_as_background_process( - "as_scheduler", self.server_name, start_scheduler + self.hs.run_as_background_process( + "as_scheduler", start_scheduler ) self.started_scheduler = True diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py index c0684380a70..204dffd2882 100644 --- a/synapse/handlers/deactivate_account.py +++ b/synapse/handlers/deactivate_account.py @@ -24,7 +24,6 @@ from synapse.api.constants import Membership from synapse.api.errors import SynapseError -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.deactivate_account import ( ReplicationNotifyAccountDeactivatedServlet, ) @@ -272,8 +271,8 @@ def _start_user_parting(self) -> None: pending deactivation, if it isn't already running. """ if not self._user_parter_running: - run_as_background_process( - "user_parter_loop", self.server_name, self._user_parter_loop + self.hs.run_as_background_process( + "user_parter_loop", self._user_parter_loop ) async def _user_parter_loop(self) -> None: diff --git a/synapse/handlers/delayed_events.py b/synapse/handlers/delayed_events.py index d47e3fd2634..79dd3e84165 100644 --- a/synapse/handlers/delayed_events.py +++ b/synapse/handlers/delayed_events.py @@ -24,9 +24,6 @@ from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.replication.http.delayed_events import ( ReplicationAddedDelayedEventRestServlet, ) @@ -58,6 +55,7 @@ class DelayedEventsHandler: def __init__(self, hs: "HomeServer"): + self.hs = hs self.server_name = hs.hostname self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() @@ -94,7 +92,10 @@ async def _schedule_db_events() -> None: hs.get_notifier().add_replication_callback(self.notify_new_event) # Kick off again (without blocking) to catch any missed notifications # that may have fired before the callback was added. - self._clock.call_later(0, self.notify_new_event) + self._clock.call_later( + 0, + self.notify_new_event, + ) # Delayed events that are already marked as processed on startup might not have been # sent properly on the last run of the server, so unmark them to send them again. @@ -112,15 +113,14 @@ async def _schedule_db_events() -> None: self._schedule_next_at(next_send_ts) # Can send the events in background after having awaited on marking them as processed - run_as_background_process( + self.hs.run_as_background_process( "_send_events", - self.server_name, self._send_events, events, ) - self._initialized_from_db = run_as_background_process( - "_schedule_db_events", self.server_name, _schedule_db_events + self._initialized_from_db = self.hs.run_as_background_process( + "_schedule_db_events", _schedule_db_events ) else: self._repl_client = ReplicationAddedDelayedEventRestServlet.make_client(hs) @@ -145,9 +145,7 @@ async def process() -> None: finally: self._event_processing = False - run_as_background_process( - "delayed_events.notify_new_event", self.server_name, process - ) + self.hs.run_as_background_process("delayed_events.notify_new_event", process) async def _unsafe_process_new_event(self) -> None: # We purposefully fetch the current max room stream ordering before @@ -542,9 +540,8 @@ def _schedule_next_at(self, next_send_ts: Timestamp) -> None: if self._next_delayed_event_call is None: self._next_delayed_event_call = self._clock.call_later( delay_sec, - run_as_background_process, + self.hs.run_as_background_process, "_send_on_timeout", - self.server_name, self._send_on_timeout, ) else: diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 9509ac422ec..c6024597b74 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -47,7 +47,6 @@ ) from synapse.logging.opentracing import log_kv, set_tag, trace from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.http.devices import ( @@ -125,7 +124,7 @@ class DeviceHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func - self.hs = hs + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = cast("GenericWorkerStore", hs.get_datastores().main) self.notifier = hs.get_notifier() self.state = hs.get_state_handler() @@ -191,10 +190,9 @@ def __init__(self, hs: "HomeServer"): and self._delete_stale_devices_after is not None ): self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, DELETE_STALE_DEVICES_INTERVAL_MS, desc="delete_stale_devices", - server_name=self.server_name, func=self._delete_stale_devices, ) @@ -963,10 +961,9 @@ class DeviceWriterHandler(DeviceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) + self.server_name = hs.hostname # nb must be called this for @measure_func + self.hs = hs # nb must be called this for @wrap_as_background_process - self.server_name = ( - hs.hostname - ) # nb must be called this for @measure_func and @wrap_as_background_process # We only need to poke the federation sender explicitly if its on the # same instance. Other federation sender instances will get notified by # `synapse.app.generic_worker.FederationSenderHandler` when it sees it @@ -1444,7 +1441,7 @@ class DeviceListUpdater(DeviceListWorkerUpdater): def __init__(self, hs: "HomeServer", device_handler: DeviceWriterHandler): super().__init__(hs) - self.server_name = hs.hostname + self.hs = hs self.federation = hs.get_federation_client() self.server_name = hs.hostname # nb must be called this for @measure_func self.clock = hs.get_clock() # nb must be called this for @measure_func @@ -1468,6 +1465,7 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceWriterHandler): self._seen_updates: ExpiringCache[str, Set[str]] = ExpiringCache( cache_name="device_update_edu", server_name=self.server_name, + hs=self.hs, clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, @@ -1477,9 +1475,8 @@ def __init__(self, hs: "HomeServer", device_handler: DeviceWriterHandler): # Attempt to resync out of sync device lists every 30s. self._resync_retry_lock = Lock() self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, 30 * 1000, - server_name=self.server_name, func=self._maybe_retry_device_resync, desc="_maybe_retry_device_resync", ) @@ -1599,9 +1596,8 @@ async def _handle_device_updates(self, user_id: str) -> None: if resync: # We mark as stale up front in case we get restarted. await self.store.mark_remote_users_device_caches_as_stale([user_id]) - run_as_background_process( + self.hs.run_as_background_process( "_maybe_retry_device_resync", - self.server_name, self.multi_user_device_resync, [user_id], False, diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 41fb3076c36..adc20f4ad02 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -72,7 +72,6 @@ from synapse.logging.context import nested_logging_context from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.module_api import NOT_SPAM from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.invite_rule import InviteRule @@ -188,9 +187,8 @@ def __init__(self, hs: "HomeServer"): # any partial-state-resync operations which were in flight when we # were shut down. if not hs.config.worker.worker_app: - run_as_background_process( + self.hs.run_as_background_process( "resume_sync_partial_state_room", - self.server_name, self._resume_partial_state_room_sync, ) @@ -318,9 +316,8 @@ async def _maybe_backfill_inner( logger.debug( "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points." ) - run_as_background_process( + self.hs.run_as_background_process( "_maybe_backfill_inner_anyway_with_max_depth", - self.server_name, self.maybe_backfill, room_id=room_id, # We use `MAX_DEPTH` so that we find all backfill points next @@ -802,9 +799,8 @@ async def do_invite_join( # lots of requests for missing prev_events which we do actually # have. Hence we fire off the background task, but don't wait for it. - run_as_background_process( + self.hs.run_as_background_process( "handle_queued_pdus", - self.server_name, self._handle_queued_pdus, room_queue, ) @@ -1877,9 +1873,8 @@ async def _sync_partial_state_room_wrapper() -> None: room_id=room_id, ) - run_as_background_process( + self.hs.run_as_background_process( desc="sync_partial_state_room", - server_name=self.server_name, func=_sync_partial_state_room_wrapper, ) diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 59886f04c40..d6390b79c7b 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -81,7 +81,6 @@ trace, ) from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEventsRestServlet, ) @@ -153,6 +152,7 @@ class FederationEventHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self._clock = hs.get_clock() self._store = hs.get_datastores().main self._state_store = hs.get_datastores().state @@ -175,6 +175,7 @@ def __init__(self, hs: "HomeServer"): ) self._notifier = hs.get_notifier() + self._server_name = hs.hostname self._is_mine_id = hs.is_mine_id self._is_mine_server_name = hs.is_mine_server_name self._instance_name = hs.get_instance_name() @@ -974,9 +975,8 @@ async def _process_new_pulled_events(new_events: Collection[EventBase]) -> None: # Process previously failed backfill events in the background to not waste # time on something that is likely to fail again. if len(events_with_failed_pull_attempts) > 0: - run_as_background_process( + self.hs.run_as_background_process( "_process_new_pulled_events_with_failed_pull_attempts", - self.server_name, _process_new_pulled_events, events_with_failed_pull_attempts, ) @@ -1568,9 +1568,8 @@ async def _process_received_pdu( resync = True if resync: - run_as_background_process( + self.hs.run_as_background_process( "resync_device_due_to_pdu", - self.server_name, self._resync_device, event.sender, ) diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 4ff8b3704b9..e874b600008 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -67,7 +67,6 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.send_events import ReplicationSendEventsRestServlet from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.types import ( @@ -99,6 +98,7 @@ class MessageHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self.auth = hs.get_auth() self.clock = hs.get_clock() self.state = hs.get_state_handler() @@ -113,8 +113,8 @@ def __init__(self, hs: "HomeServer"): self._scheduled_expiry: Optional[IDelayedCall] = None if not hs.config.worker.worker_app: - run_as_background_process( - "_schedule_next_expiry", self.server_name, self._schedule_next_expiry + self.hs.run_as_background_process( + "_schedule_next_expiry", self._schedule_next_expiry ) async def get_room_data( @@ -444,9 +444,8 @@ def _schedule_expiry_for_event(self, event_id: str, expiry_ts: int) -> None: self._scheduled_expiry = self.clock.call_later( delay, - run_as_background_process, + self.hs.run_as_background_process, "_expire_event", - self.server_name, self._expire_event, event_id, ) @@ -548,9 +547,8 @@ def __init__(self, hs: "HomeServer"): and self.config.server.cleanup_extremities_with_dummy_events ): self.clock.looping_call( - lambda: run_as_background_process( + lambda: self.hs.run_as_background_process( "send_dummy_events_to_fill_extremities", - self.server_name, self._send_dummy_events_to_fill_extremities, ), 5 * 60 * 1000, @@ -570,6 +568,7 @@ def __init__(self, hs: "HomeServer"): self._external_cache_joined_hosts_updates = ExpiringCache( cache_name="_external_cache_joined_hosts_updates", server_name=self.server_name, + hs=self.hs, clock=self.clock, expiry_ms=30 * 60 * 1000, ) @@ -2113,9 +2112,8 @@ async def persist_and_notify_client_events( if event.type == EventTypes.Message: # We don't want to block sending messages on any presence code. This # matters as sometimes presence code can take a while. - run_as_background_process( + self.hs.run_as_background_process( "bump_presence_active_time", - self.server_name, self._bump_active_time, requester.user, requester.device_id, diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py index df1a7e714ce..02a67581e75 100644 --- a/synapse/handlers/pagination.py +++ b/synapse/handlers/pagination.py @@ -29,7 +29,6 @@ from synapse.events.utils import SerializeEventConfig from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging.opentracing import trace -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.admin._base import assert_user_is_admin from synapse.streams.config import PaginationConfig from synapse.types import ( @@ -116,10 +115,9 @@ def __init__(self, hs: "HomeServer"): logger.info("Setting up purge job with config: %s", job) self.clock.looping_call( - run_as_background_process, + self.hs.run_as_background_process, job.interval, "purge_history_for_rooms_in_range", - self.server_name, self.purge_history_for_rooms_in_range, job.shortest_max_lifetime, job.longest_max_lifetime, @@ -244,9 +242,8 @@ async def purge_history_for_rooms_in_range( # We want to purge everything, including local events, and to run the purge in # the background so that it's not blocking any other operation apart from # other purges in the same room. - run_as_background_process( + self.hs.run_as_background_process( PURGE_HISTORY_ACTION_NAME, - self.server_name, self.purge_history, room_id, token, @@ -604,9 +601,8 @@ async def get_messages( # Otherwise, we can backfill in the background for eventual # consistency's sake but we don't need to block the client waiting # for a costly federation call and processing. - run_as_background_process( + self.hs.run_as_background_process( "maybe_backfill_in_the_background", - self.server_name, self.hs.get_federation_handler().maybe_backfill, room_id, curr_topo, diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py index 4d246fadbd6..1610683066a 100644 --- a/synapse/handlers/presence.py +++ b/synapse/handlers/presence.py @@ -107,7 +107,6 @@ from synapse.logging.context import run_in_background from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.http.presence import ( @@ -537,19 +536,15 @@ def __init__(self, hs: "HomeServer"): self._bump_active_client = ReplicationBumpPresenceActiveTime.make_client(hs) self._set_state_client = ReplicationPresenceSetState.make_client(hs) - self._send_stop_syncing_loop = self.clock.looping_call( - self.send_stop_syncing, UPDATE_SYNCING_USERS_MS - ) + self.clock.looping_call(self.send_stop_syncing, UPDATE_SYNCING_USERS_MS) - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - run_as_background_process, - "generic_presence.on_shutdown", - self.server_name, - self._on_shutdown, + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) + @wrap_as_background_process("WorkerPresenceHandler._on_shutdown") async def _on_shutdown(self) -> None: if self._track_presence: self.hs.get_replication_command_handler().send_command( @@ -779,9 +774,7 @@ async def bump_presence_active_time( class PresenceHandler(BasePresenceHandler): def __init__(self, hs: "HomeServer"): super().__init__(hs) - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.wheel_timer: WheelTimer[str] = WheelTimer() self.notifier = hs.get_notifier() @@ -842,13 +835,10 @@ def __init__(self, hs: "HomeServer"): # have not yet been persisted self.unpersisted_users_changes: Set[str] = set() - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - run_as_background_process, - "presence.on_shutdown", - self.server_name, - self._on_shutdown, + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) # Keeps track of the number of *ongoing* syncs on this process. While @@ -881,7 +871,10 @@ def __init__(self, hs: "HomeServer"): # The initial delay is to allow disconnected clients a chance to # reconnect before we treat them as offline. self.clock.call_later( - 30, self.clock.looping_call, self._handle_timeouts, 5000 + 30, + self.clock.looping_call, + self._handle_timeouts, + 5000, ) # Presence information is persisted, whether or not it is being tracked @@ -908,6 +901,7 @@ def __init__(self, hs: "HomeServer"): self._event_pos = self.store.get_room_max_stream_ordering() self._event_processing = False + @wrap_as_background_process("PresenceHandler._on_shutdown") async def _on_shutdown(self) -> None: """Gets called when shutting down. This lets us persist any updates that we haven't yet persisted, e.g. updates that only changes some internal @@ -1539,8 +1533,8 @@ async def _process_presence() -> None: finally: self._event_processing = False - run_as_background_process( - "presence.notify_new_event", self.server_name, _process_presence + self.hs.run_as_background_process( + "presence.notify_new_event", _process_presence ) async def _unsafe_process(self) -> None: diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index dbff28e7fb5..9dda89d85bb 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -56,8 +56,8 @@ class ProfileHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() # nb must be called this for @cached self.store = hs.get_datastores().main - self.clock = hs.get_clock() self.hs = hs self.federation = hs.get_federation_client() diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py index 5761a7f70b3..c3ff0cfaf81 100644 --- a/synapse/handlers/register.py +++ b/synapse/handlers/register.py @@ -23,7 +23,14 @@ """Contains functions for registering clients.""" import logging -from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, TypedDict +from typing import ( + TYPE_CHECKING, + Iterable, + List, + Optional, + Tuple, + TypedDict, +) from prometheus_client import Counter diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py index 623823acb02..2ab9b70f8c5 100644 --- a/synapse/handlers/room_member.py +++ b/synapse/handlers/room_member.py @@ -50,7 +50,6 @@ from synapse.handlers.worker_lock import NEW_EVENT_DURING_PURGE_LOCK_NAME from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.http.push import ReplicationCopyPusherRestServlet from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.invite_rule import InviteRule @@ -2190,7 +2189,10 @@ def __init__(self, hs: "HomeServer"): self._notifier.add_replication_callback(self.notify_new_event) # We kick this off to pick up outstanding work from before the last restart. - self._clock.call_later(0, self.notify_new_event) + self._clock.call_later( + 0, + self.notify_new_event, + ) def notify_new_event(self) -> None: """Called when there may be more deltas to process""" @@ -2205,9 +2207,7 @@ async def process() -> None: finally: self._is_processing = False - run_as_background_process( - "room_forgetter.notify_new_event", self.server_name, process - ) + self._hs.run_as_background_process("room_forgetter.notify_new_event", process) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index eec420cbb17..735cfa0a0f8 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -224,7 +224,7 @@ def __init__(self, hs: "HomeServer"): ) # a lock on the mappings - self._mapping_lock = Linearizer(name="sso_user_mapping", clock=hs.get_clock()) + self._mapping_lock = Linearizer(clock=hs.get_clock(), name="sso_user_mapping") # a map from session id to session data self._username_mapping_sessions: Dict[str, UsernameMappingSession] = {} diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py index a2602ea818e..5b4a2cc62dc 100644 --- a/synapse/handlers/stats.py +++ b/synapse/handlers/stats.py @@ -33,7 +33,6 @@ from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.metrics import SERVER_NAME_LABEL, event_processing_positions -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.types import JsonDict from synapse.util.events import get_plain_text_topic_from_event_content @@ -75,7 +74,10 @@ def __init__(self, hs: "HomeServer"): # We kick this off so that we don't have to wait for a change before # we start populating stats - self.clock.call_later(0, self.notify_new_event) + self.clock.call_later( + 0, + self.notify_new_event, + ) def notify_new_event(self) -> None: """Called when there may be more deltas to process""" @@ -90,7 +92,7 @@ async def process() -> None: finally: self._is_processing = False - run_as_background_process("stats.notify_new_event", self.server_name, process) + self.hs.run_as_background_process("stats.notify_new_event", process) async def _unsafe_process(self) -> None: # If self.pos is None then means we haven't fetched it from DB diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index c0341c56541..6f0522d5bba 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -323,6 +323,7 @@ def __init__(self, hs: "HomeServer"): ] = ExpiringCache( cache_name="lazy_loaded_members_cache", server_name=self.server_name, + hs=hs, clock=self.clock, max_len=0, expiry_ms=LAZY_LOADED_MEMBERS_CACHE_MAX_AGE, @@ -982,6 +983,7 @@ def get_lazy_loaded_members_cache( logger.debug("creating LruCache for %r", cache_key) cache = LruCache( max_size=LAZY_LOADED_MEMBERS_CACHE_MAX_SIZE, + clock=self.clock, server_name=self.server_name, ) self.lazy_loaded_members_cache[cache_key] = cache diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py index 6a7b36ea0c8..77c5b747c3f 100644 --- a/synapse/handlers/typing.py +++ b/synapse/handlers/typing.py @@ -28,7 +28,6 @@ from synapse.api.errors import AuthError, ShadowBanError, SynapseError from synapse.appservice import ApplicationService from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.streams import TypingStream @@ -78,11 +77,10 @@ class FollowerTypingHandler: """ def __init__(self, hs: "HomeServer"): + self.hs = hs # nb must be called this for @wrap_as_background_process self.store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.clock = hs.get_clock() self.is_mine_id = hs.is_mine_id self.is_mine_server_name = hs.is_mine_server_name @@ -144,9 +142,8 @@ def _handle_timeout_for_member(self, now: int, member: RoomMember) -> None: if self.federation and self.is_mine_id(member.user_id): last_fed_poke = self._member_last_federation_poke.get(member, None) if not last_fed_poke or last_fed_poke + FEDERATION_PING_INTERVAL <= now: - run_as_background_process( + self.hs.run_as_background_process( "typing._push_remote", - self.server_name, self._push_remote, member=member, typing=True, @@ -220,9 +217,8 @@ def process_replication_rows( self._rooms_updated.add(row.room_id) if self.federation: - run_as_background_process( + self.hs.run_as_background_process( "_send_changes_in_typing_to_remotes", - self.server_name, self._send_changes_in_typing_to_remotes, row.room_id, prev_typing, @@ -384,9 +380,8 @@ def _stopped_typing(self, member: RoomMember) -> None: def _push_update(self, member: RoomMember, typing: bool) -> None: if self.hs.is_mine_id(member.user_id): # Only send updates for changes to our own users. - run_as_background_process( + self.hs.run_as_background_process( "typing._push_remote", - self.server_name, self._push_remote, member, typing, diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py index 130099a2390..28961f5925f 100644 --- a/synapse/handlers/user_directory.py +++ b/synapse/handlers/user_directory.py @@ -36,7 +36,6 @@ from synapse.api.errors import Codes, SynapseError from synapse.handlers.state_deltas import MatchChange, StateDeltasHandler from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.state_deltas import StateDelta from synapse.storage.databases.main.user_directory import SearchResult from synapse.storage.roommember import ProfileInfo @@ -137,11 +136,15 @@ def __init__(self, hs: "HomeServer"): # We kick this off so that we don't have to wait for a change before # we start populating the user directory - self.clock.call_later(0, self.notify_new_event) + self.clock.call_later( + 0, + self.notify_new_event, + ) # Kick off the profile refresh process on startup self._refresh_remote_profiles_call_later = self.clock.call_later( - 10, self.kick_off_remote_profile_refresh_process + 10, + self.kick_off_remote_profile_refresh_process, ) async def search_users( @@ -193,9 +196,7 @@ async def process() -> None: self._is_processing = False self._is_processing = True - run_as_background_process( - "user_directory.notify_new_event", self.server_name, process - ) + self._hs.run_as_background_process("user_directory.notify_new_event", process) async def handle_local_profile_change( self, user_id: str, profile: ProfileInfo @@ -609,8 +610,8 @@ async def process() -> None: self._is_refreshing_remote_profiles = False self._is_refreshing_remote_profiles = True - run_as_background_process( - "user_directory.refresh_remote_profiles", self.server_name, process + self._hs.run_as_background_process( + "user_directory.refresh_remote_profiles", process ) async def _unsafe_refresh_remote_profiles(self) -> None: @@ -655,8 +656,9 @@ async def _unsafe_refresh_remote_profiles(self) -> None: if not users: return _, _, next_try_at_ts = users[0] + delay = ((next_try_at_ts - self.clock.time_msec()) // 1000) + 2 self._refresh_remote_profiles_call_later = self.clock.call_later( - ((next_try_at_ts - self.clock.time_msec()) // 1000) + 2, + delay, self.kick_off_remote_profile_refresh_process, ) @@ -692,9 +694,8 @@ async def process() -> None: self._is_refreshing_remote_profiles_for_servers.remove(server_name) self._is_refreshing_remote_profiles_for_servers.add(server_name) - run_as_background_process( + self._hs.run_as_background_process( "user_directory.refresh_remote_profiles_for_remote_server", - self.server_name, process, ) diff --git a/synapse/handlers/worker_lock.py b/synapse/handlers/worker_lock.py index 0b375790dd9..ca1e2b166c3 100644 --- a/synapse/handlers/worker_lock.py +++ b/synapse/handlers/worker_lock.py @@ -37,13 +37,13 @@ import attr from twisted.internet import defer -from twisted.internet.interfaces import IReactorTime from synapse.logging.context import PreserveLoggingContext from synapse.logging.opentracing import start_active_span from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage.databases.main.lock import Lock, LockStore from synapse.util.async_helpers import timeout_deferred +from synapse.util.clock import Clock from synapse.util.constants import ONE_MINUTE_SECONDS if TYPE_CHECKING: @@ -66,10 +66,8 @@ class WorkerLocksHandler: """ def __init__(self, hs: "HomeServer") -> None: - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process - self._reactor = hs.get_reactor() + self.hs = hs # nb must be called this for @wrap_as_background_process + self._clock = hs.get_clock() self._store = hs.get_datastores().main self._clock = hs.get_clock() self._notifier = hs.get_notifier() @@ -98,7 +96,7 @@ def acquire_lock(self, lock_name: str, lock_key: str) -> "WaitingLock": """ lock = WaitingLock( - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, lock_name=lock_name, @@ -129,7 +127,7 @@ def acquire_read_write_lock( """ lock = WaitingLock( - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, lock_name=lock_name, @@ -160,7 +158,7 @@ def acquire_multi_read_write_lock( lock = WaitingMultiLock( lock_names=lock_names, write=write, - reactor=self._reactor, + clock=self._clock, store=self._store, handler=self, ) @@ -197,7 +195,11 @@ def _wake_all_locks( if not deferred.called: deferred.callback(None) - self._clock.call_later(0, _wake_all_locks, locks) + self._clock.call_later( + 0, + _wake_all_locks, + locks, + ) @wrap_as_background_process("_cleanup_locks") async def _cleanup_locks(self) -> None: @@ -207,7 +209,7 @@ async def _cleanup_locks(self) -> None: @attr.s(auto_attribs=True, eq=False) class WaitingLock: - reactor: IReactorTime + clock: Clock store: LockStore handler: WorkerLocksHandler lock_name: str @@ -246,10 +248,11 @@ async def __aenter__(self) -> None: # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): + timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=self._get_next_retry_interval(), - reactor=self.reactor, + timeout=timeout, + clock=self.clock, ) except Exception: pass @@ -290,7 +293,7 @@ class WaitingMultiLock: write: bool - reactor: IReactorTime + clock: Clock store: LockStore handler: WorkerLocksHandler @@ -323,10 +326,11 @@ async def __aenter__(self) -> None: # periodically wake up in case the lock was released but we # weren't notified. with PreserveLoggingContext(): + timeout = self._get_next_retry_interval() await timeout_deferred( deferred=self.deferred, - timeout=self._get_next_retry_interval(), - reactor=self.reactor, + timeout=timeout, + clock=self.clock, ) except Exception: pass diff --git a/synapse/http/client.py b/synapse/http/client.py index bbb0efe8b52..370cdc3568b 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -54,7 +54,6 @@ IOpenSSLContextFactory, IReactorCore, IReactorPluggableNameResolver, - IReactorTime, IResolutionReceiver, ITCPTransport, ) @@ -88,6 +87,7 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.types import ISynapseReactor, StrSequence from synapse.util.async_helpers import timeout_deferred +from synapse.util.clock import Clock from synapse.util.json import json_decoder if TYPE_CHECKING: @@ -165,16 +165,17 @@ def _is_ip_blocked( _EPSILON = 0.00000001 -def _make_scheduler( - reactor: IReactorTime, -) -> Callable[[Callable[[], object]], IDelayedCall]: +def _make_scheduler(clock: Clock) -> Callable[[Callable[[], object]], IDelayedCall]: """Makes a schedular suitable for a Cooperator using the given reactor. (This is effectively just a copy from `twisted.internet.task`) """ def _scheduler(x: Callable[[], object]) -> IDelayedCall: - return reactor.callLater(_EPSILON, x) + return clock.call_later( + _EPSILON, + x, + ) return _scheduler @@ -367,7 +368,7 @@ def __init__( # We use this for our body producers to ensure that they use the correct # reactor. - self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_reactor())) + self._cooperator = Cooperator(scheduler=_make_scheduler(hs.get_clock())) async def request( self, @@ -436,9 +437,9 @@ async def request( # we use our own timeout mechanism rather than treq's as a workaround # for https://twistedmatrix.com/trac/ticket/9534. request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), + deferred=request_deferred, + timeout=60, + clock=self.hs.get_clock(), ) # turn timeouts into RequestTimedOutErrors @@ -763,7 +764,11 @@ async def get_file( d = read_body_with_max_size(response, output_stream, max_size) # Ensure that the body is not read forever. - d = timeout_deferred(d, 30, self.hs.get_reactor()) + d = timeout_deferred( + deferred=d, + timeout=30, + clock=self.hs.get_clock(), + ) length = await make_deferred_yieldable(d) except BodyExceededMaxSize: @@ -957,9 +962,9 @@ async def request( # for https://twistedmatrix.com/trac/ticket/9534. # (Updated url https://github.com/twisted/twisted/issues/9534) request_deferred = timeout_deferred( - request_deferred, - 60, - self.hs.get_reactor(), + deferred=request_deferred, + timeout=60, + clock=self.hs.get_clock(), ) # turn timeouts into RequestTimedOutErrors diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py index 98826c91711..9d87514be00 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py @@ -67,6 +67,9 @@ class MatrixFederationAgent: Args: reactor: twisted reactor to use for underlying requests + clock: Internal `HomeServer` clock used to track delayed and looping calls. + Should be obtained from `hs.get_clock()`. + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. @@ -97,6 +100,7 @@ def __init__( *, server_name: str, reactor: ISynapseReactor, + clock: Clock, tls_client_options_factory: Optional[FederationPolicyForHTTPS], user_agent: bytes, ip_allowlist: Optional[IPSet], @@ -109,6 +113,7 @@ def __init__( Args: server_name: Our homeserver name (used to label metrics) (`hs.hostname`). reactor + clock: Should be the `hs` clock from `hs.get_clock()` tls_client_options_factory user_agent ip_allowlist @@ -124,7 +129,6 @@ def __init__( # addresses, to prevent DNS rebinding. reactor = BlocklistingReactorWrapper(reactor, ip_allowlist, ip_blocklist) - self._clock = Clock(reactor, server_name=server_name) self._pool = HTTPConnectionPool(reactor) self._pool.retryAutomatically = False self._pool.maxPersistentPerHost = 5 @@ -147,6 +151,7 @@ def __init__( _well_known_resolver = WellKnownResolver( server_name=server_name, reactor=reactor, + clock=clock, agent=BlocklistingAgentWrapper( ProxyAgent( reactor=reactor, diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py index 97bba8231ac..2f52abcc035 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py @@ -90,6 +90,7 @@ def __init__( self, server_name: str, reactor: ISynapseThreadlessReactor, + clock: Clock, agent: IAgent, user_agent: bytes, well_known_cache: Optional[TTLCache[bytes, Optional[bytes]]] = None, @@ -99,6 +100,7 @@ def __init__( Args: server_name: Our homeserver name (used to label metrics) (`hs.hostname`). reactor + clock: Should be the `hs` clock from `hs.get_clock()` agent user_agent well_known_cache @@ -107,7 +109,7 @@ def __init__( self.server_name = server_name self._reactor = reactor - self._clock = Clock(reactor, server_name=server_name) + self._clock = clock if well_known_cache is None: well_known_cache = TTLCache( diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index c264bae6e51..4d72c72d018 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -90,6 +90,7 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.types import JsonDict from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred +from synapse.util.clock import Clock from synapse.util.json import json_decoder from synapse.util.metrics import Measure from synapse.util.stringutils import parse_and_validate_server_name @@ -270,6 +271,7 @@ def _validate(v: Any) -> bool: async def _handle_response( + clock: Clock, reactor: IReactorTime, timeout_sec: float, request: MatrixFederationRequest, @@ -299,7 +301,11 @@ async def _handle_response( check_content_type_is(response.headers, parser.CONTENT_TYPE) d = read_body_with_max_size(response, parser, max_response_size) - d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) + d = timeout_deferred( + deferred=d, + timeout=timeout_sec, + clock=clock, + ) length = await make_deferred_yieldable(d) @@ -411,6 +417,7 @@ def __init__( self.server_name = hs.hostname self.reactor = hs.get_reactor() + self.clock = hs.get_clock() user_agent = hs.version_string if hs.config.server.user_agent_suffix: @@ -424,6 +431,7 @@ def __init__( federation_agent: IAgent = MatrixFederationAgent( server_name=self.server_name, reactor=self.reactor, + clock=self.clock, tls_client_options_factory=tls_client_options_factory, user_agent=user_agent.encode("ascii"), ip_allowlist=hs.config.server.federation_ip_range_allowlist, @@ -457,7 +465,6 @@ def __init__( ip_blocklist=hs.config.server.federation_ip_range_blocklist, ) - self.clock = hs.get_clock() self._store = hs.get_datastores().main self.version_string_bytes = hs.version_string.encode("ascii") self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000 @@ -470,9 +477,9 @@ def __init__( self.max_long_retries = hs.config.federation.max_long_retries self.max_short_retries = hs.config.federation.max_short_retries - self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor)) + self._cooperator = Cooperator(scheduler=_make_scheduler(self.clock)) - self._sleeper = AwakenableSleeper(self.reactor) + self._sleeper = AwakenableSleeper(self.clock) self._simple_http_client = SimpleHttpClient( hs, @@ -484,6 +491,10 @@ def __init__( self.remote_download_linearizer = Linearizer( name="remote_download_linearizer", max_count=6, clock=self.clock ) + self._is_shutdown = False + + def shutdown(self) -> None: + self._is_shutdown = True def wake_destination(self, destination: str) -> None: """Called when the remote server may have come back online.""" @@ -629,6 +640,7 @@ async def _send_request( limiter = await synapse.util.retryutils.get_retry_limiter( destination=request.destination, our_server_name=self.server_name, + hs=self.hs, clock=self.clock, store=self._store, backoff_on_404=backoff_on_404, @@ -675,7 +687,7 @@ async def _send_request( (b"", b"", path_bytes, None, query_bytes, b"") ) - while True: + while not self._is_shutdown: try: json = request.get_json() if json: @@ -733,9 +745,9 @@ async def _send_request( bodyProducer=producer, ) request_deferred = timeout_deferred( - request_deferred, + deferred=request_deferred, timeout=_sec_timeout, - reactor=self.reactor, + clock=self.clock, ) response = await make_deferred_yieldable(request_deferred) @@ -793,7 +805,9 @@ async def _send_request( # Update transactions table? d = treq.content(response) d = timeout_deferred( - d, timeout=_sec_timeout, reactor=self.reactor + deferred=d, + timeout=_sec_timeout, + clock=self.clock, ) try: @@ -862,6 +876,15 @@ async def _send_request( delay_seconds, ) + if self._is_shutdown: + # Immediately fail sending the request instead of starting a + # potentially long sleep after the server has requested + # shutdown. + # This is the code path followed when the + # `federation_transaction_transmission_loop` has been + # cancelled. + raise + # Sleep for the calculated delay, or wake up immediately # if we get notified that the server is back up. await self._sleeper.sleep( @@ -1074,6 +1097,7 @@ async def put_json( parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( + self.clock, self.reactor, _sec_timeout, request, @@ -1152,7 +1176,13 @@ async def post_json( _sec_timeout = self.default_timeout_seconds body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.clock, + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=JsonParser(), ) return body @@ -1358,6 +1388,7 @@ async def get_json_with_headers( parser = cast(ByteParser[T], JsonParser()) body = await _handle_response( + self.clock, self.reactor, _sec_timeout, request, @@ -1431,7 +1462,13 @@ async def delete_json( _sec_timeout = self.default_timeout_seconds body = await _handle_response( - self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() + self.clock, + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=JsonParser(), ) return body diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py index 9b044f3b0ae..fa17432984a 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py @@ -161,12 +161,12 @@ async def _async_render(self, request: "SynapseRequest") -> Tuple[int, Any]: bodyProducer=QuieterFileBodyProducer(request.content), ) request_deferred = timeout_deferred( - request_deferred, + deferred=request_deferred, # This should be set longer than the timeout in `MatrixFederationHttpClient` # so that it has enough time to complete and pass us the data before we give # up. timeout=90, - reactor=self.reactor, + clock=self._clock, ) response = await make_deferred_yieldable(request_deferred) diff --git a/synapse/http/server.py b/synapse/http/server.py index ce9d5630df2..d5af8758ac2 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -420,7 +420,14 @@ def __init__( """ if clock is None: - clock = Clock( + # Ideally we wouldn't ignore the linter error here and instead enforce a + # required `Clock` be passed into the `__init__` function. + # However, this would change the function signature which is currently being + # exported to the module api. Since we don't want to break that api, we have + # to settle with ignoring the linter error here. + # As of the time of writing this, all Synapse internal usages of + # `DirectServeJsonResource` pass in the existing homeserver clock instance. + clock = Clock( # type: ignore[multiple-internal-clocks] cast(ISynapseThreadlessReactor, reactor), server_name="synapse_module_running_from_unknown_server", ) @@ -608,7 +615,14 @@ def __init__( Only optional for the Module API. """ if clock is None: - clock = Clock( + # Ideally we wouldn't ignore the linter error here and instead enforce a + # required `Clock` be passed into the `__init__` function. + # However, this would change the function signature which is currently being + # exported to the module api. Since we don't want to break that api, we have + # to settle with ignoring the linter error here. + # As of the time of writing this, all Synapse internal usages of + # `DirectServeHtmlResource` pass in the existing homeserver clock instance. + clock = Clock( # type: ignore[multiple-internal-clocks] cast(ISynapseThreadlessReactor, reactor), server_name="synapse_module_running_from_unknown_server", ) diff --git a/synapse/http/site.py b/synapse/http/site.py index 2c0c301c03f..f4f326cfde4 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py @@ -22,7 +22,7 @@ import logging import time from http import HTTPStatus -from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union +from typing import TYPE_CHECKING, Any, Generator, List, Optional, Tuple, Union import attr from zope.interface import implementer @@ -30,6 +30,7 @@ from twisted.internet.address import UNIXAddress from twisted.internet.defer import Deferred from twisted.internet.interfaces import IAddress +from twisted.internet.protocol import Protocol from twisted.python.failure import Failure from twisted.web.http import HTTPChannel from twisted.web.resource import IResource, Resource @@ -660,6 +661,70 @@ class _XForwardedForAddress: host: str +class SynapseProtocol(HTTPChannel): + """ + Synapse-specific twisted http Protocol. + + This is a small wrapper around the twisted HTTPChannel so we can track active + connections in order to close any outstanding connections on shutdown. + """ + + def __init__( + self, + site: "SynapseSite", + our_server_name: str, + max_request_body_size: int, + request_id_header: Optional[str], + request_class: type, + ): + super().__init__() + self.factory: SynapseSite = site + self.site = site + self.our_server_name = our_server_name + self.max_request_body_size = max_request_body_size + self.request_id_header = request_id_header + self.request_class = request_class + + def connectionMade(self) -> None: + """ + Called when a connection is made. + + This may be considered the initializer of the protocol, because + it is called when the connection is completed. + + Add the connection to the factory's connection list when it's established. + """ + super().connectionMade() + self.factory.addConnection(self) + + def connectionLost(self, reason: Failure) -> None: # type: ignore[override] + """ + Called when the connection is shut down. + + Clear any circular references here, and any external references to this + Protocol. The connection has been closed. In our case, we need to remove the + connection from the factory's connection list, when it's lost. + """ + super().connectionLost(reason) + self.factory.removeConnection(self) + + def requestFactory(self, http_channel: HTTPChannel, queued: bool) -> SynapseRequest: # type: ignore[override] + """ + A callable used to build `twisted.web.iweb.IRequest` objects. + + Use our own custom SynapseRequest type instead of the regular + twisted.web.server.Request. + """ + return self.request_class( + self, + self.factory, + our_server_name=self.our_server_name, + max_request_body_size=self.max_request_body_size, + queued=queued, + request_id_header=self.request_id_header, + ) + + class SynapseSite(ProxySite): """ Synapse-specific twisted http Site @@ -710,23 +775,44 @@ def __init__( assert config.http_options is not None proxied = config.http_options.x_forwarded - request_class = XForwardedForRequest if proxied else SynapseRequest - - request_id_header = config.http_options.request_id_header - - def request_factory(channel: HTTPChannel, queued: bool) -> Request: - return request_class( - channel, - self, - our_server_name=self.server_name, - max_request_body_size=max_request_body_size, - queued=queued, - request_id_header=request_id_header, - ) + self.request_class = XForwardedForRequest if proxied else SynapseRequest + + self.request_id_header = config.http_options.request_id_header + self.max_request_body_size = max_request_body_size - self.requestFactory = request_factory # type: ignore self.access_logger = logging.getLogger(logger_name) self.server_version_string = server_version_string.encode("ascii") + self.connections: List[Protocol] = [] + + def buildProtocol(self, addr: IAddress) -> SynapseProtocol: + protocol = SynapseProtocol( + self, + self.server_name, + self.max_request_body_size, + self.request_id_header, + self.request_class, + ) + return protocol + + def addConnection(self, protocol: Protocol) -> None: + self.connections.append(protocol) + + def removeConnection(self, protocol: Protocol) -> None: + if protocol in self.connections: + self.connections.remove(protocol) + + def stopFactory(self) -> None: + super().stopFactory() + + # Shutdown any connections which are still active. + # These can be long lived HTTP connections which wouldn't normally be closed + # when calling `shutdown` on the respective `Port`. + # Closing the connections here is required for us to fully shutdown the + # `SynapseHomeServer` in order for it to be garbage collected. + for protocol in self.connections[:]: + if protocol.transport is not None: + protocol.transport.loseConnection() + self.connections.clear() def log(self, request: SynapseRequest) -> None: # type: ignore[override] pass diff --git a/synapse/media/_base.py b/synapse/media/_base.py index 15b28074fd3..d3a9a66f5a9 100644 --- a/synapse/media/_base.py +++ b/synapse/media/_base.py @@ -704,6 +704,7 @@ class ThreadedFileSender: def __init__(self, hs: "HomeServer") -> None: self.reactor = hs.get_reactor() + self.clock = hs.get_clock() self.thread_pool = hs.get_media_sender_thread_pool() self.file: Optional[BinaryIO] = None @@ -712,7 +713,7 @@ def __init__(self, hs: "HomeServer") -> None: # Signals if the thread should keep reading/sending data. Set means # continue, clear means pause. - self.wakeup_event = DeferredEvent(self.reactor) + self.wakeup_event = DeferredEvent(self.clock) # Signals if the thread should terminate, e.g. because the consumer has # gone away. diff --git a/synapse/media/media_repository.py b/synapse/media/media_repository.py index 436d9b7e35f..238dc6cb2f3 100644 --- a/synapse/media/media_repository.py +++ b/synapse/media/media_repository.py @@ -67,7 +67,6 @@ from synapse.media.storage_provider import StorageProviderWrapper from synapse.media.thumbnailer import Thumbnailer, ThumbnailError from synapse.media.url_previewer import UrlPreviewer -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia from synapse.types import UserID from synapse.util.async_helpers import Linearizer @@ -187,16 +186,14 @@ def __init__(self, hs: "HomeServer"): self.media_repository_callbacks = hs.get_module_api_callbacks().media_repository def _start_update_recently_accessed(self) -> Deferred: - return run_as_background_process( + return self.hs.run_as_background_process( "update_recently_accessed_media", - self.server_name, self._update_recently_accessed, ) def _start_apply_media_retention_rules(self) -> Deferred: - return run_as_background_process( + return self.hs.run_as_background_process( "apply_media_retention_rules", - self.server_name, self._apply_media_retention_rules, ) diff --git a/synapse/media/url_previewer.py b/synapse/media/url_previewer.py index 81204913f76..1a82cc46e3e 100644 --- a/synapse/media/url_previewer.py +++ b/synapse/media/url_previewer.py @@ -44,7 +44,6 @@ from synapse.media.media_storage import MediaStorage, SHA256TransparentIOWriter from synapse.media.oembed import OEmbedProvider from synapse.media.preview_html import decode_body, parse_html_to_open_graph -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import JsonDict, UserID from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.expiringcache import ExpiringCache @@ -167,6 +166,7 @@ def __init__( media_storage: MediaStorage, ): self.clock = hs.get_clock() + self.hs = hs self.filepaths = media_repo.filepaths self.max_spider_size = hs.config.media.max_spider_size self.server_name = hs.hostname @@ -201,15 +201,14 @@ def __init__( self._cache: ExpiringCache[str, ObservableDeferred] = ExpiringCache( cache_name="url_previews", server_name=self.server_name, + hs=self.hs, clock=self.clock, # don't spider URLs more often than once an hour expiry_ms=ONE_HOUR, ) if self._worker_run_media_background_jobs: - self._cleaner_loop = self.clock.looping_call( - self._start_expire_url_cache_data, 10 * 1000 - ) + self.clock.looping_call(self._start_expire_url_cache_data, 10 * 1000) async def preview(self, url: str, user: UserID, ts: int) -> bytes: # the in-memory cache: @@ -739,8 +738,8 @@ async def _handle_oembed_response( return open_graph_result, oembed_response.author_name, expiration_ms def _start_expire_url_cache_data(self) -> Deferred: - return run_as_background_process( - "expire_url_cache_data", self.server_name, self._expire_url_cache_data + return self.hs.run_as_background_process( + "expire_url_cache_data", self._expire_url_cache_data ) async def _expire_url_cache_data(self) -> None: diff --git a/synapse/metrics/_gc.py b/synapse/metrics/_gc.py index e7783b05e6d..1da871f18ff 100644 --- a/synapse/metrics/_gc.py +++ b/synapse/metrics/_gc.py @@ -138,7 +138,9 @@ def _maybe_gc() -> None: gc_time.labels(i).observe(end - start) gc_unreachable.labels(i).set(unreachable) - gc_task = task.LoopingCall(_maybe_gc) + # We can ignore the lint here since this looping call does not hold a `HomeServer` + # reference so can be cleaned up by other means on shutdown. + gc_task = task.LoopingCall(_maybe_gc) # type: ignore[prefer-synapse-clock-looping-call] gc_task.start(0.1) diff --git a/synapse/metrics/background_process_metrics.py b/synapse/metrics/background_process_metrics.py index 93345b0e9d4..6dc2cbe1322 100644 --- a/synapse/metrics/background_process_metrics.py +++ b/synapse/metrics/background_process_metrics.py @@ -66,6 +66,8 @@ # Old versions don't have `LiteralString` from typing_extensions import LiteralString + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -397,11 +399,11 @@ def combined_context_manager() -> Generator[None, None, None]: P = ParamSpec("P") -class HasServerName(Protocol): - server_name: str +class HasHomeServer(Protocol): + hs: "HomeServer" """ - The homeserver name that this cache is associated with (used to label the metric) - (`hs.hostname`). + The homeserver that this cache is associated with (used to label the metric and + track backgroun processes for clean shutdown). """ @@ -431,27 +433,22 @@ def func(*args): ... """ def wrapper( - func: Callable[Concatenate[HasServerName, P], Awaitable[Optional[R]]], + func: Callable[Concatenate[HasHomeServer, P], Awaitable[Optional[R]]], ) -> Callable[P, "defer.Deferred[Optional[R]]"]: @wraps(func) def wrapped_func( - self: HasServerName, *args: P.args, **kwargs: P.kwargs + self: HasHomeServer, *args: P.args, **kwargs: P.kwargs ) -> "defer.Deferred[Optional[R]]": - assert self.server_name is not None, ( - "The `server_name` attribute must be set on the object where `@wrap_as_background_process` decorator is used." + assert self.hs is not None, ( + "The `hs` attribute must be set on the object where `@wrap_as_background_process` decorator is used." ) - return run_as_background_process( + return self.hs.run_as_background_process( desc, - self.server_name, func, self, *args, - # type-ignore: mypy is confusing kwargs with the bg_start_span kwarg. - # Argument 4 to "run_as_background_process" has incompatible type - # "**P.kwargs"; expected "bool" - # See https://github.com/python/mypy/issues/8862 - **kwargs, # type: ignore[arg-type] + **kwargs, ) # There are some shenanigans here, because we're decorating a method but diff --git a/synapse/metrics/common_usage_metrics.py b/synapse/metrics/common_usage_metrics.py index cd1c3c86499..43e0913d279 100644 --- a/synapse/metrics/common_usage_metrics.py +++ b/synapse/metrics/common_usage_metrics.py @@ -23,7 +23,6 @@ import attr from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process if TYPE_CHECKING: from synapse.server import HomeServer @@ -52,6 +51,7 @@ def __init__(self, hs: "HomeServer") -> None: self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() + self._hs = hs async def get_metrics(self) -> CommonUsageMetrics: """Get the CommonUsageMetrics object. If no collection has happened yet, do it @@ -64,16 +64,14 @@ async def get_metrics(self) -> CommonUsageMetrics: async def setup(self) -> None: """Keep the gauges for common usage metrics up to date.""" - run_as_background_process( + self._hs.run_as_background_process( desc="common_usage_metrics_update_gauges", - server_name=self.server_name, func=self._update_gauges, ) self._clock.looping_call( - run_as_background_process, + self._hs.run_as_background_process, 5 * 60 * 1000, desc="common_usage_metrics_update_gauges", - server_name=self.server_name, func=self._update_gauges, ) diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 7a419145e0f..12a31dd2abb 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -275,7 +275,15 @@ def run_as_background_process( # function instead. stub_server_name = "synapse_module_running_from_unknown_server" - return _run_as_background_process( + # Ignore the linter error here. Since this is leveraging the + # `run_as_background_process` function directly and we don't want to break the + # module api, we need to keep the function signature the same. This means we don't + # have access to the running `HomeServer` and cannot track this background process + # for cleanup during shutdown. + # This is not an issue during runtime and is only potentially problematic if the + # application cares about being able to garbage collect `HomeServer` instances + # during runtime. + return _run_as_background_process( # type: ignore[untracked-background-process] desc, stub_server_name, func, @@ -1402,7 +1410,7 @@ def looping_background_call( if self._hs.config.worker.run_background_tasks or run_on_all_instances: self._clock.looping_call( - self.run_as_background_process, + self._hs.run_as_background_process, msec, desc, lambda: maybe_awaitable(f(*args, **kwargs)), @@ -1460,7 +1468,7 @@ def delayed_background_call( return self._clock.call_later( # convert ms to seconds as needed by call_later. msec * 0.001, - self.run_as_background_process, + self._hs.run_as_background_process, desc, lambda: maybe_awaitable(f(*args, **kwargs)), ) @@ -1701,8 +1709,8 @@ def run_as_background_process( Note that the returned Deferred does not follow the synapse logcontext rules. """ - return _run_as_background_process( - desc, self.server_name, func, *args, bg_start_span=bg_start_span, **kwargs + return self._hs.run_as_background_process( + desc, func, *args, bg_start_span=bg_start_span, **kwargs ) async def defer_to_thread( diff --git a/synapse/notifier.py b/synapse/notifier.py index e684df4866b..9169f50c4dd 100644 --- a/synapse/notifier.py +++ b/synapse/notifier.py @@ -676,9 +676,16 @@ async def wait_for_events( # is a new token. listener = user_stream.new_listener(prev_token) listener = timeout_deferred( - listener, - (end_time - now) / 1000.0, - self.hs.get_reactor(), + deferred=listener, + timeout=(end_time - now) / 1000.0, + # We don't track these calls since they are constantly being + # overridden by new calls to /sync and they don't hold the + # `HomeServer` in memory on shutdown. It is safe to let them + # timeout of their own accord after shutting down since it + # won't delay shutdown and there won't be any adverse + # behaviour. + cancel_on_shutdown=False, + clock=self.hs.get_clock(), ) log_kv( diff --git a/synapse/push/emailpusher.py b/synapse/push/emailpusher.py index 09ca14584a2..1484bc8fc01 100644 --- a/synapse/push/emailpusher.py +++ b/synapse/push/emailpusher.py @@ -25,7 +25,6 @@ from twisted.internet.error import AlreadyCalled, AlreadyCancelled from twisted.internet.interfaces import IDelayedCall -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException, ThrottleParams from synapse.push.mailer import Mailer from synapse.push.push_types import EmailReason @@ -118,7 +117,7 @@ def _start_processing(self) -> None: if self._is_processing: return - run_as_background_process("emailpush.process", self.server_name, self._process) + self.hs.run_as_background_process("emailpush.process", self._process) def _pause_processing(self) -> None: """Used by tests to temporarily pause processing of events. @@ -228,8 +227,10 @@ async def _unsafe_process(self) -> None: self.timed_call = None if soonest_due_at is not None: - self.timed_call = self.hs.get_reactor().callLater( - self.seconds_until(soonest_due_at), self.on_timer + delay = self.seconds_until(soonest_due_at) + self.timed_call = self.hs.get_clock().call_later( + delay, + self.on_timer, ) async def save_last_stream_ordering_and_success( diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py index 5946a6e9724..5cac5de8cb4 100644 --- a/synapse/push/httppusher.py +++ b/synapse/push/httppusher.py @@ -32,7 +32,6 @@ from synapse.events import EventBase from synapse.logging import opentracing from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.push import Pusher, PusherConfig, PusherConfigException from synapse.storage.databases.main.event_push_actions import HttpPushAction from synapse.types import JsonDict, JsonMapping @@ -182,8 +181,8 @@ def on_new_receipts(self) -> None: # We could check the receipts are actually m.read receipts here, # but currently that's the only type of receipt anyway... - run_as_background_process( - "http_pusher.on_new_receipts", self.server_name, self._update_badge + self.hs.run_as_background_process( + "http_pusher.on_new_receipts", self._update_badge ) async def _update_badge(self) -> None: @@ -219,7 +218,7 @@ def _start_processing(self) -> None: if self.failing_since and self.timed_call and self.timed_call.active(): return - run_as_background_process("httppush.process", self.server_name, self._process) + self.hs.run_as_background_process("httppush.process", self._process) async def _process(self) -> None: # we should never get here if we are already processing @@ -336,8 +335,9 @@ async def _unsafe_process(self) -> None: ) else: logger.info("Push failed: delaying for %ds", self.backoff_delay) - self.timed_call = self.hs.get_reactor().callLater( - self.backoff_delay, self.on_timer + self.timed_call = self.hs.get_clock().call_later( + self.backoff_delay, + self.on_timer, ) self.backoff_delay = min( self.backoff_delay * 2, self.MAX_BACKOFF_SEC diff --git a/synapse/push/pusherpool.py b/synapse/push/pusherpool.py index d1f79ec9995..977c55b6836 100644 --- a/synapse/push/pusherpool.py +++ b/synapse/push/pusherpool.py @@ -27,7 +27,6 @@ from synapse.api.errors import Codes, SynapseError from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.push import Pusher, PusherConfig, PusherConfigException @@ -70,10 +69,8 @@ class PusherPool: """ def __init__(self, hs: "HomeServer"): - self.hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.pusher_factory = PusherFactory(hs) self.store = self.hs.get_datastores().main self.clock = self.hs.get_clock() @@ -112,9 +109,7 @@ def start(self) -> None: if not self._should_start_pushers: logger.info("Not starting pushers because they are disabled in the config") return - run_as_background_process( - "start_pushers", self.server_name, self._start_pushers - ) + self.hs.run_as_background_process("start_pushers", self._start_pushers) async def add_or_update_pusher( self, diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index d96f5541f19..f2561bc0c52 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -32,7 +32,6 @@ from synapse.federation import send_queue from synapse.federation.sender import FederationSender from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.streams import ( AccountDataStream, DeviceListsStream, @@ -344,7 +343,9 @@ async def wait_for_stream_position( # to wedge here forever. deferred: "Deferred[None]" = Deferred() deferred = timeout_deferred( - deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor + deferred=deferred, + timeout=_WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, + clock=self._clock, ) waiting_list = self._streams_to_waiters.setdefault( @@ -513,8 +514,8 @@ async def update_token(self, token: int) -> None: # no need to queue up another task. return - run_as_background_process( - "_save_and_send_ack", self.server_name, self._save_and_send_ack + self._hs.run_as_background_process( + "_save_and_send_ack", self._save_and_send_ack ) async def _save_and_send_ack(self) -> None: diff --git a/synapse/replication/tcp/handler.py b/synapse/replication/tcp/handler.py index dd7e38dd781..4d0d3d44abc 100644 --- a/synapse/replication/tcp/handler.py +++ b/synapse/replication/tcp/handler.py @@ -41,7 +41,6 @@ from twisted.internet.protocol import ReconnectingClientFactory from synapse.metrics import SERVER_NAME_LABEL, LaterGauge -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import ( ClearUserSyncsCommand, Command, @@ -132,6 +131,7 @@ class ReplicationCommandHandler: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self._replication_data_handler = hs.get_replication_data_handler() self._presence_handler = hs.get_presence_handler() self._store = hs.get_datastores().main @@ -361,9 +361,8 @@ def _add_command_to_stream_queue( return # fire off a background process to start processing the queue. - run_as_background_process( + self.hs.run_as_background_process( "process-replication-data", - self.server_name, self._unsafe_process_queue, stream_name, ) diff --git a/synapse/replication/tcp/protocol.py b/synapse/replication/tcp/protocol.py index 25a7868cd7b..bcfc65c2c0d 100644 --- a/synapse/replication/tcp/protocol.py +++ b/synapse/replication/tcp/protocol.py @@ -42,7 +42,6 @@ from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, - run_as_background_process, ) from synapse.replication.tcp.commands import ( VALID_CLIENT_COMMANDS, @@ -140,9 +139,14 @@ class BaseReplicationStreamProtocol(LineOnlyReceiver): max_line_buffer = 10000 def __init__( - self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + self, + hs: "HomeServer", + server_name: str, + clock: Clock, + handler: "ReplicationCommandHandler", ): self.server_name = server_name + self.hs = hs self.clock = clock self.command_handler = handler @@ -290,9 +294,8 @@ def handle_command(self, cmd: Command) -> None: # if so. if isawaitable(res): - run_as_background_process( + self.hs.run_as_background_process( "replication-" + cmd.get_logcontext_id(), - self.server_name, lambda: res, ) @@ -470,9 +473,13 @@ class ServerReplicationStreamProtocol(BaseReplicationStreamProtocol): VALID_OUTBOUND_COMMANDS = VALID_SERVER_COMMANDS def __init__( - self, server_name: str, clock: Clock, handler: "ReplicationCommandHandler" + self, + hs: "HomeServer", + server_name: str, + clock: Clock, + handler: "ReplicationCommandHandler", ): - super().__init__(server_name, clock, handler) + super().__init__(hs, server_name, clock, handler) self.server_name = server_name @@ -497,7 +504,7 @@ def __init__( clock: Clock, command_handler: "ReplicationCommandHandler", ): - super().__init__(server_name, clock, command_handler) + super().__init__(hs, server_name, clock, command_handler) self.client_name = client_name self.server_name = server_name diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py index 0b1be033b11..caffb2913ea 100644 --- a/synapse/replication/tcp/redis.py +++ b/synapse/replication/tcp/redis.py @@ -40,7 +40,6 @@ from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( BackgroundProcessLoggingContext, - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.commands import ( @@ -109,6 +108,7 @@ class RedisSubscriber(SubscriberProtocol): """ server_name: str + hs: "HomeServer" synapse_handler: "ReplicationCommandHandler" synapse_stream_prefix: str synapse_channel_names: List[str] @@ -146,9 +146,7 @@ def _get_logging_context(self) -> BackgroundProcessLoggingContext: def connectionMade(self) -> None: logger.info("Connected to redis") super().connectionMade() - run_as_background_process( - "subscribe-replication", self.server_name, self._send_subscribe - ) + self.hs.run_as_background_process("subscribe-replication", self._send_subscribe) async def _send_subscribe(self) -> None: # it's important to make sure that we only send the REPLICATE command once we @@ -223,8 +221,8 @@ def handle_command(self, cmd: Command) -> None: # if so. if isawaitable(res): - run_as_background_process( - "replication-" + cmd.get_logcontext_id(), self.server_name, lambda: res + self.hs.run_as_background_process( + "replication-" + cmd.get_logcontext_id(), lambda: res ) def connectionLost(self, reason: Failure) -> None: # type: ignore[override] @@ -245,9 +243,8 @@ def send_command(self, cmd: Command) -> None: Args: cmd: The command to send """ - run_as_background_process( + self.hs.run_as_background_process( "send-cmd", - self.server_name, self._async_send_command, cmd, # We originally started tracing background processes to avoid `There was no @@ -317,9 +314,8 @@ def __init__( convertNumbers=convertNumbers, ) - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname hs.get_clock().looping_call(self._send_ping, 30 * 1000) @@ -397,6 +393,7 @@ def __init__( ) self.server_name = hs.hostname + self.hs = hs self.synapse_handler = hs.get_replication_command_handler() self.synapse_stream_prefix = hs.hostname self.synapse_channel_names = channel_names @@ -412,6 +409,7 @@ def buildProtocol(self, addr: IAddress) -> RedisSubscriber: # the base method does some other things than just instantiating the # protocol. p.server_name = self.server_name + p.hs = self.hs p.synapse_handler = self.synapse_handler p.synapse_outbound_redis_connection = self.synapse_outbound_redis_connection p.synapse_stream_prefix = self.synapse_stream_prefix diff --git a/synapse/replication/tcp/resource.py b/synapse/replication/tcp/resource.py index d800cfe6f60..ef72a0a5325 100644 --- a/synapse/replication/tcp/resource.py +++ b/synapse/replication/tcp/resource.py @@ -30,7 +30,6 @@ from twisted.internet.protocol import ServerFactory from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.replication.tcp.commands import PositionCommand from synapse.replication.tcp.protocol import ServerReplicationStreamProtocol from synapse.replication.tcp.streams import EventsStream @@ -55,6 +54,7 @@ class ReplicationStreamProtocolFactory(ServerFactory): def __init__(self, hs: "HomeServer"): self.command_handler = hs.get_replication_command_handler() self.clock = hs.get_clock() + self.hs = hs self.server_name = hs.config.server.server_name # If we've created a `ReplicationStreamProtocolFactory` then we're @@ -69,7 +69,7 @@ def __init__(self, hs: "HomeServer"): def buildProtocol(self, addr: IAddress) -> ServerReplicationStreamProtocol: return ServerReplicationStreamProtocol( - self.server_name, self.clock, self.command_handler + self.hs, self.server_name, self.clock, self.command_handler ) @@ -82,6 +82,7 @@ class ReplicationStreamer: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname + self.hs = hs self.store = hs.get_datastores().main self.clock = hs.get_clock() self.notifier = hs.get_notifier() @@ -147,8 +148,8 @@ def on_notifier_poke(self) -> None: logger.debug("Notifier poke loop already running") return - run_as_background_process( - "replication_notifier", self.server_name, self._run_notifier_loop + self.hs.run_as_background_process( + "replication_notifier", self._run_notifier_loop ) async def _run_notifier_loop(self) -> None: diff --git a/synapse/replication/tcp/streams/__init__.py b/synapse/replication/tcp/streams/__init__.py index 25c15e5d486..87ac0a5ae17 100644 --- a/synapse/replication/tcp/streams/__init__.py +++ b/synapse/replication/tcp/streams/__init__.py @@ -77,6 +77,7 @@ __all__ = [ "STREAMS_MAP", "Stream", + "EventsStream", "BackfillStream", "PresenceStream", "PresenceFederationStream", @@ -87,6 +88,7 @@ "CachesStream", "DeviceListsStream", "ToDeviceStream", + "FederationStream", "AccountDataStream", "ThreadSubscriptionsStream", "UnPartialStatedRoomStream", diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 64deae76507..1084139df0f 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -66,7 +66,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client._base import client_patterns from synapse.rest.client.transactions import HttpTransactionCache from synapse.state import CREATE_KEY, POWER_KEY @@ -1225,6 +1224,7 @@ class RoomRedactEventRestServlet(TransactionRestServlet): def __init__(self, hs: "HomeServer"): super().__init__(hs) self.server_name = hs.hostname + self.hs = hs self.event_creation_handler = hs.get_event_creation_handler() self.auth = hs.get_auth() self._store = hs.get_datastores().main @@ -1307,9 +1307,8 @@ async def _do( ) if with_relations: - run_as_background_process( + self.hs.run_as_background_process( "redact_related_events", - self.server_name, self._relation_handler.redact_events_related_to, requester=requester, event_id=event_id, diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index bb63b51599b..0f3cc84dcc9 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -126,6 +126,7 @@ def __init__(self, hs: "HomeServer"): self._json_filter_cache: LruCache[str, bool] = LruCache( max_size=1000, + clock=self.clock, cache_name="sync_valid_filter", server_name=self.server_name, ) diff --git a/synapse/rest/client/transactions.py b/synapse/rest/client/transactions.py index 1a57996aecf..571ba2fa623 100644 --- a/synapse/rest/client/transactions.py +++ b/synapse/rest/client/transactions.py @@ -56,7 +56,7 @@ def __init__(self, hs: "HomeServer"): ] = {} # Try to clean entries every 30 mins. This means entries will exist # for at *LEAST* 30 mins, and at *MOST* 60 mins. - self.cleaner = self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) + self.clock.looping_call(self._cleanup, CLEANUP_PERIOD_MS) def _get_transaction_key(self, request: IRequest, requester: Requester) -> Hashable: """A helper function which returns a transaction key that can be used diff --git a/synapse/server.py b/synapse/server.py index edcab19d727..cc0d3a427b4 100644 --- a/synapse/server.py +++ b/synapse/server.py @@ -28,10 +28,27 @@ import abc import functools import logging -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Type, TypeVar, cast +from threading import Thread +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + List, + Optional, + Tuple, + Type, + TypeVar, + cast, +) +from wsgiref.simple_server import WSGIServer +from attr import dataclass from typing_extensions import TypeAlias +from twisted.internet import defer +from twisted.internet.base import _SystemEventID from twisted.internet.interfaces import IOpenSSLContextFactory from twisted.internet.tcp import Port from twisted.python.threadpool import ThreadPool @@ -44,6 +61,7 @@ from synapse.api.auth_blocking import AuthBlocking from synapse.api.filtering import Filtering from synapse.api.ratelimiting import Ratelimiter, RequestRatelimiter +from synapse.app._base import unregister_sighups from synapse.appservice.api import ApplicationServiceApi from synapse.appservice.scheduler import ApplicationServiceScheduler from synapse.config.homeserver import HomeServerConfig @@ -133,6 +151,7 @@ all_later_gauges_to_clean_up_on_shutdown, register_threadpool, ) +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager from synapse.module_api import ModuleApi from synapse.module_api.callbacks import ModuleApiCallbacks @@ -156,6 +175,7 @@ from synapse.streams.events import EventSources from synapse.synapse_rust.rendezvous import RendezvousHandler from synapse.types import DomainSpecificString, ISynapseReactor +from synapse.util.caches import CACHE_METRIC_REGISTRY from synapse.util.clock import Clock from synapse.util.distributor import Distributor from synapse.util.macaroons import MacaroonGenerator @@ -166,7 +186,9 @@ logger = logging.getLogger(__name__) if TYPE_CHECKING: + # Old Python versions don't have `LiteralString` from txredisapi import ConnectionHandler + from typing_extensions import LiteralString from synapse.handlers.jwt import JwtHandler from synapse.handlers.oidc import OidcHandler @@ -196,6 +218,7 @@ T: TypeAlias = object F = TypeVar("F", bound=Callable[["HomeServer"], T]) +R = TypeVar("R") def cache_in_self(builder: F) -> F: @@ -219,7 +242,8 @@ def cache_in_self(builder: F) -> F: @functools.wraps(builder) def _get(self: "HomeServer") -> T: try: - return getattr(self, depname) + dep = getattr(self, depname) + return dep except AttributeError: pass @@ -239,6 +263,22 @@ def _get(self: "HomeServer") -> T: return cast(F, _get) +@dataclass +class ShutdownInfo: + """Information for callable functions called at time of shutdown. + + Attributes: + func: the object to call before shutdown. + trigger_id: an ID returned when registering this event trigger. + args: the arguments to call the function with. + kwargs: the keyword arguments to call the function with. + """ + + func: Callable[..., Any] + trigger_id: _SystemEventID + kwargs: Dict[str, object] + + class HomeServer(metaclass=abc.ABCMeta): """A basic homeserver object without lazy component builders. @@ -289,6 +329,7 @@ def __init__( hostname : The hostname for the server. config: The full config for the homeserver. """ + if not reactor: from twisted.internet import reactor as _reactor @@ -300,6 +341,7 @@ def __init__( self.signing_key = config.key.signing_key[0] self.config = config self._listening_services: List[Port] = [] + self._metrics_listeners: List[Tuple[WSGIServer, Thread]] = [] self.start_time: Optional[int] = None self._instance_id = random_string(5) @@ -315,6 +357,211 @@ def __init__( # This attribute is set by the free function `refresh_certificate`. self.tls_server_context_factory: Optional[IOpenSSLContextFactory] = None + self._is_shutdown = False + self._async_shutdown_handlers: List[ShutdownInfo] = [] + self._sync_shutdown_handlers: List[ShutdownInfo] = [] + self._background_processes: set[defer.Deferred[Optional[Any]]] = set() + + def run_as_background_process( + self, + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[Optional[R]]": + """Run the given function in its own logcontext, with resource metrics + + This should be used to wrap processes which are fired off to run in the + background, instead of being associated with a particular request. + + It returns a Deferred which completes when the function completes, but it doesn't + follow the synapse logcontext rules, which makes it appropriate for passing to + clock.looping_call and friends (or for firing-and-forgetting in the middle of a + normal synapse async function). + + Because the returned Deferred does not follow the synapse logcontext rules, awaiting + the result of this function will result in the log context being cleared (bad). In + order to properly await the result of this function and maintain the current log + context, use `make_deferred_yieldable`. + + Args: + desc: a description for this background process type + server_name: The homeserver name that this background process is being run for + (this should be `hs.hostname`). + func: a function, which may return a Deferred or a coroutine + bg_start_span: Whether to start an opentracing span. Defaults to True. + Should only be disabled for processes that will not log to or tag + a span. + args: positional args for func + kwargs: keyword args for func + + Returns: + Deferred which returns the result of func, or `None` if func raises. + Note that the returned Deferred does not follow the synapse logcontext + rules. + """ + if self._is_shutdown: + raise Exception( + f"Cannot start background process. HomeServer has been shutdown {len(self._background_processes)} {len(self.get_clock()._looping_calls)} {len(self.get_clock()._call_id_to_delayed_call)}" + ) + + # Ignore linter error as this is the one location this should be called. + deferred = run_as_background_process(desc, self.hostname, func, *args, **kwargs) # type: ignore[untracked-background-process] + self._background_processes.add(deferred) + + def on_done(res: R) -> R: + try: + self._background_processes.remove(deferred) + except KeyError: + # If the background process isn't being tracked anymore we can just move on. + pass + return res + + deferred.addBoth(on_done) + return deferred + + async def shutdown(self) -> None: + """ + Cleanly stops all aspects of the HomeServer and removes any references that + have been handed out in order to allow the HomeServer object to be garbage + collected. + + You must ensure the HomeServer object to not be frozen in the garbage collector + in order for it to be cleaned up. By default, Synapse freezes the HomeServer + object in the garbage collector. + """ + + self._is_shutdown = True + + logger.info( + "Received shutdown request for %s (%s).", + self.hostname, + self.get_instance_id(), + ) + + # Unregister sighups first. If a shutdown was requested we shouldn't be responding + # to things like config changes. So it would be best to stop listening to these first. + unregister_sighups(self._instance_id) + + # TODO: It would be desireable to be able to report an error if the HomeServer + # object is frozen in the garbage collector as that would prevent it from being + # collected after being shutdown. + # In theory the following should work, but it doesn't seem to make a difference + # when I test it locally. + # + # if gc.is_tracked(self): + # logger.error("HomeServer object is tracked by garbage collection so cannot be fully cleaned up") + + for listener in self._listening_services: + # During unit tests, an incomplete `twisted.pair.testing._FakePort` is used + # for listeners so check listener type here to ensure shutdown procedure is + # only applied to actual `Port` instances. + if type(listener) is Port: + port_shutdown = listener.stopListening() + if port_shutdown is not None: + await port_shutdown + self._listening_services.clear() + + for server, thread in self._metrics_listeners: + server.shutdown() + thread.join() + self._metrics_listeners.clear() + + # TODO: Cleanup replication pieces + + self.get_keyring().shutdown() + + # Cleanup metrics associated with the homeserver + for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): + later_gauge.unregister_hooks_for_homeserver_instance_id( + self.get_instance_id() + ) + + CACHE_METRIC_REGISTRY.unregister_hooks_for_homeserver( + self.config.server.server_name + ) + + for db in self.get_datastores().databases: + db.stop_background_updates() + + if self.should_send_federation(): + try: + self.get_federation_sender().shutdown() + except Exception: + pass + + for shutdown_handler in self._async_shutdown_handlers: + try: + self.get_reactor().removeSystemEventTrigger(shutdown_handler.trigger_id) + defer.ensureDeferred(shutdown_handler.func(**shutdown_handler.kwargs)) + except Exception as e: + logger.error("Error calling shutdown async handler: %s", e) + self._async_shutdown_handlers.clear() + + for shutdown_handler in self._sync_shutdown_handlers: + try: + self.get_reactor().removeSystemEventTrigger(shutdown_handler.trigger_id) + shutdown_handler.func(**shutdown_handler.kwargs) + except Exception as e: + logger.error("Error calling shutdown sync handler: %s", e) + self._sync_shutdown_handlers.clear() + + self.get_clock().shutdown() + + for background_process in list(self._background_processes): + try: + background_process.cancel() + except Exception: + pass + self._background_processes.clear() + + for db in self.get_datastores().databases: + db._db_pool.close() + + def register_async_shutdown_handler( + self, + *, + phase: str, + eventType: str, + shutdown_func: Callable[..., Any], + **kwargs: object, + ) -> None: + """ + Register a system event trigger with the HomeServer so it can be cleanly + removed when the HomeServer is shutdown. + """ + id = self.get_clock().add_system_event_trigger( + phase, + eventType, + shutdown_func, + **kwargs, + ) + self._async_shutdown_handlers.append( + ShutdownInfo(func=shutdown_func, trigger_id=id, kwargs=kwargs) + ) + + def register_sync_shutdown_handler( + self, + *, + phase: str, + eventType: str, + shutdown_func: Callable[..., Any], + **kwargs: object, + ) -> None: + """ + Register a system event trigger with the HomeServer so it can be cleanly + removed when the HomeServer is shutdown. + """ + id = self.get_clock().add_system_event_trigger( + phase, + eventType, + shutdown_func, + **kwargs, + ) + self._sync_shutdown_handlers.append( + ShutdownInfo(func=shutdown_func, trigger_id=id, kwargs=kwargs) + ) + def register_module_web_resource(self, path: str, resource: Resource) -> None: """Allows a module to register a web resource to be served at the given path. @@ -366,36 +613,25 @@ def setup(self) -> None: self.datastores = Databases(self.DATASTORE_CLASS, self) logger.info("Finished setting up.") - def __del__(self) -> None: - """ - Called when an the homeserver is garbage collected. - - Make sure we actually do some clean-up, rather than leak data. - """ - self.cleanup() - - def cleanup(self) -> None: - """ - WIP: Clean-up any references to the homeserver and stop any running related - processes, timers, loops, replication stream, etc. - - This should be called wherever you care about the HomeServer being completely - garbage collected like in tests. It's not necessary to call if you plan to just - shut down the whole Python process anyway. - - Can be called multiple times. - """ - logger.info("Received cleanup request for %s.", self.hostname) - - # TODO: Stop background processes, timers, loops, replication stream, etc. - - # Cleanup metrics associated with the homeserver - for later_gauge in all_later_gauges_to_clean_up_on_shutdown.values(): - later_gauge.unregister_hooks_for_homeserver_instance_id( - self.get_instance_id() - ) - - logger.info("Cleanup complete for %s.", self.hostname) + # Register background tasks required by this server. This must be done + # somewhat manually due to the background tasks not being registered + # unless handlers are instantiated. + if self.config.worker.run_background_tasks: + self.start_background_tasks() + + # def __del__(self) -> None: + # """ + # Called when an the homeserver is garbage collected. + # + # Make sure we actually do some clean-up, rather than leak data. + # """ + # + # # NOTE: This is a chicken and egg problem. + # # __del__ will never be called since the HomeServer cannot be garbage collected + # # until the shutdown function has been called. So it makes no sense to call + # # shutdown inside of __del__, even though that is a logical place to assume it + # # should be called. + # self.shutdown() def start_listening(self) -> None: # noqa: B027 (no-op by design) """Start the HTTP, manhole, metrics, etc listeners @@ -442,7 +678,8 @@ def is_mine_server_name(self, server_name: str) -> bool: @cache_in_self def get_clock(self) -> Clock: - return Clock(self._reactor, server_name=self.hostname) + # Ignore the linter error since this is the one place the `Clock` should be created. + return Clock(self._reactor, server_name=self.hostname) # type: ignore[multiple-internal-clocks] def get_datastores(self) -> Databases: if not self.datastores: @@ -452,7 +689,7 @@ def get_datastores(self) -> Databases: @cache_in_self def get_distributor(self) -> Distributor: - return Distributor(server_name=self.hostname) + return Distributor(hs=self) @cache_in_self def get_registration_ratelimiter(self) -> Ratelimiter: @@ -1007,8 +1244,10 @@ def get_media_sender_thread_pool(self) -> ThreadPool: ) media_threadpool.start() - self.get_clock().add_system_event_trigger( - "during", "shutdown", media_threadpool.stop + self.register_sync_shutdown_handler( + phase="during", + eventType="shutdown", + shutdown_func=media_threadpool.stop, ) # Register the threadpool with our metrics. diff --git a/synapse/server_notices/server_notices_manager.py b/synapse/server_notices/server_notices_manager.py index 19f86b5a563..73cf4091eb4 100644 --- a/synapse/server_notices/server_notices_manager.py +++ b/synapse/server_notices/server_notices_manager.py @@ -36,6 +36,7 @@ class ServerNoticesManager: def __init__(self, hs: "HomeServer"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() # nb must be called this for @cached self._store = hs.get_datastores().main self._config = hs.config self._account_data_handler = hs.get_account_data_handler() diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index dd8d7135ba3..394dc72fa69 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -651,6 +651,7 @@ def __init__(self, hs: "HomeServer"): ExpiringCache( cache_name="state_cache", server_name=self.server_name, + hs=hs, clock=self.clock, max_len=100000, expiry_ms=EVICTION_TIMEOUT_SECONDS * 1000, diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index f214f558978..1fddcc0799a 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -56,7 +56,7 @@ def __init__( ): self.hs = hs self.server_name = hs.hostname # nb must be called this for @cached - self._clock = hs.get_clock() + self.clock = hs.get_clock() # nb must be called this for @cached self.database_engine = database.engine self.db_pool = database diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py index 9aa9e51aeb6..e3e793d5f59 100644 --- a/synapse/storage/background_updates.py +++ b/synapse/storage/background_updates.py @@ -41,7 +41,6 @@ import attr from synapse._pydantic_compat import BaseModel -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.engines import PostgresEngine from synapse.storage.types import Connection, Cursor from synapse.types import JsonDict, StrCollection @@ -285,6 +284,13 @@ def __init__(self, hs: "HomeServer", database: "DatabasePool"): self.sleep_duration_ms = hs.config.background_updates.sleep_duration_ms self.sleep_enabled = hs.config.background_updates.sleep_enabled + def shutdown(self) -> None: + """ + Stop any further background updates from happening. + """ + self.enabled = False + self._background_update_handlers.clear() + def get_status(self) -> UpdaterStatus: """An integer summarising the updater status. Used as a metric.""" if self._aborted: @@ -396,9 +402,8 @@ def start_doing_background_updates(self) -> None: # if we start a new background update, not all updates are done. self._all_done = False sleep = self.sleep_enabled - run_as_background_process( + self.hs.run_as_background_process( "background_updates", - self.server_name, self.run_background_updates, sleep, ) diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py index 120934af578..646e2cf1151 100644 --- a/synapse/storage/controllers/persist_events.py +++ b/synapse/storage/controllers/persist_events.py @@ -62,7 +62,6 @@ trace, ) from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.controllers.state import StateStorageController from synapse.storage.databases import Databases from synapse.storage.databases.main.events import DeltaState @@ -195,6 +194,7 @@ class _EventPeristenceQueue(Generic[_PersistResult]): def __init__( self, + hs: "HomeServer", server_name: str, per_item_callback: Callable[ [str, _EventPersistQueueTask], @@ -207,6 +207,7 @@ def __init__( and its result will be returned via the Deferreds returned from add_to_queue. """ self.server_name = server_name + self.hs = hs self._event_persist_queues: Dict[str, Deque[_EventPersistQueueItem]] = {} self._currently_persisting_rooms: Set[str] = set() self._per_item_callback = per_item_callback @@ -311,7 +312,7 @@ async def handle_queue_loop() -> None: self._currently_persisting_rooms.discard(room_id) # set handle_queue_loop off in the background - run_as_background_process("persist_events", self.server_name, handle_queue_loop) + self.hs.run_as_background_process("persist_events", handle_queue_loop) def _get_drainining_queue( self, room_id: str @@ -354,7 +355,7 @@ def __init__( self._instance_name = hs.get_instance_name() self.is_mine_id = hs.is_mine_id self._event_persist_queue = _EventPeristenceQueue( - self.server_name, self._process_event_persist_queue_task + hs, self.server_name, self._process_event_persist_queue_task ) self._state_resolution_handler = hs.get_state_resolution_handler() self._state_controller = state_controller diff --git a/synapse/storage/controllers/purge_events.py b/synapse/storage/controllers/purge_events.py index 14b37ac543b..ded9cb0567e 100644 --- a/synapse/storage/controllers/purge_events.py +++ b/synapse/storage/controllers/purge_events.py @@ -46,9 +46,8 @@ class PurgeEventsStorageController: """High level interface for purging rooms and event history.""" def __init__(self, hs: "HomeServer", stores: Databases): - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self.stores = stores if hs.config.worker.run_background_tasks: diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py index 66f3289d867..76978402b94 100644 --- a/synapse/storage/controllers/state.py +++ b/synapse/storage/controllers/state.py @@ -69,8 +69,8 @@ class StateStorageController: def __init__(self, hs: "HomeServer", stores: "Databases"): self.server_name = hs.hostname # nb must be called this for @cached + self.clock = hs.get_clock() self._is_mine_id = hs.is_mine_id - self._clock = hs.get_clock() self.stores = stores self._partial_state_events_tracker = PartialStateEventsTracker(stores.main) self._partial_state_room_tracker = PartialCurrentStateTracker(stores.main) @@ -78,7 +78,7 @@ def __init__(self, hs: "HomeServer", stores: "Databases"): # Used by `_get_joined_hosts` to ensure only one thing mutates the cache # at a time. Keyed by room_id. self._joined_host_linearizer = Linearizer( - name="_JoinedHostsCache", clock=self._clock + name="_JoinedHostsCache", clock=self.clock ) def notify_event_un_partial_stated(self, event_id: str) -> None: @@ -817,9 +817,7 @@ async def get_joined_hosts( state_group = object() assert state_group is not None - with Measure( - self._clock, name="get_joined_hosts", server_name=self.server_name - ): + with Measure(self.clock, name="get_joined_hosts", server_name=self.server_name): return await self._get_joined_hosts( room_id, state_group, state_entry=state_entry ) diff --git a/synapse/storage/database.py b/synapse/storage/database.py index 249a0a933c9..a4b2b26795c 100644 --- a/synapse/storage/database.py +++ b/synapse/storage/database.py @@ -62,7 +62,6 @@ make_deferred_yieldable, ) from synapse.metrics import SERVER_NAME_LABEL, register_threadpool -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage.background_updates import BackgroundUpdater from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.types import Connection, Cursor, SQLQueryParameters @@ -638,12 +637,17 @@ def __init__( # background updates of tables that aren't safe to update. self._clock.call_later( 0.0, - run_as_background_process, + self.hs.run_as_background_process, "upsert_safety_check", - self.server_name, self._check_safe_to_upsert, ) + def stop_background_updates(self) -> None: + """ + Stops the database from running any further background updates. + """ + self.updates.shutdown() + def name(self) -> str: "Return the name of this database" return self._database_config.name @@ -681,9 +685,8 @@ async def _check_safe_to_upsert(self) -> None: if background_update_names: self._clock.call_later( 15.0, - run_as_background_process, + self.hs.run_as_background_process, "upsert_safety_check", - self.server_name, self._check_safe_to_upsert, ) diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py index cad26fefa4b..674c6b921ee 100644 --- a/synapse/storage/databases/main/cache.py +++ b/synapse/storage/databases/main/cache.py @@ -751,7 +751,7 @@ def _send_invalidation_to_replication( "instance_name": self._instance_name, "cache_func": cache_name, "keys": keys, - "invalidation_ts": self._clock.time_msec(), + "invalidation_ts": self.clock.time_msec(), }, ) @@ -778,7 +778,7 @@ def _send_invalidation_to_replication_bulk( assert self._cache_id_gen is not None stream_ids = self._cache_id_gen.get_next_mult_txn(txn, len(key_tuples)) - ts = self._clock.time_msec() + ts = self.clock.time_msec() txn.call_after(self.hs.get_notifier().on_new_replication_data) self.db_pool.simple_insert_many_txn( txn, @@ -830,7 +830,8 @@ async def _clean_up_cache_invalidation_wrapper(self) -> None: next_interval = REGULAR_CLEANUP_INTERVAL_MS self.hs.get_clock().call_later( - next_interval / 1000, self._clean_up_cache_invalidation_wrapper + next_interval / 1000, + self._clean_up_cache_invalidation_wrapper, ) async def _clean_up_batch_of_old_cache_invalidations( diff --git a/synapse/storage/databases/main/censor_events.py b/synapse/storage/databases/main/censor_events.py index 3f9f482adda..45cfe97dba2 100644 --- a/synapse/storage/databases/main/censor_events.py +++ b/synapse/storage/databases/main/censor_events.py @@ -77,7 +77,7 @@ async def _censor_redactions(self) -> None: return before_ts = ( - self._clock.time_msec() - self.hs.config.server.redaction_retention_period + self.clock.time_msec() - self.hs.config.server.redaction_retention_period ) # We fetch all redactions that: diff --git a/synapse/storage/databases/main/client_ips.py b/synapse/storage/databases/main/client_ips.py index c7a330cc83d..dc6ab99a6c7 100644 --- a/synapse/storage/databases/main/client_ips.py +++ b/synapse/storage/databases/main/client_ips.py @@ -438,10 +438,11 @@ def __init__( cache_name="client_ip_last_seen", server_name=self.server_name, max_size=50000, + clock=hs.get_clock(), ) if hs.config.worker.run_background_tasks and self.user_ips_max_age: - self._clock.looping_call(self._prune_old_user_ips, 5 * 1000) + self.clock.looping_call(self._prune_old_user_ips, 5 * 1000) if self._update_on_this_worker: # This is the designated worker that can write to the client IP @@ -452,11 +453,11 @@ def __init__( Tuple[str, str, str], Tuple[str, Optional[str], int] ] = {} - self._client_ip_looper = self._clock.looping_call( - self._update_client_ips_batch, 5 * 1000 - ) - self.hs.get_clock().add_system_event_trigger( - "before", "shutdown", self._update_client_ips_batch + self.clock.looping_call(self._update_client_ips_batch, 5 * 1000) + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._update_client_ips_batch, ) @wrap_as_background_process("prune_old_user_ips") @@ -492,7 +493,7 @@ async def _prune_old_user_ips(self) -> None: ) """ - timestamp = self._clock.time_msec() - self.user_ips_max_age + timestamp = self.clock.time_msec() - self.user_ips_max_age def _prune_old_user_ips_txn(txn: LoggingTransaction) -> None: txn.execute(sql, (timestamp,)) @@ -628,7 +629,7 @@ async def insert_client_ip( return if not now: - now = int(self._clock.time_msec()) + now = int(self.clock.time_msec()) key = (user_id, access_token, ip) try: diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py index f6f3c94a0d0..a66e11f738c 100644 --- a/synapse/storage/databases/main/deviceinbox.py +++ b/synapse/storage/databases/main/deviceinbox.py @@ -96,7 +96,8 @@ def __init__( ] = ExpiringCache( cache_name="last_device_delete_cache", server_name=self.server_name, - clock=self._clock, + hs=hs, + clock=self.clock, max_len=10000, expiry_ms=30 * 60 * 1000, ) @@ -154,7 +155,7 @@ def __init__( ) if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( run_as_background_process, DEVICE_FEDERATION_INBOX_CLEANUP_INTERVAL_MS, "_delete_old_federation_inbox_rows", @@ -826,7 +827,7 @@ def add_messages_txn( ) async with self._to_device_msg_id_gen.get_next() as stream_id: - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_to_device_inbox", add_messages_txn, now_ms, stream_id ) @@ -881,7 +882,7 @@ def add_messages_txn( ) async with self._to_device_msg_id_gen.get_next() as stream_id: - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() await self.db_pool.runInteraction( "add_messages_from_remote_to_device_inbox", add_messages_txn, @@ -1002,7 +1003,7 @@ def _delete_old_federation_inbox_rows_txn(txn: LoggingTransaction) -> bool: # We delete at most 100 rows that are older than # DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS delete_before_ts = ( - self._clock.time_msec() - DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS + self.clock.time_msec() - DEVICE_FEDERATION_INBOX_CLEANUP_DELAY_MS ) sql = """ WITH to_delete AS ( @@ -1032,7 +1033,7 @@ def _delete_old_federation_inbox_rows_txn(txn: LoggingTransaction) -> bool: # We sleep a bit so that we don't hammer the database in a tight # loop first time we run this. - await self._clock.sleep(1) + await self.clock.sleep(1) async def get_devices_with_messages( self, user_id: str, device_ids: StrCollection diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index fc1e1c73f18..d4b9ce0ea0a 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -195,7 +195,7 @@ def __init__( ) if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self._prune_old_outbound_device_pokes, 60 * 60 * 1000 ) @@ -1390,7 +1390,7 @@ def _mark_remote_users_device_caches_as_stale_txn( table="device_lists_remote_resync", keyvalues={"user_id": user_id}, values={}, - insertion_values={"added_ts": self._clock.time_msec()}, + insertion_values={"added_ts": self.clock.time_msec()}, ) await self.db_pool.runInteraction( @@ -1601,7 +1601,7 @@ async def _prune_old_outbound_device_pokes( that user when the destination comes back. It doesn't matter which device we keep. """ - yesterday = self._clock.time_msec() - prune_age + yesterday = self.clock.time_msec() - prune_age def _prune_txn(txn: LoggingTransaction) -> None: # look for (user, destination) pairs which have an update older than @@ -2086,7 +2086,7 @@ def _add_device_outbound_poke_to_stream_txn( stream_id, ) - now = self._clock.time_msec() + now = self.clock.time_msec() encoded_context = json_encoder.encode(context) mark_sent = not self.hs.is_mine_id(user_id) diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 2e9f62075a8..2d3d0c0036e 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -1564,7 +1564,7 @@ def impl(txn: LoggingTransaction) -> Tuple[List[str], int]: DELETE FROM e2e_one_time_keys_json WHERE {clause} AND ts_added_ms < ? AND length(key_id) = 6 """ - args.append(self._clock.time_msec() - (7 * 24 * 3600 * 1000)) + args.append(self.clock.time_msec() - (7 * 24 * 3600 * 1000)) txn.execute(sql, args) return users, txn.rowcount @@ -1585,7 +1585,7 @@ async def allow_master_cross_signing_key_replacement_without_uia( None, if there is no such key. Otherwise, the timestamp before which replacement is allowed without UIA. """ - timestamp = self._clock.time_msec() + duration_ms + timestamp = self.clock.time_msec() + duration_ms def impl(txn: LoggingTransaction) -> Optional[int]: txn.execute( diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 5c9bd2e848d..d77420ff475 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -167,6 +167,7 @@ def __init__( # Cache of event ID to list of auth event IDs and their depths. self._event_auth_cache: LruCache[str, List[Tuple[str, int]]] = LruCache( max_size=500000, + clock=self.hs.get_clock(), server_name=self.server_name, cache_name="_event_auth_cache", size_callback=len, @@ -176,7 +177,7 @@ def __init__( # index. self.tests_allow_no_chain_cover_index = True - self._clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) + self.clock.looping_call(self._get_stats_for_federation_staging, 30 * 1000) if isinstance(self.database_engine, PostgresEngine): self.db_pool.updates.register_background_validate_constraint_and_delete_rows( @@ -1328,7 +1329,7 @@ def get_backfill_points_in_room_txn( ( room_id, current_depth, - self._clock.time_msec(), + self.clock.time_msec(), BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS, limit, @@ -1841,7 +1842,7 @@ def _record_event_failed_pull_attempt_upsert_txn( last_cause=EXCLUDED.last_cause; """ - txn.execute(sql, (room_id, event_id, 1, self._clock.time_msec(), cause)) + txn.execute(sql, (room_id, event_id, 1, self.clock.time_msec(), cause)) @trace async def get_event_ids_with_failed_pull_attempts( @@ -1905,7 +1906,7 @@ async def get_event_ids_to_not_pull_from_backoff( ), ) - current_time = self._clock.time_msec() + current_time = self.clock.time_msec() event_ids_with_backoff = {} for event_id, last_attempt_ts, num_attempts in event_failed_pull_attempts: @@ -2025,7 +2026,7 @@ async def insert_received_event_to_staging( values={}, insertion_values={ "room_id": event.room_id, - "received_ts": self._clock.time_msec(), + "received_ts": self.clock.time_msec(), "event_json": json_encoder.encode(event.get_dict()), "internal_metadata": json_encoder.encode( event.internal_metadata.get_dict() @@ -2299,7 +2300,7 @@ def _get_stats_for_federation_staging_txn( # If there is nothing in the staging area default it to 0. age = 0 if received_ts is not None: - age = self._clock.time_msec() - received_ts + age = self.clock.time_msec() - received_ts return count, age diff --git a/synapse/storage/databases/main/event_push_actions.py b/synapse/storage/databases/main/event_push_actions.py index 4db0230421f..ec26aedc6bc 100644 --- a/synapse/storage/databases/main/event_push_actions.py +++ b/synapse/storage/databases/main/event_push_actions.py @@ -95,6 +95,8 @@ import attr +from twisted.internet.task import LoopingCall + from synapse.api.constants import MAIN_TIMELINE, ReceiptTypes from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause @@ -254,6 +256,8 @@ def _deserialize_action(actions: str, is_highlight: bool) -> List[Union[dict, st class EventPushActionsWorkerStore(ReceiptsWorkerStore, StreamWorkerStore, SQLBaseStore): + _background_tasks: List[LoopingCall] = [] + def __init__( self, database: DatabasePool, @@ -263,7 +267,7 @@ def __init__( super().__init__(database, db_conn, hs) # Track when the process started. - self._started_ts = self._clock.time_msec() + self._started_ts = self.clock.time_msec() # These get correctly set by _find_stream_orderings_for_times_txn self.stream_ordering_month_ago: Optional[int] = None @@ -273,18 +277,14 @@ def __init__( self._find_stream_orderings_for_times_txn(cur) cur.close() - self.find_stream_orderings_looping_call = self._clock.looping_call( - self._find_stream_orderings_for_times, 10 * 60 * 1000 - ) + self.clock.looping_call(self._find_stream_orderings_for_times, 10 * 60 * 1000) self._rotate_count = 10000 self._doing_notif_rotation = False if hs.config.worker.run_background_tasks: - self._rotate_notif_loop = self._clock.looping_call( - self._rotate_notifs, 30 * 1000 - ) + self.clock.looping_call(self._rotate_notifs, 30 * 1000) - self._clear_old_staging_loop = self._clock.looping_call( + self.clock.looping_call( self._clear_old_push_actions_staging, 30 * 60 * 1000 ) @@ -1190,7 +1190,7 @@ def _gen_entry( is_highlight, # highlight column int(count_as_unread), # unread column thread_id, # thread_id column - self._clock.time_msec(), # inserted_ts column + self.clock.time_msec(), # inserted_ts column ) await self.db_pool.simple_insert_many( @@ -1241,14 +1241,14 @@ async def _find_stream_orderings_for_times(self) -> None: def _find_stream_orderings_for_times_txn(self, txn: LoggingTransaction) -> None: logger.info("Searching for stream ordering 1 month ago") self.stream_ordering_month_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 30 * 24 * 60 * 60 * 1000 + txn, self.clock.time_msec() - 30 * 24 * 60 * 60 * 1000 ) logger.info( "Found stream ordering 1 month ago: it's %d", self.stream_ordering_month_ago ) logger.info("Searching for stream ordering 1 day ago") self.stream_ordering_day_ago = self._find_first_stream_ordering_after_ts_txn( - txn, self._clock.time_msec() - 24 * 60 * 60 * 1000 + txn, self.clock.time_msec() - 24 * 60 * 60 * 1000 ) logger.info( "Found stream ordering 1 day ago: it's %d", self.stream_ordering_day_ago @@ -1787,7 +1787,7 @@ async def _clear_old_push_actions_staging(self) -> None: # We delete anything more than an hour old, on the assumption that we'll # never take more than an hour to persist an event. - delete_before_ts = self._clock.time_msec() - 60 * 60 * 1000 + delete_before_ts = self.clock.time_msec() - 60 * 60 * 1000 if self._started_ts > delete_before_ts: # We need to wait for at least an hour before we started deleting, @@ -1824,7 +1824,7 @@ def _clear_old_push_actions_staging_txn(txn: LoggingTransaction) -> bool: return # We sleep to ensure that we don't overwhelm the DB. - await self._clock.sleep(1.0) + await self.clock.sleep(1.0) async def get_push_actions_for_user( self, diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py index 0a0102ee64b..37dd8e48d5d 100644 --- a/synapse/storage/databases/main/events_bg_updates.py +++ b/synapse/storage/databases/main/events_bg_updates.py @@ -730,7 +730,7 @@ def _redactions_received_ts_txn(txn: LoggingTransaction) -> int: WHERE ? <= event_id AND event_id <= ? """ - txn.execute(sql, (self._clock.time_msec(), last_event_id, upper_event_id)) + txn.execute(sql, (self.clock.time_msec(), last_event_id, upper_event_id)) self.db_pool.updates._background_update_progress_txn( txn, "redactions_received_ts", {"last_event_id": upper_event_id} diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py index 31e23122115..4f9a1a4f780 100644 --- a/synapse/storage/databases/main/events_worker.py +++ b/synapse/storage/databases/main/events_worker.py @@ -70,7 +70,6 @@ ) from synapse.metrics import SERVER_NAME_LABEL from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.replication.tcp.streams import BackfillStream, UnPartialStatedEventStream @@ -282,13 +281,14 @@ def __init__( if hs.config.worker.run_background_tasks: # We periodically clean out old transaction ID mappings - self._clock.looping_call( + self.clock.looping_call( self._cleanup_old_transaction_ids, 5 * 60 * 1000, ) self._get_event_cache: AsyncLruCache[Tuple[str], EventCacheEntry] = ( AsyncLruCache( + clock=hs.get_clock(), server_name=self.server_name, cache_name="*getEvent*", max_size=hs.config.caches.event_cache_size, @@ -1154,9 +1154,7 @@ def _maybe_start_fetch_thread(self) -> None: should_start = False if should_start: - run_as_background_process( - "fetch_events", self.server_name, self._fetch_thread - ) + self.hs.run_as_background_process("fetch_events", self._fetch_thread) async def _fetch_thread(self) -> None: """Services requests for events from `_event_fetch_list`.""" @@ -1276,7 +1274,7 @@ def _fetch_event_list( were not part of this request. """ with Measure( - self._clock, name="_fetch_event_list", server_name=self.server_name + self.clock, name="_fetch_event_list", server_name=self.server_name ): try: events_to_fetch = { @@ -2278,7 +2276,7 @@ async def _cleanup_old_transaction_ids(self) -> None: """Cleans out transaction id mappings older than 24hrs.""" def _cleanup_old_transaction_ids_txn(txn: LoggingTransaction) -> None: - one_day_ago = self._clock.time_msec() - 24 * 60 * 60 * 1000 + one_day_ago = self.clock.time_msec() - 24 * 60 * 60 * 1000 sql = """ DELETE FROM event_txn_id_device_id WHERE inserted_ts < ? @@ -2633,7 +2631,7 @@ def mark_event_rejected_txn( keyvalues={"event_id": event_id}, values={ "reason": rejection_reason, - "last_check": self._clock.time_msec(), + "last_check": self.clock.time_msec(), }, ) self.db_pool.simple_update_txn( diff --git a/synapse/storage/databases/main/lock.py b/synapse/storage/databases/main/lock.py index d0e4a91b595..e2b15eaf6a5 100644 --- a/synapse/storage/databases/main/lock.py +++ b/synapse/storage/databases/main/lock.py @@ -28,7 +28,6 @@ from twisted.internet.task import LoopingCall from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.storage._base import SQLBaseStore @@ -99,15 +98,15 @@ def __init__( # lead to a race, as we may drop the lock while we are still processing. # However, a) it should be a small window, b) the lock is best effort # anyway and c) we want to really avoid leaking locks when we restart. - hs.get_clock().add_system_event_trigger( - "before", - "shutdown", - self._on_shutdown, + hs.register_async_shutdown_handler( + phase="before", + eventType="shutdown", + shutdown_func=self._on_shutdown, ) self._acquiring_locks: Set[Tuple[str, str]] = set() - self._clock.looping_call( + self.clock.looping_call( self._reap_stale_read_write_locks, _LOCK_TIMEOUT_MS / 10.0 ) @@ -153,7 +152,7 @@ async def _try_acquire_lock( if lock and await lock.is_still_valid(): return None - now = self._clock.time_msec() + now = self.clock.time_msec() token = random_string(6) def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: @@ -202,7 +201,8 @@ def _try_acquire_lock_txn(txn: LoggingTransaction) -> bool: lock = Lock( self.server_name, self._reactor, - self._clock, + self.hs, + self.clock, self, read_write=False, lock_name=lock_name, @@ -251,7 +251,7 @@ def _try_acquire_read_write_lock_txn( # constraints. If it doesn't then we have acquired the lock, # otherwise we haven't. - now = self._clock.time_msec() + now = self.clock.time_msec() token = random_string(6) self.db_pool.simple_insert_txn( @@ -270,7 +270,8 @@ def _try_acquire_read_write_lock_txn( lock = Lock( self.server_name, self._reactor, - self._clock, + self.hs, + self.clock, self, read_write=True, lock_name=lock_name, @@ -338,7 +339,7 @@ async def _reap_stale_read_write_locks(self) -> None: """ def reap_stale_read_write_locks_txn(txn: LoggingTransaction) -> None: - txn.execute(delete_sql, (self._clock.time_msec() - _LOCK_TIMEOUT_MS,)) + txn.execute(delete_sql, (self.clock.time_msec() - _LOCK_TIMEOUT_MS,)) if txn.rowcount: logger.info("Reaped %d stale locks", txn.rowcount) @@ -374,6 +375,7 @@ def __init__( self, server_name: str, reactor: ISynapseReactor, + hs: "HomeServer", clock: Clock, store: LockStore, read_write: bool, @@ -387,6 +389,7 @@ def __init__( """ self._server_name = server_name self._reactor = reactor + self._hs = hs self._clock = clock self._store = store self._read_write = read_write @@ -410,6 +413,7 @@ def _setup_looping_call(self) -> None: _RENEWAL_INTERVAL_MS, self._server_name, self._store, + self._hs, self._clock, self._read_write, self._lock_name, @@ -421,6 +425,7 @@ def _setup_looping_call(self) -> None: def _renew( server_name: str, store: LockStore, + hs: "HomeServer", clock: Clock, read_write: bool, lock_name: str, @@ -457,9 +462,8 @@ async def _internal_renew( desc="renew_lock", ) - return run_as_background_process( + return hs.run_as_background_process( "Lock._renew", - server_name, _internal_renew, store, clock, diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py index f726846e57f..b8bd0042d78 100644 --- a/synapse/storage/databases/main/media_repository.py +++ b/synapse/storage/databases/main/media_repository.py @@ -565,7 +565,7 @@ def get_pending_media_txn(txn: LoggingTransaction) -> Tuple[int, int]: sql, ( user_id.to_string(), - self._clock.time_msec() - self.unused_expiration_time, + self.clock.time_msec() - self.unused_expiration_time, ), ) row = txn.fetchone() @@ -1059,7 +1059,7 @@ def _get_media_uploaded_size_for_user_txn( txn: LoggingTransaction, ) -> int: # Calculate the timestamp for the start of the time period - start_ts = self._clock.time_msec() - time_period_ms + start_ts = self.clock.time_msec() - time_period_ms txn.execute(sql, (user_id, start_ts)) row = txn.fetchone() if row is None: diff --git a/synapse/storage/databases/main/metrics.py b/synapse/storage/databases/main/metrics.py index a3467bff3dc..49411ed0341 100644 --- a/synapse/storage/databases/main/metrics.py +++ b/synapse/storage/databases/main/metrics.py @@ -78,7 +78,7 @@ def __init__( # Read the extrems every 60 minutes if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) + self.clock.looping_call(self._read_forward_extremities, 60 * 60 * 1000) # Used in _generate_user_daily_visits to keep track of progress self._last_user_visit_update = self._get_start_of_day() @@ -224,7 +224,7 @@ async def count_daily_users(self) -> int: """ Counts the number of users who used this homeserver in the last 24 hours. """ - yesterday = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24) + yesterday = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24) return await self.db_pool.runInteraction( "count_daily_users", self._count_users, yesterday ) @@ -236,7 +236,7 @@ async def count_monthly_users(self) -> int: from the mau figure in synapse.storage.monthly_active_users which, amongst other things, includes a 3 day grace period before a user counts. """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + thirty_days_ago = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) return await self.db_pool.runInteraction( "count_monthly_users", self._count_users, thirty_days_ago ) @@ -281,7 +281,7 @@ async def count_r30v2_users(self) -> Dict[str, int]: def _count_r30v2_users(txn: LoggingTransaction) -> Dict[str, int]: thirty_days_in_secs = 86400 * 30 - now = int(self._clock.time()) + now = int(self.clock.time()) sixty_days_ago_in_secs = now - 2 * thirty_days_in_secs one_day_from_now_in_secs = now + 86400 @@ -389,7 +389,7 @@ def _get_start_of_day(self) -> int: """ Returns millisecond unixtime for start of UTC day. """ - now = time.gmtime(self._clock.time()) + now = time.gmtime(self.clock.time()) today_start = calendar.timegm((now.tm_year, now.tm_mon, now.tm_mday, 0, 0, 0)) return today_start * 1000 @@ -403,7 +403,7 @@ def _generate_user_daily_visits(txn: LoggingTransaction) -> None: logger.info("Calling _generate_user_daily_visits") today_start = self._get_start_of_day() a_day_in_milliseconds = 24 * 60 * 60 * 1000 - now = self._clock.time_msec() + now = self.clock.time_msec() # A note on user_agent. Technically a given device can have multiple # user agents, so we need to decide which one to pick. We could have diff --git a/synapse/storage/databases/main/monthly_active_users.py b/synapse/storage/databases/main/monthly_active_users.py index f5a6b98be71..86744f616ce 100644 --- a/synapse/storage/databases/main/monthly_active_users.py +++ b/synapse/storage/databases/main/monthly_active_users.py @@ -49,7 +49,6 @@ def __init__( hs: "HomeServer", ): super().__init__(database, db_conn, hs) - self._clock = hs.get_clock() self.hs = hs if hs.config.redis.redis_enabled: @@ -226,7 +225,7 @@ def _reap_users(txn: LoggingTransaction, reserved_users: List[str]) -> None: reserved_users: reserved users to preserve """ - thirty_days_ago = int(self._clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) + thirty_days_ago = int(self.clock.time_msec()) - (1000 * 60 * 60 * 24 * 30) in_clause, in_clause_args = make_in_list_sql_clause( self.database_engine, "user_id", reserved_users @@ -328,7 +327,7 @@ def _initialise_reserved_users( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, + values={"timestamp": int(self.clock.time_msec())}, ) else: logger.warning("mau limit reserved threepid %s not found in db", tp) @@ -391,7 +390,7 @@ def upsert_monthly_active_user_txn( txn, table="monthly_active_users", keyvalues={"user_id": user_id}, - values={"timestamp": int(self._clock.time_msec())}, + values={"timestamp": int(self.clock.time_msec())}, ) self._invalidate_cache_and_stream(txn, self.get_monthly_active_count, ()) diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py index ff4eb9acb29..f1dbf68971d 100644 --- a/synapse/storage/databases/main/receipts.py +++ b/synapse/storage/databases/main/receipts.py @@ -1073,7 +1073,7 @@ async def insert_receipt( if event_ts is None: return None - now = self._clock.time_msec() + now = self.clock.time_msec() logger.debug( "Receipt %s for event %s in %s (%i ms old)", receipt_type, diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py index 117444e7b75..906d1a91f68 100644 --- a/synapse/storage/databases/main/registration.py +++ b/synapse/storage/databases/main/registration.py @@ -212,7 +212,7 @@ def __init__( ) if hs.config.worker.run_background_tasks: - self._clock.call_later( + self.clock.call_later( 0.0, self._set_expiration_date_when_missing, ) @@ -226,7 +226,7 @@ def __init__( # Create a background job for culling expired 3PID validity tokens if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self.cull_expired_threepid_validation_tokens, THIRTY_MINUTES_IN_MS ) @@ -298,7 +298,7 @@ def _register_user( ) -> None: user_id_obj = UserID.from_string(user_id) - now = int(self._clock.time()) + now = int(self.clock.time()) user_approved = approved or not self._require_approval @@ -457,7 +457,7 @@ async def is_trial_user(self, user_id: str) -> bool: if not info: return False - now = self._clock.time_msec() + now = self.clock.time_msec() days = self.config.server.mau_appservice_trial_days.get( info.appservice_id, self.config.server.mau_trial_days ) @@ -640,7 +640,7 @@ def select_users_txn( return await self.db_pool.runInteraction( "get_users_expiring_soon", select_users_txn, - self._clock.time_msec(), + self.clock.time_msec(), self.config.account_validity.account_validity_renew_at, ) @@ -1084,7 +1084,7 @@ async def count_daily_user_type(self) -> Dict[str, int]: """ def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]: - yesterday = int(self._clock.time()) - (60 * 60 * 24) + yesterday = int(self.clock.time()) - (60 * 60 * 24) sql = """ SELECT user_type, COUNT(*) AS count FROM ( @@ -1496,7 +1496,7 @@ def cull_expired_threepid_validation_tokens_txn( await self.db_pool.runInteraction( "cull_expired_threepid_validation_tokens", cull_expired_threepid_validation_tokens_txn, - self._clock.time_msec(), + self.clock.time_msec(), ) @wrap_as_background_process("account_validity_set_expiration_dates") @@ -1537,7 +1537,7 @@ def set_expiration_date_for_user_txn( random value in the [now + period - d ; now + period] range, d being a delta equal to 10% of the validity period. """ - now_ms = self._clock.time_msec() + now_ms = self.clock.time_msec() assert self._account_validity_period is not None expiration_ts = now_ms + self._account_validity_period @@ -1608,7 +1608,7 @@ async def update_access_token_last_validated(self, token_id: int) -> None: Raises: StoreError if there was a problem updating this. """ - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.simple_update_one( "access_tokens", @@ -1639,7 +1639,7 @@ async def registration_token_is_valid(self, token: str) -> bool: uses_allowed, pending, completed, expiry_time = res # Check if the token has expired - now = self._clock.time_msec() + now = self.clock.time_msec() if expiry_time and expiry_time < now: return False @@ -1771,7 +1771,7 @@ def select_registration_tokens_txn( return await self.db_pool.runInteraction( "select_registration_tokens", select_registration_tokens_txn, - self._clock.time_msec(), + self.clock.time_msec(), valid, ) @@ -2251,7 +2251,7 @@ async def consume_login_token(self, token: str) -> LoginTokenLookupResult: "consume_login_token", self._consume_login_token, token, - self._clock.time_msec(), + self.clock.time_msec(), ) async def invalidate_login_tokens_by_session_id( @@ -2271,7 +2271,7 @@ async def invalidate_login_tokens_by_session_id( "auth_provider_id": auth_provider_id, "auth_provider_session_id": auth_provider_session_id, }, - updatevalues={"used_ts": self._clock.time_msec()}, + updatevalues={"used_ts": self.clock.time_msec()}, desc="invalidate_login_tokens_by_session_id", ) @@ -2640,7 +2640,6 @@ def __init__( ): super().__init__(database, db_conn, hs) - self._clock = hs.get_clock() self.config = hs.config self.db_pool.updates.register_background_index_update( @@ -2761,7 +2760,7 @@ def __init__( # Create a background job for removing expired login tokens if hs.config.worker.run_background_tasks: - self._clock.looping_call( + self.clock.looping_call( self._delete_expired_login_tokens, THIRTY_MINUTES_IN_MS ) @@ -2790,7 +2789,7 @@ async def add_access_token_to_user( The token ID """ next_id = self._access_tokens_id_gen.get_next() - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.simple_insert( "access_tokens", @@ -2874,7 +2873,7 @@ def f(txn: LoggingTransaction) -> None: keyvalues={"name": user_id}, updatevalues={ "consent_version": consent_version, - "consent_ts": self._clock.time_msec(), + "consent_ts": self.clock.time_msec(), }, ) self._invalidate_cache_and_stream(txn, self.get_user_by_id, (user_id,)) @@ -2986,7 +2985,7 @@ def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]: txn, table="threepid_validation_session", keyvalues={"session_id": session_id}, - updatevalues={"validated_at": self._clock.time_msec()}, + updatevalues={"validated_at": self.clock.time_msec()}, ) return next_link @@ -3064,7 +3063,7 @@ def _delete_expired_login_tokens_txn(txn: LoggingTransaction, ts: int) -> None: # We keep the expired tokens for an extra 5 minutes so we can measure how many # times a token is being used after its expiry - now = self._clock.time_msec() + now = self.clock.time_msec() await self.db_pool.runInteraction( "delete_expired_login_tokens", _delete_expired_login_tokens_txn, diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index 9db2e14a06f..65caf4b1eaa 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -1002,7 +1002,7 @@ async def get_joined_user_ids_from_state( """ with Measure( - self._clock, + self.clock, name="get_joined_user_ids_from_state", server_name=self.server_name, ): diff --git a/synapse/storage/databases/main/session.py b/synapse/storage/databases/main/session.py index 8a5fa8386cd..1154bb2d599 100644 --- a/synapse/storage/databases/main/session.py +++ b/synapse/storage/databases/main/session.py @@ -55,7 +55,7 @@ def __init__( # Create a background job for culling expired sessions. if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) + self.clock.looping_call(self._delete_expired_sessions, 30 * 60 * 1000) async def create_session( self, session_type: str, value: JsonDict, expiry_ms: int @@ -133,7 +133,7 @@ def _get_session( _get_session, session_type, session_id, - self._clock.time_msec(), + self.clock.time_msec(), ) @wrap_as_background_process("delete_expired_sessions") @@ -147,5 +147,5 @@ def _delete_expired_sessions_txn(txn: LoggingTransaction, ts: int) -> None: await self.db_pool.runInteraction( "delete_expired_sessions", _delete_expired_sessions_txn, - self._clock.time_msec(), + self.clock.time_msec(), ) diff --git a/synapse/storage/databases/main/sliding_sync.py b/synapse/storage/databases/main/sliding_sync.py index f7af3e88d3f..c0c5087b13c 100644 --- a/synapse/storage/databases/main/sliding_sync.py +++ b/synapse/storage/databases/main/sliding_sync.py @@ -201,7 +201,7 @@ def persist_per_connection_state_txn( "user_id": user_id, "effective_device_id": device_id, "conn_id": conn_id, - "created_ts": self._clock.time_msec(), + "created_ts": self.clock.time_msec(), }, returning=("connection_key",), ) @@ -212,7 +212,7 @@ def persist_per_connection_state_txn( table="sliding_sync_connection_positions", values={ "connection_key": connection_key, - "created_ts": self._clock.time_msec(), + "created_ts": self.clock.time_msec(), }, returning=("connection_position",), ) diff --git a/synapse/storage/databases/main/transactions.py b/synapse/storage/databases/main/transactions.py index bfc324b80d2..41c94839273 100644 --- a/synapse/storage/databases/main/transactions.py +++ b/synapse/storage/databases/main/transactions.py @@ -81,11 +81,11 @@ def __init__( super().__init__(database, db_conn, hs) if hs.config.worker.run_background_tasks: - self._clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) + self.clock.looping_call(self._cleanup_transactions, 30 * 60 * 1000) @wrap_as_background_process("cleanup_transactions") async def _cleanup_transactions(self) -> None: - now = self._clock.time_msec() + now = self.clock.time_msec() day_ago = now - 24 * 60 * 60 * 1000 def _cleanup_transactions_txn(txn: LoggingTransaction) -> None: @@ -160,7 +160,7 @@ async def set_received_txn_response( insertion_values={ "response_code": code, "response_json": db_binary_type(encode_canonical_json(response_dict)), - "ts": self._clock.time_msec(), + "ts": self.clock.time_msec(), }, desc="set_received_txn_response", ) diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py index 9b3b7e086f9..b62f3e6f5ba 100644 --- a/synapse/storage/databases/state/store.py +++ b/synapse/storage/databases/state/store.py @@ -125,6 +125,7 @@ def __init__( self._state_group_cache: DictionaryCache[int, StateKey, str] = DictionaryCache( name="*stateGroupCache*", + clock=hs.get_clock(), server_name=self.server_name, # TODO: this hasn't been tuned yet max_entries=50000, @@ -132,6 +133,7 @@ def __init__( self._state_group_members_cache: DictionaryCache[int, StateKey, str] = ( DictionaryCache( name="*stateGroupMembersCache*", + clock=hs.get_clock(), server_name=self.server_name, max_entries=500000, ) diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 1f909885258..2a167f209cb 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -55,7 +55,6 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError -from twisted.internet.interfaces import IReactorTime from twisted.python.failure import Failure from synapse.logging.context import ( @@ -549,10 +548,9 @@ class Linearizer: def __init__( self, - *, name: str, - max_count: int = 1, clock: Clock, + max_count: int = 1, ): """ Args: @@ -772,7 +770,11 @@ async def _ctx_manager() -> AsyncIterator[None]: def timeout_deferred( - deferred: "defer.Deferred[_T]", timeout: float, reactor: IReactorTime + *, + deferred: "defer.Deferred[_T]", + timeout: float, + cancel_on_shutdown: bool = True, + clock: Clock, ) -> "defer.Deferred[_T]": """The in built twisted `Deferred.addTimeout` fails to time out deferreds that have a canceller that throws exceptions. This method creates a new @@ -790,7 +792,13 @@ def timeout_deferred( Args: deferred: The Deferred to potentially timeout. timeout: Timeout in seconds - reactor: The twisted reactor to use + cancel_on_shutdown: Whether this call should be tracked for cleanup during + shutdown. In general, all calls should be tracked. There may be a use case + not to track calls with a `timeout` of 0 (or similarly short) since tracking + them may result in rapid insertions and removals of tracked calls + unnecessarily. But unless a specific instance of tracking proves to be an + issue, we can just track all delayed calls. + clock: The `Clock` instance used to track delayed calls. Returns: @@ -814,7 +822,10 @@ def time_it_out() -> None: if not new_d.called: new_d.errback(defer.TimeoutError("Timed out after %gs" % (timeout,))) - delayed_call = reactor.callLater(timeout, time_it_out) + # We don't track these calls since they are short. + delayed_call = clock.call_later( + timeout, time_it_out, call_later_cancel_on_shutdown=cancel_on_shutdown + ) def convert_cancelled(value: Failure) -> Failure: # if the original deferred was cancelled, and our timeout has fired, then @@ -956,9 +967,9 @@ class AwakenableSleeper: currently sleeping. """ - def __init__(self, reactor: IReactorTime) -> None: + def __init__(self, clock: Clock) -> None: self._streams: Dict[str, Set[defer.Deferred[None]]] = {} - self._reactor = reactor + self._clock = clock def wake(self, name: str) -> None: """Wake everything related to `name` that is currently sleeping.""" @@ -977,7 +988,11 @@ async def sleep(self, name: str, delay_ms: int) -> None: # Create a deferred that gets called in N seconds sleep_deferred: "defer.Deferred[None]" = defer.Deferred() - call = self._reactor.callLater(delay_ms / 1000, sleep_deferred.callback, None) + call = self._clock.call_later( + delay_ms / 1000, + sleep_deferred.callback, + None, + ) # Create a deferred that will get called if `wake` is called with # the same `name`. @@ -1011,8 +1026,8 @@ async def sleep(self, name: str, delay_ms: int) -> None: class DeferredEvent: """Like threading.Event but for async code""" - def __init__(self, reactor: IReactorTime) -> None: - self._reactor = reactor + def __init__(self, clock: Clock) -> None: + self._clock = clock self._deferred: "defer.Deferred[None]" = defer.Deferred() def set(self) -> None: @@ -1032,7 +1047,11 @@ async def wait(self, timeout_seconds: float) -> bool: # Create a deferred that gets called in N seconds sleep_deferred: "defer.Deferred[None]" = defer.Deferred() - call = self._reactor.callLater(timeout_seconds, sleep_deferred.callback, None) + call = self._clock.call_later( + timeout_seconds, + sleep_deferred.callback, + None, + ) try: await make_deferred_yieldable( diff --git a/synapse/util/batching_queue.py b/synapse/util/batching_queue.py index 4c4037412aa..f77301afd81 100644 --- a/synapse/util/batching_queue.py +++ b/synapse/util/batching_queue.py @@ -21,6 +21,7 @@ import logging from typing import ( + TYPE_CHECKING, Awaitable, Callable, Dict, @@ -38,9 +39,11 @@ from synapse.logging.context import PreserveLoggingContext, make_deferred_yieldable from synapse.metrics import SERVER_NAME_LABEL -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.clock import Clock +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -97,12 +100,13 @@ def __init__( self, *, name: str, - server_name: str, + hs: "HomeServer", clock: Clock, process_batch_callback: Callable[[List[V]], Awaitable[R]], ): self._name = name - self.server_name = server_name + self.hs = hs + self.server_name = hs.hostname self._clock = clock # The set of keys currently being processed. @@ -127,6 +131,14 @@ def __init__( name=self._name, **{SERVER_NAME_LABEL: self.server_name} ) + def shutdown(self) -> None: + """ + Prepares the object for garbage collection by removing any handed out + references. + """ + number_queued.remove(self._name, self.server_name) + number_of_keys.remove(self._name, self.server_name) + async def add_to_queue(self, value: V, key: Hashable = ()) -> R: """Adds the value to the queue with the given key, returning the result of the processing function for the batch that included the given value. @@ -145,9 +157,7 @@ async def add_to_queue(self, value: V, key: Hashable = ()) -> R: # If we're not currently processing the key fire off a background # process to start processing. if key not in self._processing_keys: - run_as_background_process( - self._name, self.server_name, self._process_queue, key - ) + self.hs.run_as_background_process(self._name, self._process_queue, key) with self._number_in_flight_metric.track_inprogress(): return await make_deferred_yieldable(d) diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py index 710a29e3f0f..08ff842af0f 100644 --- a/synapse/util/caches/__init__.py +++ b/synapse/util/caches/__init__.py @@ -244,7 +244,7 @@ def register_cache( collect_callback=collect_callback, ) metric_name = "cache_%s_%s_%s" % (cache_type, cache_name, server_name) - CACHE_METRIC_REGISTRY.register_hook(metric_name, metric.collect) + CACHE_METRIC_REGISTRY.register_hook(server_name, metric_name, metric.collect) return metric diff --git a/synapse/util/caches/deferred_cache.py b/synapse/util/caches/deferred_cache.py index 92d446ce2aa..016acbac710 100644 --- a/synapse/util/caches/deferred_cache.py +++ b/synapse/util/caches/deferred_cache.py @@ -47,6 +47,7 @@ from synapse.util.async_helpers import ObservableDeferred from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache, iterate_tree_cache_entry +from synapse.util.clock import Clock cache_pending_metric = Gauge( "synapse_util_caches_cache_pending", @@ -82,6 +83,7 @@ def __init__( self, *, name: str, + clock: Clock, server_name: str, max_entries: int = 1000, tree: bool = False, @@ -103,6 +105,7 @@ def __init__( prune_unread_entries: If True, cache entries that haven't been read recently will be evicted from the cache in the background. Set to False to opt-out of this behaviour. + clock: The homeserver `Clock` instance """ cache_type = TreeCache if tree else dict @@ -120,6 +123,7 @@ def metrics_cb() -> None: # a Deferred. self.cache: LruCache[KT, VT] = LruCache( max_size=max_entries, + clock=clock, server_name=server_name, cache_name=name, cache_type=cache_type, diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py index 47b8f4ddc81..6e3c8eada98 100644 --- a/synapse/util/caches/descriptors.py +++ b/synapse/util/caches/descriptors.py @@ -53,6 +53,7 @@ from synapse.util.async_helpers import delay_cancellation from synapse.util.caches.deferred_cache import DeferredCache from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock logger = logging.getLogger(__name__) @@ -154,13 +155,20 @@ def __init__( ) -class HasServerName(Protocol): +class HasServerNameAndClock(Protocol): server_name: str """ The homeserver name that this cache is associated with (used to label the metric) (`hs.hostname`). """ + clock: Clock + """ + The homeserver clock instance used to track delayed and looping calls. Important to + be able to fully cleanup the homeserver instance on server shutdown. + (`hs.get_clock()`). + """ + class DeferredCacheDescriptor(_CacheDescriptorBase): """A method decorator that applies a memoizing cache around the function. @@ -239,7 +247,7 @@ def __init__( self.prune_unread_entries = prune_unread_entries def __get__( - self, obj: Optional[HasServerName], owner: Optional[Type] + self, obj: Optional[HasServerNameAndClock], owner: Optional[Type] ) -> Callable[..., "defer.Deferred[Any]"]: # We need access to instance-level `obj.server_name` attribute assert obj is not None, ( @@ -249,9 +257,13 @@ def __get__( assert obj.server_name is not None, ( "The `server_name` attribute must be set on the object where `@cached` decorator is used." ) + assert obj.clock is not None, ( + "The `clock` attribute must be set on the object where `@cached` decorator is used." + ) cache: DeferredCache[CacheKey, Any] = DeferredCache( name=self.name, + clock=obj.clock, server_name=obj.server_name, max_entries=self.max_entries, tree=self.tree, diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index 168ddc51cd5..eb5493d322d 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -37,6 +37,7 @@ from synapse.util.caches.lrucache import LruCache from synapse.util.caches.treecache import TreeCache +from synapse.util.clock import Clock logger = logging.getLogger(__name__) @@ -127,10 +128,13 @@ class DictionaryCache(Generic[KT, DKT, DV]): for the '2' dict key. """ - def __init__(self, *, name: str, server_name: str, max_entries: int = 1000): + def __init__( + self, *, name: str, clock: Clock, server_name: str, max_entries: int = 1000 + ): """ Args: name + clock: The homeserver `Clock` instance server_name: The homeserver name that this cache is associated with (used to label the metric) (`hs.hostname`). max_entries @@ -160,6 +164,7 @@ def __init__(self, *, name: str, server_name: str, max_entries: int = 1000): Union[_PerKeyValue, Dict[DKT, DV]], ] = LruCache( max_size=max_entries, + clock=clock, server_name=server_name, cache_name=name, cache_type=TreeCache, diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py index 305af5051c4..29ce6c0a776 100644 --- a/synapse/util/caches/expiringcache.py +++ b/synapse/util/caches/expiringcache.py @@ -21,17 +21,29 @@ import logging from collections import OrderedDict -from typing import Any, Generic, Iterable, Literal, Optional, TypeVar, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + Generic, + Iterable, + Literal, + Optional, + TypeVar, + Union, + overload, +) import attr from twisted.internet import defer from synapse.config import cache as cache_config -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.util.caches import EvictionReason, register_cache from synapse.util.clock import Clock +if TYPE_CHECKING: + from synapse.server import HomeServer + logger = logging.getLogger(__name__) @@ -49,6 +61,7 @@ def __init__( *, cache_name: str, server_name: str, + hs: "HomeServer", clock: Clock, max_len: int = 0, expiry_ms: int = 0, @@ -99,9 +112,7 @@ def __init__( return def f() -> "defer.Deferred[None]": - return run_as_background_process( - "prune_cache", server_name, self._prune_cache - ) + return hs.run_as_background_process("prune_cache", self._prune_cache) self._clock.looping_call(f, self._expiry_ms / 2) diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py index 2d4cde19a5a..324acb728ab 100644 --- a/synapse/util/caches/lrucache.py +++ b/synapse/util/caches/lrucache.py @@ -45,14 +45,10 @@ overload, ) -from twisted.internet import defer, reactor +from twisted.internet import defer from synapse.config import cache as cache_config -from synapse.metrics.background_process_metrics import ( - run_as_background_process, -) from synapse.metrics.jemalloc import get_jemalloc_stats -from synapse.types import ISynapseThreadlessReactor from synapse.util import caches from synapse.util.caches import CacheMetric, EvictionReason, register_cache from synapse.util.caches.treecache import ( @@ -123,6 +119,7 @@ def update_last_access(self, clock: Clock) -> None: def _expire_old_entries( server_name: str, + hs: "HomeServer", clock: Clock, expiry_seconds: float, autotune_config: Optional[dict], @@ -228,9 +225,8 @@ async def _internal_expire_old_entries( logger.info("Dropped %d items from caches", i) - return run_as_background_process( + return hs.run_as_background_process( "LruCache._expire_old_entries", - server_name, _internal_expire_old_entries, clock, expiry_seconds, @@ -261,6 +257,7 @@ def setup_expire_lru_cache_entries(hs: "HomeServer") -> None: _expire_old_entries, 30 * 1000, server_name, + hs, clock, expiry_time, hs.config.caches.cache_autotuning, @@ -404,13 +401,13 @@ def __init__( self, *, max_size: int, + clock: Clock, server_name: str, cache_name: str, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): ... @@ -420,13 +417,13 @@ def __init__( self, *, max_size: int, + clock: Clock, server_name: str, cache_name: Literal[None] = None, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): ... @@ -435,13 +432,13 @@ def __init__( self, *, max_size: int, + clock: Clock, server_name: str, cache_name: Optional[str] = None, cache_type: Type[Union[dict, TreeCache]] = dict, size_callback: Optional[Callable[[VT], int]] = None, metrics_collection_callback: Optional[Callable[[], None]] = None, apply_cache_factor_from_config: bool = True, - clock: Optional[Clock] = None, prune_unread_entries: bool = True, extra_index_cb: Optional[Callable[[KT, VT], KT]] = None, ): @@ -492,15 +489,6 @@ def __init__( Note: The new key does not have to be unique. """ - # Default `clock` to something sensible. Note that we rename it to - # `real_clock` so that mypy doesn't think its still `Optional`. - if clock is None: - real_clock = Clock( - cast(ISynapseThreadlessReactor, reactor), server_name=server_name - ) - else: - real_clock = clock - cache: Union[Dict[KT, _Node[KT, VT]], TreeCache] = cache_type() self.cache = cache # Used for introspection. self.apply_cache_factor_from_config = apply_cache_factor_from_config @@ -592,7 +580,7 @@ def add_node( key, value, weak_ref_to_self, - real_clock, + clock, callbacks, prune_unread_entries, ) @@ -610,7 +598,7 @@ def add_node( metrics.inc_memory_usage(node.memory) def move_node_to_front(node: _Node[KT, VT]) -> None: - node.move_to_front(real_clock, list_root) + node.move_to_front(clock, list_root) def delete_node(node: _Node[KT, VT]) -> int: node.drop_from_lists() diff --git a/synapse/util/caches/response_cache.py b/synapse/util/caches/response_cache.py index 79e34262df5..3d39357236a 100644 --- a/synapse/util/caches/response_cache.py +++ b/synapse/util/caches/response_cache.py @@ -198,7 +198,17 @@ def on_complete(r: RV) -> RV: # the should_cache bit, we leave it in the cache for now and schedule # its removal later. if self.timeout_sec and context.should_cache: - self.clock.call_later(self.timeout_sec, self._entry_timeout, key) + self.clock.call_later( + self.timeout_sec, + self._entry_timeout, + key, + # We don't need to track these calls since they don't hold any strong + # references which would keep the `HomeServer` in memory after shutdown. + # We don't want to track these because they can get cancelled really + # quickly and thrash the tracking mechanism, ie. during repeated calls + # to /sync. + call_later_cancel_on_shutdown=False, + ) else: # otherwise, remove the result immediately. self.unset(key) diff --git a/synapse/util/clock.py b/synapse/util/clock.py index e85af170052..5e65cf32a4b 100644 --- a/synapse/util/clock.py +++ b/synapse/util/clock.py @@ -17,10 +17,12 @@ from typing import ( Any, Callable, + Dict, + List, ) -import attr from typing_extensions import ParamSpec +from zope.interface import implementer from twisted.internet import defer, task from twisted.internet.defer import Deferred @@ -34,24 +36,54 @@ P = ParamSpec("P") -@attr.s(slots=True) class Clock: """ A Clock wraps a Twisted reactor and provides utilities on top of it. + This clock should be used in place of calls to the base reactor wherever `LoopingCall` + or `DelayedCall` are made (such as when calling `reactor.callLater`. This is to + ensure the calls made by this `HomeServer` instance are tracked and can be cleaned + up during `HomeServer.shutdown()`. + + We enforce usage of this clock instead of using the reactor directly via lints in + `scripts-dev/mypy_synapse_plugin.py`. + + Args: reactor: The Twisted reactor to use. """ - _reactor: ISynapseThreadlessReactor = attr.ib() - _server_name: str = attr.ib() + _reactor: ISynapseThreadlessReactor + + def __init__(self, reactor: ISynapseThreadlessReactor, server_name: str) -> None: + self._reactor = reactor + self._server_name = server_name + + self._delayed_call_id: int = 0 + """Unique ID used to track delayed calls""" + + self._looping_calls: List[LoopingCall] = [] + """List of active looping calls""" + + self._call_id_to_delayed_call: Dict[int, IDelayedCall] = {} + """Mapping from unique call ID to delayed call""" + + self._is_shutdown = False + """Whether shutdown has been requested by the HomeServer""" + + def shutdown(self) -> None: + self._is_shutdown = True + self.cancel_all_looping_calls() + self.cancel_all_delayed_calls() async def sleep(self, seconds: float) -> None: d: defer.Deferred[float] = defer.Deferred() # Start task in the `sentinel` logcontext, to avoid leaking the current context # into the reactor once it finishes. with context.PreserveLoggingContext(): - self._reactor.callLater(seconds, d.callback, seconds) + # We can ignore the lint here since this class is the one location callLater should + # be called. + self._reactor.callLater(seconds, d.callback, seconds) # type: ignore[call-later-not-tracked] await d def time(self) -> float: @@ -124,6 +156,9 @@ def _looping_call_common( ) -> LoopingCall: """Common functionality for `looping_call` and `looping_call_now`""" + if self._is_shutdown: + raise Exception("Cannot start looping call. Clock has been shutdown") + def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred: assert context.current_context() is context.SENTINEL_CONTEXT, ( "Expected `looping_call` callback from the reactor to start with the sentinel logcontext " @@ -155,7 +190,9 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred: # logcontext to the reactor return context.run_in_background(f, *args, **kwargs) - call = task.LoopingCall(wrapped_f, *args, **kwargs) + # We can ignore the lint here since this is the one location LoopingCall's + # should be created. + call = task.LoopingCall(wrapped_f, *args, **kwargs) # type: ignore[prefer-synapse-clock-looping-call] call.clock = self._reactor # If `now=true`, the function will be called here immediately so we need to be # in the sentinel context now. @@ -165,10 +202,32 @@ def wrapped_f(*args: P.args, **kwargs: P.kwargs) -> Deferred: with context.PreserveLoggingContext(): d = call.start(msec / 1000.0, now=now) d.addErrback(log_failure, "Looping call died", consumeErrors=False) + self._looping_calls.append(call) return call + def cancel_all_looping_calls(self, consumeErrors: bool = True) -> None: + """ + Stop all running looping calls. + + Args: + consumeErrors: Whether to re-raise errors encountered when cancelling the + scheduled call. + """ + for call in self._looping_calls: + try: + call.stop() + except Exception: + if not consumeErrors: + raise + self._looping_calls.clear() + def call_later( - self, delay: float, callback: Callable, *args: Any, **kwargs: Any + self, + delay: float, + callback: Callable, + *args: Any, + call_later_cancel_on_shutdown: bool = True, + **kwargs: Any, ) -> IDelayedCall: """Call something later @@ -180,39 +239,78 @@ def call_later( delay: How long to wait in seconds. callback: Function to call *args: Postional arguments to pass to function. + call_later_cancel_on_shutdown: Whether this call should be tracked for cleanup during + shutdown. In general, all calls should be tracked. There may be a use case + not to track calls with a `timeout` of 0 (or similarly short) since tracking + them may result in rapid insertions and removals of tracked calls + unnecessarily. But unless a specific instance of tracking proves to be an + issue, we can just track all delayed calls. **kwargs: Key arguments to pass to function. """ - def wrapped_callback(*args: Any, **kwargs: Any) -> None: - assert context.current_context() is context.SENTINEL_CONTEXT, ( - "Expected `call_later` callback from the reactor to start with the sentinel logcontext " - f"but saw {context.current_context()}. In other words, another task shouldn't have " - "leaked their logcontext to us." - ) - - # Because this is a callback from the reactor, we will be using the - # `sentinel` log context at this point. We want the function to log with - # some logcontext as we want to know which server the logs came from. - # - # We use `PreserveLoggingContext` to prevent our new `call_later` - # logcontext from finishing as soon as we exit this function, in case `f` - # returns an awaitable/deferred which would continue running and may try to - # restore the `loop_call` context when it's done (because it's trying to - # adhere to the Synapse logcontext rules.) - # - # This also ensures that we return to the `sentinel` context when we exit - # this function and yield control back to the reactor to avoid leaking the - # current logcontext to the reactor (which would then get picked up and - # associated with the next thing the reactor does) - with context.PreserveLoggingContext( - context.LoggingContext(name="call_later", server_name=self._server_name) - ): - # We use `run_in_background` to reset the logcontext after `f` (or the - # awaitable returned by `f`) completes to avoid leaking the current - # logcontext to the reactor - context.run_in_background(callback, *args, **kwargs) + if self._is_shutdown: + raise Exception("Cannot start delayed call. Clock has been shutdown") + + def create_wrapped_callback( + track_for_shutdown_cancellation: bool, + ) -> Callable[P, None]: + def wrapped_callback(*args: Any, **kwargs: Any) -> None: + assert context.current_context() is context.SENTINEL_CONTEXT, ( + "Expected `call_later` callback from the reactor to start with the sentinel logcontext " + f"but saw {context.current_context()}. In other words, another task shouldn't have " + "leaked their logcontext to us." + ) - return self._reactor.callLater(delay, wrapped_callback, *args, **kwargs) + # Because this is a callback from the reactor, we will be using the + # `sentinel` log context at this point. We want the function to log with + # some logcontext as we want to know which server the logs came from. + # + # We use `PreserveLoggingContext` to prevent our new `call_later` + # logcontext from finishing as soon as we exit this function, in case `f` + # returns an awaitable/deferred which would continue running and may try to + # restore the `loop_call` context when it's done (because it's trying to + # adhere to the Synapse logcontext rules.) + # + # This also ensures that we return to the `sentinel` context when we exit + # this function and yield control back to the reactor to avoid leaking the + # current logcontext to the reactor (which would then get picked up and + # associated with the next thing the reactor does) + try: + with context.PreserveLoggingContext( + context.LoggingContext( + name="call_later", server_name=self._server_name + ) + ): + # We use `run_in_background` to reset the logcontext after `f` (or the + # awaitable returned by `f`) completes to avoid leaking the current + # logcontext to the reactor + context.run_in_background(callback, *args, **kwargs) + finally: + if track_for_shutdown_cancellation: + # We still want to remove the call from the tracking map. Even if + # the callback raises an exception. + self._call_id_to_delayed_call.pop(call_id) + + return wrapped_callback + + if call_later_cancel_on_shutdown: + call_id = self._delayed_call_id + self._delayed_call_id = self._delayed_call_id + 1 + + # We can ignore the lint here since this class is the one location callLater + # should be called. + call = self._reactor.callLater( + delay, create_wrapped_callback(True), *args, **kwargs + ) # type: ignore[call-later-not-tracked] + call = DelayedCallWrapper(call, call_id, self) + self._call_id_to_delayed_call[call_id] = call + return call + else: + # We can ignore the lint here since this class is the one location callLater should + # be called. + return self._reactor.callLater( + delay, create_wrapped_callback(False), *args, **kwargs + ) # type: ignore[call-later-not-tracked] def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> None: try: @@ -221,6 +319,24 @@ def cancel_call_later(self, timer: IDelayedCall, ignore_errs: bool = False) -> N if not ignore_errs: raise + def cancel_all_delayed_calls(self, ignore_errs: bool = True) -> None: + """ + Stop all scheduled calls that were marked with `cancel_on_shutdown` when they were created. + + Args: + ignore_errs: Whether to re-raise errors encountered when cancelling the + scheduled call. + """ + # We make a copy here since calling `cancel()` on a delayed_call + # will result in the call removing itself from the map mid-iteration. + for call in list(self._call_id_to_delayed_call.values()): + try: + call.cancel() + except Exception: + if not ignore_errs: + raise + self._call_id_to_delayed_call.clear() + def call_when_running( self, callback: Callable[P, object], @@ -285,7 +401,7 @@ def add_system_event_trigger( callback: Callable[P, object], *args: P.args, **kwargs: P.kwargs, - ) -> None: + ) -> Any: """ Add a function to be called when a system event occurs. @@ -299,6 +415,9 @@ def add_system_event_trigger( callback: Function to call *args: Postional arguments to pass to function. **kwargs: Key arguments to pass to function. + + Returns: + an ID that can be used to remove this call with `reactor.removeSystemEventTrigger`. """ def wrapped_callback(*args: Any, **kwargs: Any) -> None: @@ -334,6 +453,50 @@ def wrapped_callback(*args: Any, **kwargs: Any) -> None: # We can ignore the lint here since this class is the one location # `addSystemEventTrigger` should be called. - self._reactor.addSystemEventTrigger( + return self._reactor.addSystemEventTrigger( phase, event_type, wrapped_callback, *args, **kwargs ) # type: ignore[prefer-synapse-clock-add-system-event-trigger] + + +@implementer(IDelayedCall) +class DelayedCallWrapper: + """Wraps an `IDelayedCall` so that we can intercept the call to `cancel()` and + properly cleanup the delayed call from the tracking map of the `Clock`. + + args: + delayed_call: The actual `IDelayedCall` + call_id: Unique identifier for this delayed call + clock: The clock instance tracking this call + """ + + def __init__(self, delayed_call: IDelayedCall, call_id: int, clock: Clock): + self.delayed_call = delayed_call + self.call_id = call_id + self.clock = clock + + def cancel(self) -> None: + """Remove the call from the tracking map and propagate the call to the + underlying delayed_call. + """ + self.delayed_call.cancel() + try: + self.clock._call_id_to_delayed_call.pop(self.call_id) + except KeyError: + # If the delayed call isn't being tracked anymore we can just move on. + pass + + def getTime(self) -> float: + """Propagate the call to the underlying delayed_call.""" + return self.delayed_call.getTime() + + def delay(self, secondsLater: float) -> None: + """Propagate the call to the underlying delayed_call.""" + self.delayed_call.delay(secondsLater) + + def reset(self, secondsFromNow: float) -> None: + """Propagate the call to the underlying delayed_call.""" + self.delayed_call.reset(secondsFromNow) + + def active(self) -> bool: + """Propagate the call to the underlying delayed_call.""" + return self.delayed_call.active() diff --git a/synapse/util/distributor.py b/synapse/util/distributor.py index f48ae3373ce..dec6536e4e1 100644 --- a/synapse/util/distributor.py +++ b/synapse/util/distributor.py @@ -20,6 +20,7 @@ # import logging from typing import ( + TYPE_CHECKING, Any, Awaitable, Callable, @@ -36,10 +37,13 @@ from twisted.internet import defer from synapse.logging.context import make_deferred_yieldable, run_in_background -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.types import UserID from synapse.util.async_helpers import maybe_awaitable +if TYPE_CHECKING: + from synapse.server import HomeServer + + logger = logging.getLogger(__name__) @@ -58,13 +62,13 @@ class Distributor: model will do for today. """ - def __init__(self, server_name: str) -> None: + def __init__(self, hs: "HomeServer") -> None: """ Args: server_name: The homeserver name of the server (used to label metrics) (this should be `hs.hostname`). """ - self.server_name = server_name + self.hs = hs self.signals: Dict[str, Signal] = {} self.pre_registration: Dict[str, List[Callable]] = {} @@ -97,8 +101,8 @@ def fire(self, name: str, *args: Any, **kwargs: Any) -> None: if name not in self.signals: raise KeyError("%r does not have a signal named %s" % (self, name)) - run_as_background_process( - name, self.server_name, self.signals[name].fire, *args, **kwargs + self.hs.run_as_background_process( + name, self.signals[name].fire, *args, **kwargs ) diff --git a/synapse/util/metrics.py b/synapse/util/metrics.py index c4f3c8b9653..7b6ad0e459c 100644 --- a/synapse/util/metrics.py +++ b/synapse/util/metrics.py @@ -293,21 +293,46 @@ class DynamicCollectorRegistry(CollectorRegistry): def __init__(self) -> None: super().__init__() - self._pre_update_hooks: Dict[str, Callable[[], None]] = {} + self._server_name_to_pre_update_hooks: Dict[ + str, Dict[str, Callable[[], None]] + ] = {} + """ + Mapping of server name to a mapping of metric name to metric pre-update + hook + """ def collect(self) -> Generator[Metric, None, None]: """ Collects metrics, calling pre-update hooks first. """ - for pre_update_hook in self._pre_update_hooks.values(): - pre_update_hook() + for pre_update_hooks in self._server_name_to_pre_update_hooks.values(): + for pre_update_hook in pre_update_hooks.values(): + pre_update_hook() yield from super().collect() - def register_hook(self, metric_name: str, hook: Callable[[], None]) -> None: + def register_hook( + self, server_name: str, metric_name: str, hook: Callable[[], None] + ) -> None: """ Registers a hook that is called before metric collection. """ - self._pre_update_hooks[metric_name] = hook + server_hooks = self._server_name_to_pre_update_hooks.setdefault(server_name, {}) + if server_hooks.get(metric_name) is not None: + # TODO: This should be an `assert` since registering the same metric name + # multiple times will clobber the old metric. + # We currently rely on this behaviour as we instantiate multiple + # `SyncRestServlet`, one per listener, and in the `__init__` we setup a new + # LruCache. + # Once the above behaviour is changed, this should be changed to an `assert`. + logger.error( + "Metric named %s already registered for server %s", + metric_name, + server_name, + ) + server_hooks[metric_name] = hook + + def unregister_hooks_for_homeserver(self, server_name: str) -> None: + self._server_name_to_pre_update_hooks.pop(server_name, None) diff --git a/synapse/util/ratelimitutils.py b/synapse/util/ratelimitutils.py index 695eb462bfe..756677fe6c4 100644 --- a/synapse/util/ratelimitutils.py +++ b/synapse/util/ratelimitutils.py @@ -419,4 +419,7 @@ def start_next_request() -> None: except KeyError: pass - self.clock.call_later(0.0, start_next_request) + self.clock.call_later( + 0.0, + start_next_request, + ) diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py index 42a0cc7aa81..96fe2bd5664 100644 --- a/synapse/util/retryutils.py +++ b/synapse/util/retryutils.py @@ -24,7 +24,6 @@ from typing import TYPE_CHECKING, Any, Optional, Type from synapse.api.errors import CodeMessageException -from synapse.metrics.background_process_metrics import run_as_background_process from synapse.storage import DataStore from synapse.types import StrCollection from synapse.util.clock import Clock @@ -32,6 +31,7 @@ if TYPE_CHECKING: from synapse.notifier import Notifier from synapse.replication.tcp.handler import ReplicationCommandHandler + from synapse.server import HomeServer logger = logging.getLogger(__name__) @@ -62,6 +62,7 @@ async def get_retry_limiter( *, destination: str, our_server_name: str, + hs: "HomeServer", clock: Clock, store: DataStore, ignore_backoff: bool = False, @@ -124,6 +125,7 @@ async def get_retry_limiter( return RetryDestinationLimiter( destination=destination, our_server_name=our_server_name, + hs=hs, clock=clock, store=store, failure_ts=failure_ts, @@ -163,6 +165,7 @@ def __init__( *, destination: str, our_server_name: str, + hs: "HomeServer", clock: Clock, store: DataStore, failure_ts: Optional[int], @@ -181,6 +184,7 @@ def __init__( Args: destination our_server_name: Our homeserver name (used to label metrics) (`hs.hostname`) + hs: The homeserver instance clock store failure_ts: when this destination started failing (in ms since @@ -197,6 +201,7 @@ def __init__( error code. """ self.our_server_name = our_server_name + self.hs = hs self.clock = clock self.store = store self.destination = destination @@ -331,6 +336,4 @@ async def store_retry_timings() -> None: logger.exception("Failed to store destination_retry_timings") # we deliberately do this in the background. - run_as_background_process( - "store_retry_timings", self.our_server_name, store_retry_timings - ) + self.hs.run_as_background_process("store_retry_timings", store_retry_timings) diff --git a/synapse/util/task_scheduler.py b/synapse/util/task_scheduler.py index 0539989320f..7443d4e097e 100644 --- a/synapse/util/task_scheduler.py +++ b/synapse/util/task_scheduler.py @@ -32,7 +32,6 @@ ) from synapse.metrics import SERVER_NAME_LABEL, LaterGauge from synapse.metrics.background_process_metrics import ( - run_as_background_process, wrap_as_background_process, ) from synapse.types import JsonMapping, ScheduledTask, TaskStatus @@ -107,10 +106,8 @@ class TaskScheduler: OCCASIONAL_REPORT_INTERVAL_MS = 5 * 60 * 1000 # 5 minutes def __init__(self, hs: "HomeServer"): - self._hs = hs - self.server_name = ( - hs.hostname - ) # nb must be called this for @wrap_as_background_process + self.hs = hs # nb must be called this for @wrap_as_background_process + self.server_name = hs.hostname self._store = hs.get_datastores().main self._clock = hs.get_clock() self._running_tasks: Set[str] = set() @@ -215,7 +212,7 @@ async def schedule_task( if self._run_background_tasks: self._launch_scheduled_tasks() else: - self._hs.get_replication_command_handler().send_new_active_task(task.id) + self.hs.get_replication_command_handler().send_new_active_task(task.id) return task.id @@ -362,7 +359,7 @@ async def inner() -> None: finally: self._launching_new_tasks = False - run_as_background_process("launch_scheduled_tasks", self.server_name, inner) + self.hs.run_as_background_process("launch_scheduled_tasks", inner) @wrap_as_background_process("clean_scheduled_tasks") async def _clean_scheduled_tasks(self) -> None: @@ -473,7 +470,10 @@ async def wrapper() -> None: occasional_status_call.stop() # Try launch a new task since we've finished with this one. - self._clock.call_later(0.1, self._launch_scheduled_tasks) + self._clock.call_later( + 0.1, + self._launch_scheduled_tasks, + ) if len(self._running_tasks) >= TaskScheduler.MAX_CONCURRENT_RUNNING_TASKS: return @@ -493,4 +493,4 @@ async def wrapper() -> None: self._running_tasks.add(task.id) await self.update_task(task.id, status=TaskStatus.ACTIVE) - run_as_background_process(f"task-{task.action}", self.server_name, wrapper) + self.hs.run_as_background_process(f"task-{task.action}", wrapper) diff --git a/synmark/suites/logging.py b/synmark/suites/logging.py index c3f3cceaa65..cf9c836e062 100644 --- a/synmark/suites/logging.py +++ b/synmark/suites/logging.py @@ -86,7 +86,9 @@ class _logging: hs_config = Config() # To be able to sleep. - clock = Clock(reactor, server_name=hs_config.server.server_name) + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. + clock = Clock(reactor, server_name=hs_config.server.server_name) # type: ignore[multiple-internal-clocks] errors = StringIO() publisher = LogPublisher() diff --git a/synmark/suites/lrucache.py b/synmark/suites/lrucache.py index 6314035bd7c..830a3daa8fc 100644 --- a/synmark/suites/lrucache.py +++ b/synmark/suites/lrucache.py @@ -23,14 +23,19 @@ from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock async def main(reactor: ISynapseReactor, loops: int) -> float: """ Benchmark `loops` number of insertions into LruCache without eviction. """ + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. cache: LruCache[int, bool] = LruCache( - max_size=loops, server_name="synmark_benchmark" + max_size=loops, + clock=Clock(reactor, server_name="synmark_benchmark"), # type: ignore[multiple-internal-clocks] + server_name="synmark_benchmark", ) start = perf_counter() diff --git a/synmark/suites/lrucache_evict.py b/synmark/suites/lrucache_evict.py index b8cd5896970..c67e0c90017 100644 --- a/synmark/suites/lrucache_evict.py +++ b/synmark/suites/lrucache_evict.py @@ -23,6 +23,7 @@ from synapse.types import ISynapseReactor from synapse.util.caches.lrucache import LruCache +from synapse.util.clock import Clock async def main(reactor: ISynapseReactor, loops: int) -> float: @@ -30,8 +31,12 @@ async def main(reactor: ISynapseReactor, loops: int) -> float: Benchmark `loops` number of insertions into LruCache where half of them are evicted. """ + # Ignore linter error here since we are running outside of the context of a + # Synapse `HomeServer`. cache: LruCache[int, bool] = LruCache( - max_size=loops // 2, server_name="synmark_benchmark" + max_size=loops // 2, + clock=Clock(reactor, server_name="synmark_benchmark"), # type: ignore[multiple-internal-clocks] + server_name="synmark_benchmark", ) start = perf_counter() diff --git a/tests/app/test_homeserver_shutdown.py b/tests/app/test_homeserver_shutdown.py new file mode 100644 index 00000000000..d8119ba3102 --- /dev/null +++ b/tests/app/test_homeserver_shutdown.py @@ -0,0 +1,193 @@ +# +# This file is licensed under the Affero General Public License (AGPL) version 3. +# +# Copyright (C) 2025 New Vector, Ltd +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU Affero General Public License as +# published by the Free Software Foundation, either version 3 of the +# License, or (at your option) any later version. +# +# See the GNU Affero General Public License for more details: +# . +# +# Originally licensed under the Apache License, Version 2.0: +# . +# +# [This file includes modifications made by New Vector Limited] +# +# + +import gc +import weakref + +from synapse.app.homeserver import SynapseHomeServer +from synapse.storage.background_updates import UpdaterStatus + +from tests.server import ( + cleanup_test_reactor_system_event_triggers, + get_clock, + setup_test_homeserver, +) +from tests.unittest import HomeserverTestCase + + +class HomeserverCleanShutdownTestCase(HomeserverTestCase): + def setUp(self) -> None: + pass + + # NOTE: ideally we'd have another test to ensure we properly shutdown with + # real in-flight HTTP requests since those result in additional resources being + # setup that hold strong references to the homeserver. + # Mainly, the HTTP channel created by a real TCP connection from client to server + # is held open between requests and care needs to be taken in Twisted to ensure it is properly + # closed in a timely manner during shutdown. Simulating this behaviour in a unit test + # won't be as good as a proper integration test in complement. + + def test_clean_homeserver_shutdown(self) -> None: + """Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected""" + self.reactor, self.clock = get_clock() + self.hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + reactor=self.reactor, + homeserver_to_use=SynapseHomeServer, + clock=self.clock, + ) + self.wait_for_background_updates() + + hs_ref = weakref.ref(self.hs) + + # Run the reactor so any `callWhenRunning` functions can be cleared out. + self.reactor.run() + # This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor` + # we use in tests doesn't handle this properly (see doc comment) + cleanup_test_reactor_system_event_triggers(self.reactor) + + # Cleanup the homeserver. + self.get_success(self.hs.shutdown()) + + # Cleanup the internal reference in our test case + del self.hs + + # Force garbage collection. + gc.collect() + + # Ensure the `HomeServer` hs been garbage collected by attempting to use the + # weakref to it. + if hs_ref() is not None: + self.fail("HomeServer reference should not be valid at this point") + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) + + def test_clean_homeserver_shutdown_mid_background_updates(self) -> None: + """Ensure the `SynapseHomeServer` can be fully shutdown and garbage collected + before background updates have completed""" + self.reactor, self.clock = get_clock() + self.hs = setup_test_homeserver( + cleanup_func=self.addCleanup, + reactor=self.reactor, + homeserver_to_use=SynapseHomeServer, + clock=self.clock, + ) + + # Pump the background updates by a single iteration, just to ensure any extra + # resources it uses have been started. + store = weakref.proxy(self.hs.get_datastores().main) + self.get_success(store.db_pool.updates.do_next_background_update(False), by=0.1) + + hs_ref = weakref.ref(self.hs) + + # Run the reactor so any `callWhenRunning` functions can be cleared out. + self.reactor.run() + # This would normally happen as part of `HomeServer.shutdown` but the `MemoryReactor` + # we use in tests doesn't handle this properly (see doc comment) + cleanup_test_reactor_system_event_triggers(self.reactor) + + # Ensure the background updates are not complete. + self.assertNotEqual(store.db_pool.updates.get_status(), UpdaterStatus.COMPLETE) + + # Cleanup the homeserver. + self.get_success(self.hs.shutdown()) + + # Cleanup the internal reference in our test case + del self.hs + + # Force garbage collection. + gc.collect() + + # Ensure the `HomeServer` hs been garbage collected by attempting to use the + # weakref to it. + if hs_ref() is not None: + self.fail("HomeServer reference should not be valid at this point") + + # To help debug this test when it fails, it is useful to leverage the + # `objgraph` module. + # The following code serves as an example of what I have found to be useful + # when tracking down references holding the `SynapseHomeServer` in memory: + # + # all_objects = gc.get_objects() + # for obj in all_objects: + # try: + # # These are a subset of types that are typically involved with + # # holding the `HomeServer` in memory. You may want to inspect + # # other types as well. + # if isinstance(obj, DataStore): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # db_obj = obj + # if isinstance(obj, SynapseHomeServer): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # synapse_hs = obj + # if isinstance(obj, SynapseSite): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # sysite = obj + # if isinstance(obj, DatabasePool): + # print(sys.getrefcount(obj), "refs to", obj) + # if not isinstance(obj, weakref.ProxyType): + # dbpool = obj + # except Exception: + # pass + # + # print(sys.getrefcount(hs_ref()), "refs to", hs_ref()) + # + # # The following values for `max_depth` and `too_many` have been found to + # # render a useful amount of information without taking an overly long time + # # to generate the result. + # objgraph.show_backrefs(synapse_hs, max_depth=10, too_many=10) diff --git a/tests/appservice/test_scheduler.py b/tests/appservice/test_scheduler.py index 0385190f349..f4490a1a794 100644 --- a/tests/appservice/test_scheduler.py +++ b/tests/appservice/test_scheduler.py @@ -167,8 +167,9 @@ def test_single_service_up_txn_not_sent(self) -> None: ) -class ApplicationServiceSchedulerRecovererTestCase(unittest.TestCase): +class ApplicationServiceSchedulerRecovererTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: + super().setUp() self.reactor, self.clock = get_clock() self.as_api = Mock() self.store = Mock() @@ -176,6 +177,7 @@ def setUp(self) -> None: self.callback = AsyncMock() self.recoverer = _Recoverer( server_name="test_server", + hs=self.hs, clock=self.clock, as_api=self.as_api, store=self.store, diff --git a/tests/config/test_cache.py b/tests/config/test_cache.py index f56d6044a94..74db2dab087 100644 --- a/tests/config/test_cache.py +++ b/tests/config/test_cache.py @@ -24,6 +24,7 @@ from synapse.types import JsonDict from synapse.util.caches.lrucache import LruCache +from tests.server import get_clock from tests.unittest import TestCase @@ -32,6 +33,7 @@ def setUp(self) -> None: # Reset caches before each test since there's global state involved. self.config = CacheConfig(RootConfig()) self.config.reset() + _, self.clock = get_clock() def tearDown(self) -> None: # Also reset the caches after each test to leave state pristine. @@ -75,7 +77,9 @@ def test_individual_instantiated_before_config_load(self) -> None: the default cache size in the interim, and then resized once the config is loaded. """ - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) @@ -96,7 +100,9 @@ def test_individual_instantiated_after_config_load(self) -> None: self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 200) @@ -106,7 +112,9 @@ def test_global_instantiated_before_config_load(self) -> None: the default cache size in the interim, and then resized to the new default cache size once the config is loaded. """ - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 50) @@ -126,7 +134,9 @@ def test_global_instantiated_after_config_load(self) -> None: self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache: LruCache = LruCache(max_size=100, server_name="test_server") + cache: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("foo", cache_resize_callback=cache.set_cache_factor) self.assertEqual(cache.max_size, 150) @@ -145,15 +155,21 @@ def test_cache_with_asterisk_in_name(self) -> None: self.config.read_config(config, config_dir_path="", data_dir_path="") self.config.resize_all_caches() - cache_a: LruCache = LruCache(max_size=100, server_name="test_server") + cache_a: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*cache_a*", cache_resize_callback=cache_a.set_cache_factor) self.assertEqual(cache_a.max_size, 200) - cache_b: LruCache = LruCache(max_size=100, server_name="test_server") + cache_b: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*Cache_b*", cache_resize_callback=cache_b.set_cache_factor) self.assertEqual(cache_b.max_size, 300) - cache_c: LruCache = LruCache(max_size=100, server_name="test_server") + cache_c: LruCache = LruCache( + max_size=100, clock=self.clock, server_name="test_server" + ) add_resizable_cache("*cache_c*", cache_resize_callback=cache_c.set_cache_factor) self.assertEqual(cache_c.max_size, 200) @@ -168,6 +184,7 @@ def test_apply_cache_factor_from_config(self) -> None: cache: LruCache = LruCache( max_size=self.config.event_cache_size, + clock=self.clock, apply_cache_factor_from_config=False, server_name="test_server", ) diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py index 6516b7db174..df36185b99e 100644 --- a/tests/handlers/test_appservice.py +++ b/tests/handlers/test_appservice.py @@ -19,7 +19,17 @@ # # -from typing import Dict, Iterable, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Callable, + Dict, + Iterable, + List, + Optional, + TypeVar, +) from unittest.mock import AsyncMock, Mock from parameterized import parameterized @@ -36,6 +46,7 @@ TransactionUnusedFallbackKeys, ) from synapse.handlers.appservice import ApplicationServicesHandler +from synapse.metrics.background_process_metrics import run_as_background_process from synapse.rest.client import login, receipts, register, room, sendtodevice from synapse.server import HomeServer from synapse.types import ( @@ -53,6 +64,11 @@ from tests.test_utils import event_injection from tests.unittest import override_config +if TYPE_CHECKING: + from typing_extensions import LiteralString + +R = TypeVar("R") + class AppServiceHandlerTestCase(unittest.TestCase): """Tests the ApplicationServicesHandler.""" @@ -64,6 +80,17 @@ def setUp(self) -> None: self.reactor, self.clock = get_clock() hs = Mock() + + def test_run_as_background_process( + desc: "LiteralString", + func: Callable[..., Awaitable[Optional[R]]], + *args: Any, + **kwargs: Any, + ) -> "defer.Deferred[Optional[R]]": + # Ignore linter error as this is used only for testing purposes (i.e. outside of Synapse). + return run_as_background_process(desc, "test_server", func, *args, **kwargs) # type: ignore[untracked-background-process] + + hs.run_as_background_process = test_run_as_background_process hs.get_datastores.return_value = Mock(main=self.mock_store) self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None) self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None) diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 4d2807151ef..90c185bc3d4 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -79,15 +79,17 @@ def make_homeserver( ) -> HomeServer: # we mock out the keyring so as to skip the authentication check on the # federation API call. - mock_keyring = Mock(spec=["verify_json_for_server"]) + mock_keyring = Mock(spec=["verify_json_for_server", "shutdown"]) mock_keyring.verify_json_for_server = AsyncMock(return_value=True) + mock_keyring.shutdown = Mock() # we mock out the federation client too self.mock_federation_client = AsyncMock(spec=["put_json"]) self.mock_federation_client.put_json.return_value = (200, "OK") self.mock_federation_client.agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", - reactor=reactor, + reactor=self.reactor, + clock=self.clock, tls_client_options_factory=None, user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, @@ -96,7 +98,7 @@ def make_homeserver( ) # the tests assume that we are starting at unix time 1000 - reactor.pump((1000,)) + self.reactor.pump((1000,)) self.mock_hs_notifier = Mock() hs = self.setup_test_homeserver( diff --git a/tests/http/federation/test_matrix_federation_agent.py b/tests/http/federation/test_matrix_federation_agent.py index a5e1b7c2849..c66ca489a46 100644 --- a/tests/http/federation/test_matrix_federation_agent.py +++ b/tests/http/federation/test_matrix_federation_agent.py @@ -65,7 +65,7 @@ from tests import unittest from tests.http import dummy_address, get_test_ca_cert_file, wrap_server_factory_for_tls -from tests.server import FakeTransport, ThreadedMemoryReactorClock +from tests.server import FakeTransport, get_clock from tests.utils import checked_cast, default_config logger = logging.getLogger(__name__) @@ -73,7 +73,7 @@ class MatrixFederationAgentTests(unittest.TestCase): def setUp(self) -> None: - self.reactor = ThreadedMemoryReactorClock() + self.reactor, self.clock = get_clock() self.mock_resolver = AsyncMock(spec=SrvResolver) @@ -98,6 +98,7 @@ def setUp(self) -> None: self.well_known_resolver = WellKnownResolver( server_name="OUR_STUB_HOMESERVER_NAME", reactor=self.reactor, + clock=self.clock, agent=Agent(self.reactor, contextFactory=self.tls_factory), user_agent=b"test-agent", well_known_cache=self.well_known_cache, @@ -280,6 +281,7 @@ def _make_agent(self) -> MatrixFederationAgent: return MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=cast(ISynapseReactor, self.reactor), + clock=self.clock, tls_client_options_factory=self.tls_factory, user_agent=b"test-agent", # Note that this is unused since _well_known_resolver is provided. ip_allowlist=IPSet(), @@ -1024,6 +1026,7 @@ def test_get_well_known_unsigned_cert(self) -> None: agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=self.reactor, + clock=self.clock, tls_client_options_factory=tls_factory, user_agent=b"test-agent", # This is unused since _well_known_resolver is passed below. ip_allowlist=IPSet(), @@ -1033,6 +1036,7 @@ def test_get_well_known_unsigned_cert(self) -> None: _well_known_resolver=WellKnownResolver( server_name="OUR_STUB_HOMESERVER_NAME", reactor=cast(ISynapseReactor, self.reactor), + clock=self.clock, agent=Agent(self.reactor, contextFactory=tls_factory), user_agent=b"test-agent", well_known_cache=self.well_known_cache, diff --git a/tests/logging/test_opentracing.py b/tests/logging/test_opentracing.py index 057ca0db456..31cdfacd2cd 100644 --- a/tests/logging/test_opentracing.py +++ b/tests/logging/test_opentracing.py @@ -163,7 +163,9 @@ def test_overlapping_spans(self) -> None: # implements `ISynapseThreadlessReactor` (combination of the normal Twisted # Reactor/Clock interfaces), via inheritance from # `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock` - clock = Clock( + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. + clock = Clock( # type: ignore[multiple-internal-clocks] reactor, # type: ignore[arg-type] server_name="test_server", ) @@ -234,7 +236,9 @@ def test_run_in_background_active_scope_still_available(self) -> None: # implements `ISynapseThreadlessReactor` (combination of the normal Twisted # Reactor/Clock interfaces), via inheritance from # `twisted.internet.testing.MemoryReactor` and `twisted.internet.testing.Clock` - clock = Clock( + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. + clock = Clock( # type: ignore[multiple-internal-clocks] reactor, # type: ignore[arg-type] server_name="test_server", ) diff --git a/tests/metrics/test_metrics.py b/tests/metrics/test_metrics.py index 832e9917305..b3f42c76f18 100644 --- a/tests/metrics/test_metrics.py +++ b/tests/metrics/test_metrics.py @@ -164,7 +164,10 @@ def test_cache_metric(self) -> None: """ CACHE_NAME = "cache_metrics_test_fgjkbdfg" cache: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name=self.hs.hostname, max_entries=777 + name=CACHE_NAME, + clock=self.hs.get_clock(), + server_name=self.hs.hostname, + max_entries=777, ) metrics_map = get_latest_metrics() @@ -212,10 +215,10 @@ def test_cache_metric_multiple_servers(self) -> None: """ CACHE_NAME = "cache_metric_multiple_servers_test" cache1: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name="hs1", max_entries=777 + name=CACHE_NAME, clock=self.clock, server_name="hs1", max_entries=777 ) cache2: DeferredCache[str, str] = DeferredCache( - name=CACHE_NAME, server_name="hs2", max_entries=777 + name=CACHE_NAME, clock=self.clock, server_name="hs2", max_entries=777 ) metrics_map = get_latest_metrics() diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 36d32139088..1a2dab4c7d7 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -173,7 +173,13 @@ def handle_http_replication_attempt(self) -> SynapseRequest: # Set up the server side protocol server_address = IPv4Address("TCP", host, port) - channel = self.site.buildProtocol((host, port)) + # The type ignore is here because mypy doesn't think the host/port tuple is of + # the correct type, even though it is the exact example given for + # `twisted.internet.interfaces.IAddress`. + # Mypy was happy with the type before we overrode `buildProtocol` in + # `SynapseSite`, probably because there was enough inheritance indirection before + # withe the argument not having a type associated with it. + channel = self.site.buildProtocol((host, port)) # type: ignore[arg-type] # hook into the channel's request factory so that we can keep a record # of the requests @@ -185,7 +191,7 @@ def request_factory(*args: Any, **kwargs: Any) -> SynapseRequest: requests.append(request) return request - channel.requestFactory = request_factory + channel.requestFactory = request_factory # type: ignore[method-assign] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( @@ -427,7 +433,7 @@ def _handle_http_replication_attempt(self, hs: HomeServer, repl_port: int) -> No # Set up the server side protocol server_address = IPv4Address("TCP", host, port) - channel = self._hs_to_site[hs].buildProtocol((host, port)) + channel = self._hs_to_site[hs].buildProtocol((host, port)) # type: ignore[arg-type] # Connect client to server and vice versa. client_to_server_transport = FakeTransport( diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 92259f2542a..3896e0ce8a0 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -66,10 +66,11 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): def setUp(self) -> None: super().setUp() - reactor, _ = get_clock() + reactor, clock = get_clock() self.matrix_federation_agent = MatrixFederationAgent( server_name="OUR_STUB_HOMESERVER_NAME", reactor=reactor, + clock=clock, tls_client_options_factory=None, user_agent=b"SynapseInTrialTest/0.0.0", ip_allowlist=None, diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py index 8d5d0cce9a6..1cb898673bc 100644 --- a/tests/replication/test_module_cache_invalidation.py +++ b/tests/replication/test_module_cache_invalidation.py @@ -24,6 +24,7 @@ from synapse.module_api import cached from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import get_clock logger = logging.getLogger(__name__) @@ -36,6 +37,7 @@ class TestCache: current_value = FIRST_VALUE server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() async def cached_function(self, user_id: str) -> str: diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py index c22c1a6612c..bb83988d768 100644 --- a/tests/rest/client/test_transactions.py +++ b/tests/rest/client/test_transactions.py @@ -93,8 +93,10 @@ def test_logcontexts_with_async_result( ) -> Generator["defer.Deferred[Any]", object, None]: @defer.inlineCallbacks def cb() -> Generator["defer.Deferred[object]", object, Tuple[int, JsonDict]]: + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes. yield defer.ensureDeferred( - Clock(reactor, server_name="test_server").sleep(0) + Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks] ) return 1, {} diff --git a/tests/server.py b/tests/server.py index 226bdf4bbe9..a9a53eb8a42 100644 --- a/tests/server.py +++ b/tests/server.py @@ -28,6 +28,7 @@ import time import uuid import warnings +import weakref from collections import deque from io import SEEK_END, BytesIO from typing import ( @@ -56,7 +57,7 @@ import twisted from twisted.enterprise import adbapi -from twisted.internet import address, tcp, threads, udp +from twisted.internet import address, defer, tcp, threads, udp from twisted.internet._resolver import SimpleResolverComplexifier from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed @@ -524,6 +525,19 @@ def getHostByName( # overwrite it again. self.nameResolver = SimpleResolverComplexifier(FakeResolver()) + def run(self) -> None: + """ + Override the call from `MemoryReactorClock` to add an additional step that + cleans up any `whenRunningHooks` that have been called. + This is necessary for a clean shutdown to occur as these hooks can hold + references to the `SynapseHomeServer`. + """ + super().run() + + # `MemoryReactorClock` never clears the hooks that have already been called. + # So manually clear the hooks here after they have been run. + self.whenRunningHooks.clear() + def installNameResolver(self, resolver: IHostnameResolver) -> IHostnameResolver: raise NotImplementedError() @@ -649,6 +663,19 @@ def advance(self, amount: float) -> None: super().advance(0) +def cleanup_test_reactor_system_event_triggers( + reactor: ThreadedMemoryReactorClock, +) -> None: + """Cleanup any registered system event triggers. + The `twisted.internet.test.ThreadedMemoryReactor` does not implement + `removeSystemEventTrigger` so won't clean these triggers up on it's own properly. + When trying to override `removeSystemEventTrigger` in `ThreadedMemoryReactorClock` + in order to implement this functionality, twisted complains about the reactor being + unclean and fails some tests. + """ + reactor.triggers.clear() + + def validate_connector(connector: tcp.Connector, expected_ip: str) -> None: """Try to validate the obtained connector as it would happen when synapse is running and the conection will be established. @@ -780,13 +807,18 @@ def _(res: Any) -> None: d: "Deferred[None]" = Deferred() d.addCallback(lambda x: function(*args, **kwargs)) d.addBoth(_) - self._reactor.callLater(0, d.callback, True) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0, d.callback, True) # type: ignore[call-later-not-tracked] return d def get_clock() -> Tuple[ThreadedMemoryReactorClock, Clock]: + # Ignore the linter error since this is an expected usage of creating a `Clock` for + # testing purposes. reactor = ThreadedMemoryReactorClock() - hs_clock = Clock(reactor, server_name="test_server") + hs_clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] return reactor, hs_clock @@ -898,10 +930,16 @@ def _produce() -> None: # some implementations of IProducer (for example, FileSender) # don't return a deferred. d = maybeDeferred(self.producer.resumeProducing) - d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + d.addCallback(lambda x: self._reactor.callLater(0.1, _produce)) # type: ignore[call-later-not-tracked,call-overload] if not streaming: - self._reactor.callLater(0.0, _produce) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, _produce) # type: ignore[call-later-not-tracked] def write(self, byt: bytes) -> None: if self.disconnecting: @@ -913,7 +951,10 @@ def write(self, byt: bytes) -> None: # TLSMemoryBIOProtocol) get very confused if a read comes back while they are # still doing a write. Doing a callLater here breaks the cycle. if self.autoflush: - self._reactor.callLater(0.0, self.flush) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked] def writeSequence(self, seq: Iterable[bytes]) -> None: for x in seq: @@ -943,7 +984,10 @@ def flush(self, maxbytes: Optional[int] = None) -> None: self.buffer = self.buffer[len(to_write) :] if self.buffer and self.autoflush: - self._reactor.callLater(0.0, self.flush) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + self._reactor.callLater(0.0, self.flush) # type: ignore[call-later-not-tracked] if not self.buffer and self.disconnecting: logger.info("FakeTransport: Buffer now empty, completing disconnect") @@ -1020,7 +1064,7 @@ class TestHomeServer(HomeServer): def setup_test_homeserver( *, - cleanup_func: Callable[[Callable[[], None]], None], + cleanup_func: Callable[[Callable[[], Optional["Deferred[None]"]]], None], server_name: str = "test", config: Optional[HomeServerConfig] = None, reactor: Optional[ISynapseReactor] = None, @@ -1035,8 +1079,10 @@ def setup_test_homeserver( If no datastore is supplied, one is created and given to the homeserver. Args: - cleanup_func: The function used to register a cleanup routine for after the - test. + cleanup_func : The function used to register a cleanup routine for + after the test. If the function returns a Deferred, the + test case will wait until the Deferred has fired before + proceeding to the next cleanup function. server_name: Homeserver name config: Homeserver config reactor: Twisted reactor @@ -1062,7 +1108,9 @@ def setup_test_homeserver( raise ConfigError("Must be a string", ("server_name",)) if "clock" not in extra_homeserver_attributes: - extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) + # Ignore `multiple-internal-clocks` linter error here since we are creating a `Clock` + # for testing purposes (i.e. outside of Synapse). + extra_homeserver_attributes["clock"] = Clock(reactor, server_name=server_name) # type: ignore[multiple-internal-clocks] config.caches.resize_all_caches() @@ -1154,8 +1202,21 @@ def setup_test_homeserver( reactor=reactor, ) - # Register the cleanup hook - cleanup_func(hs.cleanup) + # Capture the `hs` as a `weakref` here to ensure there is no scenario where uncalled + # cleanup functions result in holding the `hs` in memory. + cleanup_hs_ref = weakref.ref(hs) + + def shutdown_hs_on_cleanup() -> "Deferred[None]": + cleanup_hs = cleanup_hs_ref() + deferred: "Deferred[None]" = defer.succeed(None) + if cleanup_hs is not None: + deferred = defer.ensureDeferred(cleanup_hs.shutdown()) + return deferred + + # Register the cleanup hook for the homeserver. + # A full `hs.shutdown()` is necessary otherwise CI tests will fail while exhibiting + # strange behaviours. + cleanup_func(shutdown_hs_on_cleanup) # Install @cache_in_self attributes for key, val in extra_homeserver_attributes.items(): @@ -1184,14 +1245,18 @@ def setup_test_homeserver( hs.get_datastores().main.USE_DEDICATED_DB_THREADS_FOR_EVENT_FETCHING = False if USE_POSTGRES_FOR_TESTS: - database_pool = hs.get_datastores().databases[0] + # Capture the `database_pool` as a `weakref` here to ensure there is no scenario where uncalled + # cleanup functions result in holding the `hs` in memory. + database_pool = weakref.ref(hs.get_datastores().databases[0]) # We need to do cleanup on PostgreSQL def cleanup() -> None: import psycopg2 # Close all the db pools - database_pool._db_pool.close() + db_pool = database_pool() + if db_pool is not None: + db_pool._db_pool.close() dropped = False diff --git a/tests/test_distributor.py b/tests/test_distributor.py index 19dafe64ed4..2dd26833c8d 100644 --- a/tests/test_distributor.py +++ b/tests/test_distributor.py @@ -26,9 +26,10 @@ from . import unittest -class DistributorTestCase(unittest.TestCase): +class DistributorTestCase(unittest.HomeserverTestCase): def setUp(self) -> None: - self.dist = Distributor(server_name="test_server") + super().setUp() + self.dist = Distributor(hs=self.hs) def test_signal_dispatch(self) -> None: self.dist.declare("alert") diff --git a/tests/util/caches/test_deferred_cache.py b/tests/util/caches/test_deferred_cache.py index 7017d6d70ac..f0deb1554ef 100644 --- a/tests/util/caches/test_deferred_cache.py +++ b/tests/util/caches/test_deferred_cache.py @@ -26,20 +26,26 @@ from synapse.util.caches.deferred_cache import DeferredCache +from tests.server import get_clock from tests.unittest import TestCase class DeferredCacheTestCase(TestCase): + def setUp(self) -> None: + super().setUp() + + _, self.clock = get_clock() + def test_empty(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) with self.assertRaises(KeyError): cache.get("foo") def test_hit(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) cache.prefill("foo", 123) @@ -47,7 +53,7 @@ def test_hit(self) -> None: def test_hit_deferred(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) origin_d: "defer.Deferred[int]" = defer.Deferred() set_d = cache.set("k1", origin_d) @@ -72,7 +78,7 @@ def check1(r: str) -> str: def test_callbacks(self) -> None: """Invalidation callbacks are called at the right time""" cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) callbacks = set() @@ -107,7 +113,7 @@ def test_callbacks(self) -> None: def test_set_fail(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) callbacks = set() @@ -146,7 +152,7 @@ def test_set_fail(self) -> None: def test_get_immediate(self) -> None: cache: DeferredCache[str, int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) d1: "defer.Deferred[int]" = defer.Deferred() cache.set("key1", d1) @@ -164,7 +170,7 @@ def test_get_immediate(self) -> None: def test_invalidate(self) -> None: cache: DeferredCache[Tuple[str], int] = DeferredCache( - name="test", server_name="test_server" + name="test", clock=self.clock, server_name="test_server" ) cache.prefill(("foo",), 123) cache.invalidate(("foo",)) @@ -174,7 +180,7 @@ def test_invalidate(self) -> None: def test_invalidate_all(self) -> None: cache: DeferredCache[str, str] = DeferredCache( - name="testcache", server_name="test_server" + name="testcache", clock=self.clock, server_name="test_server" ) callback_record = [False, False] @@ -220,6 +226,7 @@ def record_callback(idx: int) -> None: def test_eviction(self) -> None: cache: DeferredCache[int, str] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=2, apply_cache_factor_from_config=False, @@ -238,6 +245,7 @@ def test_eviction(self) -> None: def test_eviction_lru(self) -> None: cache: DeferredCache[int, str] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=2, apply_cache_factor_from_config=False, @@ -260,6 +268,7 @@ def test_eviction_lru(self) -> None: def test_eviction_iterable(self) -> None: cache: DeferredCache[int, List[str]] = DeferredCache( name="test", + clock=self.clock, server_name="test_server", max_entries=3, apply_cache_factor_from_config=False, diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py index 3eb502f9023..0e3b6ae36b7 100644 --- a/tests/util/caches/test_descriptors.py +++ b/tests/util/caches/test_descriptors.py @@ -49,6 +49,7 @@ from synapse.util.caches.descriptors import _CacheContext, cached, cachedList from tests import unittest +from tests.server import get_clock from tests.test_utils import get_awaitable_result logger = logging.getLogger(__name__) @@ -56,7 +57,10 @@ def run_on_reactor() -> "Deferred[int]": d: "Deferred[int]" = Deferred() - cast(IReactorTime, reactor).callLater(0, d.callback, 0) + # mypy ignored here because: + # - this is part of the test infrastructure (outside of Synapse) so tracking + # these calls for for homeserver shutdown doesn't make sense. + cast(IReactorTime, reactor).callLater(0, d.callback, 0) # type: ignore[call-later-not-tracked] return make_deferred_yieldable(d) @@ -67,6 +71,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> str: @@ -102,6 +107,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached(num_args=1) def fn(self, arg1: int, arg2: int) -> str: @@ -148,6 +154,7 @@ def fn(self, arg1: int, arg2: int, arg3: int) -> str: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached obj = Cls() obj.mock.return_value = "fish" @@ -179,6 +186,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, kwarg1: int = 2) -> str: @@ -214,6 +222,7 @@ def test_cache_with_sync_exception(self) -> None: class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> NoReturn: @@ -239,6 +248,7 @@ class Cls: result: Optional[Deferred] = None call_count = 0 server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> Deferred: @@ -293,6 +303,7 @@ def test_cache_logcontexts(self) -> Deferred: class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> "Deferred[int]": @@ -337,6 +348,7 @@ def test_cache_logcontexts_with_exception(self) -> "Deferred[None]": class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> Deferred: @@ -381,6 +393,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int = 2, arg3: int = 3) -> str: @@ -419,6 +432,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached(iterable=True) def fn(self, arg1: int, arg2: int) -> Tuple[str, ...]: @@ -453,6 +467,7 @@ def test_cache_iterable_with_sync_exception(self) -> None: class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached(iterable=True) def fn(self, arg1: int) -> NoReturn: @@ -476,6 +491,7 @@ def test_invalidate_cascade(self) -> None: class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(cache_context=True) async def func1(self, key: str, cache_context: _CacheContext) -> int: @@ -504,6 +520,7 @@ def test_cancel(self) -> None: class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @cached() async def fn(self, arg1: int) -> str: @@ -537,6 +554,7 @@ def test_cancel_logcontexts(self) -> None: class Cls: inner_context_was_finished = False server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() async def fn(self, arg1: int) -> str: @@ -583,6 +601,7 @@ class CacheDecoratorTestCase(unittest.HomeserverTestCase): def test_passthrough(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -599,6 +618,7 @@ def test_hit(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -619,6 +639,7 @@ def test_invalidate(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -639,6 +660,7 @@ def func(self, key: str) -> str: def test_invalidate_missing(self) -> None: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -652,6 +674,7 @@ def test_max_entries(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(max_entries=10) def func(self, key: int) -> int: @@ -681,6 +704,7 @@ def test_prefill(self) -> None: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> "Deferred[int]": @@ -701,6 +725,7 @@ def test_invalidate_context(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -736,6 +761,7 @@ def test_eviction_context(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached(max_entries=2) def func(self, key: str) -> str: @@ -775,6 +801,7 @@ def test_double_get(self) -> Generator["Deferred[Any]", object, None]: class A: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def func(self, key: str) -> str: @@ -824,6 +851,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> None: @@ -890,6 +918,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int) -> None: @@ -934,6 +963,7 @@ class Cls: def __init__(self) -> None: self.mock = mock.Mock() self.server_name = "test_server" + _, self.clock = get_clock() # nb must be called this for @cached @descriptors.cached() def fn(self, arg1: int, arg2: int) -> None: @@ -975,6 +1005,7 @@ def test_cancel(self) -> None: class Cls: server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> None: @@ -1011,6 +1042,7 @@ def test_cancel_logcontexts(self) -> None: class Cls: inner_context_was_finished = False server_name = "test_server" # nb must be called this for @cached + _, clock = get_clock() # nb must be called this for @cached @cached() def fn(self, arg1: int) -> None: @@ -1055,6 +1087,7 @@ def test_num_args_mismatch(self) -> None: class Cls: server_name = "test_server" + _, clock = get_clock() # nb must be called this for @cached @descriptors.cached(tree=True) def fn(self, room_id: str, event_id: str) -> None: diff --git a/tests/util/test_async_helpers.py b/tests/util/test_async_helpers.py index 54f7b555117..fd8d576aea8 100644 --- a/tests/util/test_async_helpers.py +++ b/tests/util/test_async_helpers.py @@ -25,7 +25,6 @@ from twisted.internet import defer from twisted.internet.defer import CancelledError, Deferred, ensureDeferred -from twisted.internet.task import Clock from twisted.python.failure import Failure from synapse.logging.context import ( @@ -152,7 +151,7 @@ def test_cancellation(self) -> None: class TimeoutDeferredTest(TestCase): def setUp(self) -> None: - self.clock = Clock() + self.reactor, self.clock = get_clock() def test_times_out(self) -> None: """Basic test case that checks that the original deferred is cancelled and that @@ -165,12 +164,16 @@ def canceller(_d: Deferred) -> None: cancelled = True non_completing_d: Deferred = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + timing_out_d = timeout_deferred( + deferred=non_completing_d, + timeout=1.0, + clock=self.clock, + ) self.assertNoResult(timing_out_d) self.assertFalse(cancelled, "deferred was cancelled prematurely") - self.clock.pump((1.0,)) + self.reactor.pump((1.0,)) self.assertTrue(cancelled, "deferred was not cancelled by timeout") self.failureResultOf(timing_out_d, defer.TimeoutError) @@ -183,11 +186,15 @@ def canceller(_d: Deferred) -> None: raise Exception("can't cancel this deferred") non_completing_d: Deferred = Deferred(canceller) - timing_out_d = timeout_deferred(non_completing_d, 1.0, self.clock) + timing_out_d = timeout_deferred( + deferred=non_completing_d, + timeout=1.0, + clock=self.clock, + ) self.assertNoResult(timing_out_d) - self.clock.pump((1.0,)) + self.reactor.pump((1.0,)) self.failureResultOf(timing_out_d, defer.TimeoutError) @@ -227,7 +234,7 @@ def mark_was_cancelled(res: Failure) -> None: timing_out_d = timeout_deferred( deferred=incomplete_d, timeout=1.0, - reactor=self.clock, + clock=self.clock, ) self.assertNoResult(timing_out_d) # We should still be in the logcontext we started in @@ -243,7 +250,7 @@ def mark_was_cancelled(res: Failure) -> None: # we're pumping the reactor in the block and return us back to our current # logcontext after the block. with PreserveLoggingContext(): - self.clock.pump( + self.reactor.pump( # We only need to pump `1.0` (seconds) as we set # `timeout_deferred(timeout=1.0)` above (1.0,) @@ -264,7 +271,7 @@ def mark_was_cancelled(res: Failure) -> None: self.assertEqual(current_context(), SENTINEL_CONTEXT) -class _TestException(Exception): +class _TestException(Exception): # pass @@ -560,8 +567,8 @@ class AwakenableSleeperTests(TestCase): "Tests AwakenableSleeper" def test_sleep(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -575,8 +582,8 @@ def test_sleep(self) -> None: self.assertTrue(d.called) def test_explicit_wake(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -592,8 +599,8 @@ def test_explicit_wake(self) -> None: reactor.advance(0.6) def test_multiple_sleepers_timeout(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d1 = defer.ensureDeferred(sleeper.sleep("name", 1000)) @@ -612,8 +619,8 @@ def test_multiple_sleepers_timeout(self) -> None: self.assertTrue(d2.called) def test_multiple_sleepers_wake(self) -> None: - reactor, _ = get_clock() - sleeper = AwakenableSleeper(reactor) + reactor, clock = get_clock() + sleeper = AwakenableSleeper(clock) d1 = defer.ensureDeferred(sleeper.sleep("name", 1000)) diff --git a/tests/util/test_batching_queue.py b/tests/util/test_batching_queue.py index 532582cf877..60bfdf38aaa 100644 --- a/tests/util/test_batching_queue.py +++ b/tests/util/test_batching_queue.py @@ -32,13 +32,12 @@ number_queued, ) -from tests.server import get_clock -from tests.unittest import TestCase +from tests.unittest import HomeserverTestCase -class BatchingQueueTestCase(TestCase): +class BatchingQueueTestCase(HomeserverTestCase): def setUp(self) -> None: - self.clock, hs_clock = get_clock() + super().setUp() # We ensure that we remove any existing metrics for "test_queue". try: @@ -51,8 +50,8 @@ def setUp(self) -> None: self._pending_calls: List[Tuple[List[str], defer.Deferred]] = [] self.queue: BatchingQueue[str, str] = BatchingQueue( name="test_queue", - server_name="test_server", - clock=hs_clock, + hs=self.hs, + clock=self.clock, process_batch_callback=self._process_queue, ) @@ -108,7 +107,7 @@ def test_simple(self) -> None: self.assertFalse(queue_d.called) # We should see a call to `_process_queue` after a reactor tick. - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo"]) @@ -134,7 +133,7 @@ def test_batching(self) -> None: self._assert_metrics(queued=2, keys=1, in_flight=2) - self.clock.pump([0]) + self.reactor.pump([0]) # We should see only *one* call to `_process_queue` self.assertEqual(len(self._pending_calls), 1) @@ -158,7 +157,7 @@ def test_queuing(self) -> None: self.assertFalse(self._pending_calls) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1")) - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) @@ -185,7 +184,7 @@ def test_queuing(self) -> None: self._assert_metrics(queued=2, keys=1, in_flight=2) # We should now see a second call to `_process_queue` - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo2", "foo3"]) self.assertFalse(queue_d2.called) @@ -206,9 +205,9 @@ def test_different_keys(self) -> None: self.assertFalse(self._pending_calls) queue_d1 = defer.ensureDeferred(self.queue.add_to_queue("foo1", key=1)) - self.clock.pump([0]) + self.reactor.pump([0]) queue_d2 = defer.ensureDeferred(self.queue.add_to_queue("foo2", key=2)) - self.clock.pump([0]) + self.reactor.pump([0]) # We queue up another item with key=2 to check that we will keep taking # things off the queue. @@ -240,7 +239,7 @@ def test_different_keys(self) -> None: self.assertFalse(queue_d3.called) # We should now see a call `_pending_calls` for `foo3` - self.clock.pump([0]) + self.reactor.pump([0]) self.assertEqual(len(self._pending_calls), 1) self.assertEqual(self._pending_calls[0][0], ["foo3"]) self.assertFalse(queue_d3.called) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 246e18fd155..16e096a4b25 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -23,12 +23,14 @@ from synapse.util.caches.dictionary_cache import DictionaryCache from tests import unittest +from tests.server import get_clock class DictCacheTestCase(unittest.TestCase): def setUp(self) -> None: + _, clock = get_clock() self.cache: DictionaryCache[str, str, str] = DictionaryCache( - name="foobar", server_name="test_server", max_entries=10 + name="foobar", clock=clock, server_name="test_server", max_entries=10 ) def test_simple_cache_hit_full(self) -> None: diff --git a/tests/util/test_expiring_cache.py b/tests/util/test_expiring_cache.py index eda2d586f63..35c0f02e3fb 100644 --- a/tests/util/test_expiring_cache.py +++ b/tests/util/test_expiring_cache.py @@ -34,6 +34,7 @@ def test_get_set(self) -> None: cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=1, ) @@ -47,6 +48,7 @@ def test_eviction(self) -> None: cache: ExpiringCache[str, str] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=2, ) @@ -66,6 +68,7 @@ def test_iterable_eviction(self) -> None: cache: ExpiringCache[str, List[int]] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, max_len=5, iterable=True, @@ -90,6 +93,7 @@ def test_time_eviction(self) -> None: cache: ExpiringCache[str, int] = ExpiringCache( cache_name="test", server_name="testserver", + hs=self.hs, clock=clock, expiry_ms=1000, ) diff --git a/tests/util/test_logcontext.py b/tests/util/test_logcontext.py index 966ea31f1a8..ca805bb20a0 100644 --- a/tests/util/test_logcontext.py +++ b/tests/util/test_logcontext.py @@ -66,7 +66,8 @@ async def test_sleep(self) -> None: """ Test `Clock.sleep` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -90,7 +91,7 @@ async def competing_callback() -> None: # so that the test can complete and we see the underlying error. callback_finished = True - reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) + reactor.callLater(0, lambda: defer.ensureDeferred(competing_callback())) # type: ignore[call-later-not-tracked] with LoggingContext(name="foo", server_name="test_server"): await clock.sleep(0) @@ -111,7 +112,8 @@ async def test_looping_call(self) -> None: """ Test `Clock.looping_call` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -161,7 +163,8 @@ async def test_looping_call_now(self) -> None: """ Test `Clock.looping_call_now` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -209,7 +212,8 @@ async def test_call_later(self) -> None: """ Test `Clock.call_later` """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -261,7 +265,8 @@ async def test_deferred_callback_await_in_current_logcontext(self) -> None: `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -318,7 +323,8 @@ async def test_deferred_callback_preserve_logging_context(self) -> None: `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -379,7 +385,8 @@ async def test_deferred_callback_fire_and_forget_with_current_context(self) -> N `d.callback(None)` without anything else. See the *Deferred callbacks* section of docs/log_contexts.md for more details. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -450,7 +457,8 @@ async def competing_callback() -> None: self._check_test_key("sentinel") async def _test_run_in_background(self, function: Callable[[], object]) -> None: - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -492,7 +500,8 @@ def callback(result: object) -> object: @logcontext_clean async def test_run_in_background_with_blocking_fn(self) -> None: async def blocking_function() -> None: - await Clock(reactor, server_name="test_server").sleep(0) + # Ignore linter error since we are creating a `Clock` for testing purposes. + await Clock(reactor, server_name="test_server").sleep(0) # type: ignore[multiple-internal-clocks] await self._test_run_in_background(blocking_function) @@ -525,7 +534,8 @@ async def test_run_in_background_with_coroutine(self) -> None: async def testfunc() -> None: self._check_test_key("foo") - d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0)) + # Ignore linter error since we are creating a `Clock` for testing purposes. + d = defer.ensureDeferred(Clock(reactor, server_name="test_server").sleep(0)) # type: ignore[multiple-internal-clocks] self.assertIs(current_context(), SENTINEL_CONTEXT) await d self._check_test_key("foo") @@ -554,7 +564,8 @@ async def test_run_coroutine_in_background(self) -> None: This will stress the logic around incomplete deferreds in `run_coroutine_in_background`. """ - clock = Clock(reactor, server_name="test_server") + # Ignore linter error since we are creating a `Clock` for testing purposes. + clock = Clock(reactor, server_name="test_server") # type: ignore[multiple-internal-clocks] # Sanity check that we start in the sentinel context self._check_test_key("sentinel") @@ -645,7 +656,7 @@ def test_make_deferred_yieldable( # the synapse rules. def blocking_function() -> defer.Deferred: d: defer.Deferred = defer.Deferred() - reactor.callLater(0, d.callback, None) + reactor.callLater(0, d.callback, None) # type: ignore[call-later-not-tracked] return d sentinel_context = current_context() @@ -692,7 +703,7 @@ def _chained_deferred_function() -> defer.Deferred: def cb(res: object) -> defer.Deferred: d2: defer.Deferred = defer.Deferred() - reactor.callLater(0, d2.callback, res) + reactor.callLater(0, d2.callback, res) # type: ignore[call-later-not-tracked] return d2 d.addCallback(cb) diff --git a/tests/util/test_lrucache.py b/tests/util/test_lrucache.py index 4d37ad0975a..56e9996b005 100644 --- a/tests/util/test_lrucache.py +++ b/tests/util/test_lrucache.py @@ -29,18 +29,28 @@ from synapse.util.caches.treecache import TreeCache from tests import unittest +from tests.server import get_clock from tests.unittest import override_config class LruCacheTestCase(unittest.HomeserverTestCase): + def setUp(self) -> None: + super().setUp() + + _, self.clock = get_clock() + def test_get_set(self) -> None: - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = "value" self.assertEqual(cache.get("key"), "value") self.assertEqual(cache["key"], "value") def test_eviction(self) -> None: - cache: LruCache[int, int] = LruCache(max_size=2, server_name="test_server") + cache: LruCache[int, int] = LruCache( + max_size=2, clock=self.clock, server_name="test_server" + ) cache[1] = 1 cache[2] = 2 @@ -54,7 +64,9 @@ def test_eviction(self) -> None: self.assertEqual(cache.get(3), 3) def test_setdefault(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) self.assertEqual(cache.setdefault("key", 1), 1) self.assertEqual(cache.get("key"), 1) self.assertEqual(cache.setdefault("key", 2), 1) @@ -63,7 +75,9 @@ def test_setdefault(self) -> None: self.assertEqual(cache.get("key"), 2) def test_pop(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = 1 self.assertEqual(cache.pop("key"), 1) self.assertEqual(cache.pop("key"), None) @@ -71,7 +85,10 @@ def test_pop(self) -> None: def test_del_multi(self) -> None: # The type here isn't quite correct as they don't handle TreeCache well. cache: LruCache[Tuple[str, str], str] = LruCache( - max_size=4, cache_type=TreeCache, server_name="test_server" + max_size=4, + clock=self.clock, + cache_type=TreeCache, + server_name="test_server", ) cache[("animal", "cat")] = "mew" cache[("animal", "dog")] = "woof" @@ -91,7 +108,9 @@ def test_del_multi(self) -> None: # Man from del_multi say "Yes". def test_clear(self) -> None: - cache: LruCache[str, int] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, int] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache["key"] = 1 cache.clear() self.assertEqual(len(cache), 0) @@ -99,7 +118,10 @@ def test_clear(self) -> None: @override_config({"caches": {"per_cache_factors": {"mycache": 10}}}) def test_special_size(self) -> None: cache: LruCache = LruCache( - max_size=10, server_name="test_server", cache_name="mycache" + max_size=10, + clock=self.clock, + server_name="test_server", + cache_name="mycache", ) self.assertEqual(cache.max_size, 100) @@ -107,7 +129,9 @@ def test_special_size(self) -> None: class LruCacheCallbacksTestCase(unittest.HomeserverTestCase): def test_get(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value") self.assertFalse(m.called) @@ -126,7 +150,9 @@ def test_get(self) -> None: def test_multi_get(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value") self.assertFalse(m.called) @@ -145,7 +171,9 @@ def test_multi_get(self) -> None: def test_set(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -161,7 +189,9 @@ def test_set(self) -> None: def test_pop(self) -> None: m = Mock() - cache: LruCache[str, str] = LruCache(max_size=1, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=1, clock=self.clock, server_name="test_server" + ) cache.set("key", "value", callbacks=[m]) self.assertFalse(m.called) @@ -182,7 +212,10 @@ def test_del_multi(self) -> None: m4 = Mock() # The type here isn't quite correct as they don't handle TreeCache well. cache: LruCache[Tuple[str, str], str] = LruCache( - max_size=4, cache_type=TreeCache, server_name="test_server" + max_size=4, + clock=self.clock, + cache_type=TreeCache, + server_name="test_server", ) cache.set(("a", "1"), "value", callbacks=[m1]) @@ -205,7 +238,9 @@ def test_del_multi(self) -> None: def test_clear(self) -> None: m1 = Mock() m2 = Mock() - cache: LruCache[str, str] = LruCache(max_size=5, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=5, clock=self.clock, server_name="test_server" + ) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -222,7 +257,9 @@ def test_eviction(self) -> None: m1 = Mock(name="m1") m2 = Mock(name="m2") m3 = Mock(name="m3") - cache: LruCache[str, str] = LruCache(max_size=2, server_name="test_server") + cache: LruCache[str, str] = LruCache( + max_size=2, clock=self.clock, server_name="test_server" + ) cache.set("key1", "value", callbacks=[m1]) cache.set("key2", "value", callbacks=[m2]) @@ -259,7 +296,7 @@ def test_eviction(self) -> None: class LruCacheSizedTestCase(unittest.HomeserverTestCase): def test_evict(self) -> None: cache: LruCache[str, List[int]] = LruCache( - max_size=5, size_callback=len, server_name="test_server" + max_size=5, clock=self.clock, size_callback=len, server_name="test_server" ) cache["key1"] = [0] cache["key2"] = [1, 2] @@ -284,7 +321,10 @@ def test_evict(self) -> None: def test_zero_size_drop_from_cache(self) -> None: """Test that `drop_from_cache` works correctly with 0-sized entries.""" cache: LruCache[str, List[int]] = LruCache( - max_size=5, size_callback=lambda x: 0, server_name="test_server" + max_size=5, + clock=self.clock, + size_callback=lambda x: 0, + server_name="test_server", ) cache["key1"] = [] @@ -402,7 +442,10 @@ def test_evict_memory(self, jemalloc_interface: Mock) -> None: class ExtraIndexLruCacheTestCase(unittest.HomeserverTestCase): def test_invalidate_simple(self) -> None: cache: LruCache[str, int] = LruCache( - max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v) + max_size=10, + clock=self.hs.get_clock(), + server_name="test_server", + extra_index_cb=lambda k, v: str(v), ) cache["key1"] = 1 cache["key2"] = 2 @@ -417,7 +460,10 @@ def test_invalidate_simple(self) -> None: def test_invalidate_multi(self) -> None: cache: LruCache[str, int] = LruCache( - max_size=10, server_name="test_server", extra_index_cb=lambda k, v: str(v) + max_size=10, + clock=self.hs.get_clock(), + server_name="test_server", + extra_index_cb=lambda k, v: str(v), ) cache["key1"] = 1 cache["key2"] = 1 diff --git a/tests/util/test_retryutils.py b/tests/util/test_retryutils.py index 82baff58837..593be93ea3a 100644 --- a/tests/util/test_retryutils.py +++ b/tests/util/test_retryutils.py @@ -35,6 +35,7 @@ def test_new_destination(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -57,6 +58,7 @@ def test_limiter(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -89,6 +91,7 @@ def test_limiter(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ), @@ -104,6 +107,7 @@ def test_limiter(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -139,6 +143,7 @@ def test_limiter(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -165,6 +170,7 @@ def test_notifier_replication(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, notifier=notifier, @@ -238,6 +244,7 @@ def test_max_retry_interval(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -261,6 +268,7 @@ def test_max_retry_interval(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ), @@ -273,6 +281,7 @@ def test_max_retry_interval(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ) @@ -297,6 +306,7 @@ def test_max_retry_interval(self) -> None: get_retry_limiter( destination="test_dest", our_server_name=self.hs.hostname, + hs=self.hs, clock=self.clock, store=store, ),