forked from meta-pytorch/torchforge
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
555 lines (472 loc) · 18.7 KB
/
main.py
File metadata and controls
555 lines (472 loc) · 18.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
# Usage: python -m apps.grpo.main --config apps/grpo/qwen3_1_7b.yaml
import asyncio
import time
import uuid
from dataclasses import dataclass
from typing import Any, Callable
import torch
import torch.nn.functional as F
import torchstore as ts
from datasets import load_dataset
from forge.actors._torchstore_utils import (
get_dcp_whole_state_dict_key,
get_param_prefix,
)
from forge.actors.policy import Policy
from forge.actors.reference_model import ReferenceModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.actors.trainer import RLTrainer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer
from forge.types import LauncherConfig, ProvisionerConfig
from forge.util.ops import compute_logprobs
from monarch.actor import endpoint
from omegaconf import DictConfig
from vllm.transformers_utils.tokenizer import get_tokenizer
@dataclass
class Episode:
# TODO: add adtional layer for multi-turn
episode_id: str
request: str
policy_version: int
pad_id: int
request_len: int
response_len: int
target: Any | None = None
# processed data
response: str | None = None
request_tokens: list[int] | None = None
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
advantage: float | None = None
@property
def request_tensor(self):
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
if tensor.shape[0] < self.request_len: # left pad
diff = self.request_len - tensor.shape[0]
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
return tensor
@property
def response_tensor(self):
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
if tensor.shape[0] < self.response_len: # right pad
diff = self.response_len - tensor.shape[0]
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
return tensor
@dataclass
class Group:
group_id: str
episodes: list[Episode]
@classmethod
def new_group(
cls,
group_id: int,
group_size: int,
request: str,
policy_version: int,
pad_id: int,
request_len: int,
response_len: int,
target: Any = None,
):
episodes = []
for _ in range(group_size):
episodes.append(
Episode(
episode_id=str(uuid.uuid4()),
request=request,
policy_version=policy_version,
pad_id=pad_id,
request_len=request_len,
response_len=response_len,
target=target,
)
)
return cls(str(group_id), episodes)
def collate(batches: list[list[Episode]]):
inputs = []
targets = []
for batch in batches:
request = [e.request_tensor for e in batch]
request = torch.stack(request) # [b x s]
response = [e.response_tensor for e in batch]
response = torch.stack(response) # [b x s]
ref_logprobs = [e.ref_logprobs for e in batch]
ref_logprobs = torch.stack(ref_logprobs).squeeze() # [b x s]
advantages = [e.advantage for e in batch]
advantages = torch.tensor(advantages).unsqueeze(-1) # [b x 1]
pad_id = batch[0].pad_id
mask = response != pad_id
input = {"tokens": torch.cat([request, response], dim=1)}
target = {
"response": response,
"ref_logprobs": ref_logprobs,
"advantages": advantages,
"padding_mask": mask,
}
inputs.append(input)
targets.append(target)
return inputs, targets
def simple_grpo_loss(
logits: torch.Tensor,
response: torch.Tensor,
ref_logprobs: torch.Tensor,
advantages: torch.Tensor,
padding_mask: torch.Tensor,
beta: float = 0.1,
) -> torch.Tensor:
"""
Example GRPO Loss Function for RLTrainer
"""
logprobs: torch.Tensor = compute_logprobs(logits, response)
# Note: This is also available in losses.grpo_loss via `SimpleGRPOLoss`
kl = torch.exp(ref_logprobs - logprobs) - (ref_logprobs - logprobs) - 1
per_token_policy_loss = torch.exp(logprobs - logprobs.detach()) * advantages
per_token_loss = -(per_token_policy_loss - beta * kl)
loss = (
((per_token_loss * padding_mask).sum(dim=1))
/ (padding_mask.sum(dim=1).clamp(min=1.0))
).mean()
return loss
@dataclass
class RewardActor(ForgeActor):
"""Reward actor that uses a list of scoring functions."""
reward_functions: list[Callable]
@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
total_rewards = 0.0
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
# Get a name for the reward function (works for classes, functions, lambdas)
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
reward,
Reduce.SUM,
)
record_metric(
f"reward/evaluate_response/avg_{reward_fn_name}_reward",
reward,
Reduce.MEAN,
)
record_metric(
f"reward/evaluate_response/std_{reward_fn_name}_reward",
reward,
Reduce.STD,
)
# avg total reward
record_metric(
"reward/evaluate_response/avg_total_reward",
reward,
Reduce.MEAN,
)
# count fn calls
record_metric(
f"reward/evaluate_response/count_{reward_fn_name}_calls",
1,
Reduce.SUM,
)
avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
@dataclass
class ComputeAdvantages(ForgeActor):
"""Compute advantages for GRPO using reward signals."""
@endpoint
async def compute(self, group: Group) -> list[float]:
# TODO: add batch processing
rewards = torch.tensor([[e.reward for e in group.episodes]])
mean = rewards.mean(1, keepdim=True)
std = rewards.std(1, keepdim=True)
advantages = (rewards - mean) / (std + 1e-4)
return advantages.squeeze(0).tolist()
@dataclass
class DatasetActor(ForgeActor):
"""Actor wrapper for HuggingFace dataset to provide async interface."""
path: str = "openai/gsm8k"
revision: str = "main"
data_split: str = "train"
streaming: bool = True
model: str = "Qwen/Qwen3-1.7B"
@endpoint
def setup(self):
self._tokenizer = get_tokenizer(self.model)
def gsm8k_transform(sample):
system_prompt = """
Put all your scratchpad work between <think> and </think> tags.
Your final answer should be between <answer> and </answer> tags otherwise it will not be scored.
"""
request: str = sample["question"]
as_chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": request},
]
formatted_request = self._tokenizer.apply_chat_template(
as_chat,
tokenize=False,
add_generation_prompt=True,
)
target: str = sample["answer"]
formatted_target = target.split("#### ")[1]
return {"request": formatted_request, "target": formatted_target}
ds = load_dataset(
self.path, self.revision, split=self.data_split, streaming=self.streaming
)
ds = ds.map(gsm8k_transform)
ds = ds.shuffle()
self._iterator = iter(ds)
@endpoint
async def sample(self) -> dict[str, str] | None:
try:
sample = next(self._iterator)
# Record dataset metrics
record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
record_metric(
"dataset/sample/avg_sample_len",
len(sample["request"]),
Reduce.MEAN,
)
return sample
except StopIteration:
return None
@endpoint
async def pad_token(self):
return self._tokenizer.pad_token_id
async def drop_weights(version: int):
print(f"Dropping weights @ version {version}")
start_time = time.perf_counter()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
# TODO: once we have something like `get_meta()` in torchstore, we can just
# query the type of the object instead of relying on keys.
dcp_key = get_dcp_whole_state_dict_key(version)
if dcp_key in matching_keys:
dcp_handle = await ts.get(dcp_key)
dcp_handle.drop()
for key in matching_keys:
await ts.delete(key)
elapsed = time.perf_counter() - start_time
print(f"Dropped weights @ version {version}, took {elapsed:.2f} seconds")
async def main(cfg: DictConfig):
"""Main GRPO training loop with rollout and training processes."""
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
# ---- Global setups ---- #
provisioner = None
if cfg.get("provisioner", None) is not None:
provisioner = await init_provisioner(
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
)
else:
provisioner = await init_provisioner()
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)
# ---- Setup services ---- #
(
dataloader,
policy,
trainer,
replay_buffer,
compute_advantages,
ref_model,
reward_actor,
) = await asyncio.gather(
DatasetActor.options(**cfg.actors.dataset).as_actor(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
RLTrainer.options(**cfg.actors.trainer).as_actor(
**cfg.trainer, loss=simple_grpo_loss
),
ReplayBuffer.options(**cfg.actors.replay_buffer).as_actor(
**cfg.replay_buffer, collate=collate
),
ComputeAdvantages.options(**cfg.actors.compute_advantages).as_actor(),
ReferenceModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
),
)
# Set max_steps to the configured value, or -1 if not specified or Null
max_steps = cfg.trainer.training.steps or -1
print("All services initialized successfully!")
shutdown_event = asyncio.Event()
# Here we spawn a torchstore storage volume per trainer process.
# We initialize after service initialization because torchstore currently
# requires access to the underlying proc meshes in the local rank strategy.
# We should be able to hide this in the future.
# TODO: support multiple host meshes
trainer_num_procs = cfg.actors.trainer["procs"]
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
await ts.initialize(
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
strategy=ts.LocalRankStrategy(),
)
print("Torchstore successfully initialized with local rank strategy")
# ---- Core RL loops ---- #
async def continuous_rollouts():
rollout_count = 0
pad_id = await dataloader.pad_token.call_one()
while not shutdown_event.is_set():
t = Tracer("main_perf/continuous_rollouts")
t.start()
sample = await dataloader.sample.call_one()
if sample is None:
print("Dataloader is empty, exiting continuous rollout")
return
t.step("data_loading")
prompt, target = sample["request"], sample["target"]
responses = await policy.generate.route(prompt)
# TODO: this shall be part of the responses metadata instead of a separate call
version = await policy.get_version.route()
t.step("policy_generation")
assert (
len(responses) > 0
), "Sanity check: Responses should NEVER return empty"
assert (
version := responses[0].generator_version
) is not None, "Response must indicate a version"
group = Group.new_group(
group_id=rollout_count,
group_size=group_size,
request=prompt,
policy_version=version,
pad_id=pad_id,
request_len=max_req_tokens,
response_len=max_res_tokens,
target=target,
)
input_ids = torch.ones(
(group_size, max_req_tokens + max_res_tokens),
dtype=torch.long,
device="cuda",
)
# Populate episode info and calculate rewards
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
episode.request_tokens = response.prompt_ids
episode.response_tokens = response.token_ids
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.route(
prompt=prompt, response=response.text, target=target
)
t.step("reward_evaluation")
ref_logprobs = await ref_model.forward.route(
input_ids, max_req_tokens, return_logprobs=True
)
t.step("reference_model_calculate_logprobs")
for i, episode in enumerate(group.episodes):
episode.ref_logprobs = ref_logprobs[i]
del ref_logprobs, input_ids
t.step("compute_logprobs")
# Calculate advantages and add to replay buffer
advantages = await compute_advantages.compute.call_one(group)
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)
# Log metrics
rollout_count += 1
record_metric(
"main/continuous_rollouts/count_rollout_iterations", 1, Reduce.SUM
)
t.stop()
async def continuous_training():
training_step = 0
restart_tracer = True # Flag to control when to restart tracer
while max_steps == -1 or training_step < max_steps:
# Restart tracer when needed (initial start or after completing a training step)
# Otherwise, we cannot measure time waiting for buffer
if restart_tracer:
t = Tracer("main_perf/continuous_training")
t.start()
restart_tracer = False
batch = await replay_buffer.sample.call_one(
curr_policy_version=training_step
)
if batch is None:
await asyncio.sleep(0.1)
else:
t.step("waiting_for_buffer")
inputs, targets = batch
await trainer.train_step.call(inputs, targets)
training_step += 1
t.step("train_step")
await trainer.push_weights.call(training_step)
t.step("push_weights")
await policy.update_weights.fanout(training_step)
t.step("update_weights")
if training_step >= 2:
await drop_weights(training_step - 1)
t.step("drop_weights")
t.stop()
restart_tracer = True
# Flush metrics every training step to WandB
await mlogger.flush.call_one(training_step)
print(
f"Reached training limit ({max_steps} steps). Exiting continuous_training loop."
)
num_rollout_threads = cfg.get("rollout_threads", 1)
num_training_threads = cfg.get("training_threads", 1)
print(
f"Starting GRPO with {num_rollout_threads} rollout threads, {num_training_threads} training threads"
)
rollout_tasks = [
asyncio.create_task(continuous_rollouts()) for _ in range(num_rollout_threads)
]
training_task = asyncio.create_task(continuous_training())
try:
await training_task
except KeyboardInterrupt:
print("Training interrupted by user")
finally:
print("Shutting down...")
shutdown_event.set()
try:
# Give rollouts up to 5s to finish naturally
await asyncio.wait_for(
asyncio.gather(*rollout_tasks, return_exceptions=True),
timeout=5,
)
except asyncio.TimeoutError:
print("Timeout waiting for rollouts; forcing cancellation...")
for t in rollout_tasks:
t.cancel()
await asyncio.gather(*rollout_tasks, return_exceptions=True)
training_task.cancel()
# give mlogger time to shutdown backends, otherwise they can stay running.
# TODO (felipemello) find more elegant solution
await mlogger.shutdown.call_one()
await asyncio.sleep(2)
await asyncio.gather(
DatasetActor.shutdown(dataloader),
policy.shutdown(),
RLTrainer.shutdown(trainer),
ReplayBuffer.shutdown(replay_buffer),
ComputeAdvantages.shutdown(compute_advantages),
ref_model.shutdown(),
reward_actor.shutdown(),
)
# TODO - add a global shutdown that implicitly shuts down all services
# and remote allocations
await shutdown()
if __name__ == "__main__":
@parse
def _main(cfg):
asyncio.run(main(cfg))
_main() # @parse grabs the cfg from CLI