Skip to content
Open
3 changes: 2 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ async def continuous_training():
t.step("update_weights")

if training_step >= 2:
await drop_weights(training_step - 1)
# TODO: figure out why setting to training_step - 1 will trigger error
await drop_weights(training_step - 2)
t.step("drop_weights")

t.stop()
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: false
model:
name: qwen3
flavor: 1.7B
Expand Down
17 changes: 9 additions & 8 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
local_batch_size: 8 # per-device batch size
group_size: 16
local_batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
max_res_tokens: 1536
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 16 # equal to batch size for now

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -48,6 +48,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: false
model:
name: qwen3
flavor: 32B
Expand All @@ -69,8 +70,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +91,7 @@ replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 8
dp_size: 1

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +120,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
num_replicas: 4
hosts: 1
with_gpus: true
ref_model:
Expand Down
2 changes: 1 addition & 1 deletion apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: true
use_dcp: false
model:
name: qwen3
flavor: 8B
Expand Down
95 changes: 59 additions & 36 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import logging
import os
import socket
import sys
from collections.abc import Mapping
from copy import copy
Expand All @@ -18,6 +19,7 @@
import torch.distributed.checkpoint as dcp
import torchstore as ts
from monarch.actor import current_rank, endpoint, ProcMesh
from torch.profiler import profile, ProfilerActivity, record_function
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig

Expand Down Expand Up @@ -61,6 +63,21 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

hostname = socket.gethostname()


def trace_handler(rank, p):
p.export_chrome_trace(
f"/mnt/data/yuxuanh/profiler/{hostname}_trace_rank_{rank}_"
+ str(p.step_num)
+ ".json"
)


async def _ts_parallel_get(keys: list[str]) -> list[torch.Tensor]:
coros = [ts.get(key) for key in keys]
return await asyncio.gather(*coros)


@dataclass
class Policy(PolicyInterface):
Expand Down Expand Up @@ -552,42 +569,48 @@ async def _load_tensor_parallel_state_dict(
@endpoint
async def update(self, version: int):
"""Update model weights by reading state dict from torchstore"""
logger.info(
f"[PolicyWorker::update] start updating weights to version {version}"
)
model = self.worker.model_runner.model
prefix = get_param_prefix(version)
logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
logger.debug(f"{matching_keys=}")
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
loaded_weights = set()
t = Tracer("policy_worker_perf/update", timer="gpu")
t.start()
# Entire state dict is stored in a single DCP handle
if dcp_whole_state_dict_key in matching_keys:
logger.info(
f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}"
)
dcp_handle = await ts.get(dcp_whole_state_dict_key)
hf_param_names = dcp_handle.param_names
for name in hf_param_names:
param = load_tensor_from_dcp(dcp_handle, name)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else: # Load each parameter from torchstore directly without DCP
hf_param_names = [extract_param_name(key) for key in matching_keys]
# We can't pass a generator since vllm load_weights is not async.
# Instead, we just call load_weights with one parameter at a time.
for name in hf_param_names:
param_key = get_param_key(version, name)
param = await ts.get(param_key)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
t.stop()
logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}")
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
on_trace_ready=lambda p: trace_handler(self.rank, p),
with_stack=True,
profile_memory=True,
) as prof:
with record_function("policy_worker_perf/update"):
logger.info(
f"[PolicyWorker::update] start updating weights to version {version}"
)
model = self.worker.model_runner.model
prefix = get_param_prefix(version)
logger.debug(f"{prefix=}")
matching_keys = await ts.keys(prefix)
logger.debug(f"{matching_keys=}")
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
loaded_weights = set()
t = Tracer("policy_worker_perf/update", timer="gpu")
t.start()
# Entire state dict is stored in a single DCP handle
if dcp_whole_state_dict_key in matching_keys:
logger.info(
f"Loading {dcp_whole_state_dict_key} from DCP with handle {dcp_whole_state_dict_key}"
)
dcp_handle = await ts.get(dcp_whole_state_dict_key)
hf_param_names = dcp_handle.param_names
for name in hf_param_names:
param = load_tensor_from_dcp(dcp_handle, name)
loaded = model.load_weights([(name, param)])
del param
loaded_weights.update(loaded)
else: # Load each parameter from torchstore directly without DCP
hf_param_names = [extract_param_name(key) for key in matching_keys]
param_keys = [
get_param_key(version, name) for name in hf_param_names
]
new_params = await _ts_parallel_get(param_keys)
loaded = model.load_weights(zip(hf_param_names, new_params))
loaded_weights.update(loaded)
t.stop()
logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}")

@endpoint
async def setup_kv_cache(self):
Expand Down
19 changes: 10 additions & 9 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import logging
import math
import os
Expand All @@ -12,7 +13,7 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
from typing import Callable
from typing import Callable, Iterable

import torch
import torch.distributed.checkpoint as dcp
Expand All @@ -38,11 +39,7 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)
from forge.actors._torchstore_utils import DcpHandle, get_dcp_whole_state_dict_key

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
Expand Down Expand Up @@ -92,6 +89,12 @@ def cleanup_old_weight_versions(
logger.debug(f"Error deleting {item_path}: {e}")


async def _parallel_put(kv_pairs: Iterable[tuple[str, Tensor]]):
keys, tensors = zip(*kv_pairs)
coros = [ts.put(key, tensor.detach().cpu()) for key, tensor in zip(keys, tensors)]
await asyncio.gather(*coros)


@dataclass
class RLTrainer(ForgeActor):
job: Job = field(default_factory=Job)
Expand Down Expand Up @@ -334,9 +337,7 @@ async def push_weights(self, policy_version: int) -> None:
await ts.put(key, dcp_handle)
t.step("dcp_save")
else:
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)
await _parallel_put(hf_state_dict.items())
t.step("ts_save")
t.stop()
end_time = time.perf_counter()
Expand Down
6 changes: 6 additions & 0 deletions src/forge/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def get_value(self) -> Any:
description="Sets the maximum frame length for Monarch's actor message delivery in bytes.",
)

OMP_NUM_THREADS = EnvVar(
name="OMP_NUM_THREADS",
default=16, # Recommended <= # cores / # of gpus since we are using 1 gpu per process
description="Sets the number of OpenMP threads to use. This is used for CPU-bound operations in PyTorch.",
)


def all_env_vars() -> list[EnvVar]:
"""Retrieves all registered environment variable names."""
Expand Down
Loading