Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
from forge.types import LauncherConfig, ProvisionerConfig, Launcher
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -314,12 +314,12 @@ async def main(cfg: DictConfig):
max_res_tokens = cfg.max_res_tokens

# ---- Global setups ---- #
if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(**cfg.provisioner.launcher)
)
# if cfg.get("provisioner", None) is not None:
await init_provisioner(
ProvisionerConfig(
launcher_config=LauncherConfig(launcher=Launcher("slurm")),
)
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
Expand Down
15 changes: 8 additions & 7 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
group_size: 8
batch_size: 8
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm
launcher_config:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
Expand All @@ -37,7 +38,7 @@ dataset:
policy:
engine_config:
model: ${model}
tensor_parallel_size: 4
tensor_parallel_size: 8
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
Expand Down Expand Up @@ -69,8 +70,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 4
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand Down Expand Up @@ -136,8 +137,8 @@ actors:
procs: 1
with_gpus: false
trainer:
procs: 8
hosts: 1
procs: 4
# hosts: 1
with_gpus: true
replay_buffer:
procs: 1
Expand Down
2 changes: 2 additions & 0 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ async def launch(cls, *args, **kwargs) -> "ForgeActor":
mesh_name=cls.mesh_name,
)

print(f"Spawning proc mesh {cls.mesh_name} with gpus {cls.with_gpus}")

proc_mesh = await get_proc_mesh(process_config=cfg)

actor_name = kwargs.pop("name", cls.__name__)
Expand Down
6 changes: 3 additions & 3 deletions src/forge/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ class LauncherConfig:
"""A launcher config for the scheduler."""

launcher: Launcher
job_name: str
services: dict[str, ServiceConfig]
actors: dict[str, ProcessConfig]
job_name: str | None = None
services: dict[str, ServiceConfig] | None = None
actors: dict[str, ProcessConfig] | None = None


@dataclass
Expand Down
Loading