-
Notifications
You must be signed in to change notification settings - Fork 2.3k
🕵️♂️ Agent training #4300
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
🕵️♂️ Agent training #4300
Conversation
…_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw`
|
For information:
from transformers import AutoTokenizer
from trl.chat_template_utils import is_chat_template_prefix_preserving
tokenizers = [
"trl-internal-testing/tiny-CohereForCausalLM",
"trl-internal-testing/tiny-DbrxForCausalLM",
"trl-internal-testing/tiny-DeepseekV3ForCausalLM",
"trl-internal-testing/tiny-DeepseekV3ForCausalLM-0528",
"trl-internal-testing/tiny-FalconMambaForCausalLM",
"trl-internal-testing/tiny-Gemma2ForCausalLM",
"trl-internal-testing/tiny-GemmaForCausalLM",
"trl-internal-testing/tiny-GptOssForCausalLM",
"trl-internal-testing/tiny-LlamaForCausalLM-3.1",
"trl-internal-testing/tiny-LlamaForCausalLM-3.2",
"trl-internal-testing/tiny-LlamaForCausalLM-3",
"trl-internal-testing/tiny-MistralForCausalLM-0.1",
"trl-internal-testing/tiny-MistralForCausalLM-0.2",
"trl-internal-testing/tiny-Phi3ForCausalLM",
"trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"trl-internal-testing/tiny-Qwen3ForCausalLM",
]
print(f"| Tokenizer | Tool Calls | Prefix-Preserving |")
print("| --- | --- | --- |")
for tokenizer_name in tokenizers:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
tool_support = "✅" if "tool_call" in tokenizer.chat_template else "❌"
prefix_preserving = "✅" if is_chat_template_prefix_preserving(tokenizer) else "❌"
print(f"| {tokenizer_name} | {tool_support} | {prefix_preserving} |") |
trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py
Outdated
Show resolved
Hide resolved
trl/trainer/grpo_trainer.py
Outdated
|
|
||
| # Extract tool calls from the completions | ||
| if self.tools: | ||
| tool_calls = [completion[0].get("tool_calls") for completion in completions] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it is not too complicated then L1510-1625 should go in their own method called _run_tool_calls or similar
This block can then be:
if self.tools:
tool_mask, completions, ... = self._run_tool_calls(...)
else:
tool_mask = None
It would make it simpler to parse the codebase when you are not using tools. It may make the tool call block more testable? There is a lot of stuff going on here, overall I understand the logic but there may be edge cases that slip though. Is there are way we can break up while loop into smaller methods and unit test them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Ed that it would be great if we can refactor this somewhat to isolate the tool-calling logic and enable it to be tested independently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be made independent of GRPO trainer as other online trainers could then potentially benifit from this helper?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored to isolate the _tool_call_loop. 94c2ff2. Now we have, as @edbeeching suggested:
if self.tools:
tool_mask, ... = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs)
else:
tool_mask = Nonemuch easier to read indeed.
It may make the tool call block more testable? There is a lot of stuff going on here
enable it to be tested independently
I get the motivation, and I agree that "there is a lot of stuff going on here." My take is that it would be better to split this into smaller helper functions or utilities. That makes the code easier to follow, easier to test, and more flexible overall. I'll try to do it.
However, I usually avoid testing private methods or internal logic directly, because those tests can break easily for the wrong reasons: any small change inside can force a test rewrite. I prefer testing through the public API, even if it means patching a bit to make sure the right part runs. Like here
trl/tests/test_grpo_trainer.py
Lines 1721 to 1797 in 94c2ff2
| def test_training_with_tools(self): | |
| # In this test, we define a simple tool that multiplies two integers. Regardless of the input prompt, | |
| # the model will generate 3 completions, 2 of which will be valid tool calls. Among the 2 tool calls, one will | |
| # succeed and the other will fail (because of a wrong argument name). | |
| def multiply(a: int, b: int) -> int: | |
| """ | |
| Multiplies two integers. | |
| Args: | |
| a: The first integer. | |
| b: The second integer. | |
| Returns: | |
| The product of the two integers. | |
| """ | |
| return a * b | |
| dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") | |
| training_args = GRPOConfig( | |
| output_dir=self.tmp_dir, | |
| learning_rate=0.1, | |
| per_device_train_batch_size=3, | |
| num_generations=3, | |
| max_completion_length=128, | |
| report_to="none", | |
| ) | |
| trainer = GRPOTrainer( | |
| model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", | |
| reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", | |
| args=training_args, | |
| train_dataset=dataset, | |
| tools=[multiply], | |
| ) | |
| previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} | |
| def fake_generate(input_ids, **kwargs): | |
| if input_ids.shape[0] == 3: # first call | |
| # fmt: off | |
| completion_ids = torch.tensor( | |
| [ | |
| # '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n</tool_call><|im_end|>' | |
| [151657, 198, 4913, 606, 788, 330, 64648, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], | |
| # '<tool_call>\n{"name": "multiply", "arguments": {"a": 3, "c": 4}}\n</tool_call><|im_end|>' | |
| [151657, 198, 4913, 606, 788, 330, 64648, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], | |
| # "I don't know any tool<|im_end|>" | |
| [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], | |
| ], | |
| device=input_ids.device, | |
| ) | |
| # fmt: on | |
| else: # second call will only have two inputs in the batch, because two examples have a tool call. | |
| completion_ids = torch.tensor( | |
| [ | |
| # 'Done!<|im_end|>' | |
| [17453, 0, 151645], | |
| # 'Done!<|im_end|>' | |
| [17453, 0, 151645], | |
| ], | |
| device=input_ids.device, | |
| ) | |
| return torch.cat([input_ids, completion_ids], dim=-1) | |
| with patch.object(trainer.model, "generate", side_effect=fake_generate): | |
| trainer.train() | |
| assert trainer.state.log_history[-1]["train_loss"] is not None | |
| assert trainer.state.log_history[-1]["tools/call_frequency"] is not None | |
| assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) | |
| assert trainer.state.log_history[-1]["tools/failure_frequency"] is not None | |
| assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2) | |
| # Check that the params have changed | |
| for n, param in previous_trainable_params.items(): | |
| new_param = trainer.model.get_parameter(n) | |
| assert not torch.equal(param, new_param), f"Parameter {n} has not changed." |
It is more difficult than it seems to find the right balance. 😅 I'd be curious to know what @albertvillanova thinks about it (when he is back).
can this be made independent of GRPO trainer as other online trainers could then potentially benefit from this helper?
It’s quite complex because a lot of the internal logic of this new method depends on other methods and attributes (self.chat_template_kwargs, self._tool_dict, self.processing_class, self._generate_single_turn, ...). In this case, I would lean more towards copying. From experience, even though this increases the amount of duplicated code in the codebase, it doesn’t really make maintenance harder, in fact, it often makes maintenance much easier. For example, RLOO and GRPO have a lot of duplicated code, but they are very easy to maintain together with very little shared abstraction.
| agg_prompt_lengths = self.accelerator.gather(prompt_lengths) | ||
| agg_completion_lengths = self.accelerator.gather(completion_lengths) | ||
| total_prompt_tokens = agg_prompt_lengths.sum() | ||
| total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
agg_completion_lengths.sum() # = num_items_in_batch
Does / should this include the tool call tokens?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does include the tool call tokens, but shouldn't. I'll need to fix it. Excellent catch, thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this implementation. In general it is great, the only issue I have is that there is a lot of new code in L1510-1625 of the grpo_trainer.py file. It would be great to extract this into at least one method, or several methods for some of the blocks of code:
For example make a _run_tool_calls method which contains:
...
while idxs_with_tool:
...
# Call the tools, and build the new prompt for generation
... = call_the_tools(..)
# Tokenize and filter samples whose length exceeds max allowed length.
... = tokenize_and_filter(...)
# Generate new completions after tool execution
prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn(prompt_completion_tools)
# Sanity check: from experience
... etc
Then add tests for those methods, there seem to be a lot of edge cases you have already covered, capturing those in tests would be good and will make this more robust.
The level of granularity you want with these methods is up to you of course. But at least one method for _run_tool_calls(...) would make the _generate method easier to parse, particularly for new users who are trying to underand the codebase and are not training llms to use tools (yet!).
lewtun
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Epic tour de force @qgallouedec 🔥 !
Overall LGTM with some questions around the logic for truncating long multi-step rollouts in terms of max_model_length instead of max_completions_length
| # Ensure distributed rendezvous variables are set without colliding across concurrent runs | ||
| ensure_master_addr_port() | ||
|
|
||
| if self.max_prompt_length is not None and self.max_completion_length is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we now use the model's default max_model_len?
If so, would it make sense to expose max_model_len as an arg in the config to accommodate cases where the sequences one is training on are much smaller than the default value? This allows one to get better throughput / lower memory usage.
| # Tokenize and filter samples whose length exceeds max allowed length. This is important, because if vLLM | ||
| # is called with an input longer than its max model length, it will error out. | ||
| pct_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] | ||
| overlong = [len(pct) - len(p) >= self.max_completion_length for p, pct in zip(p_ids, pct_ids, strict=True)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this refer instead to max_model_len because technically one could have a small max_completion_length but a long context that is built up over many steps of turn calling.
| for idx in range(len(idxs_with_tool)): | ||
| idx_with_tool = idxs_with_tool[idx] | ||
| pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool | ||
| assert prompt_ids[idx_with_tool] == pct[: len(prompt_ids[idx_with_tool])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we raise a proper ValueError here with an informative error message for the user?
| pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool | ||
| assert prompt_ids[idx_with_tool] == pct[: len(prompt_ids[idx_with_tool])] | ||
|
|
||
| # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment here, that I don't think the issue is about exceeding the max_completion_length, rather that we should left-truncate to ensure we have enough tokens to avoid exceeding the model's context length (set via max_model_len)
trl/trainer/grpo_trainer.py
Outdated
|
|
||
| # Extract tool calls from the completions | ||
| if self.tools: | ||
| tool_calls = [completion[0].get("tool_calls") for completion in completions] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with Ed that it would be great if we can refactor this somewhat to isolate the tool-calling logic and enable it to be tested independently
| return tokenizer | ||
| raise ValueError( | ||
| "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " | ||
| "tokenizer or processor." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To help users, shall we add a pointer to the docs?
| "tokenizer or processor." | |
| "tokenizer or processor. See the Transformers [docs](https://huggingface.co/docs/transformers/main/en/chat_response_parsing#response-parsing) for more details on response parsing." |
|
Added agent training example script (trackio)! |
… calculation to exclude tool tokens
|
@codex review |
1 similar comment
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a minor change to fix #4609.
After the merge of huggingface/transformers#40936, the attribute does not necessarily exist. Before it was None.
Feel free to ignore it if there is a better solution!
| # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. | ||
| # While waiting for broader adoption, we provide this utility function to manually set the response schema for | ||
| # known chat templates. | ||
| if tools and not processing_class.response_schema: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To fix #4609:
| if tools and not processing_class.response_schema: | |
| if tools and not getattr(processing_class, "response_schema", None): |
| guess = completion[-1]["content"].strip() | ||
| guess = completion[-1]["content"].strip().lower() | ||
| guess_clean = guess.replace("*", "").replace("`", "").strip() | ||
| reward = 0.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could L75-82 be simplified to:
if guess_clean == ans.lower():
reward = 0.5
else:
reward = -0.2
examples/scripts/grpo_agent.py
Outdated
| if "error" in turn["content"].lower(): | ||
| reward -= 0.3 # penalize errors | ||
|
|
||
| if tool_called and tool_response_ok: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
107-112 would be easier to parse for a reader like this:
if tool_called:
if tool_response_ok:
reward += 0.25
else:
reward -= 0.2
else:
reward -= 0.3

What does this PR do?
This PR implements tool calling for GRPO. The API is as follows:
This PR contains a few important changes:
🚨 Removal of
max_prompt_lengthThis PR contains a breaking change:
max_prompt_lengthhas been removed from GRPO.Here are the reasons: (tldr: because it’s extremely hard to implement reliably with multi-turn tool calling, likely harmful to training anyway, likely not used in practice, and dropping it simplifies the API while keeping it consistent across LLMs and VLMs.)
Supporting
max_prompt_lengthwith tool calling is extremely complex.For single-turn generation it works fine, but multi-turn generation introduces a major challenge: the prompt grows after every step. Since the model is called repeatedly with an increasingly long prompt, we would need to recalculate the allowed prompt length dynamically based on how many tokens have already been generated. Implementing this reliably is tricky and adds significant complexity.
Truncating prompts is likely worse than dropping samples altogether.
Although I’m not aware of formal studies, intuition suggests that truncation can remove information necessary to solve the task. Training on such incomplete examples can lead to strong biases, whereas simply skipping overly long samples avoids this risk.
It simplifies the API and removes confusing edge cases.
Previously, when training VLMs, we had to tell users to disable prompt truncation entirely because Transformers does not support truncating multimodal prompts. This led to inconsistent, non-user-friendly recommendations. Removing
max_prompt_lengthallows us to provide one clean, unified API that works for all model types.It very likely not a widely used feature anyway
Online decoding
Before calling the reward function, we need to decode the completion. Previously, this was done here:
trl/trl/trainer/grpo_trainer.py
Lines 1605 to 1617 in 1a9ff52
The issue is that, while this works for single-turn outputs, it does not allow reliable parsing of multi-turn text. See this internal discussion. The workaround is to parse after each turn, which requires moving the decoding logic inside the generation loop (in
_generate):trl/trl/trainer/grpo_trainer.py
Lines 1483 to 1495 in c54bf4f
trl/trl/trainer/grpo_trainer.py
Line 1543 in c54bf4f
trl/trl/trainer/grpo_trainer.py
Lines 1614 to 1618 in c54bf4f
The method then returns the list of messages:
trl/trl/trainer/grpo_trainer.py
Line 1669 in c54bf4f
Note that this change removes support for the "bootstrap" feature. I haven’t had time to investigate adding support for it.
Tool mask
We don't want the loss to be computed on the tokens corresponding to the tool result. Consequently,
_generatebuilds and return atool_masktrl/trl/trainer/grpo_trainer.py
Line 1668 in c54bf4f
which is then used to mask these tokens in the loss computation.
trl/trl/trainer/grpo_trainer.py
Line 2100 in c54bf4f
Schema and fixed chat template
Chat template
To make this feature work, we need the chat template to be prefix-preserving. Ie:
trl/trl/chat_template_utils.py
Lines 195 to 212 in 9f0aa3d
The issue is that some widely used tokenizers, such as GPT-OSS and Qwen3, are not prefix-preserving due to the way they handle think tokens. To address this, I suggest using a slightly modified version of the template that ensures it is prefix-preserving. Additionally, as @lewtun pointed out, it’s not even clear whether these templates might make the inference OOD
Response schema
To parse tool calls from the model’s response, we rely on
tokenizer.parse_response, introduced in huggingface/transformers#40894. This requires the tokenizer to have aresponse_schema(integrated in a similar way as chat templates). However, very few (no?) model repositories currently include such a schema.To enable this feature despite the lack of adoption, I propose adding a mapping for some popular chat templates to their response schemas (currently only Qwen3).
trl/trl/chat_template_utils.py
Lines 172 to 174 in fbb625f
Ideally, once adoption increases and model repos start including proper response schemas, we can remove this custom mapping entirely.
A fair amount of complexity in the generation
This PR adds 60+ lines of intricate code with many special cases in the generation method. While it’s admittedly hard to follow, after a lot of iteration this is likely the simplest reliable way to implement the feature. Normally, I would be very reluctant to introduce this level of complexity, but given the impact of this feature, I believe it’s truly worth it.
trl/trl/trainer/grpo_trainer.py
Lines 1509 to 1623 in c54bf4f
Next steps