diff --git a/apps/grpo/main.py b/apps/grpo/main.py index ed33e7e2..99448312 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -493,22 +493,6 @@ async def continuous_training(): training_task.cancel() - # give mlogger time to shutdown backends, otherwise they can stay running. - # TODO (felipemello) find more elegant solution - await mlogger.shutdown.call_one() - await asyncio.sleep(2) - - await asyncio.gather( - DatasetActor.shutdown(dataloader), - policy.shutdown(), - RLTrainer.shutdown(trainer), - ReplayBuffer.shutdown(replay_buffer), - ComputeAdvantages.shutdown(compute_advantages), - ref_model.shutdown(), - reward_actor.shutdown(), - ) - # TODO - add a global shutdown that implicitly shuts down all services - # and remote allocations await shutdown() diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 4a5cbf17..c54fba6e 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -8,11 +8,19 @@ import math import sys -from typing import Any, Type, TypeVar +from typing import Any, Type, TYPE_CHECKING, TypeVar from monarch.actor import Actor, current_rank, current_size, endpoint -from forge.controller.provisioner import get_proc_mesh, stop_proc_mesh +if TYPE_CHECKING: + from monarch._src.actor.actor_mesh import ActorMesh + +from forge.controller.provisioner import ( + get_proc_mesh, + register_actor, + register_service, + stop_proc_mesh, +) from forge.types import ProcessConfig, ServiceConfig @@ -122,7 +130,7 @@ def options( .. code-block:: python actor = await MyForgeActor.as_actor(...) - await actor.shutdown() + await MyForgeActor.shutdown(actor) """ attrs = { @@ -164,7 +172,10 @@ async def as_service( logger.info("Spawning Service for %s", cls.__name__) service = Service(cfg, cls, actor_args, actor_kwargs) await service.__initialize__() - return ServiceInterface(service, cls) + service_interface = ServiceInterface(service, cls) + # Register this service with the provisioner so it can cleanly shut this down + await register_service(service_interface) + return service_interface @endpoint async def setup(self): @@ -182,7 +193,7 @@ async def setup(self): pass @classmethod - async def launch(cls, *args, **kwargs) -> "ForgeActor": + async def launch(cls, *args, **kwargs) -> "ActorMesh": """Provisions and deploys a new actor. This method is used by `Service` to provision a new replica. @@ -222,6 +233,8 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: """ logger.info("Spawning single actor %s", cls.__name__) actor = await cls.launch(*args, **actor_kwargs) + # Register this actor with the provisioner so it can cleanly shut this down + await register_actor(actor) return actor @classmethod diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 8f5a77f4..c23d5fdd 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,6 +13,7 @@ import socket import uuid +from monarch._src.actor.actor_mesh import ActorMesh from monarch._src.actor.shape import Extent from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host @@ -133,6 +134,9 @@ def __init__(self, cfg: ProvisionerConfig | None = None): if not self.launcher: logger.warning("Launcher not provided, remote allocations will not work.") + self._registered_actors: list["ForgeActor"] = [] + self._registered_services: list["ServiceInterface"] = [] + async def initialize(self): """Call this after creating the instance""" if self.launcher is not None: @@ -338,8 +342,55 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): commands.kill(server_name) del self._proc_host_map[proc_mesh] + def register_service(self, service: "ServiceInterface") -> None: + """Registers a service allocation for cleanup.""" + # Import ServiceInterface here instead of at top-level to avoid circular import + from forge.controller.service import ServiceInterface + + if not isinstance(service, ServiceInterface): + raise TypeError( + f"register_service expected ServiceInterface, got {type(service)}" + ) + + self._registered_services.append(service) + + def register_actor(self, actor: "ForgeActor") -> None: + """Registers a single actor allocation for cleanup.""" + + if not isinstance(actor, ActorMesh): + raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}") + + self._registered_actors.append(actor) + + async def shutdown_all_allocations(self): + """Gracefully shut down all tracked actors and services.""" + logger.info( + f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." + ) + # --- ServiceInterface --- + for service in reversed(self._registered_services): + try: + await service.shutdown() + + except Exception as e: + logger.warning(f"Failed to shut down {service}: {e}") + + # --- Actor instance (ForgeActor or underlying ActorMesh) --- + for actor in reversed(self._registered_actors): + try: + # Get the class to call shutdown on (ForgeActor or its bound class) + actor_cls = getattr(actor, "_class", None) or actor.__class__ + await actor_cls.shutdown(actor) + + except Exception as e: + logger.warning(f"Failed to shut down {actor}: {e}") + + self._registered_actors.clear() + self._registered_services.clear() + async def shutdown(self): """Tears down all remaining remote allocations.""" + await self.shutdown_all_allocations() async with self._lock: for server_name in self._server_names: commands.kill(server_name) @@ -408,12 +459,43 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): return await provisioner.host_mesh_from_proc(proc_mesh) +async def register_service(service: "ServiceInterface") -> None: + """Registers a service allocation with the global provisioner.""" + provisioner = await _get_provisioner() + provisioner.register_service(service) + + +async def register_actor(actor: "ForgeActor") -> None: + """Registers an actor allocation with the global provisioner.""" + provisioner = await _get_provisioner() + provisioner.register_actor(actor) + + async def stop_proc_mesh(proc_mesh: ProcMesh): provisioner = await _get_provisioner() return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) +async def shutdown_metric_logger(): + """Shutdown the global metric logger and all its backends.""" + from forge.observability.metric_actors import get_or_create_metric_logger + + logger.info("Shutting down metric logger...") + try: + mlogger = await get_or_create_metric_logger() + await mlogger.shutdown.call_one() + except Exception as e: + logger.warning(f"Failed to shutdown metric logger: {e}") + + async def shutdown(): + + await shutdown_metric_logger() + logger.info("Shutting down provisioner..") + provisioner = await _get_provisioner() - return await provisioner.shutdown() + result = await provisioner.shutdown() + + logger.info("Shutdown completed successfully") + return result diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index 6d71ddf6..e5ee6fdd 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -232,9 +232,7 @@ async def continuous_training(): except KeyboardInterrupt: print("Training interrupted by user") finally: - print("Shutting down trainer...") - await RLTrainer.shutdown(trainer) - await mlogger.shutdown.call_one() + print("Shutting down...") await shutdown() print("Trainer shutdown complete.") diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index e862ac60..06938b0b 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -533,15 +533,6 @@ async def continuous_training(): training_task.cancel() finally: print("Shutting down...") - await asyncio.gather( - DatasetActor.shutdown(dataloader), - policy.shutdown(), - Trainer.shutdown(trainer), - ReplayBuffer.shutdown(replay_buffer), - reward_actor.shutdown(), - ) - # TODO - add a global shutdown that implicitly shuts down all services - # and remote allocations await shutdown() diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb70..57ccd97b 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -110,14 +110,6 @@ async def main(): await mlogger.flush.call_one(i) # shutdown - await mlogger.shutdown.call_one() - await asyncio.sleep(2) - - await asyncio.gather( - trainer.shutdown(), - generator.shutdown(), - ) - await shutdown() diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 54b09384..b79ff9b5 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -66,7 +66,6 @@ async def run(cfg: DictConfig): print("-" * 80) print("\nShutting down...") - await policy.shutdown() await shutdown()