From cc5fb5a5c29e61eb2067bc8cf7a9ec6850529ead Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 16:56:48 -0700 Subject: [PATCH 01/16] centralized shutdown --- src/forge/controller/actor.py | 22 +++++++++++++--- src/forge/controller/provisioner.py | 32 ++++++++++++++++++++++- src/forge/controller/service/interface.py | 1 + src/forge/controller/service/replica.py | 2 +- tests/sandbox/toy_rl/sumdigits.py | 2 ++ 5 files changed, 54 insertions(+), 5 deletions(-) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index a899da6f0..d55185fa6 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -12,7 +12,7 @@ from monarch.actor import Actor, current_rank, current_size, endpoint -from forge.controller.provisioner import get_proc_mesh, stop_proc_mesh +from forge.controller.provisioner import _get_provisioner, get_proc_mesh, stop_proc_mesh from forge.types import ProcessConfig, ServiceConfig @@ -127,7 +127,9 @@ async def as_service( logger.info("Spawning Service for %s", cls.__name__) service = Service(cfg, cls, actor_args, actor_kwargs) await service.__initialize__() - return ServiceInterface(service, cls) + service_interface = ServiceInterface(service, cls) + await cls.register_allocation(service_interface) + return service_interface @endpoint async def setup(self): @@ -144,6 +146,17 @@ async def setup(self): """ pass + @classmethod + async def register_allocation(cls, alloc: "ForgeActor | ServiceInterface") -> None: + """Registers an allocation (service/actor) with the provisioner.""" + provisioner = await _get_provisioner() + try: + provisioner = await _get_provisioner() + if provisioner is not None: + await provisioner.track_allocation(alloc) + except Exception as e: + logger.warning(f"Failed to register allocation {alloc}: {e}") + @classmethod async def launch(cls, *args, **kwargs) -> "ForgeActor": """Provisions and deploys a new actor. @@ -185,13 +198,16 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: """ logger.info("Spawning single actor %s", cls.__name__) actor = await cls.launch(*args, **actor_kwargs) + await cls.register_allocation(actor) return actor @classmethod - async def shutdown(cls, actor: "ForgeActor"): + async def shutdown(cls, actor: "ForgeActor", queit: bool = False): """Shuts down an actor. This method is used by `Service` to teardown a replica. """ + if not queit: + logger.info(f"Shutting down actor {getattr(actor, 'name', cls.__name__)}") if actor._proc_mesh is None: raise AssertionError("Called shutdown on a replica with no proc_mesh.") await stop_proc_mesh(actor._proc_mesh) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 5ca331f32..ddc145795 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -12,7 +12,7 @@ import os import socket import uuid -from typing import Optional +from typing import Any, Optional from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host @@ -132,6 +132,8 @@ def __init__(self, cfg: ProvisionerConfig | None = None): if not self.launcher: logger.warning("Launcher not provided, remote allocations will not work.") + self._allocations: list[Any] = [] # all live actor/service instances + async def initialize(self): """Call this after creating the instance""" if self.launcher is not None: @@ -303,8 +305,36 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): commands.kill(server_name) del self._proc_host_map[proc_mesh] + async def track_allocation(self, alloc: Any): + """Tracks an allocation for cleanup.""" + self._allocations.append(alloc) + + async def shutdown_all_allocations(self): + """Gracefully shut down all tracked actors and services.""" + from forge.controller.actor import ForgeActor + from forge.controller.service import ServiceInterface + + for alloc in self._allocations: + try: + # --- ServiceInterface --- + if isinstance(alloc, ServiceInterface): + await alloc.shutdown() + + # --- ForgeActor instance --- + elif isinstance(alloc, ForgeActor): + await alloc.__class__.shutdown(alloc) + + else: + logger.warning(f"Unknown allocation type: {type(alloc)}") + + except Exception as e: + logger.warning(f"Failed to shut down {alloc}: {e}") + + self._allocations.clear() + async def shutdown(self): """Tears down all remaining remote allocations.""" + await self.shutdown_all_allocations() async with self._lock: for server_name in self._server_names: commands.kill(server_name) diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index 5b7e2f884..6f26e444c 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -200,6 +200,7 @@ async def shutdown(self) -> None: """ Shut down the underlying Service. """ + logger.info(f"Shutting down service {self.actor_def.__name__}") await self._service.stop() def session(self) -> "SessionContext": diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 9ab8ec20a..28aa56d8a 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -405,7 +405,7 @@ async def stop(self): # Stop the actor if self.actor: try: - await self.actor_def.shutdown(self.actor) + await self.actor_def.shutdown(self.actor, queit=True) except Exception as e: logger.warning( "Error stopping proc_mesh for replica %d: %s", self.idx, e diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 14b5f6ebe..9517dcfcd 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -574,7 +574,9 @@ async def continuous_training(): Trainer.shutdown(trainer), ReplayBuffer.shutdown(replay_buffer), reward_actor.shutdown(), + ref_model.shutdown(), ) + # TODO - add a global shutdown that implicitly shuts down all services # and remote allocations await shutdown() From 9cc5531403829f65a7d5e839e8dc5647b0ee9f3c Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 18:16:40 -0700 Subject: [PATCH 02/16] debug. TODO: simplify --- src/forge/controller/provisioner.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index ddc145795..cff22d682 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -307,14 +307,22 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async def track_allocation(self, alloc: Any): """Tracks an allocation for cleanup.""" + from forge.controller.service import ServiceInterface + self._allocations.append(alloc) + alloc_type = "service" if isinstance(alloc, ServiceInterface) else "actor" + print( + f"Registered allocation {alloc_type} {alloc}, current allocations len: {len(self._allocations)}" + ) async def shutdown_all_allocations(self): """Gracefully shut down all tracked actors and services.""" + from monarch._src.actor.actor_mesh import ActorMesh + from forge.controller.actor import ForgeActor from forge.controller.service import ServiceInterface - for alloc in self._allocations: + for alloc in reversed(self._allocations): try: # --- ServiceInterface --- if isinstance(alloc, ServiceInterface): @@ -324,6 +332,21 @@ async def shutdown_all_allocations(self): elif isinstance(alloc, ForgeActor): await alloc.__class__.shutdown(alloc) + # --- ActorMesh (spawned actor group) --- + elif isinstance(alloc, ActorMesh): + actor_cls = getattr(alloc, "_class", None) + if actor_cls is not None and hasattr(actor_cls, "shutdown"): + await actor_cls.shutdown(alloc) + else: + # fallback if class not available + inner = getattr(alloc, "_inner", None) + if hasattr(inner, "shutdown"): + await inner.shutdown() + else: + logger.warning( + f"ActorMesh {alloc.__name__} has no shutdown()" + ) + else: logger.warning(f"Unknown allocation type: {type(alloc)}") From 84dcd17a3d2b1f98b6ce64020cfb92987eb2492b Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 18:17:31 -0700 Subject: [PATCH 03/16] debug --- src/forge/actors/policy.py | 2 +- src/forge/controller/actor.py | 4 ++-- src/forge/controller/service/replica.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 330bead57..d9c05da28 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -223,7 +223,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride] @classmethod async def shutdown( # pyright: ignore[reportIncompatibleMethodOverride] - cls: type["Policy"], actor: "Policy" + cls: type["Policy"], actor: "Policy", quiet: bool = False ): assert ( actor._policy_proc is not None diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index d55185fa6..3b7192051 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -202,11 +202,11 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: return actor @classmethod - async def shutdown(cls, actor: "ForgeActor", queit: bool = False): + async def shutdown(cls, actor: "ForgeActor", quiet: bool = False): """Shuts down an actor. This method is used by `Service` to teardown a replica. """ - if not queit: + if not quiet: logger.info(f"Shutting down actor {getattr(actor, 'name', cls.__name__)}") if actor._proc_mesh is None: raise AssertionError("Called shutdown on a replica with no proc_mesh.") diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 28aa56d8a..979caf93f 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -405,7 +405,7 @@ async def stop(self): # Stop the actor if self.actor: try: - await self.actor_def.shutdown(self.actor, queit=True) + await self.actor_def.shutdown(self.actor, quiet=True) except Exception as e: logger.warning( "Error stopping proc_mesh for replica %d: %s", self.idx, e From 8ba3e5fedf67a322e3b128a81ffb5c03a6e6a19b Mon Sep 17 00:00:00 2001 From: DNXie Date: Wed, 8 Oct 2025 18:17:58 -0700 Subject: [PATCH 04/16] update all main --- apps/grpo/main.py | 12 ------------ tests/sandbox/rl_trainer/main.py | 1 - tests/sandbox/toy_rl/sumdigits.py | 11 ----------- tests/sandbox/toy_rl/toy_metrics/main.py | 6 ------ tests/sandbox/vllm/main.py | 1 - 5 files changed, 31 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc2..0fe549cef 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -495,18 +495,6 @@ async def continuous_training(): # TODO (felipemello) find more elegant solution await mlogger.shutdown.call_one() await asyncio.sleep(2) - - await asyncio.gather( - DatasetActor.shutdown(dataloader), - policy.shutdown(), - RLTrainer.shutdown(trainer), - ReplayBuffer.shutdown(replay_buffer), - ComputeAdvantages.shutdown(compute_advantages), - ref_model.shutdown(), - reward_actor.shutdown(), - ) - # TODO - add a global shutdown that implicitly shuts down all services - # and remote allocations await shutdown() diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index 1441bb9e3..fa089a6d7 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -222,7 +222,6 @@ async def continuous_training(): print("Training interrupted by user") finally: print("Shutting down trainer...") - await RLTrainer.shutdown(trainer) await mlogger.shutdown.call_one() await shutdown() print("Trainer shutdown complete.") diff --git a/tests/sandbox/toy_rl/sumdigits.py b/tests/sandbox/toy_rl/sumdigits.py index 9517dcfcd..e6684a0bb 100644 --- a/tests/sandbox/toy_rl/sumdigits.py +++ b/tests/sandbox/toy_rl/sumdigits.py @@ -568,17 +568,6 @@ async def continuous_training(): training_task.cancel() finally: print("Shutting down...") - await asyncio.gather( - DatasetActor.shutdown(dataloader), - policy.shutdown(), - Trainer.shutdown(trainer), - ReplayBuffer.shutdown(replay_buffer), - reward_actor.shutdown(), - ref_model.shutdown(), - ) - - # TODO - add a global shutdown that implicitly shuts down all services - # and remote allocations await shutdown() diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index d999fb700..4cb7f0339 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -112,12 +112,6 @@ async def main(): # shutdown await mlogger.shutdown.call_one() await asyncio.sleep(2) - - await asyncio.gather( - trainer.shutdown(), - generator.shutdown(), - ) - await shutdown() diff --git a/tests/sandbox/vllm/main.py b/tests/sandbox/vllm/main.py index 0f3ce662c..44ca6f1c5 100644 --- a/tests/sandbox/vllm/main.py +++ b/tests/sandbox/vllm/main.py @@ -67,7 +67,6 @@ async def run(cfg: DictConfig): print("-" * 80) print("\nShutting down...") - await policy.shutdown() await shutdown() From 5b8d81ca30f3a6e0db23bad23929ba8fef211512 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 9 Oct 2025 11:30:52 -0700 Subject: [PATCH 05/16] simplify shutting down --- apps/grpo/main.py | 1 - src/forge/controller/provisioner.py | 23 +++++------------------ 2 files changed, 5 insertions(+), 19 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 0fe549cef..a6d88e0f2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -490,7 +490,6 @@ async def continuous_training(): training_task.cancel() finally: print("Shutting down...") - # give mlogger time to shutdown backends, otherwise they can stay running. # TODO (felipemello) find more elegant solution await mlogger.shutdown.call_one() diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cff22d682..68d298627 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -328,24 +328,11 @@ async def shutdown_all_allocations(self): if isinstance(alloc, ServiceInterface): await alloc.shutdown() - # --- ForgeActor instance --- - elif isinstance(alloc, ForgeActor): - await alloc.__class__.shutdown(alloc) - - # --- ActorMesh (spawned actor group) --- - elif isinstance(alloc, ActorMesh): - actor_cls = getattr(alloc, "_class", None) - if actor_cls is not None and hasattr(actor_cls, "shutdown"): - await actor_cls.shutdown(alloc) - else: - # fallback if class not available - inner = getattr(alloc, "_inner", None) - if hasattr(inner, "shutdown"): - await inner.shutdown() - else: - logger.warning( - f"ActorMesh {alloc.__name__} has no shutdown()" - ) + # --- Actor instance (ForgeActor or underlying ActorMesh) --- + elif isinstance(alloc, (ForgeActor, ActorMesh)): + # Get the class to call shutdown on (ForgeActor or its bound class) + actor_cls = getattr(alloc, "_class", None) or alloc.__class__ + await actor_cls.shutdown(alloc) else: logger.warning(f"Unknown allocation type: {type(alloc)}") From 5fe99d571012fc61dacabace9f1b5835c3201395 Mon Sep 17 00:00:00 2001 From: DNXie Date: Thu, 9 Oct 2025 12:09:41 -0700 Subject: [PATCH 06/16] clean up code --- src/forge/controller/actor.py | 4 +--- src/forge/controller/provisioner.py | 6 ------ 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 3b7192051..2c7a88c27 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -149,11 +149,9 @@ async def setup(self): @classmethod async def register_allocation(cls, alloc: "ForgeActor | ServiceInterface") -> None: """Registers an allocation (service/actor) with the provisioner.""" - provisioner = await _get_provisioner() try: provisioner = await _get_provisioner() - if provisioner is not None: - await provisioner.track_allocation(alloc) + await provisioner.track_allocation(alloc) except Exception as e: logger.warning(f"Failed to register allocation {alloc}: {e}") diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 68d298627..be444eef6 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -307,13 +307,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): async def track_allocation(self, alloc: Any): """Tracks an allocation for cleanup.""" - from forge.controller.service import ServiceInterface - self._allocations.append(alloc) - alloc_type = "service" if isinstance(alloc, ServiceInterface) else "actor" - print( - f"Registered allocation {alloc_type} {alloc}, current allocations len: {len(self._allocations)}" - ) async def shutdown_all_allocations(self): """Gracefully shut down all tracked actors and services.""" From d705df9117e61d0a1737ac35154f5c058ac6c86e Mon Sep 17 00:00:00 2001 From: DNXie Date: Fri, 10 Oct 2025 16:44:46 -0700 Subject: [PATCH 07/16] fix lint and update docstring --- src/forge/controller/actor.py | 4 ++-- src/forge/controller/provisioner.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 2c7a88c27..d17bbcce6 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -81,11 +81,11 @@ def options( # Pre-configure a single actor actor = await MyForgeActor.options(procs=1, hosts=1).as_actor(...) - await actor.shutdown() + await MyForgeActor.shutdown(actor) # Default usage without calling options actor = await MyForgeActor.as_actor(...) - await actor.shutdown() + await MyForgeActor.shutdown(actor) """ attrs = { diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index bbaaecf22..eb0c60b16 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,7 +13,7 @@ import socket import uuid -from typing import Any, Optional +from typing import Any from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host From 52fa867795c6bc1805dfed771b60ee73b95205e5 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 11:06:45 -0700 Subject: [PATCH 08/16] add function register_xx --- apps/grpo/main.py | 11 ++--- src/forge/controller/actor.py | 22 ++++----- src/forge/controller/provisioner.py | 70 +++++++++++++++++++---------- 3 files changed, 62 insertions(+), 41 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index a6d88e0f2..1923cb73a 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -476,13 +476,14 @@ async def continuous_training(): print( f"Starting GRPO with {num_rollout_threads} rollout threads, {num_training_threads} training threads" ) - rollout_tasks = [ - asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) - ] - training_task = asyncio.create_task(continuous_training()) + # rollout_tasks = [ + # asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + # ] + # training_task = asyncio.create_task(continuous_training()) try: - await asyncio.gather(*rollout_tasks, training_task) + # await asyncio.gather(*rollout_tasks, training_task) + pass except KeyboardInterrupt: print("Training interrupted by user") for rollout_task in rollout_tasks: diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index d17bbcce6..f1a8ae0e5 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -12,7 +12,12 @@ from monarch.actor import Actor, current_rank, current_size, endpoint -from forge.controller.provisioner import _get_provisioner, get_proc_mesh, stop_proc_mesh +from forge.controller.provisioner import ( + get_proc_mesh, + register_actor, + register_service, + stop_proc_mesh, +) from forge.types import ProcessConfig, ServiceConfig @@ -128,7 +133,7 @@ async def as_service( service = Service(cfg, cls, actor_args, actor_kwargs) await service.__initialize__() service_interface = ServiceInterface(service, cls) - await cls.register_allocation(service_interface) + await register_service(service_interface) return service_interface @endpoint @@ -147,16 +152,7 @@ async def setup(self): pass @classmethod - async def register_allocation(cls, alloc: "ForgeActor | ServiceInterface") -> None: - """Registers an allocation (service/actor) with the provisioner.""" - try: - provisioner = await _get_provisioner() - await provisioner.track_allocation(alloc) - except Exception as e: - logger.warning(f"Failed to register allocation {alloc}: {e}") - - @classmethod - async def launch(cls, *args, **kwargs) -> "ForgeActor": + async def launch(cls, *args, **kwargs) -> "ActorMesh": """Provisions and deploys a new actor. This method is used by `Service` to provision a new replica. @@ -196,7 +192,7 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: """ logger.info("Spawning single actor %s", cls.__name__) actor = await cls.launch(*args, **actor_kwargs) - await cls.register_allocation(actor) + await register_actor(actor) return actor @classmethod diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index eb0c60b16..e0cfd7845 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,8 +13,6 @@ import socket import uuid -from typing import Any - from monarch._src.actor.shape import NDSlice, Shape from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host from monarch.tools import commands @@ -133,7 +131,8 @@ def __init__(self, cfg: ProvisionerConfig | None = None): if not self.launcher: logger.warning("Launcher not provided, remote allocations will not work.") - self._allocations: list[Any] = [] # all live actor/service instances + self._registered_actors: list["ForgeActor"] = [] + self._registered_services: list["ServiceInterface"] = [] async def initialize(self): """Call this after creating the instance""" @@ -306,36 +305,49 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): commands.kill(server_name) del self._proc_host_map[proc_mesh] - async def track_allocation(self, alloc: Any): - """Tracks an allocation for cleanup.""" - self._allocations.append(alloc) + def register_service(self, service: "ServiceInterface") -> None: + """Registers a service allocation for cleanup.""" + from forge.controller.service import ServiceInterface - async def shutdown_all_allocations(self): - """Gracefully shut down all tracked actors and services.""" + if not isinstance(service, ServiceInterface): + raise TypeError( + f"register_service expected ServiceInterface, got {type(service)}" + ) + + self._registered_services.append(service) + + def register_actor(self, actor: "ForgeActor") -> None: + """Registers a single actor allocation for cleanup.""" + # from forge.controller.actor import ForgeActor from monarch._src.actor.actor_mesh import ActorMesh - from forge.controller.actor import ForgeActor - from forge.controller.service import ServiceInterface + if not isinstance(actor, ActorMesh): + raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}") - for alloc in reversed(self._allocations): + self._registered_actors.append(actor) + + async def shutdown_all_allocations(self): + """Gracefully shut down all tracked actors and services.""" + # --- ServiceInterface --- + for service in reversed(self._registered_services): try: - # --- ServiceInterface --- - if isinstance(alloc, ServiceInterface): - await alloc.shutdown() + await service.shutdown() - # --- Actor instance (ForgeActor or underlying ActorMesh) --- - elif isinstance(alloc, (ForgeActor, ActorMesh)): - # Get the class to call shutdown on (ForgeActor or its bound class) - actor_cls = getattr(alloc, "_class", None) or alloc.__class__ - await actor_cls.shutdown(alloc) + except Exception as e: + logger.warning(f"Failed to shut down {service}: {e}") - else: - logger.warning(f"Unknown allocation type: {type(alloc)}") + # --- Actor instance (ForgeActor or underlying ActorMesh) --- + for actor in reversed(self._registered_actors): + try: + # Get the class to call shutdown on (ForgeActor or its bound class) + actor_cls = getattr(actor, "_class", None) or actor.__class__ + await actor_cls.shutdown(actor) except Exception as e: - logger.warning(f"Failed to shut down {alloc}: {e}") + logger.warning(f"Failed to shut down {actor}: {e}") - self._allocations.clear() + self._registered_actors.clear() + self._registered_services.clear() async def shutdown(self): """Tears down all remaining remote allocations.""" @@ -408,6 +420,18 @@ async def host_mesh_from_proc(proc_mesh: ProcMesh): return await provisioner.host_mesh_from_proc(proc_mesh) +async def register_service(service: "ServiceInterface") -> None: + """Registers a service allocation with the global provisioner.""" + provisioner = await _get_provisioner() + provisioner.register_service(service) + + +async def register_actor(actor: "ForgeActor") -> None: + """Registers an actor allocation with the global provisioner.""" + provisioner = await _get_provisioner() + provisioner.register_actor(actor) + + async def stop_proc_mesh(proc_mesh: ProcMesh): provisioner = await _get_provisioner() return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) From f0ba99a7243c7567fcb1cc549603a4d312674288 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 11:07:11 -0700 Subject: [PATCH 09/16] rollback changes --- apps/grpo/main.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1923cb73a..a6d88e0f2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -476,14 +476,13 @@ async def continuous_training(): print( f"Starting GRPO with {num_rollout_threads} rollout threads, {num_training_threads} training threads" ) - # rollout_tasks = [ - # asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) - # ] - # training_task = asyncio.create_task(continuous_training()) + rollout_tasks = [ + asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads) + ] + training_task = asyncio.create_task(continuous_training()) try: - # await asyncio.gather(*rollout_tasks, training_task) - pass + await asyncio.gather(*rollout_tasks, training_task) except KeyboardInterrupt: print("Training interrupted by user") for rollout_task in rollout_tasks: From 4e477b59f7d36aabc76d8f25b0ec24a4394e7331 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 11:10:20 -0700 Subject: [PATCH 10/16] cleanup --- src/forge/controller/provisioner.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index e0cfd7845..1a151364b 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -318,7 +318,6 @@ def register_service(self, service: "ServiceInterface") -> None: def register_actor(self, actor: "ForgeActor") -> None: """Registers a single actor allocation for cleanup.""" - # from forge.controller.actor import ForgeActor from monarch._src.actor.actor_mesh import ActorMesh if not isinstance(actor, ActorMesh): From e7ae73d8129da98690db0714ce96bdd8666b832f Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 11:40:49 -0700 Subject: [PATCH 11/16] resolve comment --- src/forge/controller/actor.py | 6 +++--- src/forge/controller/provisioner.py | 15 ++++++++++----- src/forge/controller/service/interface.py | 1 - src/forge/controller/service/replica.py | 2 +- 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index f1a8ae0e5..1512fb087 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -133,6 +133,7 @@ async def as_service( service = Service(cfg, cls, actor_args, actor_kwargs) await service.__initialize__() service_interface = ServiceInterface(service, cls) + # Register this service with the provisioner so it can cleanly shut this down await register_service(service_interface) return service_interface @@ -192,16 +193,15 @@ async def as_actor(cls: Type[T], *args, **actor_kwargs) -> T: """ logger.info("Spawning single actor %s", cls.__name__) actor = await cls.launch(*args, **actor_kwargs) + # Register this actor with the provisioner so it can cleanly shut this down await register_actor(actor) return actor @classmethod - async def shutdown(cls, actor: "ForgeActor", quiet: bool = False): + async def shutdown(cls, actor: "ForgeActor"): """Shuts down an actor. This method is used by `Service` to teardown a replica. """ - if not quiet: - logger.info(f"Shutting down actor {getattr(actor, 'name', cls.__name__)}") if actor._proc_mesh is None: raise AssertionError("Called shutdown on a replica with no proc_mesh.") await stop_proc_mesh(actor._proc_mesh) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 1a151364b..d4acea084 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,16 +13,18 @@ import socket import uuid -from monarch._src.actor.shape import NDSlice, Shape -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host -from monarch.tools import commands - from forge.controller.launcher import BaseLauncher, get_launcher from forge.env_constants import FORGE_DISABLE_METRICS from forge.types import ProcessConfig, ProvisionerConfig +from monarch._src.actor.actor_mesh import ActorMesh + +from monarch._src.actor.shape import NDSlice, Shape +from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host +from monarch.tools import commands + logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -307,6 +309,7 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh): def register_service(self, service: "ServiceInterface") -> None: """Registers a service allocation for cleanup.""" + # Import ServiceInterface here instead of at top-level to avoid circular import from forge.controller.service import ServiceInterface if not isinstance(service, ServiceInterface): @@ -318,7 +321,6 @@ def register_service(self, service: "ServiceInterface") -> None: def register_actor(self, actor: "ForgeActor") -> None: """Registers a single actor allocation for cleanup.""" - from monarch._src.actor.actor_mesh import ActorMesh if not isinstance(actor, ActorMesh): raise TypeError(f"register_actor expected ActorMesh, got {type(actor)}") @@ -327,6 +329,9 @@ def register_actor(self, actor: "ForgeActor") -> None: async def shutdown_all_allocations(self): """Gracefully shut down all tracked actors and services.""" + logger.info( + f"Shutting down {len(self._registered_services)} service(s) and {len(self._registered_actors)} actor(s)..." + ) # --- ServiceInterface --- for service in reversed(self._registered_services): try: diff --git a/src/forge/controller/service/interface.py b/src/forge/controller/service/interface.py index e7413f724..c64d5c3f3 100644 --- a/src/forge/controller/service/interface.py +++ b/src/forge/controller/service/interface.py @@ -200,7 +200,6 @@ async def shutdown(self) -> None: """ Shut down the underlying Service. """ - logger.info(f"Shutting down service {self.actor_def.__name__}") await self._service.stop() def session(self) -> "SessionContext": diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index 4497762b7..d7c8c919c 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -404,7 +404,7 @@ async def stop(self): # Stop the actor if self.actor: try: - await self.actor_def.shutdown(self.actor, quiet=True) + await self.actor_def.shutdown() except Exception as e: logger.warning( "Error stopping proc_mesh for replica %d: %s", self.idx, e From 77c2344ea9bdb5cc568fd47654f6d7dbf52030e4 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 12:05:06 -0700 Subject: [PATCH 12/16] fix bug --- src/forge/controller/service/replica.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/forge/controller/service/replica.py b/src/forge/controller/service/replica.py index d7c8c919c..fa9a7791a 100644 --- a/src/forge/controller/service/replica.py +++ b/src/forge/controller/service/replica.py @@ -404,7 +404,7 @@ async def stop(self): # Stop the actor if self.actor: try: - await self.actor_def.shutdown() + await self.actor_def.shutdown(self.actor) except Exception as e: logger.warning( "Error stopping proc_mesh for replica %d: %s", self.idx, e From d4f3d57fb50f897d7538f9ba754c9c7cce5f54e0 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 12:22:23 -0700 Subject: [PATCH 13/16] fix lint --- src/forge/controller/provisioner.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index ed43dfaa4..cf9114cfe 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -13,8 +13,10 @@ import socket import uuid +from monarch._src.actor.actor_mesh import ActorMesh + from monarch._src.actor.shape import Extent, NDSlice, Shape -from monarch.actor import Actor, endpoint, ProcMesh +from monarch.actor import Actor, endpoint from monarch.tools import commands @@ -24,12 +26,6 @@ from forge.types import ProcessConfig, ProvisionerConfig -from monarch._src.actor.actor_mesh import ActorMesh - -from monarch._src.actor.shape import NDSlice, Shape -from monarch.actor import Actor, endpoint, HostMesh, ProcMesh, this_host -from monarch.tools import commands - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) From f31aa3a84316aad15f801c29b9105db614d6e496 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 14:00:32 -0700 Subject: [PATCH 14/16] move mlogger.shutdown into global shutdown() --- apps/grpo/main.py | 4 ---- src/forge/controller/provisioner.py | 20 +++++++++++++++++++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/apps/grpo/main.py b/apps/grpo/main.py index 1ff49da77..33c5df6f2 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -536,10 +536,6 @@ async def continuous_training(): training_task.cancel() - # give mlogger time to shutdown backends, otherwise they can stay running. - # TODO (felipemello) find more elegant solution - await mlogger.shutdown.call_one() - await asyncio.sleep(2) await shutdown() diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index cf9114cfe..030bf0aa3 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -498,7 +498,25 @@ async def stop_proc_mesh(proc_mesh: ProcMesh): return await provisioner.stop_proc_mesh(proc_mesh=proc_mesh) +async def shutdown_metric_logger(): + """Shutdown the global metric logger and all its backends.""" + from forge.observability.metric_actors import get_or_create_metric_logger + + logger.info("Shutting down metric logger...") + try: + mlogger = await get_or_create_metric_logger() + await mlogger.shutdown.call_one() + except Exception as e: + logger.warning(f"Failed to shutdown metric logger: {e}") + + async def shutdown(): logger.info("Shutting down provisioner..") + + await shutdown_metric_logger() + provisioner = await _get_provisioner() - return await provisioner.shutdown() + result = await provisioner.shutdown() + + logger.info("Shutdown completed successfully") + return result From ffbf0caae3140f6a85cf214d8888c29f398f1821 Mon Sep 17 00:00:00 2001 From: DNXie Date: Mon, 13 Oct 2025 16:02:07 -0700 Subject: [PATCH 15/16] cleanup, update other main --- src/forge/controller/provisioner.py | 5 +++-- tests/sandbox/rl_trainer/main.py | 3 +-- tests/sandbox/toy_rl/toy_metrics/main.py | 2 -- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/src/forge/controller/provisioner.py b/src/forge/controller/provisioner.py index 030bf0aa3..af0e580f2 100644 --- a/src/forge/controller/provisioner.py +++ b/src/forge/controller/provisioner.py @@ -16,7 +16,7 @@ from monarch._src.actor.actor_mesh import ActorMesh from monarch._src.actor.shape import Extent, NDSlice, Shape -from monarch.actor import Actor, endpoint +from monarch.actor import Actor, endpoint, ProcMesh from monarch.tools import commands @@ -511,10 +511,11 @@ async def shutdown_metric_logger(): async def shutdown(): - logger.info("Shutting down provisioner..") await shutdown_metric_logger() + logger.info("Shutting down provisioner..") + provisioner = await _get_provisioner() result = await provisioner.shutdown() diff --git a/tests/sandbox/rl_trainer/main.py b/tests/sandbox/rl_trainer/main.py index ddfda1404..e5ee6fddd 100644 --- a/tests/sandbox/rl_trainer/main.py +++ b/tests/sandbox/rl_trainer/main.py @@ -232,8 +232,7 @@ async def continuous_training(): except KeyboardInterrupt: print("Training interrupted by user") finally: - print("Shutting down trainer...") - await mlogger.shutdown.call_one() + print("Shutting down...") await shutdown() print("Trainer shutdown complete.") diff --git a/tests/sandbox/toy_rl/toy_metrics/main.py b/tests/sandbox/toy_rl/toy_metrics/main.py index 4cb7f0339..57ccd97b5 100644 --- a/tests/sandbox/toy_rl/toy_metrics/main.py +++ b/tests/sandbox/toy_rl/toy_metrics/main.py @@ -110,8 +110,6 @@ async def main(): await mlogger.flush.call_one(i) # shutdown - await mlogger.shutdown.call_one() - await asyncio.sleep(2) await shutdown() From 5dc138ce28fda64fb79537de189b11569499ad59 Mon Sep 17 00:00:00 2001 From: DNXie Date: Tue, 14 Oct 2025 15:35:08 -0700 Subject: [PATCH 16/16] fix broken ci --- src/forge/controller/actor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/forge/controller/actor.py b/src/forge/controller/actor.py index 952e52d5f..c54fba6e0 100644 --- a/src/forge/controller/actor.py +++ b/src/forge/controller/actor.py @@ -8,10 +8,13 @@ import math import sys -from typing import Any, Type, TypeVar +from typing import Any, Type, TYPE_CHECKING, TypeVar from monarch.actor import Actor, current_rank, current_size, endpoint +if TYPE_CHECKING: + from monarch._src.actor.actor_mesh import ActorMesh + from forge.controller.provisioner import ( get_proc_mesh, register_actor,