Skip to content

Commit c54bf4f

Browse files
committed
another fix
1 parent 1c026ce commit c54bf4f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

trl/trainer/grpo_trainer.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,13 +1607,14 @@ def _generate(self, prompts: list):
16071607
completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx]
16081608

16091609
# Decode post-tool completions
1610-
post_tool_completions = [parse_response(self.processing_class, ids) for ids in post_tool_ids]
1610+
post_tool_completions = [
1611+
parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids
1612+
]
16111613

16121614
# Add post-tool completions to the existing completions
16131615
for idx in range(len(idxs_with_tool)):
16141616
idx_with_tool = idxs_with_tool[idx]
1615-
# When the post-tool if completly truncated, content is empty.
1616-
if post_tool_completions[idx]["content"] or "tool_calls" in post_tool_completions[idx]:
1617+
if post_tool_completions[idx]: # {} if post-tool completions completely truncated
16171618
completions[idx_with_tool].append(post_tool_completions[idx])
16181619

16191620
# Check for further tool calls

0 commit comments

Comments
 (0)