diff --git a/apps/grpo/main.py b/apps/grpo/main.py index c64f00bc..88e6c3cc 100644 --- a/apps/grpo/main.py +++ b/apps/grpo/main.py @@ -462,7 +462,8 @@ async def continuous_training(): t.step("update_weights") if training_step >= 2: - await drop_weights(training_step - 1) + # TODO: figure out why setting to training_step - 1 will trigger error + await drop_weights(training_step - 2) t.step("drop_weights") t.stop() diff --git a/apps/grpo/qwen3_1_7b.yaml b/apps/grpo/qwen3_1_7b.yaml index 8ff427ad..d3b73dc9 100644 --- a/apps/grpo/qwen3_1_7b.yaml +++ b/apps/grpo/qwen3_1_7b.yaml @@ -45,6 +45,7 @@ policy: # Trainer configuration trainer: + use_dcp: false model: name: qwen3 flavor: 1.7B diff --git a/apps/grpo/qwen3_32b.yaml b/apps/grpo/qwen3_32b.yaml index 8fc056a6..3e660091 100644 --- a/apps/grpo/qwen3_32b.yaml +++ b/apps/grpo/qwen3_32b.yaml @@ -3,10 +3,10 @@ # NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability # Global configuration -group_size: 2 -local_batch_size: 8 # per-device batch size +group_size: 16 +local_batch_size: 16 max_req_tokens: 512 -max_res_tokens: 512 +max_res_tokens: 1536 model: "Qwen/Qwen3-32B" off_by_n: 1 # Off by one by default @@ -14,7 +14,7 @@ provisioner: launcher: slurm # Main loop configuration -rollout_threads: 1 # Recommended to set equal to policy.num_replicas +rollout_threads: 16 # equal to batch size for now # Observability configuration metric_logging: @@ -48,6 +48,7 @@ policy: # Trainer configuration trainer: + use_dcp: false model: name: qwen3 flavor: 32B @@ -69,8 +70,8 @@ trainer: enable: false parallelism: data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 + data_parallel_shard_degree: 1 + tensor_parallel_degree: 8 pipeline_parallel_degree: 1 context_parallel_degree: 1 expert_parallel_degree: 1 @@ -90,7 +91,7 @@ replay_buffer: batch_size: ${local_batch_size} max_policy_age: ${off_by_n} # dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree - dp_size: 8 + dp_size: 1 # Reference model configuration ref_model: @@ -119,7 +120,7 @@ ref_model: services: policy: procs: ${policy.engine_args.tensor_parallel_size} - num_replicas: 1 + num_replicas: 4 hosts: 1 with_gpus: true ref_model: diff --git a/apps/grpo/qwen3_8b.yaml b/apps/grpo/qwen3_8b.yaml index fedf2f36..be87aa87 100644 --- a/apps/grpo/qwen3_8b.yaml +++ b/apps/grpo/qwen3_8b.yaml @@ -41,7 +41,7 @@ policy: # Trainer configuration trainer: - use_dcp: true + use_dcp: false model: name: qwen3 flavor: 8B @@ -53,7 +53,7 @@ trainer: lr_scheduler: warmup_steps: 1 training: - local_local_batch_size: ${local_batch_size} + local_batch_size: ${local_batch_size} seq_len: 2048 max_norm: 1.0 steps: 1000000 diff --git a/src/forge/actors/policy.py b/src/forge/actors/policy.py index 686ec973..9a6705ea 100644 --- a/src/forge/actors/policy.py +++ b/src/forge/actors/policy.py @@ -9,6 +9,7 @@ import asyncio import logging import os +import socket import sys from collections.abc import Mapping from copy import copy @@ -61,6 +62,13 @@ logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) +hostname = socket.gethostname() + + +async def _ts_parallel_get(keys: list[str]) -> list[torch.Tensor]: + coros = [ts.get(key) for key in keys] + return await asyncio.gather(*coros) + @dataclass class Policy(PolicyInterface): @@ -578,14 +586,10 @@ async def update(self, version: int): loaded_weights.update(loaded) else: # Load each parameter from torchstore directly without DCP hf_param_names = [extract_param_name(key) for key in matching_keys] - # We can't pass a generator since vllm load_weights is not async. - # Instead, we just call load_weights with one parameter at a time. - for name in hf_param_names: - param_key = get_param_key(version, name) - param = await ts.get(param_key) - loaded = model.load_weights([(name, param)]) - del param - loaded_weights.update(loaded) + param_keys = [get_param_key(version, name) for name in hf_param_names] + new_params = await _ts_parallel_get(param_keys) + loaded = model.load_weights(zip(hf_param_names, new_params)) + loaded_weights.update(loaded) t.stop() logger.debug(f"[PolicyWorker::update] Loaded weights: {loaded_weights}") diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index f4199db7..a135c3a7 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import asyncio import logging import math import os @@ -12,12 +13,23 @@ import time from collections.abc import Mapping from dataclasses import dataclass, field, fields -from typing import Callable +from typing import Callable, Iterable import torch import torch.distributed.checkpoint as dcp import torchstore as ts +from forge.actors._torchstore_utils import ( + DcpHandle, + get_dcp_whole_state_dict_key, + get_param_key, +) + +from forge.controller import ForgeActor +from forge.data.utils import batch_to_device +from forge.observability.metrics import record_metric, Reduce +from forge.observability.perf_tracker import Tracer + from monarch.actor import current_rank, current_size, endpoint from torch import Tensor from torch.distributed.checkpoint._nested_dict import flatten_state_dict @@ -38,17 +50,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -from forge.actors._torchstore_utils import ( - DcpHandle, - get_dcp_whole_state_dict_key, - get_param_key, -) - -from forge.controller import ForgeActor -from forge.data.utils import batch_to_device -from forge.observability.metrics import record_metric, Reduce -from forge.observability.perf_tracker import Tracer - logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) @@ -336,7 +337,8 @@ async def push_weights(self, policy_version: int) -> None: else: for name, param in hf_state_dict.items(): key = get_param_key(policy_version, name) - await ts.put(key, param) + await ts.put(key, param.detach().cpu()) + t.step("ts_save") t.step("ts_save") t.stop() end_time = time.perf_counter() diff --git a/src/forge/env.py b/src/forge/env.py index 1699ecc9..071f7a43 100644 --- a/src/forge/env.py +++ b/src/forge/env.py @@ -105,6 +105,12 @@ def get_value(self) -> Any: description="Sets the maximum frame length for Monarch's actor message delivery in bytes.", ) +OMP_NUM_THREADS = EnvVar( + name="OMP_NUM_THREADS", + default=16, # Recommended <= # cores / # of gpus since we are using 1 gpu per process + description="Sets the number of OpenMP threads to use. This is used for CPU-bound operations in PyTorch.", +) + def all_env_vars() -> list[EnvVar]: """Retrieves all registered environment variable names."""