Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1850,6 +1850,78 @@ def test_training_sequence_importance_sampling(self):
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

def test_training_dynamic_temperature(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
max_steps=10, # Very short training for testing
max_temp=2.0, # Dynamic temperature parameters
min_temp=0.1,
temp_warmup_steps=3,
temperature=1.0, # Fallback static temperature
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

# Test that dynamic temperature is enabled
self.assertTrue(trainer.use_dynamic_temp)
self.assertEqual(trainer.max_temp, 2.0)
self.assertEqual(trainer.min_temp, 0.1)
self.assertEqual(trainer.temp_warmup_steps, 3)

# Test temperature calculation at different steps
# Step 0: Should be close to 0 (warmup start)
temp_step_0 = trainer.get_temp(0)
self.assertAlmostEqual(temp_step_0, 2.0 * 1 / 3, places=4) # max_temp * (0+1) / temp_warmup_steps

# Warmup end: Should be max_temp
temp_warmup_end = trainer.get_temp(3)
self.assertAlmostEqual(temp_warmup_end, 2.0, places=4)

# Max steps: Should be min_temp
temp_max_steps = trainer.get_temp(10)
self.assertAlmostEqual(temp_max_steps, 0.1, places=4)

# Beyond max steps: Should remain min_temp
temp_beyond = trainer.get_temp(15)
self.assertAlmostEqual(temp_beyond, 0.1, places=4)

# Test that static temperature fallback works
static_args = GRPOConfig(
output_dir=tmp_dir,
temperature=1.5,
per_device_train_batch_size=3,
num_generations=3,
max_completion_length=32,
max_steps=10,
report_to="none",
)
static_trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=static_args,
train_dataset=dataset,
)

# Should use static temperature
self.assertFalse(static_trainer.use_dynamic_temp)
self.assertEqual(static_trainer.get_temp(0), 1.5)
self.assertEqual(static_trainer.get_temp(100), 1.5)

# Run a few training steps to ensure no errors
trainer.train()


if __name__ == "__main__":
unittest.main()
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,18 @@ class GRPOConfig(TrainingArguments):
default=1.0,
metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
)
max_temp: Optional[float] = field(
default=None,
metadata={"help": "Maximum temperature for dynamic temperature scheduling. If None, uses static temperature."},
)
min_temp: Optional[float] = field(
default=None,
metadata={"help": "Minimum temperature for dynamic temperature scheduling. If None, uses static temperature."},
)
temp_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "Number of warmup steps for temperature scheduling. Temperature linearly increases from 0 to max_temp."},
)
top_p: float = field(
default=1.0,
metadata={
Expand Down
55 changes: 49 additions & 6 deletions trl/trainer/grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import copy
import inspect
import math
import os
import re
import textwrap
Expand Down Expand Up @@ -658,6 +659,15 @@ def __init__(
self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
self.num_generations = args.num_generations # = G in the GRPO paper
self.temperature = args.temperature
# Dynamic temperature scheduling
self.max_temp = args.max_temp
self.min_temp = args.min_temp
self.temp_warmup_steps = args.temp_warmup_steps
self.use_dynamic_temp = (
self.max_temp is not None and
self.min_temp is not None and
self.temp_warmup_steps is not None
)
self.top_p = args.top_p
self.top_k = args.top_k
self.min_p = args.min_p
Expand Down Expand Up @@ -908,6 +918,33 @@ def __init__(
reward_func, evaluation_mode=True, device_placement=True
)

def get_temp(self, it):
"""
Compute dynamic temperature using cosine decay schedule.

Args:
it: current training step

Returns:
temperature value for current step
"""
if not self.use_dynamic_temp:
return self.temperature

max_steps = self.args.max_steps

# 1) linear warmup for temp_warmup_steps
if it < self.temp_warmup_steps:
return self.max_temp * (it + 1) / self.temp_warmup_steps
# 2) if it > max_steps, return min temperature
if it > max_steps:
return self.min_temp
# 3) in between, use cosine decay down to min temperature
decay_ratio = (it - self.temp_warmup_steps) / (max_steps - self.temp_warmup_steps)
assert 0 <= decay_ratio <= 1
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff starts at 1 and goes to 0
return self.min_temp + coeff * (self.max_temp - self.min_temp)

def _set_signature_columns_if_needed(self):
# If `self.args.remove_unused_columns` is True, non-signature columns are removed.
# By default, this method sets `self._signature_columns` to the model's expected inputs.
Expand Down Expand Up @@ -1117,8 +1154,8 @@ def _get_per_token_logps_and_entropies(
logits = logits[:, -logits_to_keep:, :] # (B, logits_to_keep, H)
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature

current_temp = self.get_temp(self.state.global_step)
logits = logits / current_temp
completion_ids = input_ids_batch[:, -logits_to_keep:]
logps = selective_log_softmax(logits, completion_ids) # compute logprobs
all_logps.append(logps)
Expand Down Expand Up @@ -1444,7 +1481,7 @@ def _generate_and_score_completions(
images=ordered_set_of_images,
n=self.num_generations,
repetition_penalty=self.repetition_penalty,
temperature=self.temperature,
temperature=self.get_temp(self.state.global_step),
top_p=self.top_p,
top_k=-1 if self.top_k is None else self.top_k,
min_p=0.0 if self.min_p is None else self.min_p,
Expand Down Expand Up @@ -1473,7 +1510,7 @@ def _generate_and_score_completions(
generation_kwargs = {
"n": 1, # vLLM on each GPU generates only 1 in colocate mode
"repetition_penalty": self.repetition_penalty,
"temperature": self.temperature,
"temperature": self.get_temp(self.state.global_step),
"top_p": self.top_p,
"top_k": -1 if self.top_k is None else self.top_k,
"min_p": 0.0 if self.min_p is None else self.min_p,
Expand Down Expand Up @@ -1553,8 +1590,11 @@ def _generate_and_score_completions(
elif self.args.fp16:
unwrapped_model.to(torch.float16)
with torch.inference_mode():
# Update generation config with dynamic temperature
dynamic_generation_config = copy.deepcopy(self.generation_config)
dynamic_generation_config.temperature = self.get_temp(self.state.global_step)
all_outputs = unwrapped_model.generate_batch(
paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False
paged_prompt_inputs.input_ids, generation_config=dynamic_generation_config, progress_bar=False
)
completion_ids = [output.generated_tokens for output in all_outputs.values()]
completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
Expand All @@ -1575,8 +1615,11 @@ def _generate_and_score_completions(
FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(),
):
prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask
# Update generation config with dynamic temperature
dynamic_generation_config = copy.deepcopy(self.generation_config)
dynamic_generation_config.temperature = self.get_temp(self.state.global_step)
prompt_completion_ids = unwrapped_model.generate(
**prompt_inputs, generation_config=self.generation_config, disable_compile=True
**prompt_inputs, generation_config=dynamic_generation_config, disable_compile=True
)
# Compute prompt length and extract completion ids
prompt_length = prompt_ids.size(1)
Expand Down
Loading