Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ async def main(cfg: DictConfig):
provisioner = await init_provisioner()

metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

# ---- Setup services ---- #
Expand Down
11 changes: 6 additions & 5 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,12 @@ def bootstrap(env: dict[str, str]):

self._proc_host_map[procs] = host_mesh

# Spawn local fetcher actor on each process and register with global logger
# Spawn LocalFetcherActor for this ProcMesh and register with GlobalLoggingActor.
# When called, the LocalFetcherActor is broadcast by Monarch to all ranks in the ProcMesh.
if not FORGE_DISABLE_METRICS.get_value():
from forge.observability.metric_actors import get_or_create_metric_logger

_ = await get_or_create_metric_logger(procs)
_ = await get_or_create_metric_logger(procs, process_name=mesh_name)
return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand All @@ -333,14 +334,14 @@ async def stop_proc_mesh(self, proc_mesh: ProcMesh):
)
return
async with self._lock:
# Deregister local logger from global logger
if hasattr(proc_mesh, "_local_fetcher"):
# Deregister LocalFetcherActor from GlobalLoggingActor
if hasattr(proc_mesh, "_local_fetcher") and hasattr(proc_mesh, "_uid"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for a proc_mesh that has _local_fetcher but not _uid?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, they should always have both. I guess i was having extra safe here. Is it confusing?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand you write it like this to be safe. But I just worry it may hide some potential errors. How about raise an error if it has _local_fetcher but not _uid?

from forge.observability.metric_actors import (
get_or_create_metric_logger,
)

global_logger = await get_or_create_metric_logger(proc_mesh)
await global_logger.deregister_fetcher.call_one(proc_mesh)
await global_logger.deregister_fetcher.call_one(proc_mesh._uid)

if hasattr(proc_mesh, "_gpu_ids"):
gpu_manager = self._host_gpu_map[proc_mesh._host._host_id]
Expand Down
6 changes: 3 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
from .metrics import (
BackendRole,
ConsoleBackend,
get_actor_name_with_rank,
get_logger_backend_class,
LoggerBackend,
MaxAccumulator,
MeanAccumulator,
Expand All @@ -29,12 +27,12 @@
WandbBackend,
)
from .perf_tracker import trace, Tracer
from .utils import get_proc_name_with_rank

__all__ = [
# Main API functions
"record_metric",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
Expand All @@ -45,6 +43,8 @@
"BackendRole",
# Enums
"Reduce",
# Utility functions
"get_proc_name_with_rank",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
133 changes: 81 additions & 52 deletions src/forge/observability/metric_actors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,17 @@

import asyncio
import logging
import uuid
from typing import Any, Union

from monarch.actor import Actor, endpoint, get_or_spawn_controller, ProcMesh, this_proc
from monarch.actor import (
Actor,
context,
endpoint,
get_or_spawn_controller,
ProcMesh,
this_proc,
)

from forge.env import FORGE_DISABLE_METRICS
from forge.observability.metrics import (
Expand All @@ -27,36 +35,35 @@

async def get_or_create_metric_logger(
proc_mesh: ProcMesh | None = None,
process_name: str | None = None,
) -> "GlobalLoggingActor":
"""Initializes a LocalFetcherActor in the specified process mesh (or current process if None),
if not already initialized, registers it with the GlobalLoggingActor and returns the
GlobalLoggingActor instance.
"""Spawns a LocalFetcherActor for the specified ProcMesh (if not already initialized),
registers it with the GlobalLoggingActor, and returns the GlobalLoggingActor.

There are primarily two ways to use this function:
1. In the main process, call `get_or_create_metric_logger()` to get the global logger.
2. In service processes, call `get_or_create_metric_logger(proc_mesh)` to register the
local fetcher with the global logger.
Usage:
1. Main process: call `get_or_create_metric_logger()` to get the global logger
2. Service spawning: call `get_or_create_metric_logger(proc_mesh, process_name)` to register the
map(proc_mesh,local fetcher) with the global logger, so it knows to broadcast to all ranks.

Args:
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None,
uses `monarch.actor.this_proc()`.
proc_mesh: Optional ProcMesh to spawn LocalFetcherActor on. If None, uses `this_proc()`.
process_name: Optional process name (e.g., "TrainActor") for logging. Auto-detected from the context if None.

Returns:
GlobalLoggingActor: The global logging controller.

Raises:
ValueError: If the logging state is inconsistent, i.e. the fetcher is already
registered, but only in the process or the global logger.
ValueError: If the logging state is inconsistent.

Example:
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric

# Main process setup
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")

# Initialize logging backends
await mlogger.init_backends({
await mlogger.init_backends.call_one({
"console": {"reduce_across_ranks": True},
"wandb": {"project": "my_project", "reduce_across_ranks": False}
})
Expand All @@ -66,12 +73,12 @@ async def get_or_create_metric_logger(

# Training loop
for step in range(max_steps):
record_metric("loss", 1.2, step, reduction_type=Reduce.MEAN)
record_metric("loss", 1.2, reduction_type=Reduce.MEAN)
# ... training code with record_metric() calls ...
await mlogger.flush(step) # Log metrics for this step
await mlogger.flush.call_one(step) # Log metrics for this step

# Shutdown
await mlogger.shutdown()
await mlogger.shutdown.call_one()
"""
# Get or create the singleton global logger
global _global_logger
Expand All @@ -85,9 +92,15 @@ async def get_or_create_metric_logger(
# Determine process context
proc = proc_mesh if proc_mesh is not None else this_proc()

# Auto-detect process_name from proc mesh if not provided
if process_name is None:
ctx = context()
process_name = ctx.actor_instance.actor_id.actor_name

# Check current state for consistency
proc_has_local_fetcher = hasattr(proc, "_local_fetcher")
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc)
proc_id = proc._uid if proc_has_local_fetcher else None
global_logger_has_local_fetcher = await global_logger.has_fetcher.call_one(proc_id)

# Consistency check: both should be in sync
if proc_has_local_fetcher != global_logger_has_local_fetcher:
Expand All @@ -102,24 +115,32 @@ async def get_or_create_metric_logger(
# Setup local_fetcher_actor if needed (unless disabled by environment flag)
if not proc_has_local_fetcher and not FORGE_DISABLE_METRICS.get_value():
local_fetcher_actor = proc.spawn(
"local_fetcher_actor", LocalFetcherActor, global_logger
"local_fetcher_actor", LocalFetcherActor, global_logger, process_name
)
await global_logger.register_fetcher.call_one(local_fetcher_actor, proc)
# Generate a unique ID to map procmesh to fetcher
proc._uid = str(uuid.uuid4())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix! LGTM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you mind approving it?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am a little concerning about the broken CI. Wouldn't it cause all the subsequent commits to break as well?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think that errors are related to this PR. But let me confirm by opening a dummy PR

proc._local_fetcher = local_fetcher_actor # pyre-ignore

await global_logger.register_fetcher.call_one(local_fetcher_actor, proc._uid)

return global_logger


class LocalFetcherActor(Actor):
"""Thin per-process actor used to trigger MetricCollector singleton
operations without direct access. It is what GlobalLoggingActor
uses to broadcast inits/flushes across ranks.
"""Actor spawned once per ProcMesh that, when called, runs on every rank in that ProcMesh
and accesses each rank's local MetricCollector.

GlobalLoggingActor -> per-rank LocalFetcherActor -> per-rank MetricCollector
Flow:
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""

def __init__(self, global_logger: Union["GlobalLoggingActor", None] = None) -> None:
def __init__(
self,
global_logger: Union["GlobalLoggingActor", None] = None,
process_name: str | None = None,
) -> None:
self.global_logger = global_logger
self.process_name = process_name
_is_initialized = False

@endpoint
Expand All @@ -146,10 +167,22 @@ async def init_backends(
self,
metadata_per_primary_backend: dict[str, dict[str, Any]],
config: dict[str, Any],
global_step: int = 0,
) -> None:
"""Init local (per-rank) logger backends and MetricCollector."""
"""Init per-rank logger backends and MetricCollector.

Args:
metadata_per_primary_backend (dict[str, dict[str, Any]]): Metadata from primary backends for shared state.
config (dict[str, Any]): Backend configurations with logging modes and settings.
global_step (int): Initial step for metrics.
"""
collector = MetricCollector()
await collector.init_backends(metadata_per_primary_backend, config)
await collector.init_backends(
metadata_per_primary_backend,
config,
global_step,
process_name=self.process_name,
)

@endpoint
async def shutdown(self) -> None:
Expand All @@ -158,22 +191,17 @@ async def shutdown(self) -> None:


class GlobalLoggingActor(Actor):
"""Coordinates metric logging across all ranks for every training step.
"""Coordinates metric logging across all ProcMeshes and their ranks.

Supports multiple logging backends (e.g., WandB, TensorBoard, etc.),
for per-rank and/or global reduction logging modes.
with per-rank and/or global reduction logging modes.

If a backend config has flag `reduce_across_ranks=False`, an instance of the backend
is initialized per-rank, otherwise it is done once globally.

This GlobalLoggingActor should be spawned once in the controller. A LocalFetcherActor
is automatically spawned per-rank in `forge.controller.provisioner.py` and registered
with this actor. The LocalFetcherActor is responsible for instantiating
the per-rank MetricCollector.

In summary, the flow is:
- GlobalLoggingActor init_backends() -> LocalFetcherActor init_backends() -> per-rank MetricCollector
- GlobalLoggingActor flush() -> LocalFetcherActor flush() -> per-rank MetricCollector flush
Flow:
GlobalLoggingActor.method() -> per-procmesh LocalFetcherActor.method() -> per-rank MetricCollector.method() -> logger
"""

def __init__(self):
Expand Down Expand Up @@ -209,7 +237,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:

for backend_name, backend_config in config.items():
backend = get_logger_backend_class(backend_name)(backend_config)
await backend.init(role=BackendRole.GLOBAL)
await backend.init(role=BackendRole.GLOBAL, name="global_reduce")

# Extract metadata from primary logger to be shared with secondary loggers
# and store it
Expand Down Expand Up @@ -237,30 +265,31 @@ async def init_backends(self, config: dict[str, Any]) -> None:
await asyncio.gather(*tasks, return_exceptions=True)

@endpoint
async def register_fetcher(
self, fetcher: LocalFetcherActor, name: str | ProcMesh
) -> None:
"""Registers a fetcher with the global actor. Each key represents a process mesh.
If there are 2 processes, each with 2 replicas with N gpus, we would
have 4 keys, i.e. 2 proces meshes, each with 2 replicas."""
self.fetchers[name] = fetcher # pyre-ignore
async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> None:
"""Registers a LocalFetcherActor with the GlobalLoggingActor. One LocalFetcherActor per ProcMesh.

Args:
fetcher: The LocalFetcherActor instance for a ProcMesh
proc_id: Unique identifier for the ProcMesh
"""
self.fetchers[proc_id] = fetcher

# Self-init for respawned actors
if self.config:
logger.debug(f"Initializing new LocalFetcherActor {name}")
logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}")
await fetcher.init_backends.call(
self.metadata_per_primary_backend, self.config
)

@endpoint
async def deregister_fetcher(self, name: str | ProcMesh) -> None:
if name not in self.fetchers:
async def deregister_fetcher(self, proc_id: str) -> None:
if proc_id not in self.fetchers:
logger.warning(
f"Fetcher {name} not registered in GlobalLoggingActor. Cannot deregister."
f"Fetcher {proc_id} not registered in GlobalLoggingActor. Cannot deregister."
f"Available fetchers: {self.fetchers.keys()}"
)
return
del self.fetchers[name]
del self.fetchers[proc_id]

@endpoint
async def flush(self, global_step: int) -> None:
Expand Down Expand Up @@ -333,9 +362,9 @@ async def flush(self, global_step: int) -> None:
await logger_backend.log(reduced_metrics, global_step)

@endpoint
def has_fetcher(self, name: str | ProcMesh) -> bool:
"""Check if a fetcher is registered with the given name."""
return name in self.fetchers
def has_fetcher(self, proc_id: str) -> bool:
"""Check if a fetcher is registered with the given proc_id."""
return proc_id in self.fetchers

@endpoint
def get_fetcher_count(self) -> int:
Expand Down
Loading
Loading