Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/prime_rl/trainer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,15 +190,22 @@ class ModelConfig(BaseConfig):
),
] = False

param_dtype: Annotated[
Literal["bfloat16", "float16", "float32"],
Field(
description="The dtype to use for the model parameters.",
),
] = "bfloat16"

optimization_dtype: Annotated[
Literal["bfloat16", "float32"],
Literal["bfloat16", "float16", "float32"],
Field(
description="The dtype to use for the model optimization.",
),
] = "float32"

reduce_dtype: Annotated[
Literal["bfloat16", "float32"],
Literal["bfloat16", "float16", "float32"],
Field(
description="The dtype to use for the model reduce.",
),
Expand Down
3 changes: 2 additions & 1 deletion src/prime_rl/trainer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@

DTYPE_MAP = {
"bfloat16": torch.bfloat16,
"float16": torch.float16,
"float32": torch.float32,
}

Expand Down Expand Up @@ -127,7 +128,7 @@ def setup_tokenizer(config: ModelConfig) -> PreTrainedTokenizer:


def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDims):
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=DTYPE_MAP[config.reduce_dtype])
mp_policy = MixedPrecisionPolicy(param_dtype=DTYPE_MAP[config.param_dtype], reduce_dtype=DTYPE_MAP[config.reduce_dtype])
# TODO: Support dp_replicate
if config.dp_replicate > 1:
hsdp_mesh = parallel_dims.world_mesh["dp_replicate", "dp_shard_cp"]
Expand Down
137 changes: 78 additions & 59 deletions src/prime_rl/trainer/rl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from prime_rl.utils.monitor import setup_monitor
from prime_rl.utils.pydantic_config import parse_argv
from prime_rl.utils.utils import clean_exit, to_col_format
from prime_rl.trainer.model import DTYPE_MAP


@clean_exit
Expand Down Expand Up @@ -79,6 +80,12 @@ def train(config: RLTrainerConfig):
model = setup_model(config.model, parallel_dims)
tokenizer = setup_tokenizer(config.model)

if config.model.param_dtype == "float16":
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
scaler = ShardedGradScaler(growth_interval=400)
else:
scaler = None

# Set up the optimizer
logger.info(f"Initializing optimizer ({config.optim})")
logger.info(f"Using `{config.loss.ratio_type}` importance ratio ({config.loss})")
Expand Down Expand Up @@ -192,69 +199,81 @@ def train(config: RLTrainerConfig):
# we only all reduce at the last grad acc step
model.set_requires_all_reduce(micro_step == len(micro_batches) - 1)

input_ids = micro_batch["input_ids"].to("cuda")
position_ids = micro_batch["position_ids"].to("cuda")
advantages = micro_batch["advantages"].to("cuda")
loss_mask = micro_batch["loss_mask"].to("cuda")
inference_logprobs = micro_batch["inference_logprobs"].to("cuda")
temperature = micro_batch["temperature"]

# Forward pass
with maybe_record_function("forward"):
logits = forward(model, input_ids, position_ids).float().contiguous()
shifted_logits = shift_logits(logits)
shifted_logits = shifted_logits / temperature
trainer_logprobs = selective_log_softmax(shifted_logits, input_ids)

# Compute loss
response_lengths = get_response_lengths(position_ids)
loss, loss_tensors = compute_loss(
trainer_logprobs=trainer_logprobs.squeeze().split(response_lengths),
inference_logprobs=inference_logprobs.squeeze().split(response_lengths),
advantages=advantages.squeeze().split(response_lengths),
loss_mask=loss_mask.squeeze().split(response_lengths),
loss_config=config.loss,
loss_scale=loss_scale,
)

# Compute entropy
entropy = compute_entropy(shifted_logits)

# Delete logits and shifted_logits before backward pass to avoid memory spike
del logits, shifted_logits

# Backward pass
with maybe_record_function("backward"):
loss.backward()

# Add relevant tensors to tensor dict for logging purposes
tensors["trainer_probs"].append(torch.exp(trainer_logprobs)[loss_mask].detach().to("cpu"))
tensors["inference_probs"].append(torch.exp(inference_logprobs)[loss_mask].detach().to("cpu"))
tensors["entropy"].append(entropy[loss_mask].detach().to("cpu"))
tensors["loss"].append(loss.detach().to("cpu").unsqueeze(0))

if is_tt_moe_model(model):
load_balance_stats = get_load_balance_stats(model)
for k, v in load_balance_stats.items():
if v is not None:
tensors[k].append(v)

# Add loss tensors to tensor dict for logging purposes
for key, loss_tensor in loss_tensors.items():
loss_tensor = loss_tensor.detach().to("cpu")
tensors[key].append(loss_tensor)

# Debug log with *local, micro step* stats
micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy'][-1].mean().item():.4f} | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}"
if "max_vio" in tensors:
micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}"
logger.debug(micro_step_message)

with torch.autocast(device_type="cuda", dtype=DTYPE_MAP[config.model.param_dtype]):

input_ids = micro_batch["input_ids"].to("cuda")
position_ids = micro_batch["position_ids"].to("cuda")
advantages = micro_batch["advantages"].to("cuda")
loss_mask = micro_batch["loss_mask"].to("cuda")
inference_logprobs = micro_batch["inference_logprobs"].to("cuda")
temperature = micro_batch["temperature"]

# Forward pass
with maybe_record_function("forward"):
logits = forward(model, input_ids, position_ids).float().contiguous()
shifted_logits = shift_logits(logits)
shifted_logits = shifted_logits / temperature
trainer_logprobs = selective_log_softmax(shifted_logits, input_ids)

# Compute loss
response_lengths = get_response_lengths(position_ids)
loss, loss_tensors = compute_loss(
trainer_logprobs=trainer_logprobs.squeeze().split(response_lengths),
inference_logprobs=inference_logprobs.squeeze().split(response_lengths),
advantages=advantages.squeeze().split(response_lengths),
loss_mask=loss_mask.squeeze().split(response_lengths),
loss_config=config.loss,
loss_scale=loss_scale,
)

# Compute entropy
entropy = compute_entropy(shifted_logits)

# Delete logits and shifted_logits before backward pass to avoid memory spike
del logits, shifted_logits

# Backward pass
with maybe_record_function("backward"):
if scaler is not None:
scaler.scale(loss).backward()
else:
loss.backward()

# Add relevant tensors to tensor dict for logging purposes
tensors["trainer_probs"].append(torch.exp(trainer_logprobs)[loss_mask].detach().to("cpu"))
tensors["inference_probs"].append(torch.exp(inference_logprobs)[loss_mask].detach().to("cpu"))
tensors["entropy"].append(entropy[loss_mask].detach().to("cpu"))
tensors["loss"].append(loss.detach().to("cpu").unsqueeze(0))

if is_tt_moe_model(model):
load_balance_stats = get_load_balance_stats(model)
for k, v in load_balance_stats.items():
if v is not None:
tensors[k].append(v)

# Add loss tensors to tensor dict for logging purposes
for key, loss_tensor in loss_tensors.items():
loss_tensor = loss_tensor.detach().to("cpu")
tensors[key].append(loss_tensor)

# Debug log with *local, micro step* stats
micro_step_message = f"Micro Step {micro_step}/{len(micro_batches)} | Loss: {tensors['loss'][-1].mean().item():.4f} | Entropy: {tensors['entropy'][-1].mean().item():.4f} | Mismatch KL: {tensors['mismatch_kl'][-1].mean().item():.4f}"
if "max_vio" in tensors:
micro_step_message += f" | Max Vio: {tensors['max_vio'][-1].mean().item():.4f}"
logger.debug(micro_step_message)


if scaler is not None:
scaler.unscale_(optimizer)
# Optionally, clip the gradients
grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=config.optim.max_norm).full_tensor()

# Update the model parameters
optimizer.step()
if scaler is not None:
scaler.step(optimizer)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The scaler need to be dp aware other we got weights mismatch between worker when the grad scaler skip a step

scaler.update()
else:
optimizer.step()
optimizer.zero_grad()

# Update learning rate scheduler
Expand Down