Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 46 additions & 30 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
from forge.util.ops import selective_log_softmax
from monarch.actor import endpoint
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer
Expand All @@ -43,7 +44,7 @@ class Episode:
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
ref_logits: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None

Expand Down Expand Up @@ -107,8 +108,8 @@ def collate(batches: list[list[Episode]]):
response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]

ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
ref_logits = [e.ref_logits for e in batch]
ref_logits = torch.stack(ref_logits).squeeze() # [b x s]

advantages = [e.advantage for e in batch]
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
Expand All @@ -119,7 +120,7 @@ def collate(batches: list[list[Episode]]):
input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"ref_logprobs": ref_logprobs,
"ref_logits": ref_logits,
"advantages": advantages,
"padding_mask": mask,
}
Expand All @@ -129,30 +130,35 @@ def collate(batches: list[list[Episode]]):


def compute_logprobs(
logits: torch.Tensor, input_ids: torch.Tensor, temperature: float = 1.0
logits: torch.Tensor, target_ids: torch.Tensor, temperature: float = 1.0
) -> torch.Tensor:
context_length = logits.shape[1] - input_ids.shape[1]
logits = logits[:, context_length - 1 : -1]
logprobs = torch.log_softmax(logits / temperature, dim=-1).to(input_ids.device)
logprobs = torch.gather(logprobs, 2, input_ids.unsqueeze(-1)).squeeze(-1)
return logprobs
logits = logits[:, -target_ids.size(1) : -1, :].float()
scaled_logits = logits / temperature
logprobs = selective_log_softmax(scaled_logits, target_ids)
return logprobs.to(target_ids.device)


def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
ref_logprobs: torch.Tensor,
ref_logits: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
print(f"num of padding: {padding_mask.sum(dim=1)}")
# assert ref_logits.dtype == torch.long
# assert logits.dtype == torch.long
logprobs = compute_logprobs(logits, response)
ref_logprobs = compute_logprobs(ref_logits, response)
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
print(f"kl (no padding): {(kl * padding_mask).mean(dim=1)}")
# Pad out via padding mask
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
(per_token_loss * padding_mask).sum(dim=1)
/ padding_mask.sum(dim=1).clamp(min=1.0)
).mean()
return loss

Expand Down Expand Up @@ -299,28 +305,31 @@ async def continuous_rollouts():
target=target,
)

input_ids = torch.ones(
(group_size, max_req_tokens + max_req_tokens),
dtype=torch.long,
device="cuda",
)
# Populate episode info and calculate rewards
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
# Populate episode info, compute ref logprobs, and calculate rewards
for episode, response in zip(group.episodes, responses):
episode.request_tokens = response.prompt_ids
episode.response_tokens = response.token_ids
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.ref_logits = await ref_model.forward.choose(
torch.cat(
[episode.request_tensor, episode.response_tensor]
).unsqueeze(0)
)
episode.reward = await reward_actor.evaluate_response.choose(
prompt=prompt, response=response.text, target=target
)

# Calculate reference logprobs
ref_logits = await ref_model.forward.choose(input_ids)
ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logits, ref_logprobs, input_ids
# # Calculate reference logprobs
# print(f" input ids dtype: {input_ids.dtype}")
# ref_logits = await ref_model.forward.choose(input_ids)
# # ref_logits = ref_logits[:, :-1, :] # Exclude the last token
# # ref_logits = ref_logits[:, -max_res_tokens:, :]
# print(f" ref logits dtype: {ref_logits.dtype}")
# print("Computed ref logits")
# # ref_logprobs = compute_logprobs(ref_logits, input_ids[:, max_req_tokens:])
# for i, episode in enumerate(group.episodes):
# episode.ref_logits = ref_logits[i]
# del ref_logits, input_ids

# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.choose(group)
Expand All @@ -342,15 +351,22 @@ async def continuous_rollouts():

async def continuous_training():
training_step = 0
_tokenizer = get_tokenizer("Qwen/Qwen3-1.7B")
while True:
batch = await replay_buffer.sample.choose(curr_policy_version=training_step)
if batch is None:
await asyncio.sleep(0.1)
else:
inputs, targets = batch
loss = await trainer.train_step.choose(inputs, targets)
tokens = inputs[0]["tokens"]
print(f"Training input: {_tokenizer.batch_decode(tokens)}")
print(f"Num of padding tokens: {targets[0]['padding_mask'].sum(dim=1)}")
metrics = await trainer.train_step.choose(inputs, targets)
training_step += 1
mlogger.log("loss/training_step", loss, training_step)
mlogger.log("loss/training_step", metrics["loss"], training_step)
mlogger.log(
"grad_norm/training_step", metrics["grad_norm"], training_step
)
await trainer.push_weights.call(training_step)
await policy.update_weights.call(training_step)

Expand Down
11 changes: 7 additions & 4 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@

# Global configuration
group_size: 8
batch_size: 16
batch_size: 8
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"
off_by_n: 1 # Off by one by default
off_by_n: 0 # Off by one by default

# Dataset configuration
dataset:
Expand All @@ -24,11 +24,14 @@ policy:
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
dtype: "float32"
gpu_memory_utilization: 0.9
sampling_config:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
seed: 42

# Trainer configuration
trainer:
Expand All @@ -47,7 +50,7 @@ trainer:
seq_len: 2048
max_norm: 1.0
steps: 1000000
dtype: bfloat16
dtype: float32
gc_freq: 1
compile:
enable: false
Expand Down Expand Up @@ -83,7 +86,7 @@ ref_model:
flavor: 1.7B
hf_assets_path: hf://${model}
training:
dtype: bfloat16
dtype: float32
gc_freq: 1
compile:
enable: false
Expand Down
1 change: 1 addition & 0 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ async def run(cfg: DictConfig):
inputs, targets = await replay_buffer.sample.choose(curr_policy_version=0)
outputs = await trainer.train_step.choose(inputs, targets)
print("Loss: ", outputs["loss"])
print("Gradient Norm: ", outputs["grad_norm"])

print("Shutting down...")
await trainer.shutdown()
Expand Down
19 changes: 10 additions & 9 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@
import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore.state_dict_utils import DELIM
from vllm.config import VllmConfig
Expand All @@ -43,15 +52,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.data_models.completion import Completion
from forge.data_models.prompt import to_prompt

from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand All @@ -77,6 +77,7 @@ class SamplingConfig:
temperature: float = 1.0
top_p: float = 1.0
logprobs: int = 1
seed: int | None = None

def __post_init__(self):
super().__init__()
Expand Down
10 changes: 6 additions & 4 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dataclasses import dataclass, field, fields

import torch

from forge.controller import ForgeActor
from monarch.actor import current_rank, current_size, endpoint
from torch.distributed.tensor import DTensor

Expand All @@ -26,8 +28,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

Expand Down Expand Up @@ -86,13 +86,15 @@ def __post_init__(self):
async def setup(self):
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.model = self.engine.model_parts[0] # Currently not using PP
self.model.eval()

@endpoint
async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
self.engine.gc_handler.run(self.step)
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
# print(f"Ref model input_ids: {input_ids}")
# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
Expand All @@ -112,7 +114,7 @@ async def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
with self.engine.train_context(optional_context_parallel_ctx):
with self.engine.maybe_enable_amp:
with torch.inference_mode():
logits = model_parts[0](input_ids)
logits = self.model(input_ids)
self.step += 1
if isinstance(logits, DTensor):
logits = logits.full_tensor()
Expand Down
27 changes: 21 additions & 6 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from monarch.actor import current_rank, current_size, endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
Expand All @@ -32,11 +33,10 @@
Parallelism,
Training,
)
from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from transformers import AutoModelForCausalLM

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -166,6 +166,10 @@ def forward_backward(
assert len(model_parts) == 1
with self.engine.maybe_enable_amp:
logits = model_parts[0](**inputs)
# hf_logits = self.hf_model(input_ids=inputs["tokens"]).logits.to(
# "cpu"
# )
# assert torch.allclose(logits, hf_logits)
loss = self.loss(logits, **targets)
# need to free to before bwd to avoid peaking memory
del logits
Expand All @@ -176,7 +180,7 @@ def forward_backward(
@endpoint
def train_step(
self, inputs: list[dict[str, Tensor]], targets: list[dict[str, Tensor]]
) -> float:
) -> dict[str, float]:
self.engine.gc_handler.run(self.step)
local_inputs = inputs[self.engine.dp_rank]
local_targets = targets[self.engine.dp_rank]
Expand All @@ -193,6 +197,17 @@ def train_step(
loss = self.forward_backward(local_inputs, local_targets)
torch.distributed.all_reduce(loss)

grad_norm = dist_utils.clip_grad_norm_(
[p for m in self.engine.model_parts for p in m.parameters()],
self.training.max_norm,
foreach=True,
pp_mesh=None,
# (
# self.engine.parallel_dims.world_mesh["pp"] if self.engine.parallel_dims.pp_enabled else None
# ),
ep_enabled=False, # parallel_dims.ep_enabled,
)

self.engine.optimizers.step()
self.engine.optimizers.zero_grad()
self.engine.lr_schedulers.step()
Expand All @@ -203,7 +218,7 @@ def train_step(
last_step=self.step == self.num_training_steps,
)

return loss.item()
return {"loss": loss.item(), "grad_norm": grad_norm}

@endpoint
async def push_weights(self, policy_version: int) -> None:
Expand Down
Loading