From 552e899015c18c1a10a3b2ffe80eaa964d44afbb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 20:57:51 +0000 Subject: [PATCH 001/153] Refactor image handling: replace `image_split_sizes` with `image_grid_thw` in GRPO and RLOO trainers; update `split_pixel_values_by_grid` to use `image_grid_thw` --- tests/test_utils.py | 6 +++--- trl/trainer/grpo_trainer.py | 20 ++------------------ trl/trainer/rloo_trainer.py | 20 ++------------------ trl/trainer/utils.py | 4 ++-- 4 files changed, 9 insertions(+), 41 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 60730d685d0..f036a897e1b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -873,7 +873,7 @@ def test_with_scalar(self): class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { - "image_split_sizes": [4, 4], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -884,7 +884,7 @@ def test_split_correctly_0(self): def test_split_correctly_1(self): batch = { - "image_split_sizes": [4, 8], + "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -900,7 +900,7 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { - "image_split_sizes": torch.tensor([2, 2]), # Total = 4 + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0c2ad9a3121..9f618eefe8f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -791,7 +791,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -807,7 +806,8 @@ def _get_per_token_logps_and_entropies( if image_grid_thw is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] if pixel_values is not None: - if image_split_sizes is not None: + if image_grid_thw is not None: + image_split_sizes = image_grid_thw.prod(dim=1).tolist() start_pixel_idx = sum(image_split_sizes[:start]) end_pixel_idx = sum(image_split_sizes[: start + batch_size]) model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] @@ -1078,7 +1078,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1086,11 +1085,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1104,10 +1098,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1392,7 +1382,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -1417,7 +1406,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1431,7 +1419,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1580,8 +1567,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1656,7 +1641,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) if self.top_entropy_quantile < 1.0: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 86aeb8910a0..b70c3b4db4f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -777,7 +777,6 @@ def _get_per_token_logps_and_entropies( image_grid_thw=None, pixel_attention_mask=None, image_sizes=None, - image_split_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: """Compute log-probs and (optionally) entropies for each token.""" batch_size = batch_size or input_ids.size(0) # Chunk inputs into smaller batches to reduce memory peak @@ -793,7 +792,8 @@ def _get_per_token_logps_and_entropies( if image_grid_thw is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] if pixel_values is not None: - if image_split_sizes is not None: + if image_grid_thw is not None: + image_split_sizes = image_grid_thw.prod(dim=1).tolist() start_pixel_idx = sum(image_split_sizes[:start]) end_pixel_idx = sum(image_split_sizes[: start + batch_size]) model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] @@ -1064,7 +1064,6 @@ def _generate_and_score_completions( # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -1072,11 +1071,6 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -1090,10 +1084,6 @@ def _generate_and_score_completions( prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -1346,7 +1336,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1363,7 +1352,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1377,7 +1365,6 @@ def _generate_and_score_completions( image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -1498,8 +1485,6 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output @profiling_decorator @@ -1527,7 +1512,6 @@ def _compute_loss(self, model, inputs): image_grid_thw=inputs.get("image_grid_thw"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), - image_split_sizes=inputs.get("image_split_sizes"), ) logps = (per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 16ce8321612..37612a423bd 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1783,10 +1783,10 @@ def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Unio Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]`, while keeping other entries unchanged. """ - if "image_split_sizes" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch: return batch - lengths = batch["image_split_sizes"] # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0): From 449ef079191ed50fae281c07b9d6775efcb345f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 21:05:47 +0000 Subject: [PATCH 002/153] simpler --- trl/trainer/grpo_trainer.py | 15 ++++++--------- trl/trainer/rloo_trainer.py | 15 ++++++--------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9f618eefe8f..bb902445d09 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -803,16 +803,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_grid_thw is not None: - image_split_sizes = image_grid_thw.prod(dim=1).tolist() - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index b70c3b4db4f..56ffbfe7fea 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -789,16 +789,13 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None: + if image_grid_thw is not None and pixel_values is not None: model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - if pixel_values is not None: - if image_grid_thw is not None: - image_split_sizes = image_grid_thw.prod(dim=1).tolist() - start_pixel_idx = sum(image_split_sizes[:start]) - end_pixel_idx = sum(image_split_sizes[: start + batch_size]) - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] - else: - model_inputs["pixel_values"] = pixel_values[start : start + batch_size] + start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() + end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + elif pixel_values is not None: + model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: model_inputs["pixel_attention_mask"] = pixel_attention_mask[start : start + batch_size] if image_sizes is not None: From c8933aa856b2b71d10470456356c48bae4aefa17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 21:10:06 +0000 Subject: [PATCH 003/153] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index b067e7410c7..f2a675fab16 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -93,7 +93,6 @@ def _generate_and_score_completions(self, inputs): # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} has_images = "image" in inputs[0] - image_split_sizes = None if has_images: images = [example.get("image") for example in inputs] kwargs = {"images": [[img] for img in images]} @@ -101,11 +100,6 @@ def _generate_and_score_completions(self, inputs): if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) - if hasattr(self.processing_class, "_get_num_multimodal_tokens"): - image_sizes = [(image.height, image.width) for image in images] - multimodal_extra_data = self.processing_class._get_num_multimodal_tokens(image_sizes) - image_split_sizes = multimodal_extra_data.num_image_patches - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] prompt_inputs = self.processing_class( @@ -116,13 +110,9 @@ def _generate_and_score_completions(self, inputs): add_special_tokens=False, **kwargs, ) - prompt_inputs = super(_GRPOTrainer, self)._prepare_inputs(prompt_inputs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - if "image_grid_thw" in prompt_inputs and image_split_sizes is None: - # Fallback for VLMs that require image_grid_thw but don't provide _get_num_multimodal_tokens - image_split_sizes = prompt_inputs["image_grid_thw"].prod(dim=1).tolist() - if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, @@ -407,7 +397,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: old_per_token_logps = None @@ -432,7 +421,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -446,7 +434,6 @@ def _generate_and_score_completions(self, inputs): image_grid_thw=prompt_inputs.get("image_grid_thw"), pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), - image_split_sizes=image_split_sizes, ) else: ref_per_token_logps = None @@ -652,6 +639,4 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] - if image_split_sizes is not None: - output["image_split_sizes"] = image_split_sizes return output From 229c5549291b65c59537717893b7b09ad1cec0e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 22:45:57 +0000 Subject: [PATCH 004/153] multi-image grpo --- tests/test_utils.py | 27 +++++++++++++++++ trl/trainer/grpo_trainer.py | 60 +++++++++++++++++++------------------ trl/trainer/utils.py | 16 ++++++---- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f036a897e1b..0fc16682336 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -874,6 +874,7 @@ class SplitPixelValuesByGridTester(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), + "num_images": [1, 1], "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) @@ -881,10 +882,15 @@ def test_split_correctly_0(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))) def test_split_correctly_1(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 4]]), + "num_images": [1, 1], "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) @@ -892,6 +898,10 @@ def test_split_correctly_1(self): self.assertEqual(len(result["pixel_values"]), 2) self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} @@ -901,11 +911,28 @@ def test_missing_keys(self): def test_mismatched_length(self): batch = { "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } with self.assertRaises(ValueError): split_pixel_values_by_grid(batch) + def test_multi_images(self): + batch = { + "image_grid_thw": torch.tensor([[1, 1, 2], [1, 2, 2], [1, 2, 1]]), # Total = 8 + "num_images": [1, 2], + "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] + } + result = split_pixel_values_by_grid(batch) + self.assertIsInstance(result["pixel_values"], list) + self.assertEqual(len(result["pixel_values"]), 2) + self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])) + self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])) + self.assertIsInstance(result["image_grid_thw"], list) + self.assertEqual(len(result["image_grid_thw"]), 2) + self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))) + self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) + class TruncateWithProtectedTokensTester(TrlTestCase): def test_basic_example(self): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bb902445d09..f98d895fb18 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -464,7 +464,7 @@ def __init__( self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -609,7 +609,7 @@ def _set_signature_columns_if_needed(self): # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -804,9 +804,9 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) + start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() + end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] @@ -1070,14 +1070,19 @@ def _generate_and_score_completions( # VLM chat template. original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} + if images is not None: + kwargs = {"images": images} for prompt in prompts: if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=1) @@ -1152,7 +1157,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1161,7 +1166,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1226,7 +1231,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1234,15 +1239,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1507,8 +1510,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -1564,6 +1567,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1790,14 +1795,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for img in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append(wandb.Image(img)) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 37612a423bd..7cd16472c16 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -19,6 +19,7 @@ from collections.abc import Sequence, Sized from dataclasses import dataclass, field from importlib.metadata import version +from itertools import accumulate from typing import Any, Literal, Optional, Union import numpy as np @@ -1780,20 +1781,23 @@ def identity(x): def split_pixel_values_by_grid(batch: dict[str, torch.Tensor]) -> dict[str, Union[torch.Tensor, list[torch.Tensor]]]: """ - Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in - `batch["image_grid_thw"]`, while keeping other entries unchanged. + Splits `batch["pixel_values"]` into a list of tensors based on the product of each row in `batch["image_grid_thw"]` + and batch["num_images"] while keeping other entries unchanged. """ - if "image_grid_thw" not in batch or "pixel_values" not in batch: + if "image_grid_thw" not in batch or "pixel_values" not in batch or "num_images" not in batch: return batch - lengths = batch["image_grid_thw"].prod(-1).tolist() # [batch_size] + lengths = batch["image_grid_thw"].prod(-1).tolist() # [num_images] pixel_values = batch["pixel_values"] # [total, feature_dim] if sum(lengths) != pixel_values.size(0): raise ValueError(f"Mismatch: sum(lengths) = {sum(lengths)} != pixel_values.size(0) = {pixel_values.size(0)}") - split_values = list(torch.split(batch["pixel_values"], lengths, dim=0)) - return {**batch, "pixel_values": split_values} + boundaries = [0, *accumulate(batch["num_images"])] # [3, 4, 5] -> [0, 3, 7, 12] + sections = [sum(lengths[boundaries[i] : boundaries[i + 1]]) for i in range(len(batch["num_images"]))] + split_values = list(torch.split(batch["pixel_values"], sections, dim=0)) + image_grid_thw = list(torch.split(batch["image_grid_thw"], batch["num_images"], dim=0)) + return {**batch, "pixel_values": split_values, "image_grid_thw": image_grid_thw} def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch.Tensor]]]) -> dict[str, torch.Tensor]: From 3ca6ad50036aba363f8a87e7c227efceab7b4496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 19 Sep 2025 23:31:06 +0000 Subject: [PATCH 005/153] log with wandb --- trl/trainer/grpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f98d895fb18..eb9e9bfcd9f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1083,9 +1083,9 @@ def _generate_and_score_completions( kwargs = {} if images is not None: kwargs = {"images": images} - for prompt in prompts: + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1797,9 +1797,9 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non if self._logs["images"]: table["images"] = [] - for img in self._logs["images"]: + for image_list in self._logs["images"]: # Convert images to wandb Image objects for proper visualization - table["images"].append(wandb.Image(img)) + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: From dcf4b92da0085d2d94f3a8ca5bb9e1b18e20b86a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:18:18 +0000 Subject: [PATCH 006/153] no vlm reward models --- tests/test_grpo_trainer.py | 92 ++++++++++++++++++++++++++++++++++--- trl/trainer/grpo_trainer.py | 16 ++++--- 2 files changed, 94 insertions(+), 14 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 4e4321febcd..ced4de9d73a 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1258,6 +1258,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1269,7 +1273,7 @@ def test_training_vlm(self, model_id): ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1301,6 +1305,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1312,7 +1320,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1342,6 +1350,10 @@ def test_training_vlm_peft(self): base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1352,7 +1364,7 @@ def test_training_vlm_peft(self): ) trainer = GRPOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1376,6 +1388,10 @@ def test_training_vlm_peft(self): def test_training_vlm_and_importance_sampling(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1387,7 +1403,7 @@ def test_training_vlm_and_importance_sampling(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1413,6 +1429,10 @@ def test_training_vlm_and_importance_sampling(self): def test_training_vlm_and_liger(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1425,7 +1445,7 @@ def test_training_vlm_and_liger(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1451,6 +1471,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1462,7 +1486,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = GRPOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1495,6 +1519,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = GRPOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1508,7 +1536,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = GRPOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1519,7 +1584,20 @@ def test_training_vlm_and_vllm(self, model_id) -> None: self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + # Check that the params have changed + # Because of the way the tiny models are initialized, the gradient does not flow properly through the + # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. + params_to_skip = ( + # "model.vision_tower.", + # "model.multi_modal_projector.", + # "model.vision_model.", + # "model.connector.modality_projection.", + # "model.visual.", + # "model.image_newline", + ) for n, param in previous_trainable_params.items(): + if n.startswith(params_to_skip): + continue new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index eb9e9bfcd9f..59135c232e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -1020,6 +1019,14 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): + # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed + for prompt in prompts: + for turn in prompt: + if isinstance(turn["content"], list): + logger.warning_once("Visual reward models aren't supported yet; dropping image.") + turn["content"] = " ".join( + e["text"] for e in turn["content"] if e["type"] == "text" + ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: @@ -1065,11 +1072,6 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) - if "images" in inputs[0]: images = [example.get("images") for example in inputs] elif "image" in inputs[0]: @@ -1436,7 +1438,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) From 30ad7ca371286e2d55998d285ae66a3f123eee83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:37:54 +0000 Subject: [PATCH 007/153] rloo --- trl/trainer/rloo_trainer.py | 78 +++++++++++++++++++------------------ 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 56ffbfe7fea..48801496b01 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import inspect import os import re @@ -536,7 +535,7 @@ def decode(example, tokenizer): self.num_completions_to_print = args.num_completions_to_print # Keep logs sized to the generation batch to record only outputs from the latest model update. self._logs = { - "image": deque(maxlen=args.generation_batch_size), + "images": deque(maxlen=args.generation_batch_size), "prompt": deque(maxlen=args.generation_batch_size), "completion": deque(maxlen=args.generation_batch_size), "rewards": defaultdict(lambda: deque(maxlen=args.generation_batch_size)), @@ -678,7 +677,7 @@ def _set_signature_columns_if_needed(self): # In RLOOTrainer, we preprocess data, so using the model's signature columns doesn't work. # Instead, we set them to the columns expected by the `training_step` method, hence the override. if self._signature_columns is None: - self._signature_columns = ["prompt", "image"] + self._signature_columns = ["prompt", "image", "images"] # This method overrides `Trainer.get_train_dataloader` to support our custom batching strategy. # Instead of returning a standard per-step batch (i.e., `per_device_batch_size), our dataloader loads an @@ -790,9 +789,9 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = image_grid_thw[start : start + batch_size] - start_pixel_idx = image_grid_thw[:start].prod(-1).sum().item() - end_pixel_idx = image_grid_thw[: start + batch_size].prod(-1).sum().item() + model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) + start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() + end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] @@ -1006,6 +1005,14 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): + # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed + for prompt in prompts: + for turn in prompt: + if isinstance(turn["content"], list): + logger.warning_once("Visual reward models aren't supported yet; dropping image.") + turn["content"] = " ".join( + e["text"] for e in turn["content"] if e["type"] == "text" + ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: @@ -1051,22 +1058,22 @@ def _generate_and_score_completions( prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -1133,7 +1140,7 @@ def _generate_and_score_completions( # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -1142,7 +1149,7 @@ def _generate_and_score_completions( # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -1205,7 +1212,7 @@ def _generate_and_score_completions( torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -1213,15 +1220,13 @@ def _generate_and_score_completions( all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -1379,7 +1384,7 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -1463,8 +1468,8 @@ def _generate_and_score_completions( self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(gather_object(images)) + if images is not None: + self._logs["images"].extend(gather_object(images)) output = { "prompt_ids": prompt_ids, @@ -1482,6 +1487,8 @@ def _generate_and_score_completions( output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output @profiling_decorator @@ -1588,14 +1595,11 @@ def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> Non "advantage": self._logs["advantages"], } - if self._logs["image"]: - table["image"] = [] - for img in self._logs["image"]: - if img is not None: - # Convert images to wandb Image objects for proper visualization - table["image"].append(wandb.Image(img)) - else: - table["image"].append(None) + if self._logs["images"]: + table["images"] = [] + for image_list in self._logs["images"]: + # Convert images to wandb Image objects for proper visualization + table["images"].append([wandb.Image(image) for image in image_list]) df = pd.DataFrame(table) if self.wandb_log_unique_prompts: From 86cc30bf3c307eebcd9223ec7db8bcc8784573b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 00:43:43 +0000 Subject: [PATCH 008/153] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 50 +++++++++++++-------------- 1 file changed, 25 insertions(+), 25 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index f2a675fab16..af83e076dd9 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -83,22 +83,22 @@ def _generate_and_score_completions(self, inputs): prompts = [x["prompt"] for x in inputs] - # We don't yet support visual reward models/function, so we keep a copy of the original text-only prompts for - # later use in the reward computation. If images are present, we insert {"type": "image"} as required by the - # VLM chat template. - original_prompts = copy.deepcopy(prompts) + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] kwargs = {} - has_images = "image" in inputs[0] - if has_images: - images = [example.get("image") for example in inputs] - kwargs = {"images": [[img] for img in images]} - for prompt in prompts: + if images is not None: + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=1) + prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] @@ -170,7 +170,7 @@ def _generate_and_score_completions(self, inputs): # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts_text = gather_object(prompts_text) - if has_images: + if images is not None: all_images = gather_object(images) if self.accelerator.is_main_process: @@ -179,7 +179,7 @@ def _generate_and_score_completions(self, inputs): # prompt individually. ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - if has_images: + if images is not None: ordered_set_of_images = all_images[:: self.num_generations] else: ordered_set_of_images = None @@ -244,7 +244,7 @@ def _generate_and_score_completions(self, inputs): torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - if has_images: + if images is not None: gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) all_images = [img for sublist in gathered_images for img in sublist] @@ -252,15 +252,13 @@ def _generate_and_score_completions(self, inputs): all_images = None else: all_prompts_text = prompts_text - all_images = images if has_images else None + all_images = images - if has_images and all_images: + if images is not None and all_images: vllm_inputs = [] - for prompt, image in zip(all_prompts_text, all_images): - if image is not None: - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image}}) - else: - vllm_inputs.append(prompt) + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) + else: vllm_inputs = all_prompts_text @@ -451,7 +449,7 @@ def _generate_and_score_completions(self, inputs): # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - rewards_per_func = self._calculate_rewards(inputs, original_prompts, completions, completion_ids_list) + rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).nansum(dim=1) @@ -563,12 +561,12 @@ def _generate_and_score_completions(self, inputs): # Log prompt and completion texts all_prompts_text = gather_object(prompts_text) all_completions_text = gather_object(completions_text) - all_images = gather_object(images) if has_images else None + all_images = gather_object(images) if images is not None else None if self.num_remains_in_group is not None and mode == "train": group_global_indices_list = group_global_indices.tolist() all_prompts_text = [all_prompts_text[i] for i in group_global_indices_list] all_completions_text = [all_completions_text[i] for i in group_global_indices_list] - if has_images: + if images is not None: all_images = [all_images[i] for i in group_global_indices_list] self._logs["prompt"].extend(all_prompts_text) @@ -577,8 +575,8 @@ def _generate_and_score_completions(self, inputs): self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) - if has_images: - self._logs["image"].extend(all_images) + if images is not None: + self._logs["images"].extend(gather_object(images)) if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) @@ -639,4 +637,6 @@ def _generate_and_score_completions(self, inputs): output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] + if images is not None: + output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] return output From 088897b9cd37925268fb8fcb48b56aa4bc2b65d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:25:10 +0000 Subject: [PATCH 009/153] fix --- trl/trainer/grpo_trainer.py | 23 +++++++++++++++++------ trl/trainer/rloo_trainer.py | 21 ++++++++++++++++----- trl/trainer/utils.py | 12 ++++++++---- 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 59135c232e6..87a08096a13 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -788,6 +788,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -801,12 +802,16 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) - start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() - end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1362,6 +1367,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -1382,6 +1389,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1406,6 +1414,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1419,6 +1428,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1570,7 +1580,7 @@ def _generate_and_score_completions( if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -1643,6 +1653,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 48801496b01..4e50bbc4501 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -774,6 +774,7 @@ def _get_per_token_logps_and_entropies( compute_entropy=False, pixel_values=None, image_grid_thw=None, + num_images=None, pixel_attention_mask=None, image_sizes=None, ) -> dict[str, Optional[torch.Tensor]]: @@ -789,10 +790,15 @@ def _get_per_token_logps_and_entropies( model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} if image_grid_thw is not None and pixel_values is not None: - model_inputs["image_grid_thw"] = torch.cat(image_grid_thw[start : start + batch_size]) - start_pixel_idx = 0 if start == 0 else torch.cat(image_grid_thw[:start]).prod(-1).sum().item() - end_pixel_idx = torch.cat(image_grid_thw[: start + batch_size]).prod(-1).sum().item() - model_inputs["pixel_values"] = pixel_values[start_pixel_idx:end_pixel_idx] + rows_per_image = image_grid_thw.prod(dim=-1) + rows_per_sample = torch.split(rows_per_image, num_images) + rows_per_sample = torch.stack([s.sum() for s in rows_per_sample]) + cum_rows = torch.cat([torch.tensor([0], device=rows_per_sample.device), rows_per_sample.cumsum(0)]) + row_start, row_end = cum_rows[start].item(), cum_rows[start + batch_size].item() + model_inputs["pixel_values"] = pixel_values[row_start:row_end] + cum_imgs = torch.tensor([0] + num_images).cumsum(0) + img_start, img_end = cum_imgs[start], cum_imgs[start + batch_size] + model_inputs["image_grid_thw"] = image_grid_thw[img_start:img_end] elif pixel_values is not None: model_inputs["pixel_values"] = pixel_values[start : start + batch_size] if pixel_attention_mask is not None: @@ -1326,6 +1332,8 @@ def _generate_and_score_completions( logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies( @@ -1336,6 +1344,7 @@ def _generate_and_score_completions( batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1352,6 +1361,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1365,6 +1375,7 @@ def _generate_and_score_completions( batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -1488,7 +1499,7 @@ def _generate_and_score_completions( if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output @profiling_decorator diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 7cd16472c16..337e9857c1a 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1806,12 +1806,16 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch tensor along the first dimension. """ pixel_values = batch.get("pixel_values") - if isinstance(pixel_values, list): merged = torch.cat(pixel_values, dim=0) - return {**batch, "pixel_values": merged} - else: - return batch + batch = {**batch, "pixel_values": merged} + + image_grid_thw = batch.get("image_grid_thw") + if isinstance(image_grid_thw, list): + merged = torch.cat(image_grid_thw, dim=0) + batch = {**batch, "image_grid_thw": merged} + + return batch def truncate_with_protected_tokens( From d2adc63eb66c70592813b718ad5791e5fdead371 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:52:33 +0000 Subject: [PATCH 010/153] test peft --- tests/test_grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index ced4de9d73a..5577e1dd25d 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1348,7 +1348,7 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" From f4c82bfc0470c5a2fb590880e7a639eb77b93f2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 02:55:59 +0000 Subject: [PATCH 011/153] fix gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index af83e076dd9..6ac4b6acc7f 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import logging import re from contextlib import nullcontext @@ -373,6 +372,8 @@ def _generate_and_score_completions(self, inputs): logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + num_images = [len(img_list) for img_list in images] if images is not None else None + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the @@ -393,6 +394,7 @@ def _generate_and_score_completions(self, inputs): batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -417,6 +419,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -430,6 +433,7 @@ def _generate_and_score_completions(self, inputs): batch_size=batch_size, pixel_values=prompt_inputs.get("pixel_values"), image_grid_thw=prompt_inputs.get("image_grid_thw"), + num_images=num_images, pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), image_sizes=prompt_inputs.get("image_sizes"), ) @@ -638,5 +642,5 @@ def _generate_and_score_completions(self, inputs): if "image_sizes" in prompt_inputs: output["image_sizes"] = prompt_inputs["image_sizes"] if images is not None: - output["num_images"] = [len(img_list) if img_list is not None else 0 for img_list in images] + output["num_images"] = num_images return output From 1257796ba85a6566ff4a3749a36b4b90360fd5ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:01:47 +0000 Subject: [PATCH 012/153] rloo test --- tests/test_rloo_trainer.py | 67 +++++++++++++++++++++++++++++++++++--- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 12042bd2b3b..2bab06f218c 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1089,6 +1089,10 @@ def test_prepare_input_called_with_correct_data(self): def test_training_vlm(self, model_id): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1100,7 +1104,7 @@ def test_training_vlm(self, model_id): ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1132,6 +1136,10 @@ def test_training_vlm(self, model_id): def test_training_vlm_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, beta=0.1, # set beta to non-zero value to test the case where the reference model is used @@ -1143,7 +1151,7 @@ def test_training_vlm_beta_non_zero(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1173,6 +1181,10 @@ def test_training_vlm_peft(self): base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1183,7 +1195,7 @@ def test_training_vlm_peft(self): ) trainer = RLOOTrainer( model=model, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, peft_config=LoraConfig(target_modules=["q_proj", "v_proj"]), @@ -1208,6 +1220,10 @@ def test_training_vlm_and_prompt_truncation(self): # If not handled properly, prompt truncation may truncate image token dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, # increase the learning rate to speed up the test @@ -1219,7 +1235,7 @@ def test_training_vlm_and_prompt_truncation(self): ) trainer = RLOOTrainer( model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) @@ -1252,6 +1268,10 @@ def test_training_vlm_and_prompt_truncation(self): def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + training_args = RLOOConfig( output_dir=self.tmp_dir, learning_rate=0.1, @@ -1265,7 +1285,44 @@ def test_training_vlm_and_vllm(self, model_id) -> None: ) trainer = RLOOTrainer( model=model_id, - reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + reward_funcs=reward_func, + args=training_args, + train_dataset=dataset, + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + trainer.train() + + self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + + @require_vision + def test_training_vlm_multi_image(self): + dataset = load_dataset("trl-internal-testing/zen-multi-image", "conversational_prompt_only", split="train") + + # For now, mixing image+text and text-only examples is not supported, so we filter out text-only examples + dataset = dataset.filter(lambda x: len(x["images"]) > 0) + + def reward_func(completions, **kwargs): + """Reward function that rewards longer completions.""" + return [float(len(completion[0]["content"])) for completion in completions] + + training_args = RLOOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, # increase the learning rate to speed up the test + per_device_train_batch_size=3, # reduce the batch size to reduce memory usage + num_generations=3, # reduce the number of generations to reduce memory usage + max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it + report_to="none", + ) + trainer = RLOOTrainer( + model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", + reward_funcs=reward_func, args=training_args, train_dataset=dataset, ) From 099a39bd6a9f90b082f3554facc699ac5463ee86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:04:07 +0000 Subject: [PATCH 013/153] peft rloo --- tests/test_rloo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 2bab06f218c..399419ec3c1 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1179,7 +1179,7 @@ def test_training_vlm_peft(self): "trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration" ) base_param_names = [f"base_model.model.{n}" for n, _ in model.named_parameters()] - dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") + dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") def reward_func(completions, **kwargs): """Reward function that rewards longer completions.""" From 529add673c30175a32be9454973e9435e28b2251 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 03:55:03 +0000 Subject: [PATCH 014/153] oops --- trl/trainer/rloo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4e50bbc4501..3671af229e2 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1525,6 +1525,7 @@ def _compute_loss(self, model, inputs): compute_entropy=True, pixel_values=inputs.get("pixel_values"), image_grid_thw=inputs.get("image_grid_thw"), + num_images=inputs.get("num_images"), pixel_attention_mask=inputs.get("pixel_attention_mask"), image_sizes=inputs.get("image_sizes"), ) From fc6b11fcaeb182edbe0c3e5c336bddcaeea7bf3c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 04:22:54 +0000 Subject: [PATCH 015/153] update test --- tests/test_utils.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fc16682336..6f6ba1579ef 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1071,12 +1071,16 @@ def test_empty_protected_tokens_list(self): class UnsplitPixelValuesByGridTester(TrlTestCase): def test_unsplit_correctly(self): - split = [torch.randn(4, 5), torch.randn(2, 5)] - merged = torch.cat(split, dim=0) - batch = {"pixel_values": split, "other_key": torch.tensor([1])} + pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] + pixel_values_merged = torch.cat(pixel_values, dim=0) + image_grid_thw = [torch.tensor([[1, 2, 2]]), torch.tensor([[1, 2, 1]])] + image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) + batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) self.assertIsInstance(result["pixel_values"], torch.Tensor) - self.assertTrue(torch.allclose(result["pixel_values"], merged)) + self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged)) + self.assertIsInstance(result["image_grid_thw"], torch.Tensor) + self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged)) self.assertIn("other_key", result) def test_no_op_if_not_list(self): From ae1f497959032ae6ba0120cbb13b8f406b9e1799 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 05:08:48 +0000 Subject: [PATCH 016/153] generate method --- trl/trainer/grpo_trainer.py | 155 +++++++++++++++++++++--------------- 1 file changed, 92 insertions(+), 63 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 87a08096a13..bfc24a651c3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1069,21 +1069,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1094,7 +1083,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1106,6 +1097,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1279,8 +1271,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) sampling_per_token_logps = [ torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs ] @@ -1318,9 +1311,9 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + else: # Regular generation path with ( @@ -1331,14 +1324,18 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] + sampling_per_token_logps = None # not used in this case # Mask everything after the first EOS token is_eos = completion_ids == self.eos_token_id @@ -1347,10 +1344,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) @@ -1361,7 +1354,72 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + if images is not None: + self._logs["image"].extend(gather_object(images)) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1387,11 +1445,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -1412,11 +1467,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1426,16 +1478,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1484,27 +1534,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1571,14 +1600,14 @@ def _generate_and_score_completions( output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From f99843262210380e08a43874e778b3270381bffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 05:18:40 +0000 Subject: [PATCH 017/153] debug --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4231ef227ec..48ee6cc9295 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,6 +77,7 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate + export CUDA_LAUNCH_BLOCKING=1 make test - name: Post to Slack From fa738768c685196407db94a652d1e52cf88bcc22 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 15:21:36 +0000 Subject: [PATCH 018/153] skip failing test --- tests/test_online_dpo_trainer.py | 58 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 29 deletions(-) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 47fbd1f5a1f..30b39ce3464 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -419,35 +419,35 @@ def test_generation_config_setup(self): self.assertEqual(trainer.generation_config.max_new_tokens, 64) self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs - @require_torch_accelerator - @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) - def test_training_with_transformers_paged(self, config_name): - if Version(transformers.__version__) < Version("4.56.2"): - pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") - training_args = OnlineDPOConfig( - output_dir=self.tmp_dir, - per_device_train_batch_size=2, - max_steps=3, - learning_rate=5.0e-7, - eval_strategy="steps", - report_to="none", - use_transformers_paged=True, - ) - dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - - trainer = OnlineDPOTrainer( - model=self.model, - reward_funcs=self.reward_model, - args=training_args, - train_dataset=dummy_dataset["train"], - eval_dataset=dummy_dataset["test"], - processing_class=self.tokenizer, - reward_processing_classes=self.reward_tokenizer, - ) - trainer.train() - - # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + # @require_torch_accelerator + # @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + # def test_training_with_transformers_paged(self, config_name): + # if Version(transformers.__version__) < Version("4.56.2"): + # pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") + # training_args = OnlineDPOConfig( + # output_dir=self.tmp_dir, + # per_device_train_batch_size=2, + # max_steps=3, + # learning_rate=5.0e-7, + # eval_strategy="steps", + # report_to="none", + # use_transformers_paged=True, + # ) + # dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + # trainer = OnlineDPOTrainer( + # model=self.model, + # reward_funcs=self.reward_model, + # args=training_args, + # train_dataset=dummy_dataset["train"], + # eval_dataset=dummy_dataset["test"], + # processing_class=self.tokenizer, + # reward_processing_classes=self.reward_tokenizer, + # ) + # trainer.train() + + # # Check if training loss is available + # self.assertIn("train_loss", trainer.state.log_history[-1]) @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): From fc52e6832d5b0e9f3403a01fd9571f4dd537ee5f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 16:26:34 +0000 Subject: [PATCH 019/153] test fixed! --- .github/workflows/tests.yml | 1 - scripts/generate_tiny_models.py | 1 + tests/test_online_dpo_trainer.py | 58 ++++++++++++++++---------------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 48ee6cc9295..4231ef227ec 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -77,7 +77,6 @@ jobs: - name: Test with pytest run: | source .venv/bin/activate - export CUDA_LAUNCH_BLOCKING=1 make test - name: Post to Slack diff --git a/scripts/generate_tiny_models.py b/scripts/generate_tiny_models.py index f8e779896f6..0000f788d09 100644 --- a/scripts/generate_tiny_models.py +++ b/scripts/generate_tiny_models.py @@ -292,6 +292,7 @@ def init_weights_tiny_model(model): "hidden_size": 16, "num_attention_heads": 4, "num_key_value_heads": 2, + "embed_dim": 64, } config = AutoConfig.from_pretrained(model_id, text_config=text_config, vision_config=vision_config) diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 30b39ce3464..47fbd1f5a1f 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -419,35 +419,35 @@ def test_generation_config_setup(self): self.assertEqual(trainer.generation_config.max_new_tokens, 64) self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs - # @require_torch_accelerator - # @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) - # def test_training_with_transformers_paged(self, config_name): - # if Version(transformers.__version__) < Version("4.56.2"): - # pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") - # training_args = OnlineDPOConfig( - # output_dir=self.tmp_dir, - # per_device_train_batch_size=2, - # max_steps=3, - # learning_rate=5.0e-7, - # eval_strategy="steps", - # report_to="none", - # use_transformers_paged=True, - # ) - # dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) - - # trainer = OnlineDPOTrainer( - # model=self.model, - # reward_funcs=self.reward_model, - # args=training_args, - # train_dataset=dummy_dataset["train"], - # eval_dataset=dummy_dataset["test"], - # processing_class=self.tokenizer, - # reward_processing_classes=self.reward_tokenizer, - # ) - # trainer.train() - - # # Check if training loss is available - # self.assertIn("train_loss", trainer.state.log_history[-1]) + @require_torch_accelerator + @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + def test_training_with_transformers_paged(self, config_name): + if Version(transformers.__version__) < Version("4.56.2"): + pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.56.2") + training_args = OnlineDPOConfig( + output_dir=self.tmp_dir, + per_device_train_batch_size=2, + max_steps=3, + learning_rate=5.0e-7, + eval_strategy="steps", + report_to="none", + use_transformers_paged=True, + ) + dummy_dataset = load_dataset("trl-internal-testing/zen", config_name) + + trainer = OnlineDPOTrainer( + model=self.model, + reward_funcs=self.reward_model, + args=training_args, + train_dataset=dummy_dataset["train"], + eval_dataset=dummy_dataset["test"], + processing_class=self.tokenizer, + reward_processing_classes=self.reward_tokenizer, + ) + trainer.train() + + # Check if training loss is available + self.assertIn("train_loss", trainer.state.log_history[-1]) @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): From 4fc2b5b71d7b11c4e5488c8ad90b8999061798ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:13:23 +0000 Subject: [PATCH 020/153] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 354 ++------------------------ 1 file changed, 27 insertions(+), 327 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 6ac4b6acc7f..d3b59a72c81 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -13,22 +13,15 @@ # limitations under the License. import logging -import re -from contextlib import nullcontext from typing import Any, Callable import torch -import torch.utils.data -from accelerate.utils import broadcast_object_list, gather_object -from torch.distributed.fsdp import FullyShardedDataParallel as FSDP -from transformers.utils import is_flash_attn_2_available +from accelerate.utils import gather_object -from ...data_utils import is_conversational, maybe_apply_chat_template, prepare_multimodal_messages -from ...extras.profiling import profiling_context +from ...data_utils import is_conversational from ...import_utils import is_vllm_available -from ...models import unwrap_model_for_generation from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer -from ...trainer.utils import nanmax, nanmin, nanstd, pad, truncate_with_protected_tokens +from ...trainer.utils import nanmax, nanmin, nanstd logger = logging.getLogger(__name__) @@ -36,8 +29,7 @@ GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] if is_vllm_available(): - from vllm import SamplingParams - from vllm.sampling_params import GuidedDecodingParams + pass class GFPOTrainer(_GRPOTrainer): @@ -89,284 +81,22 @@ def _generate_and_score_completions(self, inputs): else: images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] - - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - - if self.max_prompt_length is not None: - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() - - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step - - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) - - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["completion_ids"], output["logprobs"]) - else: - payload = None - - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] - - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts - elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" - else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - else: - # Regular generation path - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask - prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True - ) - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -392,11 +122,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: old_per_token_logps = None @@ -417,11 +144,8 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -431,16 +155,14 @@ def _generate_and_score_completions(self, inputs): attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -529,28 +251,6 @@ def _generate_and_score_completions(self, inputs): completion_lengths = completion_mask.sum(1) agg_completion_lengths = self.accelerator.gather(completion_lengths) num_items_in_batch = agg_completion_lengths.sum() - is_eos = completion_ids == self.eos_token_id - - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): @@ -633,14 +333,14 @@ def _generate_and_score_completions(self, inputs): output["importance_sampling_ratio"] = importance_sampling_ratio if ref_per_token_logps is not None: output["ref_per_token_logps"] = ref_per_token_logps - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From b628744752d54c3f2028b4a2d420bd88fc06a188 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:15:02 +0000 Subject: [PATCH 021/153] rm vllm --- trl/experimental/gfpo/gfpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index d3b59a72c81..5e228c1e883 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -19,7 +19,6 @@ from accelerate.utils import gather_object from ...data_utils import is_conversational -from ...import_utils import is_vllm_available from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer from ...trainer.utils import nanmax, nanmin, nanstd @@ -28,9 +27,6 @@ GroupFilterFunc = Callable[[list[list[Any]], list[list[Any]]], list[list[float]]] -if is_vllm_available(): - pass - class GFPOTrainer(_GRPOTrainer): def __init__( From d3a769fe8fb5a8a4b9e21b6d96c8b1d8f84c7960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 20 Sep 2025 17:15:13 +0000 Subject: [PATCH 022/153] fix doc --- docs/source/experimental.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/experimental.md b/docs/source/experimental.md index 65471f4e421..1413e56a2df 100644 --- a/docs/source/experimental.md +++ b/docs/source/experimental.md @@ -66,7 +66,7 @@ class GroupFilter: return group_scores training_args = GFPOConfig( - output_dir="Qwen3-0.6B-GFPO" + output_dir="Qwen3-0.6B-GFPO", per_device_train_batch_size=4, num_remains_in_group=2, bf16=True, From 05270f820f69bad6b3edc1edbeee2d53189a9143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 22 Sep 2025 23:51:57 +0000 Subject: [PATCH 023/153] update layers to ignore --- tests/test_grpo_trainer.py | 11 ----------- tests/test_rloo_trainer.py | 2 -- 2 files changed, 13 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 5577e1dd25d..cc484c56d0f 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1291,7 +1291,6 @@ def reward_func(completions, **kwargs): "model.vision_tower.", "model.multi_modal_projector.", "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) @@ -1587,17 +1586,7 @@ def reward_func(completions, **kwargs): # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ( - # "model.vision_tower.", - # "model.multi_modal_projector.", - # "model.vision_model.", - # "model.connector.modality_projection.", - # "model.visual.", - # "model.image_newline", - ) for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 399419ec3c1..cde52de6047 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1121,8 +1121,6 @@ def reward_func(completions, **kwargs): params_to_skip = ( "model.vision_tower.", "model.multi_modal_projector.", - "model.vision_model.", - "model.connector.modality_projection.", "model.visual.", "model.image_newline", ) From 1c530948681255ce59bc267939d79a80ec1c5d93 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 22 Sep 2025 23:57:13 +0000 Subject: [PATCH 024/153] clarify image column desc --- docs/source/dataset_formats.md | 2 +- docs/source/grpo_trainer.md | 4 +++- docs/source/rloo_trainer.md | 4 +++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/docs/source/dataset_formats.md b/docs/source/dataset_formats.md index da606a7d97e..8a105ff5e34 100644 --- a/docs/source/dataset_formats.md +++ b/docs/source/dataset_formats.md @@ -1037,7 +1037,7 @@ Some trainers also support fine-tuning vision-language models (VLMs) using image A conversational vision dataset differs from a standard conversational dataset in two key ways: -1. The dataset must contain the key `images` with the image data. +1. The dataset must contain the key `images` with the image data (as lists of PIL images) or `image` with a single PIL image. 2. The `"content"` field in messages must be a list of dictionaries, where each dictionary specifies the type of data: `"image"` or `"text"`. Example: diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 172e5f93111..e998ef63f69 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -562,7 +562,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -605,7 +607,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 66a8f3e16e4..bce71b1f0bf 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -533,7 +533,9 @@ Tested with: - **SmolVLM2** — e.g., `HuggingFaceTB/SmolVLM2-2.2B-Instruct` + Compatibility with all VLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + ### Quick Start @@ -576,7 +578,7 @@ VLM training may fail if image tokens are truncated. We highly recommend disabli Each training sample should include: - `prompt`: Text formatted via the processor's chat template -- `image`: A single image (PIL or NumPy array) +- `image`/`images`: PIL Image or list of PIL Images The trainer automatically handles image-to-tensor conversion via the model’s image processor. From 9b6652eed4fdc6c7e73c40d81917a3e5c9ad024c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 23 Sep 2025 00:05:23 +0000 Subject: [PATCH 025/153] rm VLM x RM warning --- trl/trainer/grpo_trainer.py | 8 -------- trl/trainer/rloo_trainer.py | 8 -------- 2 files changed, 16 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 87a08096a13..69825102b27 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1024,14 +1024,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): - # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed - for prompt in prompts: - for turn in prompt: - if isinstance(turn["content"], list): - logger.warning_once("Visual reward models aren't supported yet; dropping image.") - turn["content"] = " ".join( - e["text"] for e in turn["content"] if e["type"] == "text" - ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3671af229e2..359cb68e43b 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1011,14 +1011,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): with profiling_context(self, reward_func_name): if isinstance(reward_func, nn.Module): # Module (no PretrainedModel) for compat with compiled models if is_conversational(inputs[0]): - # VLM reward models aren't supported yet, so we drop the image and raise a warning if needed - for prompt in prompts: - for turn in prompt: - if isinstance(turn["content"], list): - logger.warning_once("Visual reward models aren't supported yet; dropping image.") - turn["content"] = " ".join( - e["text"] for e in turn["content"] if e["type"] == "text" - ) messages = [{"messages": p + c} for p, c in zip(prompts, completions)] texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] else: From c83e7108319d19ffe55866a8f7401f9741b93df1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 24 Sep 2025 17:17:14 +0000 Subject: [PATCH 026/153] same for rloo --- trl/trainer/grpo_trainer.py | 7 +- trl/trainer/rloo_trainer.py | 150 ++++++++++++++++++++---------------- 2 files changed, 87 insertions(+), 70 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 78c3ccff638..a947d52cec3 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -480,7 +480,7 @@ def __init__( if not is_vllm_available(): raise ImportError( "vLLM is not available and `use_vllm` is set to True. Please install vLLM with " - "`pip install [vllm]` to use it." + "`pip install trl[vllm]` to use it." ) if self.vllm_mode == "server": @@ -533,7 +533,7 @@ def __init__( distributed_executor_backend="external_launcher", # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, - # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory + # Latest vLLM v1 memory profiler is misled by the high default value (i.e., 32768) - thinking there's not enough memory max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, @@ -1366,9 +1366,6 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - if images is not None: - self._logs["image"].extend(gather_object(images)) - return ( prompt_ids, completion_ids, diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index e87ecf95b37..5ff29112e9c 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -81,7 +81,6 @@ if is_peft_available(): from peft import PeftConfig, PeftModel - if is_vllm_available(): from vllm import LLM, SamplingParams from vllm.sampling_params import GuidedDecodingParams @@ -788,7 +787,6 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} - if image_grid_thw is not None and pixel_values is not None: rows_per_image = image_grid_thw.prod(dim=-1) rows_per_sample = torch.split(rows_per_image, num_images) @@ -1048,21 +1046,10 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: + def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] @@ -1073,7 +1060,9 @@ def _generate_and_score_completions( if isinstance(prompt, list): # i.e., when using conversational data prepare_multimodal_messages(prompt, num_images=len(image_list)) - prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] prompt_inputs = self.processing_class( text=prompts_text, @@ -1085,6 +1074,7 @@ def _generate_and_score_completions( ) prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1250,8 +1240,9 @@ def _generate_and_score_completions( # Pad the completions, and concatenate them with the prompts completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + completion_mask = pad(completion_mask, padding_value=0) elif self.use_transformers_paged: # Re-process inputs for paged generation if needed @@ -1286,9 +1277,9 @@ def _generate_and_score_completions( completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + else: # Regular generation path with ( @@ -1299,9 +1290,12 @@ def _generate_and_score_completions( torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - prompt_inputs["input_ids"], prompt_inputs["attention_mask"] = prompt_ids, prompt_mask prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) @@ -1315,10 +1309,6 @@ def _generate_and_score_completions( sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging completion_lengths = completion_mask.sum(1) @@ -1327,7 +1317,66 @@ def _generate_and_score_completions( truncated_completions = ~is_eos.any(dim=1) completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + agg_completion_lengths = self.accelerator.gather(completion_lengths) + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + forward_kwargs + ) + + def _generate_and_score_completions( + self, inputs: list[dict[str, Union[torch.Tensor, Any]]] + ) -> dict[str, Union[torch.Tensor, Any]]: + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompts = [x["prompt"] for x in inputs] + + if "images" in inputs[0]: + images = [example.get("images") for example in inputs] + elif "image" in inputs[0]: + images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] + else: + images = None + + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + forward_kwargs, + ) = self._generate(prompts, images) + + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens @@ -1343,11 +1392,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) old_logps = (old_per_token_logps * completion_mask).sum(1) # mask out padding and tokens after EOS @@ -1360,11 +1406,8 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: with self.accelerator.unwrap_model(self.model).disable_adapter(): @@ -1374,16 +1417,14 @@ def _generate_and_score_completions( attention_mask, logits_to_keep, batch_size=batch_size, - pixel_values=prompt_inputs.get("pixel_values"), - image_grid_thw=prompt_inputs.get("image_grid_thw"), num_images=num_images, - pixel_attention_mask=prompt_inputs.get("pixel_attention_mask"), - image_sizes=prompt_inputs.get("image_sizes"), + **forward_kwargs, # may contain pixel_values, image_grid_thw, pixel_attention_mask and image_sizes ) else: ref_per_token_logps = None - # Decode the generated completions + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) if is_conversational(inputs[0]): completions = [] @@ -1436,33 +1477,12 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Log the metrics - if mode == "train": - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Calculate and log the mean KL divergence between current and reference model if self.beta != 0.0: mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) self._metrics[mode]["kl"].append(self.accelerator.gather(mean_kl).nanmean().item()) - # Log completion lengths, mean, min, max - agg_completion_lengths = self.accelerator.gather(completion_lengths) - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1491,14 +1511,14 @@ def _generate_and_score_completions( "old_logps": old_logps, "advantages": advantages, } - if "pixel_values" in prompt_inputs: - output["pixel_values"] = prompt_inputs["pixel_values"] - if "image_grid_thw" in prompt_inputs: - output["image_grid_thw"] = prompt_inputs["image_grid_thw"] - if "pixel_attention_mask" in prompt_inputs: - output["pixel_attention_mask"] = prompt_inputs["pixel_attention_mask"] - if "image_sizes" in prompt_inputs: - output["image_sizes"] = prompt_inputs["image_sizes"] + if "pixel_values" in forward_kwargs: + output["pixel_values"] = forward_kwargs["pixel_values"] + if "image_grid_thw" in forward_kwargs: + output["image_grid_thw"] = forward_kwargs["image_grid_thw"] + if "pixel_attention_mask" in forward_kwargs: + output["pixel_attention_mask"] = forward_kwargs["pixel_attention_mask"] + if "image_sizes" in forward_kwargs: + output["image_sizes"] = forward_kwargs["image_sizes"] if images is not None: output["num_images"] = num_images return output From ec6ad259d22cbb817eb3b5a6cf948799721893bc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 24 Sep 2025 17:26:25 +0000 Subject: [PATCH 027/153] nits style and align --- trl/trainer/grpo_trainer.py | 1 + trl/trainer/rloo_trainer.py | 18 +++--------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a947d52cec3..7d9138319e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1303,6 +1303,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn + sampling_per_token_logps = None # not used in this case else: # Regular generation path diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 5ff29112e9c..a4e31c7a9cf 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -787,6 +787,7 @@ def _get_per_token_logps_and_entropies( # Build model inputs - check if the model supports logits_to_keep (some models and VLMs don't) model_inputs = {"input_ids": input_ids_batch, "attention_mask": attention_mask_batch} + if image_grid_thw is not None and pixel_values is not None: rows_per_image = image_grid_thw.prod(dim=-1) rows_per_sample = torch.split(rows_per_image, num_images) @@ -1340,13 +1341,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - forward_kwargs - ) + return prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1363,13 +1358,7 @@ def _generate_and_score_completions( else: images = None - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs = self._generate(prompts, images) # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need # to re-tokenize completions if the reward is computed from tokens. @@ -1477,7 +1466,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Calculate and log the mean KL divergence between current and reference model if self.beta != 0.0: mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) From b0dceb97ac87021eb9f61a1632625c38a0128a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 04:03:39 +0000 Subject: [PATCH 028/153] restart --- trl/trainer/grpo_trainer.py | 459 +++++++++--------------------------- 1 file changed, 117 insertions(+), 342 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 7d9138319e6..6adee765ba9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,7 +14,6 @@ import inspect import os -import re import textwrap from collections import defaultdict, deque from contextlib import nullcontext @@ -27,7 +26,7 @@ import torch.utils.data import transformers from accelerate import logging -from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed +from accelerate.utils import gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -46,9 +45,9 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ..data_utils import apply_chat_template, is_conversational from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -73,7 +72,6 @@ shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, - truncate_with_protected_tokens, unsplit_pixel_values_by_grid, ) @@ -537,6 +535,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + enforce_eager=True, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1058,324 +1057,104 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate_transformers(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - - if self.max_prompt_length is not None: - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + _is_conversational = is_conversational({"prompt": prompts}) + if _is_conversational: + prompts_ids_list = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True) + else: + prompts_ids_list = self.processing_class(prompts)["input_ids"] + prompt_ids = [torch.tensor(ids, device=device) for ids in prompts_ids_list] + prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + forward_kwargs = {} + prompt_completion_ids = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] - - # Generate completions using either vLLM or regular generation - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() + completion_ids = prompt_completion_ids[:, prompt_ids.shape[1] :] + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if _is_conversational: + completions = [[{"role": "assistant", "content": content}] for content in completions] + return completions + + def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): + if self.guided_decoding_regex: + guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) + else: + guided_decoding = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "guided_decoding": guided_decoding, + "logprobs": 0, # only return the logprob of the generated token + } + if self.args.generation_kwargs is not None: + generation_kwargs.update(self.args.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) + + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] + else: + all_prompts = prompts - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_contents = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=True) + all_completions = [[{"role": "assistant", "content": content}] for content in all_contents] - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + if self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completions = all_completions[tp_slice] + else: + completions = all_completions - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) - payload = (output["completion_ids"], output["logprobs"]) - else: - payload = None + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=1) - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] + return completions - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] - - # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + def _generate(self, prompts, images): + if self.use_vllm: + if self.vllm_mode == "server": + raise NotImplementedError("vLLM server mode is not supported yet.") elif self.vllm_mode == "colocate": - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None - else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text - - with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] - - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] - - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) - - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - completion_mask = pad(completion_mask, padding_value=0) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - - elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + return self._generate_vllm_colocate(prompts, images) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" - with ( - profiling_context(self, "transformers.generate_batch"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - # Cast to the appropriate dtype based on training configuration - if self.args.bf16: - unwrapped_model.to(torch.bfloat16) - elif self.args.fp16: - unwrapped_model.to(torch.float16) - with torch.inference_mode(): - all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False - ) - unwrapped_model.train() # restore training mode, as generate_batch forces eval mode - completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn - sampling_per_token_logps = None # not used in this case - + raise ValueError(f"Invalid vLLM mode: {self.vllm_mode}") + elif self.use_transformers_paged: + raise NotImplementedError("Transformers paged generation is not supported yet.") else: - # Regular generation path - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_completion_ids = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - **forward_kwargs, - generation_config=self.generation_config, - disable_compile=True, - ) - # Compute prompt length and extract completion ids - prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] - completion_ids = prompt_completion_ids[:, prompt_length:] - sampling_per_token_logps = None # not used in this case - - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - - # Log the metrics - if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) + return self._generate_transformers(prompts, images) def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1392,23 +1171,25 @@ def _generate_and_score_completions( else: images = None - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) = self._generate(prompts, images) - - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - - # Concatenate prompt_mask with completion_mask for logit computation - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + completions = self._generate(prompts, images) + + prompts_completions = [p + c for p, c in zip(prompts, completions)] + prompt_ids_list = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True) + prompt_completion_ids_list = self.processing_class.apply_chat_template(prompts_completions) + completion_ids_list = [pc[len(p) :] for p, pc in zip(prompt_ids_list, prompt_completion_ids_list)] + + # Convert to tensors and pad + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids) for ids in completion_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + forward_kwargs = {} logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size @@ -1472,17 +1253,6 @@ def _generate_and_score_completions( else: ref_per_token_logps = None - # Decode - prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) - completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text - # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. @@ -1522,6 +1292,11 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] + # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging + completion_lengths = completion_mask.sum(1) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss + # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1533,8 +1308,8 @@ def _generate_and_score_completions( self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) # Log prompt and completion texts - self._logs["prompt"].extend(gather_object(prompts_text)) - self._logs["completion"].extend(gather_object(completions_text)) + self._logs["prompt"].extend(gather_object([str(p) for p in prompts])) + self._logs["completion"].extend(gather_object([str(c) for c in completions])) for i, name in enumerate(self.reward_func_names): self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) From ebe32c26d83618301c17924b0625ce1d98561a79 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 06:14:02 +0000 Subject: [PATCH 029/153] progress --- trl/data_utils.py | 50 ++++++++++++++++++++++++++++++++++++- trl/trainer/grpo_trainer.py | 41 ++++++++++++++++++++++++------ 2 files changed, 82 insertions(+), 9 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index 75e7a76f979..72191757d73 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -16,7 +16,7 @@ from collections.abc import Sequence from itertools import takewhile from typing import Any, Callable, Optional, TypeVar, Union - +import copy import numpy as np import pyarrow as pa import pyarrow.compute as pc @@ -57,6 +57,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) ] ``` """ + messages = copy.deepcopy(messages) image_included = False for message in messages: if message["role"] == "system": @@ -74,7 +75,54 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) message["content"] = [{"type": "text", "text": message["content"]}] else: raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.") + return messages + + +def prepare_multimodal_messages_2(messages: list[dict[str, Any]], images: list) -> None: + """ + Convert messages into a structured multimodal format if needed. + + Each message's content is transformed from a raw string into a list of typed parts. The first user message is + prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. + + Args: + messages (`list[dict[str, Any]]`): + Messages with `"role"` and `"content"`. Content may be a raw string before transformation. + images (`list`): + Images to include in the first user message. + """ + messages = copy.deepcopy(messages) + image_idx = 0 + for message in messages: + if message["role"] == "user": + for part in message["content"]: + if part["type"] == "image": + part["image"] = images[image_idx] + image_idx += 1 + return messages + +def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> None: + """ + Convert messages into a structured multimodal format if needed. + + Each message's content is transformed from a raw string into a list of typed parts. The first user message is + prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. + + Args: + messages (`list[dict[str, Any]]`): + Messages with `"role"` and `"content"`. Content may be a raw string before transformation. + images (`list`): + Images to include in the first user message. + """ + messages = copy.deepcopy(messages) + for message in messages: + if message["role"] == "user": + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" + part["image_pil"] = part.pop("image") + return messages def is_conversational(example: dict[str, Any]) -> bool: r""" diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6adee765ba9..b286f7eeb8f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -47,7 +47,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational +from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_2, prepare_multimodal_messages, prepare_multimodal_messages_vllm from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -1122,12 +1122,26 @@ def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): else: all_prompts = prompts + _is_conversational = is_conversational({"prompt": prompts[0]}) + + if images: + all_prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in all_prompts] + with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + if _is_conversational: + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_contents = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=True) - all_completions = [[{"role": "assistant", "content": content}] for content in all_contents] + + if _is_conversational: + all_contents = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=True) + if images: + all_contents = [[{"type": "text", "text": content}] for content in all_contents] + all_completions = [[{"role": "assistant", "content": content}] for content in all_contents] + else: + all_completions = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=False) if self.vllm_tensor_parallel_size > 1: # Slice completions for this rank within its TP group. @@ -1163,6 +1177,7 @@ def _generate_and_score_completions( mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] + _is_conversational = is_conversational(inputs[0]) if "images" in inputs[0]: images = [example.get("images") for example in inputs] @@ -1170,12 +1185,23 @@ def _generate_and_score_completions( images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] else: images = None + + if images is not None: + prompts = [prepare_multimodal_messages(prompt, len(image_list)) for prompt, image_list in zip(prompts, images)] + prompts = [prepare_multimodal_messages_2(prompt, image_list) for prompt, image_list in zip(prompts, images)] completions = self._generate(prompts, images) - + # Tokenize and extract completion ids prompts_completions = [p + c for p, c in zip(prompts, completions)] - prompt_ids_list = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True) - prompt_completion_ids_list = self.processing_class.apply_chat_template(prompts_completions) + if _is_conversational: + forward_kwargs = self.processing_class.apply_chat_template(prompts, tokenize=True, add_generation_prompt=True, return_dict=True) + forward_kwargs = super()._prepare_inputs(forward_kwargs) + prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") + prompt_completion_ids_list = self.processing_class.apply_chat_template(prompts_completions, tokenize=True) + else: + prompt_ids_list = self.processing_class(prompts)["input_ids"] + prompt_completion_ids_list = self.processing_class(prompts_completions)["input_ids"] + forward_kwargs = {} completion_ids_list = [pc[len(p) :] for p, pc in zip(prompt_ids_list, prompt_completion_ids_list)] # Convert to tensors and pad @@ -1189,7 +1215,6 @@ def _generate_and_score_completions( completion_mask = pad(completion_mask, padding_value=0, padding_side="right") prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - forward_kwargs = {} logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size From 0213662cd4403767d919505d5d7fa7231b829f89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 18:24:46 +0000 Subject: [PATCH 030/153] progress continues --- trl/data_utils.py | 30 ++++++---------- trl/trainer/grpo_trainer.py | 68 ++++++++++++++++++++----------------- 2 files changed, 47 insertions(+), 51 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index 72191757d73..bbe0a4af934 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict, deque from collections.abc import Sequence from itertools import takewhile from typing import Any, Callable, Optional, TypeVar, Union -import copy + import numpy as np import pyarrow as pa import pyarrow.compute as pc @@ -78,20 +79,9 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) return messages -def prepare_multimodal_messages_2(messages: list[dict[str, Any]], images: list) -> None: - """ - Convert messages into a structured multimodal format if needed. - - Each message's content is transformed from a raw string into a list of typed parts. The first user message is - prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. - - Args: - messages (`list[dict[str, Any]]`): - Messages with `"role"` and `"content"`. Content may be a raw string before transformation. - images (`list`): - Images to include in the first user message. - """ - messages = copy.deepcopy(messages) +def insert_images(messages: list[dict[str, Any]], images: list) -> None: + """ """ + messages = prepare_multimodal_messages(messages, num_images=len(images)) image_idx = 0 for message in messages: if message["role"] == "user": @@ -118,12 +108,14 @@ def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> None: messages = copy.deepcopy(messages) for message in messages: if message["role"] == "user": - for part in message["content"]: - if part["type"] == "image": - part["type"] = "image_pil" - part["image_pil"] = part.pop("image") + if isinstance(message["content"], list): # if already prepared, the content will be a list + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" + part["image_pil"] = part.pop("image") return messages + def is_conversational(example: dict[str, Any]) -> bool: r""" Check if the example is in a conversational format. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b286f7eeb8f..af17cade57f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -47,7 +47,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages_2, prepare_multimodal_messages, prepare_multimodal_messages_vllm +from ..data_utils import apply_chat_template, insert_images, is_conversational, prepare_multimodal_messages_vllm from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -1057,13 +1057,18 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_transformers(self, prompts: list[str], images: Optional[list]): + def _generate_transformers(self, prompts: list[str]): device = self.accelerator.device - _is_conversational = is_conversational({"prompt": prompts}) + _is_conversational = is_conversational({"prompt": prompts[0]}) if _is_conversational: - prompts_ids_list = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True) + forward_kwargs = self.processing_class.apply_chat_template( + prompts, add_generation_prompt=True, return_dict=True, tokenize=True + ) + forward_kwargs = super()._prepare_inputs(forward_kwargs) + prompts_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") else: prompts_ids_list = self.processing_class(prompts)["input_ids"] + forward_kwargs = {} prompt_ids = [torch.tensor(ids, device=device) for ids in prompts_ids_list] prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") @@ -1077,7 +1082,6 @@ def _generate_transformers(self, prompts: list[str], images: Optional[list]): torch.no_grad(), FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): - forward_kwargs = {} prompt_completion_ids = unwrapped_model.generate( input_ids=prompt_ids, attention_mask=prompt_mask, @@ -1087,11 +1091,9 @@ def _generate_transformers(self, prompts: list[str], images: Optional[list]): ) completion_ids = prompt_completion_ids[:, prompt_ids.shape[1] :] completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if _is_conversational: - completions = [[{"role": "assistant", "content": content}] for content in completions] return completions - def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): + def _generate_vllm_colocate(self, prompts: list[str]): if self.guided_decoding_regex: guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) else: @@ -1124,7 +1126,8 @@ def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): _is_conversational = is_conversational({"prompt": prompts[0]}) - if images: + is_multimodal = isinstance(prompts[0], list) and isinstance(prompts[0][0]["content"], list) + if is_multimodal: all_prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in all_prompts] with profiling_context(self, "vLLM.generate"): @@ -1135,13 +1138,7 @@ def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - if _is_conversational: - all_contents = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=True) - if images: - all_contents = [[{"type": "text", "text": content}] for content in all_contents] - all_completions = [[{"role": "assistant", "content": content}] for content in all_contents] - else: - all_completions = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=False) + all_completions = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=False) if self.vllm_tensor_parallel_size > 1: # Slice completions for this rank within its TP group. @@ -1157,18 +1154,18 @@ def _generate_vllm_colocate(self, prompts: list[str], images: Optional[list]): return completions - def _generate(self, prompts, images): + def _generate(self, prompts): if self.use_vllm: if self.vllm_mode == "server": raise NotImplementedError("vLLM server mode is not supported yet.") elif self.vllm_mode == "colocate": - return self._generate_vllm_colocate(prompts, images) + return self._generate_vllm_colocate(prompts) else: raise ValueError(f"Invalid vLLM mode: {self.vllm_mode}") elif self.use_transformers_paged: raise NotImplementedError("Transformers paged generation is not supported yet.") else: - return self._generate_transformers(prompts, images) + return self._generate_transformers(prompts) def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1185,23 +1182,23 @@ def _generate_and_score_completions( images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] else: images = None - + if images is not None: - prompts = [prepare_multimodal_messages(prompt, len(image_list)) for prompt, image_list in zip(prompts, images)] - prompts = [prepare_multimodal_messages_2(prompt, image_list) for prompt, image_list in zip(prompts, images)] + prompts = [insert_images(prompt, image_list) for prompt, image_list in zip(prompts, images)] - completions = self._generate(prompts, images) - # Tokenize and extract completion ids - prompts_completions = [p + c for p, c in zip(prompts, completions)] + completion_texts = self._generate(prompts) if _is_conversational: - forward_kwargs = self.processing_class.apply_chat_template(prompts, tokenize=True, add_generation_prompt=True, return_dict=True) - forward_kwargs = super()._prepare_inputs(forward_kwargs) - prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") - prompt_completion_ids_list = self.processing_class.apply_chat_template(prompts_completions, tokenize=True) + prompt_texts = self.processing_class.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=False + ) else: - prompt_ids_list = self.processing_class(prompts)["input_ids"] - prompt_completion_ids_list = self.processing_class(prompts_completions)["input_ids"] - forward_kwargs = {} + prompt_texts = prompts + prompts_completions = [p + c for p, c in zip(prompt_texts, completion_texts)] + kwargs = {"images": images} if images is not None else {} + forward_kwargs = self.processing_class(text=prompt_texts, **kwargs) + forward_kwargs = super()._prepare_inputs(forward_kwargs) + prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") + prompt_completion_ids_list = self.processing_class(text=prompts_completions, **kwargs)["input_ids"] completion_ids_list = [pc[len(p) :] for p, pc in zip(prompt_ids_list, prompt_completion_ids_list)] # Convert to tensors and pad @@ -1281,6 +1278,13 @@ def _generate_and_score_completions( # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. + if _is_conversational: + # Note that we need to decode the completion ids instead of using completion_texts, because in + # completions_texts we may have special tokens like EOS. + decodes_completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in decodes_completions] + else: + completions = completion_texts rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum From 8b3a72460253b4eb5655311948eecd1a7b50b70c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 23:27:53 +0000 Subject: [PATCH 031/153] progress again again --- trl/trainer/grpo_trainer.py | 190 +++++++++++++++++++++++++----------- 1 file changed, 132 insertions(+), 58 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index af17cade57f..071e0ccd81d 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -26,7 +26,7 @@ import torch.utils.data import transformers from accelerate import logging -from accelerate.utils import gather, gather_object, is_peft_model, set_seed +from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP @@ -1057,41 +1057,55 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_transformers(self, prompts: list[str]): - device = self.accelerator.device - _is_conversational = is_conversational({"prompt": prompts[0]}) - if _is_conversational: - forward_kwargs = self.processing_class.apply_chat_template( - prompts, add_generation_prompt=True, return_dict=True, tokenize=True - ) - forward_kwargs = super()._prepare_inputs(forward_kwargs) - prompts_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") + def _generate_vllm_server(self, prompts: list[str]): + prompts_text = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True, tokenize=False) + + all_prompts_text = gather_object(prompts_text) + images = None + if images is not None: + all_images = gather_object(images) + + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + + if images is not None: + ordered_set_of_images = all_images[:: self.num_generations] + else: + ordered_set_of_images = None + + with profiling_context(self, "vLLM.generate"): + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + payload = (output["completion_ids"], output["logprobs"]) else: - prompts_ids_list = self.processing_class(prompts)["input_ids"] - forward_kwargs = {} - prompt_ids = [torch.tensor(ids, device=device) for ids in prompts_ids_list] - prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + payload = None - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_completion_ids = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - **forward_kwargs, - generation_config=self.generation_config, - disable_compile=True, - ) - completion_ids = prompt_completion_ids[:, prompt_ids.shape[1] :] - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - return completions + # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + all_completion_ids, all_logprobs = obj_list[0] + + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), + ) + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] + return completion_ids, logprobs def _generate_vllm_colocate(self, prompts: list[str]): if self.guided_decoding_regex: @@ -1137,33 +1151,92 @@ def _generate_vllm_colocate(self, prompts: list[str]): all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - - all_completions = self.processing_class.batch_decode(all_completion_ids, skip_special_tokens=False) + all_logprobs = [ + [next(iter(lp.values())).logprob for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] if self.vllm_tensor_parallel_size > 1: # Slice completions for this rank within its TP group. # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completions = all_completions[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] else: - completions = all_completions + completion_ids = all_completion_ids + logprobs = all_logprobs if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) - return completions + return completion_ids, logprobs + + def _generate_transformers_paged(self, prompts: list[str]): + raise NotImplementedError("Transformers with model paging is not yet supported.") + + def _generate_transformers(self, prompts: list[str]): + device = self.accelerator.device + _is_conversational = is_conversational({"prompt": prompts[0]}) + if _is_conversational: + forward_kwargs = self.processing_class.apply_chat_template( + prompts, add_generation_prompt=True, return_dict=True, tokenize=True + ) + forward_kwargs = super()._prepare_inputs(forward_kwargs) + prompts_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") + else: + prompts_ids_list = self.processing_class(prompts)["input_ids"] + forward_kwargs = {} + prompt_ids = [torch.tensor(ids, device=device) for ids in prompts_ids_list] + prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, + ) + completion_ids = prompt_completion_ids[:, prompt_ids.shape[1] :] + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask)] + return completion_ids, None def _generate(self, prompts): if self.use_vllm: + if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: + # Wake up colocated vLLM instances if needed + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up() + + # First, update the vLLM weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + if self.vllm_mode == "server": - raise NotImplementedError("vLLM server mode is not supported yet.") + return self._generate_vllm_server(prompts) elif self.vllm_mode == "colocate": return self._generate_vllm_colocate(prompts) else: - raise ValueError(f"Invalid vLLM mode: {self.vllm_mode}") + raise ValueError(f"Unsupported vLLM mode: {self.vllm_mode}.") elif self.use_transformers_paged: - raise NotImplementedError("Transformers paged generation is not supported yet.") + return self._generate_transformers_paged(prompts) else: return self._generate_transformers(prompts) @@ -1186,20 +1259,16 @@ def _generate_and_score_completions( if images is not None: prompts = [insert_images(prompt, image_list) for prompt, image_list in zip(prompts, images)] - completion_texts = self._generate(prompts) + completion_ids_list, logprobs = self._generate(prompts) if _is_conversational: - prompt_texts = self.processing_class.apply_chat_template( - prompts, add_generation_prompt=True, tokenize=False + forward_kwargs = self.processing_class.apply_chat_template( + prompts, add_generation_prompt=True, tokenize=True, return_dict=True ) + forward_kwargs = super()._prepare_inputs(forward_kwargs) + prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") else: - prompt_texts = prompts - prompts_completions = [p + c for p, c in zip(prompt_texts, completion_texts)] - kwargs = {"images": images} if images is not None else {} - forward_kwargs = self.processing_class(text=prompt_texts, **kwargs) - forward_kwargs = super()._prepare_inputs(forward_kwargs) - prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") - prompt_completion_ids_list = self.processing_class(text=prompts_completions, **kwargs)["input_ids"] - completion_ids_list = [pc[len(p) :] for p, pc in zip(prompt_ids_list, prompt_completion_ids_list)] + prompt_ids_list = self.processing_class(text=prompts)["input_ids"] + forward_kwargs = {} # Convert to tensors and pad prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1210,6 +1279,9 @@ def _generate_and_score_completions( prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if logprobs is not None: + sampling_per_token_logps = [torch.tensor(lp, device=device) for lp in logprobs] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) @@ -1279,12 +1351,14 @@ def _generate_and_score_completions( # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. if _is_conversational: - # Note that we need to decode the completion ids instead of using completion_texts, because in - # completions_texts we may have special tokens like EOS. - decodes_completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in decodes_completions] + # Reconstructing a conversational turn this way might not be perfectly accurate, since a completion could + # be truncated. Applying the chat template may therefore produce text that differs slightly from the + # sequence generated by the model. We assume the reward model is robust enough to tolerate these small + # discrepancies. + completions = self.processing_class.batch_decode(completion_ids_list, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in completions] else: - completions = completion_texts + completions = self.processing_class.batch_decode(completion_ids_list, skip_special_tokens=False) rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum From c1ae6aa787cab340438f9d22a6e13f919b4a3d9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 23:56:11 +0000 Subject: [PATCH 032/153] back to working point --- trl/trainer/grpo_trainer.py | 528 ++++++++++++++++++++++-------------- 1 file changed, 325 insertions(+), 203 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 071e0ccd81d..7d9138319e6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,6 +14,7 @@ import inspect import os +import re import textwrap from collections import defaultdict, deque from contextlib import nullcontext @@ -45,9 +46,9 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, insert_images, is_conversational, prepare_multimodal_messages_vllm +from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -72,6 +73,7 @@ shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, + truncate_with_protected_tokens, unsplit_pixel_values_by_grid, ) @@ -535,7 +537,6 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, - enforce_eager=True, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1057,188 +1058,324 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_vllm_server(self, prompts: list[str]): - prompts_text = self.processing_class.apply_chat_template(prompts, add_generation_prompt=True, tokenize=False) + def _generate(self, prompts: list[str], images: Optional[list]): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" - all_prompts_text = gather_object(prompts_text) - images = None + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] + kwargs = {} if images is not None: - all_images = gather_object(images) + kwargs = {"images": images} + for prompt, image_list in zip(prompts, images): + if isinstance(prompt, list): # i.e., when using conversational data + prepare_multimodal_messages(prompt, num_images=len(image_list)) - if self.accelerator.is_main_process: - # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate - # num_generations outputs for each one. This is faster than generating outputs for each duplicate - # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + prompts_text = [ + maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - - with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + **kwargs, + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + + if self.max_prompt_length is not None: + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. + # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, + # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] + protected = [token for token in protected if token is not None] + prompt_ids, prompt_mask = truncate_with_protected_tokens( + prompt_ids, prompt_mask, self.max_prompt_length, protected + ) + + prompts_text = self.processing_class.batch_decode( + prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False + ) + prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] + + # The chat template sometimes inserts a single image token into the prompt text. However, when this text is + # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the + # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We + # collapse them back into a single token string to match the original chat template in case it originally + # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images + # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only + # the vision_start_token_id (e.g. ). + if self.image_token is not None: + escaped_img_token = re.escape(self.image_token) + # Search for the image token in the chat template + if re.search(escaped_img_token, self.processing_class.chat_template): + prompts_text = [ + re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text + ] + else: + # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id + if self.vision_end_token_id is not None: + escaped_eoi_token = re.escape( + self.processing_class.tokenizer.decode([self.vision_end_token_id]) + ) + prompts_text = [ + re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text + ] + else: + # If vision_end_token_id is None, just remove the image tokens + prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] + + # Generate completions using either vLLM or regular generation + if self.use_vllm: + if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: + # wake up colocated vLLM instances if needed + torch.cuda.empty_cache() # required to avoid OOM in some cases + self.llm.wake_up() + + # First, update the vLLM weights if needed + if self.state.global_step != self._last_loaded_step: + self._move_model_to_vllm() + self._last_loaded_step = self.state.global_step + + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process + if self.vllm_mode == "server": + all_prompts_text = gather_object(prompts_text) + if images is not None: + all_images = gather_object(images) + + if self.accelerator.is_main_process: + # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate + # num_generations outputs for each one. This is faster than generating outputs for each duplicate + # prompt individually. + ordered_set_of_prompts = all_prompts_text[:: self.num_generations] + + if images is not None: + ordered_set_of_images = all_images[:: self.num_generations] + else: + ordered_set_of_images = None + + with profiling_context(self, "vLLM.generate"): + output = self.vllm_client.generate( + prompts=ordered_set_of_prompts, + images=ordered_set_of_images, + n=self.num_generations, + repetition_penalty=self.repetition_penalty, + temperature=self.temperature, + top_p=self.top_p, + top_k=-1 if self.top_k is None else self.top_k, + min_p=0.0 if self.min_p is None else self.min_p, + max_tokens=self.max_completion_length, + guided_decoding_regex=self.guided_decoding_regex, + generation_kwargs=self.args.generation_kwargs, + ) + payload = (output["completion_ids"], output["logprobs"]) + else: + payload = None + + # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. + obj_list = [payload] + broadcast_object_list(obj_list, from_process=0) + completion_ids, all_logprobs = obj_list[0] + + process_slice = slice( + self.accelerator.process_index * len(prompts), + (self.accelerator.process_index + 1) * len(prompts), ) - payload = (output["completion_ids"], output["logprobs"]) - else: - payload = None + completion_ids = completion_ids[process_slice] + all_logprobs = all_logprobs[process_slice] - # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. - obj_list = [payload] - broadcast_object_list(obj_list, from_process=0) - all_completion_ids, all_logprobs = obj_list[0] + # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts + elif self.vllm_mode == "colocate": + if self.guided_decoding_regex: + guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) + else: + guided_decoding = None + + generation_kwargs = { + "n": 1, # vLLM on each GPU generates only 1 in colocate mode + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "guided_decoding": guided_decoding, + "logprobs": 0, # only return the logprob of the generated token + } + if self.args.generation_kwargs is not None: + generation_kwargs.update(self.args.generation_kwargs) + sampling_params = SamplingParams(**generation_kwargs) - process_slice = slice( - self.accelerator.process_index * len(prompts), - (self.accelerator.process_index + 1) * len(prompts), - ) - completion_ids = all_completion_ids[process_slice] - logprobs = all_logprobs[process_slice] - return completion_ids, logprobs + if self.vllm_tensor_parallel_size > 1: + # Gather prompts from all ranks in the TP group and flatten. + # Each rank starts with its own prompts; after gathering, all ranks see the full group set. + orig_size = len(prompts_text) + gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) + all_prompts_text = [p for sublist in gathered_prompts for p in sublist] + + if images is not None: + gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] + torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) + all_images = [img for sublist in gathered_images for img in sublist] + else: + all_images = None + else: + all_prompts_text = prompts_text + all_images = images - def _generate_vllm_colocate(self, prompts: list[str]): - if self.guided_decoding_regex: - guided_decoding = GuidedDecodingParams(regex=self.guided_decoding_regex) - else: - guided_decoding = None - - generation_kwargs = { - "n": 1, # vLLM on each GPU generates only 1 in colocate mode - "repetition_penalty": self.repetition_penalty, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": -1 if self.top_k is None else self.top_k, - "min_p": 0.0 if self.min_p is None else self.min_p, - "max_tokens": self.max_completion_length, - "guided_decoding": guided_decoding, - "logprobs": 0, # only return the logprob of the generated token - } - if self.args.generation_kwargs is not None: - generation_kwargs.update(self.args.generation_kwargs) - sampling_params = SamplingParams(**generation_kwargs) - - if self.vllm_tensor_parallel_size > 1: - # Gather prompts from all ranks in the TP group and flatten. - # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts) - gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) - all_prompts = [p for sublist in gathered_prompts for p in sublist] - else: - all_prompts = prompts + if images is not None and all_images: + vllm_inputs = [] + for prompt, image_list in zip(all_prompts_text, all_images): + vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - _is_conversational = is_conversational({"prompt": prompts[0]}) + else: + vllm_inputs = all_prompts_text - is_multimodal = isinstance(prompts[0], list) and isinstance(prompts[0][0]["content"], list) - if is_multimodal: - all_prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in all_prompts] + with profiling_context(self, "vLLM.generate"): + all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - with profiling_context(self, "vLLM.generate"): - if _is_conversational: - all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) - else: - all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) + completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_logprobs = [ + [next(iter(lp.values())).logprob for lp in output.logprobs] + for outputs in all_outputs + for output in outputs.outputs + ] - all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] - all_logprobs = [ - [next(iter(lp.values())).logprob for lp in output.logprobs] - for outputs in all_outputs - for output in outputs.outputs - ] + if self.vllm_tensor_parallel_size > 1: + # Slice completions for this rank within its TP group. + # Each rank generates all outputs — we keep only our share. + local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) + tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + completion_ids = completion_ids[tp_slice] + all_logprobs = all_logprobs[tp_slice] - if self.vllm_tensor_parallel_size > 1: - # Slice completions for this rank within its TP group. - # Each rank generates all outputs — we keep only our share. - local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) - tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = all_completion_ids[tp_slice] - logprobs = all_logprobs[tp_slice] - else: - completion_ids = all_completion_ids - logprobs = all_logprobs + if self.args.vllm_enable_sleep_mode: + self.llm.sleep(level=1) - if self.args.vllm_enable_sleep_mode: - self.llm.sleep(level=1) + # Pad the completions, and concatenate them with the prompts + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id) + completion_mask = pad(completion_mask, padding_value=0) + sampling_per_token_logps = [ + torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs + ] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - return completion_ids, logprobs + elif self.use_transformers_paged: + # Re-process inputs for paged generation if needed + # Note: images are already validated and preprocessed above + paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) + previous_attn = self.model_wrapped.config._attn_implementation - def _generate_transformers_paged(self, prompts: list[str]): - raise NotImplementedError("Transformers with model paging is not yet supported.") + if is_flash_attn_2_available(): + self.model_wrapped.config._attn_implementation = "paged_attention" + else: + self.model_wrapped.config._attn_implementation = "sdpa_paged" + with ( + profiling_context(self, "transformers.generate_batch"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + # Cast to the appropriate dtype based on training configuration + if self.args.bf16: + unwrapped_model.to(torch.bfloat16) + elif self.args.fp16: + unwrapped_model.to(torch.float16) + with torch.inference_mode(): + all_outputs = unwrapped_model.generate_batch( + paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + ) + unwrapped_model.train() # restore training mode, as generate_batch forces eval mode + completion_ids = [output.generated_tokens for output in all_outputs.values()] + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + # Restore the original attention implementation, training mode + self.model_wrapped.config._attn_implementation = previous_attn + sampling_per_token_logps = None # not used in this case - def _generate_transformers(self, prompts: list[str]): - device = self.accelerator.device - _is_conversational = is_conversational({"prompt": prompts[0]}) - if _is_conversational: - forward_kwargs = self.processing_class.apply_chat_template( - prompts, add_generation_prompt=True, return_dict=True, tokenize=True - ) - forward_kwargs = super()._prepare_inputs(forward_kwargs) - prompts_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") else: - prompts_ids_list = self.processing_class(prompts)["input_ids"] - forward_kwargs = {} - prompt_ids = [torch.tensor(ids, device=device) for ids in prompts_ids_list] - prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - - with ( - profiling_context(self, "transformers.generate"), - unwrap_model_for_generation( - self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation - ) as unwrapped_model, - torch.no_grad(), - FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), - ): - prompt_completion_ids = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - **forward_kwargs, - generation_config=self.generation_config, - disable_compile=True, - ) - completion_ids = prompt_completion_ids[:, prompt_ids.shape[1] :] + # Regular generation path + with ( + profiling_context(self, "transformers.generate"), + unwrap_model_for_generation( + self.model_wrapped, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation + ) as unwrapped_model, + torch.no_grad(), + FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), + ): + prompt_completion_ids = unwrapped_model.generate( + input_ids=prompt_ids, + attention_mask=prompt_mask, + **forward_kwargs, + generation_config=self.generation_config, + disable_compile=True, + ) + # Compute prompt length and extract completion ids + prompt_length = prompt_ids.size(1) + prompt_ids = prompt_completion_ids[:, :prompt_length] + completion_ids = prompt_completion_ids[:, prompt_length:] + sampling_per_token_logps = None # not used in this case + + # Mask everything after the first EOS token is_eos = completion_ids == self.eos_token_id eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask)] - return completion_ids, None - def _generate(self, prompts): - if self.use_vllm: - if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: - # Wake up colocated vLLM instances if needed - torch.cuda.empty_cache() # required to avoid OOM in some cases - self.llm.wake_up() + # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging + completion_lengths = completion_mask.sum(1) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - # First, update the vLLM weights if needed - if self.state.global_step != self._last_loaded_step: - self._move_model_to_vllm() - self._last_loaded_step = self.state.global_step + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + truncated_completions = ~is_eos.any(dim=1) + completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - if self.vllm_mode == "server": - return self._generate_vllm_server(prompts) - elif self.vllm_mode == "colocate": - return self._generate_vllm_colocate(prompts) - else: - raise ValueError(f"Unsupported vLLM mode: {self.vllm_mode}.") - elif self.use_transformers_paged: - return self._generate_transformers_paged(prompts) - else: - return self._generate_transformers(prompts) + # Log the metrics + if mode == "train": + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] + + # Log completion lengths, mean, min, max + self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) + + # Identify sequences that terminated with EOS and log their lengths + agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) + term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] + clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) + self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found + term_completion_lengths = torch.zeros(1, device=device) + self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) + self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) + self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + + return ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1247,7 +1384,6 @@ def _generate_and_score_completions( mode = "train" if self.model.training else "eval" prompts = [x["prompt"] for x in inputs] - _is_conversational = is_conversational(inputs[0]) if "images" in inputs[0]: images = [example.get("images") for example in inputs] @@ -1256,34 +1392,23 @@ def _generate_and_score_completions( else: images = None - if images is not None: - prompts = [insert_images(prompt, image_list) for prompt, image_list in zip(prompts, images)] + ( + prompt_ids, + completion_ids, + prompt_mask, + completion_mask, + num_items_in_batch, + sampling_per_token_logps, + forward_kwargs, + ) = self._generate(prompts, images) - completion_ids_list, logprobs = self._generate(prompts) - if _is_conversational: - forward_kwargs = self.processing_class.apply_chat_template( - prompts, add_generation_prompt=True, tokenize=True, return_dict=True - ) - forward_kwargs = super()._prepare_inputs(forward_kwargs) - prompt_ids_list, _ = forward_kwargs.pop("input_ids"), forward_kwargs.pop("attention_mask") - else: - prompt_ids_list = self.processing_class(text=prompts)["input_ids"] - forward_kwargs = {} - - # Convert to tensors and pad - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] - prompt_mask = [torch.ones_like(ids) for ids in prompt_ids] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids) for ids in completion_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - if logprobs is not None: - sampling_per_token_logps = [torch.tensor(lp, device=device) for lp in logprobs] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") - prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need + # to re-tokenize completions if the reward is computed from tokens. + completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + + # Concatenate prompt_mask with completion_mask for logit computation + prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size @@ -1347,18 +1472,20 @@ def _generate_and_score_completions( else: ref_per_token_logps = None + # Decode + prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) + completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + if is_conversational(inputs[0]): + completions = [] + for prompt, completion in zip(prompts, completions_text): + bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" + completions.append([{"role": "assistant", "content": bootstrap + completion}]) + else: + completions = completions_text + # Calculate rewards for each reward function. rewards_per_func aggregates rewards across all processes. This is # important because rewards will be normalized per group, and completions are distributed. We will later slice # rewards_per_func to extract each process's subset. - if _is_conversational: - # Reconstructing a conversational turn this way might not be perfectly accurate, since a completion could - # be truncated. Applying the chat template may therefore produce text that differs slightly from the - # sequence generated by the model. We assume the reward model is robust enough to tolerate these small - # discrepancies. - completions = self.processing_class.batch_decode(completion_ids_list, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in completions] - else: - completions = self.processing_class.batch_decode(completion_ids_list, skip_special_tokens=False) rewards_per_func = self._calculate_rewards(inputs, prompts, completions, completion_ids_list) # Apply weights to each reward function's output and sum @@ -1395,11 +1522,6 @@ def _generate_and_score_completions( all_process_advantages = advantages.clone() # keep the aggregated advantages for logging advantages = advantages[process_slice] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) - agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - # Calculate mean reward per function, but only for samples where the function was applied (non-NaN values) for i, reward_func_name in enumerate(self.reward_func_names): mean_rewards = torch.nanmean(rewards_per_func[:, i]).item() @@ -1411,8 +1533,8 @@ def _generate_and_score_completions( self._metrics[mode]["frac_reward_zero_std"].append(is_std_zero.float().mean().item()) # Log prompt and completion texts - self._logs["prompt"].extend(gather_object([str(p) for p in prompts])) - self._logs["completion"].extend(gather_object([str(c) for c in completions])) + self._logs["prompt"].extend(gather_object(prompts_text)) + self._logs["completion"].extend(gather_object(completions_text)) for i, name in enumerate(self.reward_func_names): self._logs["rewards"][name].extend(rewards_per_func[:, i].tolist()) self._logs["advantages"].extend(all_process_advantages.tolist()) From 1a66b431d00def9527472f49b11eb14b7ecd4158 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 25 Sep 2025 23:57:14 +0000 Subject: [PATCH 033/153] revert chage data utils --- trl/data_utils.py | 40 ---------------------------------------- 1 file changed, 40 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index bbe0a4af934..75e7a76f979 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy from collections import defaultdict, deque from collections.abc import Sequence from itertools import takewhile @@ -58,7 +57,6 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) ] ``` """ - messages = copy.deepcopy(messages) image_included = False for message in messages: if message["role"] == "system": @@ -76,44 +74,6 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) message["content"] = [{"type": "text", "text": message["content"]}] else: raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.") - return messages - - -def insert_images(messages: list[dict[str, Any]], images: list) -> None: - """ """ - messages = prepare_multimodal_messages(messages, num_images=len(images)) - image_idx = 0 - for message in messages: - if message["role"] == "user": - for part in message["content"]: - if part["type"] == "image": - part["image"] = images[image_idx] - image_idx += 1 - return messages - - -def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> None: - """ - Convert messages into a structured multimodal format if needed. - - Each message's content is transformed from a raw string into a list of typed parts. The first user message is - prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. - - Args: - messages (`list[dict[str, Any]]`): - Messages with `"role"` and `"content"`. Content may be a raw string before transformation. - images (`list`): - Images to include in the first user message. - """ - messages = copy.deepcopy(messages) - for message in messages: - if message["role"] == "user": - if isinstance(message["content"], list): # if already prepared, the content will be a list - for part in message["content"]: - if part["type"] == "image": - part["type"] = "image_pil" - part["image_pil"] = part.pop("image") - return messages def is_conversational(example: dict[str, Any]) -> bool: From 9435a9400fc6a6d65b648db730e85a44eaf1d38f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 02:48:11 +0000 Subject: [PATCH 034/153] refactor in grpo --- trl/trainer/grpo_trainer.py | 118 +++++++++++++++++------------------- 1 file changed, 55 insertions(+), 63 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 216591acf6a..9a79b68179a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1192,14 +1192,14 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - completion_ids, all_logprobs = obj_list[0] + all_completion_ids, all_logprobs = obj_list[0] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) - completion_ids = completion_ids[process_slice] - all_logprobs = all_logprobs[process_slice] + completion_ids = all_completion_ids[process_slice] + logprobs = all_logprobs[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": @@ -1252,7 +1252,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] all_logprobs = [ [next(iter(lp.values())).logprob for lp in output.logprobs] for outputs in all_outputs @@ -1264,22 +1264,12 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] - all_logprobs = all_logprobs[tp_slice] + completion_ids = all_completion_ids[tp_slice] + logprobs = all_logprobs[tp_slice] if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - completion_mask = pad(completion_mask, padding_value=0) - sampling_per_token_logps = [ - torch.tensor(logprobs, device=device, dtype=torch.float32) for logprobs in all_logprobs - ] - sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0) - elif self.use_transformers_paged: # Re-process inputs for paged generation if needed # Note: images are already validated and preprocessed above @@ -1309,13 +1299,11 @@ def _generate(self, prompts: list[str], images: Optional[list]): ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = paged_prompt_inputs.input_ids # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn - sampling_per_token_logps = None # not used in this case + + logprobs = None # not used in this case else: # Regular generation path @@ -1338,29 +1326,27 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - sampling_per_token_logps = None # not used in this case - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] + + logprobs = None # not used in this case - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) + # Get completion length per sequence, used for logging + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) agg_completion_lengths = self.accelerator.gather(completion_lengths) num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() - # Log the metrics if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self.state.num_input_tokens_seen += num_items_in_batch.item() self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths, mean, min, max @@ -1369,25 +1355,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found term_completion_lengths = torch.zeros(1, device=device) self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) + return prompt_ids, completion_ids, num_items_in_batch, logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1404,19 +1383,32 @@ def _generate_and_score_completions( else: images = None - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) = self._generate(prompts, images) - - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + (prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, forward_kwargs) = ( + self._generate(prompts, images) + ) + + # Identify truncated sequences (not ending with EOS or PAD) before any padding is applied + eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) From 3d8ea27c6807207e444c7bcef8ce1ffef1a4ddc1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 02:54:26 +0000 Subject: [PATCH 035/153] wrong merge commit --- trl/trainer/grpo_trainer.py | 62 ------------------------------------- 1 file changed, 62 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f92998c0feb..9a79b68179a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1326,7 +1326,6 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_length = prompt_ids.size(1) prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - sampling_per_token_logps = None # not used in this case # Mask everything after the first EOS token is_eos = completion_ids == self.eos_token_id @@ -1411,67 +1410,6 @@ def _generate_and_score_completions( if self.mask_truncated_completions: completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() - # Log the metrics - if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() - self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] - - # Log completion lengths, mean, min, max - self._metrics[mode]["completions/mean_length"].append(agg_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_length"].append(agg_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) - - # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) - if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found - term_completion_lengths = torch.zeros(1, device=device) - self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) - self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) - self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - - return ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) - - def _generate_and_score_completions( - self, inputs: list[dict[str, Union[torch.Tensor, Any]]] - ) -> dict[str, Union[torch.Tensor, Any]]: - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - prompts = [x["prompt"] for x in inputs] - - if "images" in inputs[0]: - images = [example.get("images") for example in inputs] - elif "image" in inputs[0]: - images = [[example.get("image")] if example.get("image") is not None else None for example in inputs] - else: - images = None - - ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, - num_items_in_batch, - sampling_per_token_logps, - forward_kwargs, - ) = self._generate(prompts, images) - - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] - # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) From 27dc9585a04f1231e4d4057ff852eee52e9cf2a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 03:09:42 +0000 Subject: [PATCH 036/153] fix num_input_tokens_seen --- trl/trainer/grpo_trainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9a79b68179a..664041ef756 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1302,7 +1302,6 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_ids = paged_prompt_inputs.input_ids # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn - logprobs = None # not used in this case else: @@ -1333,20 +1332,21 @@ def _generate(self, prompts: list[str], images: Optional[list]): eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - logprobs = None # not used in this case # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) agg_completion_lengths = self.accelerator.gather(completion_lengths) - num_items_in_batch = agg_completion_lengths.sum() # this is required for the DAPO loss + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # Log the metrics if mode == "train": - self.state.num_input_tokens_seen += num_items_in_batch.item() + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths, mean, min, max @@ -1366,7 +1366,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, num_items_in_batch, logprobs, forward_kwargs + return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] From 53772ef7b8a21520afc7b0570577cd7ed46fd511 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 16:02:03 +0000 Subject: [PATCH 037/153] getting closer --- trl/trainer/grpo_trainer.py | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 664041ef756..d63ae64d367 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1070,9 +1070,8 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device - mode = "train" if self.model.training else "eval" # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to @@ -1088,15 +1087,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) + prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} @@ -1266,6 +1257,9 @@ def _generate(self, prompts: list[str], images: Optional[list]): tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) completion_ids = all_completion_ids[tp_slice] logprobs = all_logprobs[tp_slice] + else: + completion_ids = all_completion_ids + logprobs = all_logprobs if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1336,6 +1330,14 @@ def _generate(self, prompts: list[str], images: Optional[list]): completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case + return prompt_ids, completion_ids, logprobs, forward_kwargs + + def _generate(self, prompts: list[str], images: Optional[list]): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompt_ids, completion_ids, completion_logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) @@ -1366,7 +1368,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs + return prompt_ids, completion_ids, completion_logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1383,8 +1385,8 @@ def _generate_and_score_completions( else: images = None - (prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, forward_kwargs) = ( - self._generate(prompts, images) + (prompt_ids_list, completion_ids_list, sampling_per_token_logps_list, forward_kwargs) = self._generate( + prompts, images ) # Identify truncated sequences (not ending with EOS or PAD) before any padding is applied @@ -1406,6 +1408,8 @@ def _generate_and_score_completions( else: sampling_per_token_logps = None + num_items_in_batch = completion_mask.sum() + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() From 8766fa5cc0153aac737380d29b64b644e26393c5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 16:12:07 +0000 Subject: [PATCH 038/153] consistent naming --- trl/trainer/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d63ae64d367..e281848d458 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1336,7 +1336,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, completion_logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1368,7 +1368,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, completion_logprobs, forward_kwargs + return prompt_ids, completion_ids, logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] From 236b78b455c6e3aea031939a90af118ce1e593dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 16:14:18 +0000 Subject: [PATCH 039/153] better --- trl/trainer/grpo_trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e281848d458..55fa7677229 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1389,10 +1389,6 @@ def _generate_and_score_completions( prompts, images ) - # Identify truncated sequences (not ending with EOS or PAD) before any padding is applied - eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] - is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) - # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] @@ -1412,6 +1408,8 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: + eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation From 9da4830c536b0cda3b736ec21e61ad13d692bd11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 16:22:44 +0000 Subject: [PATCH 040/153] simplify a bit + comment --- trl/trainer/grpo_trainer.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 55fa7677229..ec3bf224f09 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1344,7 +1344,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): agg_prompt_lengths = self.accelerator.gather(prompt_lengths) agg_completion_lengths = self.accelerator.gather(completion_lengths) total_prompt_tokens = agg_prompt_lengths.sum() - total_completion_tokens = agg_completion_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss # Log the metrics if mode == "train": @@ -1368,7 +1368,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, logprobs, forward_kwargs + return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1385,9 +1385,13 @@ def _generate_and_score_completions( else: images = None - (prompt_ids_list, completion_ids_list, sampling_per_token_logps_list, forward_kwargs) = self._generate( - prompts, images - ) + ( + prompt_ids_list, + completion_ids_list, + num_items_in_batch, + sampling_per_token_logps_list, + forward_kwargs, + ) = self._generate(prompts, images) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1404,8 +1408,6 @@ def _generate_and_score_completions( else: sampling_per_token_logps = None - num_items_in_batch = completion_mask.sum() - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] From b3bd0b05d4e5e242fb4b622836c122c1d467d41a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 18:05:49 +0000 Subject: [PATCH 041/153] another one --- tests/test_utils.py | 109 +++++++++--------------------------- trl/trainer/grpo_trainer.py | 24 ++++++-- trl/trainer/utils.py | 68 ++++++++-------------- 3 files changed, 66 insertions(+), 135 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6f6ba1579ef..a2dce2a3d49 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -937,136 +937,79 @@ def test_multi_images(self): class TruncateWithProtectedTokensTester(TrlTestCase): def test_basic_example(self): """Test the basic example from the problem description.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [2, 3, 6] + prompt_ids = [1, 2, 3, 4, 5] + protected_tokens = [2, 3] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - expected_ids = torch.tensor([[2, 3, 5], [6, 9, 10]]) - expected_mask = torch.ones_like(expected_ids) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertTrue(torch.equal(new_ids, expected_ids)) - self.assertTrue(torch.equal(new_mask, expected_mask)) + expected_ids = [2, 3, 5] + self.assertEqual(new_ids, expected_ids) def test_no_truncation_needed(self): """Test when target length equals current length.""" - prompt_ids = torch.tensor([[1, 2, 3]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3] protected_tokens = [2] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertTrue(torch.equal(new_ids, prompt_ids)) - self.assertTrue(torch.equal(new_mask, prompt_mask)) + self.assertEqual(new_ids, prompt_ids) def test_no_protected_tokens(self): """Test truncation with no protected tokens (normal right truncation).""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor([[3, 4, 5]]) # Last 3 tokens - self.assertTrue(torch.equal(new_ids, expected_ids)) + expected_ids = [3, 4, 5] # Last 3 tokens + self.assertEqual(new_ids, expected_ids) def test_all_tokens_protected(self): """Test when all remaining tokens are protected.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [3, 4, 5] target_length = 3 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor([[3, 4, 5]]) - self.assertTrue(torch.equal(new_ids, expected_ids)) + expected_ids = [3, 4, 5] + self.assertEqual(new_ids, expected_ids) def test_too_many_protected_tokens(self): """Test error when too many protected tokens for target length.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [1, 2, 3, 4, 5] protected_tokens = [1, 2, 3, 4] target_length = 3 with self.assertRaises(ValueError): - truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) def test_single_batch_single_token(self): """Test edge case with single batch and single token.""" - prompt_ids = torch.tensor([[5]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [5] protected_tokens = [5] target_length = 1 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - self.assertTrue(torch.equal(new_ids, prompt_ids)) - - def test_mask_preservation(self): - """Test that mask values are correctly preserved.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.tensor([[1, 0, 1, 0, 1]]) # Mixed mask values - protected_tokens = [2, 4] - target_length = 3 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) - - expected_ids = torch.tensor([[2, 4, 5]]) - expected_mask = torch.tensor([[0, 0, 1]]) # Corresponding mask values - - self.assertTrue(torch.equal(new_ids, expected_ids)) - self.assertTrue(torch.equal(new_mask, expected_mask)) - - def test_multiple_batches_different_protected(self): - """Test multiple batches where protected tokens appear differently.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5], [2, 6, 7, 8, 9], [10, 11, 12, 2, 13]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [2] - target_length = 3 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - expected_ids = torch.tensor( - [ - [2, 4, 5], # 2 is protected, keep last 2 non-protected (4,5) - [2, 8, 9], # 2 is protected, keep last 2 non-protected (8,9) - [12, 2, 13], # 2 is protected, keep last 2 non-protected (12,13) - ] - ) - - self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertEqual(new_ids, prompt_ids) def test_order_preservation(self): """Test that relative order is preserved.""" - prompt_ids = torch.tensor([[10, 2, 20, 3, 30, 40]]) - prompt_mask = torch.ones_like(prompt_ids) + prompt_ids = [10, 2, 20, 3, 30, 40] protected_tokens = [2, 3] target_length = 4 - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - # Should keep protected tokens 2,3 and last 2 non-protected tokens 30,40 + # Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40 # Order should be: 2, 3, 30, 40 (maintaining original relative positions) - expected_ids = torch.tensor([[2, 3, 30, 40]]) - - self.assertTrue(torch.equal(new_ids, expected_ids)) - - def test_empty_protected_tokens_list(self): - """Test with empty protected tokens list.""" - prompt_ids = torch.tensor([[1, 2, 3, 4, 5]]) - prompt_mask = torch.ones_like(prompt_ids) - protected_tokens = [] - target_length = 2 - - new_ids, new_mask = truncate_with_protected_tokens(prompt_ids, prompt_mask, target_length, protected_tokens) + expected_ids = [2, 3, 30, 40] - expected_ids = torch.tensor([[4, 5]]) # Last 2 tokens - self.assertTrue(torch.equal(new_ids, expected_ids)) + self.assertEqual(new_ids, expected_ids) class UnsplitPixelValuesByGridTester(TrlTestCase): diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ec3bf224f09..948bfa1dde6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1087,9 +1087,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - prompt_inputs = self.processing_class(text=prompts_text, add_special_tokens=False, **kwargs) + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + add_special_tokens=False, + **kwargs, + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. @@ -1097,9 +1106,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) + prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False @@ -1300,6 +1307,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + with ( profiling_context(self, "transformers.generate"), unwrap_model_for_generation( @@ -1357,7 +1369,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] + eos_and_pad = [self.eos_token_id, self.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) agg_is_truncated = self.accelerator.gather(is_truncated) self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) @@ -1410,7 +1422,7 @@ def _generate_and_score_completions( # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: - eos_and_pad = [self.processing_class.eos_token_id, self.processing_class.pad_token_id] + eos_and_pad = [self.eos_token_id, self.pad_token_id] is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index b9f97020ed2..d192036671d 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1861,63 +1861,39 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch return batch -def truncate_with_protected_tokens( - ids: torch.Tensor, mask: torch.Tensor, target_length: int, protected_tokens: list[int] -) -> tuple[torch.Tensor, torch.Tensor]: +def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]: """ - Truncate tensors to target length while preserving protected tokens. + Truncate list to target length while preserving protected tokens. Args: - ids (`torch.Tensor`): + sequences (`list[int]`): Input tensor of token IDs, shape (batch_size, sequence_length). - mask (`torch.Tensor`): - Input tensor of attention masks, shape (batch_size, sequence_length). target_length (`int`): Desired length of the output sequences. protected_tokens (`list[int]`): List of token IDs that should be preserved in the output. """ protected_set = set(protected_tokens) - # Create protected_tokens tensor once to avoid recreating it on every call - protected_tokens_tensor = torch.tensor(list(protected_set), device=ids.device) - - def process_sequence(ids, mask): - # Create boolean masks - is_protected = torch.isin(ids, protected_tokens_tensor) - is_non_protected = ~is_protected - - # Count tokens - num_protected = is_protected.sum().item() - num_non_protected_needed = target_length - num_protected - - if num_non_protected_needed < 0: - raise ValueError( - f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " - f"Please increase target length to at least {num_protected} or disable truncation." - ) - - # Select which non-protected tokens to keep (rightmost ones) - non_protected_indices = torch.where(is_non_protected)[0] - keep_non_protected = torch.zeros_like(is_non_protected) - if num_non_protected_needed > 0: - keep_indices = non_protected_indices[-num_non_protected_needed:] - keep_non_protected[keep_indices] = True - - # Final mask: protected OR selected non-protected - keep_mask = is_protected | keep_non_protected - return ids[keep_mask], mask[keep_mask] - - # Process each sequence in the batch - truncated_seq = [] - truncated_mask = [] - - for i in range(ids.shape[0]): - new_ids, new_mask = process_sequence(ids[i], mask[i]) - truncated_seq.append(new_ids) - truncated_mask.append(new_mask) - - return torch.stack(truncated_seq), torch.stack(truncated_mask) + # Count protected tokens + num_protected = sum(1 for t in ids if t in protected_set) + if num_protected > target_length: + raise ValueError( + f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " + f"Please increase target length to at least {num_protected} or disable truncation." + ) + num_non_protected_needed = target_length - num_protected + result = [] + + # Iterate backward to select rightmost non-protected tokens + for t in reversed(ids): + if t in protected_set: + result.append(t) + elif num_non_protected_needed > 0: + result.append(t) + num_non_protected_needed -= 1 + # Reverse to restore original order + return result[::-1] def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: From d79b9e1c8fef750db650bafb6eac3f0cf229220c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 18:41:51 +0000 Subject: [PATCH 042/153] get prompt ids from generation --- trl/extras/vllm_client.py | 2 +- trl/scripts/vllm_serve.py | 3 ++- trl/trainer/grpo_trainer.py | 13 +++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 0932697d6ee..962cc1664e0 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -246,7 +246,7 @@ def pil_to_base64(image): ) if response.status_code == 200: json_response = response.json() - return {"completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]} + return {"prompt_ids": json_response["prompt_ids"], "completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]} else: raise Exception(f"Request failed: {response.status_code}, {response.text}") diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 3e448aedf13..7e7174f261c 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -596,13 +596,14 @@ async def generate(request: GenerateRequest): # Flatten and combine all results all_outputs = list(chain.from_iterable(all_outputs)) # from list of list to single list + prompt_ids = [output.prompt_token_ids for output in all_outputs] completion_ids = [list(output.token_ids) for outputs in all_outputs for output in outputs.outputs] logprobs: list[list[float]] = [ [sanitize_logprob(next(iter(logprob.values()))) for logprob in output.logprobs] for outputs in all_outputs for output in outputs.outputs ] - return {"completion_ids": completion_ids, "logprobs": logprobs} + return {"prompt_ids": prompt_ids, "completion_ids": completion_ids, "logprobs": logprobs} class InitCommunicatorRequest(BaseModel): host: str diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 948bfa1dde6..9c2ca3e1e88 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1096,11 +1096,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): **kwargs, ) prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). @@ -1183,19 +1184,20 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) - payload = (output["completion_ids"], output["logprobs"]) + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_completion_ids, all_logprobs = obj_list[0] + all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) + prompt_ids = all_prompt_ids[process_slice] completion_ids = all_completion_ids[process_slice] logprobs = all_logprobs[process_slice] @@ -1250,6 +1252,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] all_logprobs = [ [next(iter(lp.values())).logprob for lp in output.logprobs] @@ -1262,9 +1265,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] completion_ids = all_completion_ids[tp_slice] logprobs = all_logprobs[tp_slice] else: + prompt_ids = all_prompt_ids completion_ids = all_completion_ids logprobs = all_logprobs From 8d34d546bb2f0dc00bb4853e8fe2562802c5d3cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 18:56:45 +0000 Subject: [PATCH 043/153] remove pad token removal --- trl/trainer/grpo_trainer.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 948bfa1dde6..215161be1ad 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1102,8 +1102,8 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special + # tokens are needed for generation. protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] @@ -1111,7 +1111,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] # The chat template sometimes inserts a single image token into the prompt text. However, when this text is # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the From 0e2ae34a9342d552daef42906daf697a6c46617f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:41:24 +0000 Subject: [PATCH 044/153] rely on generator for prompt truncation --- trl/extras/vllm_client.py | 12 ++++- trl/scripts/vllm_serve.py | 5 +++ trl/trainer/grpo_trainer.py | 90 +++++++++++-------------------------- 3 files changed, 41 insertions(+), 66 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 962cc1664e0..fdc0b88a8c4 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -178,6 +178,7 @@ def generate( top_k: int = -1, min_p: float = 0.0, max_tokens: int = 16, + truncate_prompt_tokens: Optional[int] = None, guided_decoding_regex: Optional[str] = None, generation_kwargs: Optional[dict] = None, ) -> list[list[int]]: @@ -203,6 +204,10 @@ def generate( Minimum probability for sampling. max_tokens (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each prompt. + truncate_prompt_tokens (`int`, *optional*): + If set to `-1`, will use the truncation size supported by the model. If set to an integer k, will use + only the last k tokens from the prompt (i.e., left truncation). If set to `None`, truncation is + disabled. guided_decoding_regex (`str`, *optional*): Regular expression to guide the decoding process. generation_kwargs (`dict`, *optional*): @@ -240,13 +245,18 @@ def pil_to_base64(image): "top_k": top_k, "min_p": min_p, "max_tokens": max_tokens, + "truncate_prompt_tokens": truncate_prompt_tokens, "guided_decoding_regex": guided_decoding_regex, "generation_kwargs": generation_kwargs or {}, }, ) if response.status_code == 200: json_response = response.json() - return {"prompt_ids": json_response["prompt_ids"], "completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]} + return { + "prompt_ids": json_response["prompt_ids"], + "completion_ids": json_response["completion_ids"], + "logprobs": json_response["logprobs"], + } else: raise Exception(f"Request failed: {response.status_code}, {response.text}") diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 7e7174f261c..03d8c8dc06c 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -495,6 +495,7 @@ class GenerateRequest(BaseModel): top_k: int = -1 min_p: float = 0.0 max_tokens: int = 16 + truncate_prompt_tokens: Optional[int] = None guided_decoding_regex: Optional[str] = None generation_kwargs: dict = field(default_factory=dict) @@ -524,6 +525,9 @@ async def generate(request: GenerateRequest): - `min_p` (`float`, *optional*, defaults to `0.0`): Minimum probability threshold for sampling. - `max_tokens` (`int`, *optional*, defaults to `16`): Maximum number of tokens to generate for each completion. + - `truncate_prompt_tokens` (`int`, *optional*): If set to `-1`, will use the truncation size supported + by the model. If set to an integer k, will use only the last k tokens from the prompt (i.e., left + truncation). If set to `None`, truncation is disabled. - `guided_decoding_regex` (`str`, *optional*): A regex pattern for guided decoding. If provided, the model will only generate tokens that match this regex pattern. - `generation_kwargs` (`dict`, *optional*): Additional generation parameters to pass to the vLLM @@ -569,6 +573,7 @@ async def generate(request: GenerateRequest): "top_k": request.top_k, "min_p": request.min_p, "max_tokens": request.max_tokens, + "truncate_prompt_tokens": request.truncate_prompt_tokens, "guided_decoding": guided_decoding, "logprobs": 0, } diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index cca71fca1b8..c1fc44caec8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,7 +14,6 @@ import inspect import os -import re import textwrap from collections import defaultdict, deque from contextlib import nullcontext @@ -71,7 +70,6 @@ shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, - truncate_with_protected_tokens, unsplit_pixel_values_by_grid, ) @@ -205,14 +203,16 @@ def reward_func(completions, **kwargs): "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", "id": "2402.03300", # docstyle-ignore - "citation": textwrap.dedent("""\ + "citation": textwrap.dedent( + """\ @article{shao2024deepseekmath, title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, } - """), + """ + ), } def __init__( @@ -549,6 +549,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + enforce_eager=True, # avoid ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1087,58 +1088,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - - if self.max_prompt_length is not None: - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] - - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special - # tokens are needed for generation. - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] + if images is not None: + prompt_inputs = self.processing_class(text=prompts_text, return_tensors="pt", **kwargs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1180,6 +1135,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): top_k=-1 if self.top_k is None else self.top_k, min_p=0.0 if self.min_p is None else self.min_p, max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) @@ -1215,6 +1171,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding": guided_decoding, "logprobs": 0, # only return the logprob of the generated token } @@ -1311,10 +1268,17 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + prompt_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + truncation=True, + truncation_side="left", + add_special_tokens=False, + **kwargs, + ) + prompt_inputs = super()._prepare_inputs(prompt_inputs) with ( profiling_context(self, "transformers.generate"), @@ -1325,11 +1289,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): prompt_completion_ids = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - **forward_kwargs, - generation_config=self.generation_config, - disable_compile=True, + **prompt_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids prompt_length = prompt_ids.size(1) From 46d8eb79cfb2fa604b5f70242a5aba30397a7b95 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:43:17 +0000 Subject: [PATCH 045/153] revert --- trl/trainer/grpo_trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c1fc44caec8..49dd085c4d4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -203,16 +203,14 @@ def reward_func(completions, **kwargs): "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", "id": "2402.03300", # docstyle-ignore - "citation": textwrap.dedent( - """\ + "citation": textwrap.dedent("""\ @article{shao2024deepseekmath, title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, } - """ - ), + """), } def __init__( From 11acc758c257c2e3d37f3796cb84d3d08f573b76 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:43:45 +0000 Subject: [PATCH 046/153] rm enforce eager --- trl/trainer/grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 49dd085c4d4..78b4934c531 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -547,7 +547,6 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, - enforce_eager=True, # avoid ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) From acee7d817f641b5a3d8afb564a3368f008758784 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:45:09 +0000 Subject: [PATCH 047/153] rm truncate_with_protected_tokens --- tests/test_utils.py | 79 -------------------------------------------- trl/trainer/utils.py | 35 -------------------- 2 files changed, 114 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index a2dce2a3d49..fcb8337e1bb 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -42,7 +42,6 @@ shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, - truncate_with_protected_tokens, unsplit_pixel_values_by_grid, ) @@ -934,84 +933,6 @@ def test_multi_images(self): self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) -class TruncateWithProtectedTokensTester(TrlTestCase): - def test_basic_example(self): - """Test the basic example from the problem description.""" - prompt_ids = [1, 2, 3, 4, 5] - protected_tokens = [2, 3] - target_length = 3 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - expected_ids = [2, 3, 5] - self.assertEqual(new_ids, expected_ids) - - def test_no_truncation_needed(self): - """Test when target length equals current length.""" - prompt_ids = [1, 2, 3] - protected_tokens = [2] - target_length = 3 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - self.assertEqual(new_ids, prompt_ids) - - def test_no_protected_tokens(self): - """Test truncation with no protected tokens (normal right truncation).""" - prompt_ids = [1, 2, 3, 4, 5] - protected_tokens = [] - target_length = 3 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - expected_ids = [3, 4, 5] # Last 3 tokens - self.assertEqual(new_ids, expected_ids) - - def test_all_tokens_protected(self): - """Test when all remaining tokens are protected.""" - prompt_ids = [1, 2, 3, 4, 5] - protected_tokens = [3, 4, 5] - target_length = 3 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - expected_ids = [3, 4, 5] - self.assertEqual(new_ids, expected_ids) - - def test_too_many_protected_tokens(self): - """Test error when too many protected tokens for target length.""" - prompt_ids = [1, 2, 3, 4, 5] - protected_tokens = [1, 2, 3, 4] - target_length = 3 - - with self.assertRaises(ValueError): - truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - def test_single_batch_single_token(self): - """Test edge case with single batch and single token.""" - prompt_ids = [5] - protected_tokens = [5] - target_length = 1 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - self.assertEqual(new_ids, prompt_ids) - - def test_order_preservation(self): - """Test that relative order is preserved.""" - prompt_ids = [10, 2, 20, 3, 30, 40] - protected_tokens = [2, 3] - target_length = 4 - - new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - - # Should keep protected tokens 2, 3 and last 2 non-protected tokens 30, 40 - # Order should be: 2, 3, 30, 40 (maintaining original relative positions) - expected_ids = [2, 3, 30, 40] - - self.assertEqual(new_ids, expected_ids) - - class UnsplitPixelValuesByGridTester(TrlTestCase): def test_unsplit_correctly(self): pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d192036671d..21473b67178 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1861,41 +1861,6 @@ def unsplit_pixel_values_by_grid(batch: dict[str, Union[torch.Tensor, list[torch return batch -def truncate_with_protected_tokens(ids: list[int], target_length: int, protected_tokens: list[int]) -> list[int]: - """ - Truncate list to target length while preserving protected tokens. - - Args: - sequences (`list[int]`): - Input tensor of token IDs, shape (batch_size, sequence_length). - target_length (`int`): - Desired length of the output sequences. - protected_tokens (`list[int]`): - List of token IDs that should be preserved in the output. - """ - protected_set = set(protected_tokens) - - # Count protected tokens - num_protected = sum(1 for t in ids if t in protected_set) - if num_protected > target_length: - raise ValueError( - f"target_length ({target_length}) is too small for the protected tokens ({num_protected} tokens). " - f"Please increase target length to at least {num_protected} or disable truncation." - ) - num_non_protected_needed = target_length - num_protected - result = [] - - # Iterate backward to select rightmost non-protected tokens - for t in reversed(ids): - if t in protected_set: - result.append(t) - elif num_non_protected_needed > 0: - result.append(t) - num_non_protected_needed -= 1 - # Reverse to restore original order - return result[::-1] - - def create_model_from_path(model_id: str, **kwargs) -> PreTrainedModel: """ Create a model from a given path using the specified initialization arguments. From 0b5865e8f5e4dbae566af42e9f5becb048dc1248 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:57:23 +0000 Subject: [PATCH 048/153] ensure proper truncation and side --- trl/trainer/grpo_trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 78b4934c531..ffd75710c1c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1265,17 +1265,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - prompt_inputs = self.processing_class( + self.processing_class.truncation_side = "left" # ensure left truncation for generation + generate_inputs = self.processing_class( text=prompts_text, return_tensors="pt", padding=True, padding_side="left", + max_length=self.max_prompt_length, truncation=True, - truncation_side="left", add_special_tokens=False, **kwargs, ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) + generate_inputs = super()._prepare_inputs(generate_inputs) with ( profiling_context(self, "transformers.generate"), @@ -1286,11 +1287,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): prompt_completion_ids = unwrapped_model.generate( - **prompt_inputs, generation_config=self.generation_config, disable_compile=True + **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids + prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token From d8af0039fa0d2bd6e19622d3faedf2b1409fe135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 19:59:12 +0000 Subject: [PATCH 049/153] rm useless comment --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ffd75710c1c..15c31caf49b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1265,7 +1265,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - self.processing_class.truncation_side = "left" # ensure left truncation for generation + self.processing_class.truncation_side = "left" generate_inputs = self.processing_class( text=prompts_text, return_tensors="pt", From fc263a309a8956b6803af7f3618c6f2d3bb19dfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 20:01:37 +0000 Subject: [PATCH 050/153] rm imports --- .../grpo_with_replay_buffer_trainer.py | 8 +------- trl/trainer/rloo_trainer.py | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 5be414b0c96..9a726c5a381 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -28,13 +28,7 @@ from trl.import_utils import is_vllm_available from trl.models import unwrap_model_for_generation from trl.trainer.grpo_trainer import GRPOTrainer -from trl.trainer.utils import ( - nanmax, - nanmin, - nanstd, - pad, - truncate_with_protected_tokens, -) +from trl.trainer.utils import nanmax, nanmin, nanstd, pad from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3c7490eaac3..a67b9871aac 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -71,7 +71,6 @@ shuffle_sequence_dict, split_pixel_values_by_grid, split_tensor_dict, - truncate_with_protected_tokens, unsplit_pixel_values_by_grid, ) From 35f99fd8674a1c8e3e5285fd1fb630e3163f6ce4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 21:27:33 +0000 Subject: [PATCH 051/153] requires padding --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 15c31caf49b..c2ce1d4ee36 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1086,7 +1086,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): ] if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, return_tensors="pt", **kwargs) + prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) prompt_inputs = super()._prepare_inputs(prompt_inputs) forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} else: From 8149d0578fccf3feab9091cfed00881a76d98e63 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 21:27:47 +0000 Subject: [PATCH 052/153] rm truncation test --- tests/test_grpo_trainer.py | 41 -------------------------------------- 1 file changed, 41 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a839a654bca..5ad8f671691 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1472,47 +1472,6 @@ def reward_func(completions, **kwargs): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - @require_vision - def test_training_vlm_and_prompt_truncation(self): - # If not handled properly, prompt truncation may truncate image token - dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") - - def reward_func(completions, **kwargs): - """Reward function that rewards longer completions.""" - return [float(len(completion[0]["content"])) for completion in completions] - - training_args = GRPOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=3, # reduce the batch size to reduce memory usage - num_generations=3, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=18, - report_to="none", - ) - trainer = GRPOTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs=reward_func, - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - - # Check that the params have changed - # Because of the way the tiny models are initialized, the gradient does not flow properly through the - # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ("model.visual.",) - for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue - new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - @require_vision @require_vllm @parameterized.expand( From 9925199ee9937fbc647a8a7596f14df3c608dee4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 22:14:58 +0000 Subject: [PATCH 053/153] move forward_kwargs outside of generate --- trl/trainer/grpo_trainer.py | 40 ++++++++++++++++++++----------------- 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c2ce1d4ee36..d08a810ecdc 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1084,14 +1084,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): prompts_text = [ maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1304,13 +1296,13 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] logprobs = None # not used in this case - return prompt_ids, completion_ids, logprobs, forward_kwargs + return prompt_ids, completion_ids, logprobs def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs, forward_kwargs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1342,7 +1334,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, forward_kwargs + return prompt_ids, completion_ids, total_completion_tokens, logprobs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1359,13 +1351,9 @@ def _generate_and_score_completions( else: images = None - ( - prompt_ids_list, - completion_ids_list, - num_items_in_batch, - sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( + prompts, images + ) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1397,6 +1385,22 @@ def _generate_and_score_completions( num_images = [len(img_list) for img_list in images] if images is not None else None + # Get forward_kwargs for models with multimodal inputs + if images is not None: + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] + for prompt, image_list in zip(prompts, images): + prepare_multimodal_messages(prompt, num_images=len(image_list)) + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the From 48a1c30e7e95fdc4ccea3a60faf893a207c415eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 22:20:23 +0000 Subject: [PATCH 054/153] don't re-prepare data --- trl/trainer/grpo_trainer.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d08a810ecdc..b6852dfd60b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1387,11 +1387,6 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - for prompt, image_list in zip(prompts, images): - prepare_multimodal_messages(prompt, num_images=len(image_list)) prompts_text = [ apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] From 15c6620c84c29d9e7f72262e4de3e5c257dca060 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 23:32:38 +0000 Subject: [PATCH 055/153] refactor: update prepare_multimodal_messages to accept images directly and enhance handling of structured messages --- tests/test_data_utils.py | 97 ++++++++++++++++---------- trl/data_utils.py | 95 +++++++++++++++++++++----- trl/trainer/grpo_trainer.py | 133 ++++++++++++++++-------------------- trl/trainer/sft_trainer.py | 24 ++++--- 4 files changed, 213 insertions(+), 136 deletions(-) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 0a9eba7f7bb..5e5ce8fdb40 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import copy import itertools import textwrap import unittest @@ -20,6 +19,7 @@ from datasets import Dataset, DatasetDict from parameterized import parameterized +from PIL import Image from transformers import AutoProcessor, AutoTokenizer from trl.data_utils import ( @@ -47,30 +47,46 @@ def test_basic_user_assistant_conversation(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (32, 32), color="red") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, ] self.assertEqual(messages, expected) def test_first_user_message_gets_image(self): - """Test that only the first user message gets an image placeholder.""" + """Test that only the first user message gets an image.""" messages = [ {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, {"role": "user", "content": "How about the grass?"}, ] - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (32, 32), color="red") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, - {"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]}, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + { + "role": "user", + "content": [{"type": "text", "text": "How about the grass?"}], + }, ] self.assertEqual(messages, expected) @@ -81,20 +97,23 @@ def test_multiple_images(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - - prepare_multimodal_messages(messages, num_images=3) + images = [Image.new("RGB", (32, 32), color=color) for color in ["red", "green", "blue"]] + messages = prepare_multimodal_messages(messages, images=images) expected = [ { "role": "user", "content": [ - {"type": "image"}, - {"type": "image"}, - {"type": "image"}, + {"type": "image", "image": images[0]}, + {"type": "image", "image": images[1]}, + {"type": "image", "image": images[2]}, {"type": "text", "text": "What color is the sky?"}, ], }, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, ] self.assertEqual(messages, expected) @@ -106,11 +125,18 @@ def test_system_message_transformation(self): {"role": "user", "content": "What color is the sky?"}, ] - prepare_multimodal_messages(messages, num_images=1) + image = Image.new("RGB", (32, 32), color="red") + messages = prepare_multimodal_messages(messages, images=[image]) expected = [ - {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant"}]}, - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, ] self.assertEqual(messages, expected) @@ -123,25 +149,22 @@ def test_already_prepared_messages_unchanged(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - original = copy.deepcopy(messages) - prepare_multimodal_messages(messages, num_images=1) - - self.assertEqual(messages, original) - - def test_mixed_prepared_and_unprepared_messages(self): - """Test handling of mixed prepared and unprepared messages.""" - messages = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, - {"role": "user", "content": "What about the grass?"}, - ] + image = Image.new("RGB", (32, 32), color="red") + messages = prepare_multimodal_messages(messages, images=[image]) - prepare_multimodal_messages(messages, num_images=1) - - expected = [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, - {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, - {"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]}, + expected = messages = [ + { + "role": "system", + "content": [{"type": "text", "text": "You are a helpful assistant"}], + }, + { + "role": "user", + "content": [{"type": "image", "image": image}, {"type": "text", "text": "What color is the sky?"}], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, ] self.assertEqual(messages, expected) diff --git a/trl/data_utils.py b/trl/data_utils.py index 75e7a76f979..a1b188bf0de 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy from collections import defaultdict, deque from collections.abc import Sequence from itertools import takewhile @@ -28,19 +29,29 @@ DatasetType = TypeVar("DatasetType", Dataset, DatasetDict) -def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) -> None: +def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> list[dict[str, Any]]: """ - Convert messages into a structured multimodal format if needed. - - Each message's content is transformed from a raw string into a list of typed parts. The first user message is - prefixed with an image placeholder, while all other user and assistant messages are wrapped as text entries. + Convert messages into a structured multimodal format and inject the provided images into the message contents. Args: messages (`list[dict[str, Any]]`): Messages with `"role"` and `"content"`. Content may be a raw string before transformation. - num_images (`int`): - Number of images to include in the first user message. This is used to determine how many image - placeholders to add. + List of messages a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing + either a string or a list of structured blocks if already prepared. + images (`list`): + List of image objects to insert. + + Returns: + `list[dict[str, Any]]`: A deep-copied list of messages where every `"content"` value is a list of structured + content blocks, and all `"image"` placeholders are populated with the corresponding image objects. + + Notes: + - When the input `messages` isn't already in the structured format, (i.e., all `"content"` values are strings), + the function transforms them into the structured format by wrapping text in `{"type": "text", "text": ...}` + and inserting `{"type": "image"}` placeholders for the images *before* the first user message. + - When the input `messages` is already in the structured format (i.e., all `"content"` values are lists of + structured blocks), the function only fills in the actual images in the existing `{"type": "image"}` + placeholders. If the number of placeholders does not match the number of provided images, an error is raised. Example: ```python @@ -50,24 +61,28 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) {"role": "assistant", "content": "It looks like a cat."}, ] - # Output (num_images=1) + # Output, one image provided [ - {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What's in this image?"}]}, + {"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]}, ] ``` """ - image_included = False + + messages = copy.deepcopy(messages) # avoid modifying the original messages + + # First, convert all messages to the structured format if needed, and insert image placeholders if needed + images_included = False for message in messages: if message["role"] == "system": if isinstance(message["content"], str): # if already prepared, the content will be a list message["content"] = [{"type": "text", "text": message["content"]}] elif message["role"] == "user": - if isinstance(message["content"], str) and not image_included: - placeholders = [{"type": "image"}] * num_images - message["content"] = [*placeholders, {"type": "text", "text": message["content"]}] - image_included = True - elif isinstance(message["content"], str) and image_included: + if isinstance(message["content"], str) and not images_included: + image_entries = [{"type": "image"}] * len(images) + message["content"] = [*image_entries, {"type": "text", "text": message["content"]}] + images_included = True + elif isinstance(message["content"], str) and images_included: message["content"] = [{"type": "text", "text": message["content"]}] elif message["role"] == "assistant": if isinstance(message["content"], str): @@ -75,6 +90,54 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], num_images: int) else: raise ValueError(f"Invalid role in message: {message['role']}. Expected 'user', 'assistant', or 'system'.") + # Then, check that the number of image placeholders matches the number of images provided + num_placeholders = sum(sum(1 for part in message["content"] if part["type"] == "image") for message in messages) + if num_placeholders != len(images): + raise ValueError( + f"Number of images provided ({len(images)}) does not match number of image placeholders ({num_placeholders})." + ) + + # Then, fill in the actual images in the placeholders + img_idx = 0 + for message in messages: + for part in message["content"]: + if part["type"] == "image": + part["image"] = images[img_idx] + img_idx += 1 + + return messages + + +def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + """ + Convert structured multimodal messages into a format compatible with vLLM. Replaces `"type": "image"` blocks with + `"type": "image_pil"` blocks, and `"image": Image` with `"image_pil": Image`. + + Args: + messages (`list[dict[str, Any]]`): + Messages with `"role"` and `"content"`. Content is expected to be a list of structured blocks. + + Returns: + `list[dict[str, Any]]`: + A deep-copied list of messages compatible with vLLM's expected input format. + + Example: + ```python + # Input + [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}] + + # Output + [{"role": "user", "content": [{"type": "image_pil", "image_pil": }, {"type": "text", "text": "What's in this image?"}]}] + ``` + """ + messages = copy.deepcopy(messages) # avoid modifying the original messages + for message in messages: + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" # vLLM expects 'image_pil' key for images + part["image_pil"] = part.pop("image") + return messages + def is_conversational(example: dict[str, Any]) -> bool: r""" diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index b6852dfd60b..cb661587f04 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -46,7 +46,12 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available -from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from ..data_utils import ( + apply_chat_template, + is_conversational, + prepare_multimodal_messages, + prepare_multimodal_messages_vllm, +) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_liger_kernel_available, is_vllm_available @@ -547,6 +552,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + enforce_eager=True, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1068,22 +1074,9 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1096,38 +1089,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) + all_prompts = gather_object(prompts) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - + ordered_set_of_prompts = all_prompts[:: self.num_generations] + + sampling_params = { + "n": self.num_generations, + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, + "guided_decoding_regex": self.guided_decoding_regex, + "generation_kwargs": self.args.generation_kwargs, + } with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None @@ -1171,31 +1161,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.vllm_tensor_parallel_size > 1: # Gather prompts from all ranks in the TP group and flatten. # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) + orig_size = len(prompts) gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text + all_prompts = prompts with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] @@ -1258,16 +1235,20 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path self.processing_class.truncation_side = "left" - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) + processor_kwargs = { + "return_tensors": "pt", + "padding": True, + "padding_side": "left", + "max_length": self.max_prompt_length, + "truncation": True, + "return_dict": True, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True + ) + else: + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1298,11 +1279,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): return prompt_ids, completion_ids, logprobs - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate(self, prompts: list[str]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1351,8 +1332,14 @@ def _generate_and_score_completions( else: images = None + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( - prompts, images + prompts ) # Convert lists of token IDs to padded tensors diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index c706b420418..f564354cf2d 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -384,9 +384,7 @@ def _collate_language_modeling(self, examples: list[dict[str, Any]]) -> dict[str images = [example["images"] for example in examples] if "messages" in examples[0]: # conversational case - for example in examples: - prepare_multimodal_messages(example["messages"], len(example["images"])) - messages = [example["messages"] for example in examples] + messages = [prepare_multimodal_messages(example["messages"], example["images"]) for example in examples] texts = self.processor.apply_chat_template(messages) elif self.dataset_text_field in examples[0]: # standard case texts = [example[self.dataset_text_field] for example in examples] @@ -424,7 +422,8 @@ def _collate_prompt_completion(self, examples: list[dict[str, Any]]) -> dict[str images = [example["images"] for example in examples] if is_conversational(examples[0]): # conversational case for example in examples: - prepare_multimodal_messages(example["prompt"] + example["completion"], len(example["images"])) + example["prompt"] = prepare_multimodal_messages(example["prompt"], images=example["images"]) + example["completion"] = prepare_multimodal_messages(example["completion"], images=[]) examples = [apply_chat_template(example, self.processor) for example in examples] prompts = [example["prompt"] for example in examples] @@ -970,10 +969,13 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo output = {} if is_conversational(example): if self._is_vlm: - prepare_multimodal_messages(example["prompt"], num_images=0) - prepare_multimodal_messages(example["completion"], num_images=0) + prompt = prepare_multimodal_messages(example["prompt"], images=[]) + completion = prepare_multimodal_messages(example["completion"], images=[]) + else: + prompt = example["prompt"] + completion = example["completion"] prompt_ids = processing_class.apply_chat_template( - example["prompt"], + prompt, tokenize=True, tools=example.get("tools"), **example.get("chat_template_kwargs", {}), @@ -982,7 +984,7 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo # even for single examples, while for LLMs it returns lists of ints. prompt_ids = prompt_ids[0] if isinstance(prompt_ids[0], list) else prompt_ids prompt_completion_processed = processing_class.apply_chat_template( - example["prompt"] + example["completion"], + prompt + completion, return_dict=True, tokenize=True, return_assistant_tokens_mask=assistant_only_loss, @@ -1020,9 +1022,11 @@ def tokenize_fn(example, processing_class, dataset_text_field, assistant_only_lo else: # language modeling case if is_conversational(example): if self._is_vlm: - prepare_multimodal_messages(example["messages"], num_images=0) + messages = prepare_multimodal_messages(example["messages"], images=[]) + else: + messages = example["messages"] processed = processing_class.apply_chat_template( - example["messages"], + messages, return_dict=True, tokenize=True, return_assistant_tokens_mask=assistant_only_loss, From 55a2480195a31559412a0873c3886f9222bb9c21 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 26 Sep 2025 23:46:50 +0000 Subject: [PATCH 056/153] rloo + doc --- trl/trainer/rloo_trainer.py | 107 +++++++++++++++++++++--------------- trl/trainer/utils.py | 4 +- 2 files changed, 64 insertions(+), 47 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 3c7490eaac3..a1095bc1736 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1061,9 +1061,8 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device - mode = "train" if self.model.training else "eval" # If the prompts are conversational and the inputs contain images, we need to convert the prompts from # [{"role": "user", "content": "What color is the sky?"}] to @@ -1090,21 +1089,19 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_inputs = super()._prepare_inputs(prompt_inputs) prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We manually remove leading pad tokens from the decoded text, - # because we can't use `skip_special_tokens=True` (some special tokens are still needed for generation). + # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special + # tokens are needed for generation. protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] protected = [token for token in protected if token is not None] - prompt_ids, prompt_mask = truncate_with_protected_tokens( - prompt_ids, prompt_mask, self.max_prompt_length, protected - ) + prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] prompts_text = self.processing_class.batch_decode( prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False ) - prompts_text = [re.sub(rf"^({re.escape(self.pad_token)})+", "", text) for text in prompts_text] # The chat template sometimes inserts a single image token into the prompt text. However, when this text is # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the @@ -1183,13 +1180,13 @@ def _generate(self, prompts: list[str], images: Optional[list]): # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - completion_ids, _ = obj_list[0] + all_completion_ids, _ = obj_list[0] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) - completion_ids = completion_ids[process_slice] + completion_ids = all_completion_ids[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts elif self.vllm_mode == "colocate": @@ -1241,24 +1238,20 @@ def _generate(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) - completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] + all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] if self.vllm_tensor_parallel_size > 1: # Slice completions for this rank within its TP group. # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) - completion_ids = completion_ids[tp_slice] + completion_ids = all_completion_ids[tp_slice] + else: + completion_ids = all_completion_ids if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) - # Pad the completions, and concatenate them with the prompts - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_mask = [torch.ones(len(ids), device=device, dtype=torch.long) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id) - completion_mask = pad(completion_mask, padding_value=0) - elif self.use_transformers_paged: # Re-process inputs for paged generation if needed # Note: images are already validated and preprocessed above @@ -1288,15 +1281,17 @@ def _generate(self, prompts: list[str], images: Optional[list]): ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] - completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") - prompt_ids = [torch.tensor(ids, device=device) for ids in paged_prompt_inputs.input_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_ids = paged_prompt_inputs.input_ids # Restore the original attention implementation, training mode self.model_wrapped.config._attn_implementation = previous_attn else: # Regular generation path + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + with ( profiling_context(self, "transformers.generate"), unwrap_model_for_generation( @@ -1317,25 +1312,34 @@ def _generate(self, prompts: list[str], images: Optional[list]): prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] - # Mask everything after the first EOS token - is_eos = completion_ids == self.eos_token_id - eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) - eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] - sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) - completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + # Mask everything after the first EOS token + is_eos = completion_ids == self.eos_token_id + eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) + eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] + sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) + completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - # Sum along sequence dimension (dim=1) to get completion length per sequence, used for logging - completion_lengths = completion_mask.sum(1) + return prompt_ids, completion_ids, forward_kwargs - # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask - if self.mask_truncated_completions: - truncated_completions = ~is_eos.any(dim=1) - completion_mask = completion_mask * (~truncated_completions).unsqueeze(1).int() + def _generate(self, prompts: list[str], images: Optional[list]): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images) + + # Get completion length per sequence, used for logging + prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + agg_prompt_lengths = self.accelerator.gather(prompt_lengths) + agg_completion_lengths = self.accelerator.gather(completion_lengths) + total_prompt_tokens = agg_prompt_lengths.sum() + total_completion_tokens = agg_completion_lengths.sum() # = num_items_in_batch, required for the DAPO loss # Log the metrics if mode == "train": - attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) - self.state.num_input_tokens_seen += self.accelerator.gather(attention_mask.sum()).sum().item() + self.state.num_input_tokens_seen += (total_prompt_tokens + total_completion_tokens).item() self._metrics[mode]["num_tokens"] = [self.state.num_input_tokens_seen] # Log completion lengths, mean, min, max @@ -1345,17 +1349,18 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/max_length"].append(agg_completion_lengths.float().max().item()) # Identify sequences that terminated with EOS and log their lengths - agg_terminated_with_eos = self.accelerator.gather(is_eos.any(dim=1)) - term_completion_lengths = agg_completion_lengths[agg_terminated_with_eos] - clipped_completions_ratio = 1 - len(term_completion_lengths) / len(agg_completion_lengths) - self._metrics[mode]["completions/clipped_ratio"].append(clipped_completions_ratio) + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids], device=device) + agg_is_truncated = self.accelerator.gather(is_truncated) + self._metrics[mode]["completions/clipped_ratio"].append(agg_is_truncated.float().mean().item()) + term_completion_lengths = agg_completion_lengths[~agg_is_truncated] if len(term_completion_lengths) == 0: # edge case where no terminated sequences are found term_completion_lengths = torch.zeros(1, device=device) self._metrics[mode]["completions/mean_terminated_length"].append(term_completion_lengths.float().mean().item()) self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs + return prompt_ids, completion_ids, forward_kwargs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1372,11 +1377,23 @@ def _generate_and_score_completions( else: images = None - prompt_ids, completion_ids, prompt_mask, completion_mask, forward_kwargs = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images) + + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index d192036671d..7af4c69c6c9 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1867,9 +1867,9 @@ def truncate_with_protected_tokens(ids: list[int], target_length: int, protected Args: sequences (`list[int]`): - Input tensor of token IDs, shape (batch_size, sequence_length). + Input sequence of token IDs. target_length (`int`): - Desired length of the output sequences. + Desired length of the output sequence. protected_tokens (`list[int]`): List of token IDs that should be preserved in the output. """ From 7b7a11d83351c2919094ea607396eeb4e574924e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 27 Sep 2025 00:00:52 +0000 Subject: [PATCH 057/153] test and doc --- tests/test_vllm_client_server.py | 110 +++++++++++++++++++------------ trl/extras/vllm_client.py | 8 ++- 2 files changed, 76 insertions(+), 42 deletions(-) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 08a302da41c..6171845c1de 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -74,36 +74,42 @@ def setUpClass(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - self.assertIsInstance(outputs, list) + # Check that the outputs are lists + self.assertIsInstance(prompt_ids, list) + self.assertIsInstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + # Check that the number of sequences are equal to the number of prompts + self.assertEqual(len(prompt_ids), len(prompts)) + self.assertEqual(len(completion_ids), len(prompts)) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ "completion_ids" ] # Check that the output is a list - self.assertIsInstance(outputs, list) + self.assertIsInstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(outputs), 2 * len(prompts)) + self.assertEqual(len(completion_ids), 2 * len(prompts)) # Check that the generated sequences are lists of integers - for seq in outputs: + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) # Check that the length of the generated sequences is less than or equal to 32 - for seq in outputs: + for seq in completion_ids: self.assertLessEqual(len(seq), 32) def test_update_model_params(self): @@ -150,36 +156,42 @@ def setUpClass(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - self.assertIsInstance(outputs, list) + # Check that the outputs are lists + self.assertIsInstance(prompt_ids, list) + self.assertIsInstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + # Check that the number of sequences are equal to the number of prompts + self.assertEqual(len(prompt_ids), len(prompts)) + self.assertEqual(len(completion_ids), len(prompts)) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ + completion_ids = self.client.generate(prompts, n=2, repetition_penalty=0.9, temperature=0.8, max_tokens=32)[ "completion_ids" ] # Check that the output is a list - self.assertIsInstance(outputs, list) + self.assertIsInstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(outputs), 2 * len(prompts)) + self.assertEqual(len(completion_ids), 2 * len(prompts)) # Check that the generated sequences are lists of integers - for seq in outputs: + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) # Check that the length of the generated sequences is less than or equal to 32 - for seq in outputs: + for seq in completion_ids: self.assertLessEqual(len(seq), 32) def test_update_model_params(self): @@ -228,16 +240,22 @@ def setUpClass(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - self.assertIsInstance(outputs, list) + # Check that the outputs are lists + self.assertIsInstance(prompt_ids, list) + self.assertIsInstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + # Check that the number of sequences are equal to the number of prompts + self.assertEqual(len(prompt_ids), len(prompts)) + self.assertEqual(len(completion_ids), len(prompts)) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) def test_update_model_params(self): @@ -286,16 +304,22 @@ def setUpClass(cls): def test_generate(self): prompts = ["Hello, AI!", "Tell me a joke"] - outputs = self.client.generate(prompts)["completion_ids"] + outputs = self.client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] - # Check that the output is a list - self.assertIsInstance(outputs, list) + # Check that the outputs are lists + self.assertIsInstance(prompt_ids, list) + self.assertIsInstance(completion_ids, list) - # Check that the number of generated sequences is equal to the number of prompts - self.assertEqual(len(outputs), len(prompts)) + # Check that the number of sequences are equal to the number of prompts + self.assertEqual(len(prompt_ids), len(prompts)) + self.assertEqual(len(completion_ids), len(prompts)) - # Check that the generated sequences are lists of integers - for seq in outputs: + # Check that the sequences are lists of integers + for seq in prompt_ids: + self.assertTrue(all(isinstance(tok, int) for tok in seq)) + for seq in completion_ids: self.assertTrue(all(isinstance(tok, int) for tok in seq)) def test_update_model_params(self): @@ -344,9 +368,13 @@ def test_init_communicator_with_device_int(self): # Test basic functionality prompts = ["Hello, AI!"] - outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + outputs = client.generate(prompts) + prompt_ids = outputs["prompt_ids"] + completion_ids = outputs["completion_ids"] + self.assertIsInstance(prompt_ids, list) + self.assertEqual(len(prompt_ids), len(prompts)) + self.assertIsInstance(completion_ids, list) + self.assertEqual(len(completion_ids), len(prompts)) client.close_communicator() diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 962cc1664e0..d8c6c679ad7 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -212,6 +212,8 @@ def generate( Returns: `dict` with keys: + - `prompt_ids` (`list[list[int]]`): + List of lists of token IDs representing the tokenized input prompts. - `completion_ids` (`list[list[int]]`): List of lists of token IDs representing the model-generated completions for each prompt. - `logprobs` (`list[list[float]]`): @@ -246,7 +248,11 @@ def pil_to_base64(image): ) if response.status_code == 200: json_response = response.json() - return {"prompt_ids": json_response["prompt_ids"], "completion_ids": json_response["completion_ids"], "logprobs": json_response["logprobs"]} + return { + "prompt_ids": json_response["prompt_ids"], + "completion_ids": json_response["completion_ids"], + "logprobs": json_response["logprobs"], + } else: raise Exception(f"Request failed: {response.status_code}, {response.text}") From c5064d61ea9cc02ac345ee076476cc15b9165496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 27 Sep 2025 00:04:17 +0000 Subject: [PATCH 058/153] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 33 +++++++++++++++++++-------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 5e228c1e883..28202a52f27 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -20,7 +20,7 @@ from ...data_utils import is_conversational from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer -from ...trainer.utils import nanmax, nanmin, nanstd +from ...trainer.utils import nanmax, nanmin, nanstd, pad logger = logging.getLogger(__name__) @@ -78,18 +78,33 @@ def _generate_and_score_completions(self, inputs): images = None ( - prompt_ids, - completion_ids, - prompt_mask, - completion_mask, + prompt_ids_list, + completion_ids_list, num_items_in_batch, - sampling_per_token_logps, + sampling_per_token_logps_list, forward_kwargs, ) = self._generate(prompts, images) - # Convert tensor to a list of lists of token IDs. This will be passed to the reward function, avoiding the need - # to re-tokenize completions if the reward is computed from tokens. - completion_ids_list = [row[mask_row].tolist() for row, mask_row in zip(completion_ids, completion_mask.bool())] + # Convert lists of token IDs to padded tensors + prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] + prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] + prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") + prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") + completion_mask = pad(completion_mask, padding_value=0, padding_side="right") + if sampling_per_token_logps_list is not None: + sampling_per_token_logps = [torch.tensor(logps, device=device) for logps in sampling_per_token_logps_list] + sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") + else: + sampling_per_token_logps = None + + # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask + if self.mask_truncated_completions: + eos_and_pad = [self.eos_token_id, self.pad_token_id] + is_truncated = torch.tensor([ids[-1] not in eos_and_pad for ids in completion_ids_list], device=device) + completion_mask = completion_mask * (~is_truncated).unsqueeze(1).int() # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) From 6bc15a3185ccfd428177880494847d4f473920da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sun, 28 Sep 2025 16:57:37 +0000 Subject: [PATCH 059/153] wip --- trl/data_utils.py | 11 ++-- trl/trainer/grpo_trainer.py | 107 +++++++++++++++++++++++++++++++++--- 2 files changed, 104 insertions(+), 14 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index a1b188bf0de..4ae778627c0 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -132,10 +132,11 @@ def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dic """ messages = copy.deepcopy(messages) # avoid modifying the original messages for message in messages: - for part in message["content"]: - if part["type"] == "image": - part["type"] = "image_pil" # vLLM expects 'image_pil' key for images - part["image_pil"] = part.pop("image") + if isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" # vLLM expects 'image_pil' key for images + part["image_pil"] = part.pop("image") return messages @@ -211,7 +212,7 @@ def apply_chat_template( # Apply the chat template to the prompt, adding the generation prompt if "prompt" in example: last_role = example["prompt"][-1]["role"] - if last_role == "user": + if last_role in ["user", "tool"]: add_generation_prompt = True continue_final_message = False elif last_role == "assistant": diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index cb661587f04..135c6ea604f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -13,8 +13,10 @@ # limitations under the License. import inspect +import json import os import textwrap +import traceback from collections import defaultdict, deque from contextlib import nullcontext from functools import partial @@ -92,6 +94,7 @@ if is_wandb_available(): import wandb +import re logger = logging.get_logger(__name__) @@ -99,6 +102,23 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] +def extract_tool_calls(text: str) -> list[dict[str, Any]]: + """ + Extract JSON objects from ... blocks in `text` + and return them in the format: + {"type": "function", "function": {...}} + """ + # Find every block between and + blocks = re.findall(r'\s*(\{.*?\})\s*', text, flags=re.DOTALL) + + result = [] + for block in blocks: + try: + parsed = json.loads(block) + except json.JSONDecodeError as e: + continue + result.append({"type": "function", "function": parsed}) + return result or None class GRPOTrainer(BaseTrainer): """ @@ -208,14 +228,16 @@ def reward_func(completions, **kwargs): "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", "id": "2402.03300", # docstyle-ignore - "citation": textwrap.dedent("""\ + "citation": textwrap.dedent( + """\ @article{shao2024deepseekmath, title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, } - """), + """ + ), } def __init__( @@ -230,7 +252,10 @@ def __init__( callbacks: Optional[list[TrainerCallback]] = None, optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), peft_config: Optional["PeftConfig"] = None, + tools=None, ): + self.tools = tools or [] + self._tool_dict = {tool.__name__: tool for tool in self.tools} # Args if args is None: model_name = model if isinstance(model, str) else model.config._name_or_path @@ -278,7 +303,7 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, padding_side="left") # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -534,7 +559,7 @@ def __init__( ensure_master_addr_port() if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length + max_model_len = self.max_prompt_length + self.max_completion_length + 512 else: max_model_len = None self.llm = LLM( @@ -552,7 +577,7 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, - enforce_eager=True, + # enforce_eager=True, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) @@ -1115,7 +1140,9 @@ def _generate_single_turn(self, prompts: list): } with profiling_context(self, "vLLM.generate"): if is_conversational({"prompt": ordered_set_of_prompts[0]}): - output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + output = self.vllm_client.chat( + prompts=ordered_set_of_prompts, tools=self.tools, **sampling_params + ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) @@ -1170,7 +1197,9 @@ def _generate_single_turn(self, prompts: list): with profiling_context(self, "vLLM.generate"): if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + all_outputs = self.llm.chat( + all_prompts, sampling_params=sampling_params, tools=self.tools, use_tqdm=False + ) else: all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) @@ -1245,7 +1274,7 @@ def _generate_single_turn(self, prompts: list): } if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True + conversation=prompts, **processor_kwargs, tokenize=True, tools=self.tools ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1284,6 +1313,65 @@ def _generate(self, prompts: list[str]): mode = "train" if self.model.training else "eval" prompt_ids, completion_ids, logprobs = self._generate_single_turn(prompts) + completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + # parsed_completions = [] + # for content in completion_contents: + # try: + # parsed_completions.append(self.processing_class.parse_response(content)) + # except Exception as e: + # logger.warning(f"Failed to parse model output: {content}\nError: {e}") + # parsed_completions.append(None) + # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] + tool_calls = [extract_tool_calls(content) for content in completion_contents] + idxs_with_tool = [i for i, t in enumerate(tool_calls) if t] # find indices that actually have a tool call + tool_calls = [tool_calls[i] for i in idxs_with_tool] + + while idxs_with_tool: + prompts_for_generation = [prompts[i] for i in idxs_with_tool] + for idx, tool_call_list, prompt_for_generation in zip(idxs_with_tool, tool_calls, prompts_for_generation): + prompt_for_generation.append({"role": "assistant", "content": completion_contents[idx]}) + for tool_call in tool_call_list: + if tool_call["type"] == "function": + function = tool_call["function"] + try: + result = self._tool_dict[function["name"]](**function["arguments"]) + except Exception as e: + # store the full traceback as a string in the result + result = {"error": str(e), "traceback": traceback.format_exc()} + else: + result = {"error": f"Unsupported tool call type: {tool_call['type']}"} + tool_call["result"] = result + tool_message = {"role": "tool", "name": function["name"], "content": str(result)} + prompt_for_generation.append(tool_message) + + prompt_completion_tool_ids, post_tool_ids, _ = self._generate_single_turn(prompts_for_generation) + + # Truncate post-tool completion so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length + for i in range(len(post_tool_ids)): + excess_length = len(prompt_completion_tool_ids[i]) + len(post_tool_ids[i]) - ( + self.max_prompt_length + self.max_completion_length + ) + if excess_length > 0: + post_tool_ids[i] = post_tool_ids[i][:-excess_length] + + for idx, pct, post_tool in zip(idxs_with_tool, prompt_completion_tool_ids, post_tool_ids): + completion_ids[idx] = pct[len(prompt_ids[idx]) :] + post_tool + + cc = self.processing_class.batch_decode(post_tool_ids, skip_special_tokens=True) + # parsed_completions = [] + # for content in cc: + # try: + # parsed_completions.append(self.processing_class.parse_response(content)) + # except Exception as e: + # logger.warning(f"Failed to parse model output: {content}\nError: {e}") + # parsed_completions.append(None) + # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] + tool_calls = [extract_tool_calls(content) for content in cc] + completion_contents =[None] * len(completion_contents) + for i, content in zip(idxs_with_tool, cc): + completion_contents[i] = content + idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls) if tc] + tool_calls = [tc for tc in tool_calls if tc] # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1375,7 +1463,8 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools)["prompt"] + for prompt in prompts ] prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") prompt_inputs = super()._prepare_inputs(prompt_inputs) From e7aa9452735865b4a42096f981527dd4722738dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 30 Sep 2025 23:10:16 +0000 Subject: [PATCH 060/153] fix vllm client server --- trl/extras/vllm_client.py | 8 ++++++-- trl/scripts/vllm_serve.py | 8 +++++++- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index d8c6c679ad7..00e5f2c817b 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -83,8 +83,12 @@ class VLLMClient: >>> client = VLLMClient() >>> client.generate(["Hello, AI!", "Tell me a joke"]) - [[2980, 498, 1492, 752, 448, 264, 13027, 8645, 30, 358, 2776, 4460, 311, 3270, 264, 2025], - [911, 7988, 1251, 382, 3838, 653, 498, 1618, 4325, 879, 2581, 20027, 264, 21428, 30, 362]] + {'prompt_ids': [[9707, 11, 15235, 0], + [40451, 752, 264, 21646]], + 'completion_ids': [[11479, 752, 5046, 279, 1465, 304, 419, 23670, 2038, 358, 2776, 4378, 369, 847, 15549, 6733], + [911, 19654, 382, 3838, 1558, 279, 16158, 1977, 979, 498, 2299, 4460, 311, 10542, 432, 518]], + 'logprobs': [[-5.193126201629639, -0.05592319369316101, -4.861808776855469, -1.673396110534668, -2.6316866874694824, -0.2861405313014984, -0.35006725788116455, -5.23351526260376, -0.1447441577911377, -5.21489953994751, -1.6022650003433228, -1.9649192094802856, -2.1338791847229004, -1.2775304317474365, -10.004860877990723, -4.171003818511963], + [-0.012896230444312096, -5.747106552124023, -1.5248860120773315, -1.9286258220672607, -2.8512537479400635, -2.8055880069732666, -3.019822835922241, -0.37132859230041504, -0.6311739087104797, -2.562908411026001, -3.1664533615112305, -2.685293436050415, -0.007259538397192955, -7.339841842651367, -1.188662052154541, -3.54781436920166]]} >>> from transformers import AutoModelForCausalLM diff --git a/trl/scripts/vllm_serve.py b/trl/scripts/vllm_serve.py index 7e7174f261c..901f8177ce0 100644 --- a/trl/scripts/vllm_serve.py +++ b/trl/scripts/vllm_serve.py @@ -499,6 +499,7 @@ class GenerateRequest(BaseModel): generation_kwargs: dict = field(default_factory=dict) class GenerateResponse(BaseModel): + prompt_ids: list[list[int]] completion_ids: list[list[int]] logprobs: list[list[float]] @@ -532,6 +533,7 @@ async def generate(request: GenerateRequest): Returns: `GenerateResponse`: + - `prompt_ids` (list of list of `int`): A list of lists of token IDs for each input prompt. - `completion_ids` (list of list of `int`): A list of lists of token IDs for each generated completion. - `logprobs` (list of list of `float`): A list of lists of log probabilities for each token in the generated completions. @@ -543,7 +545,11 @@ async def generate(request: GenerateRequest): Example response: ```json - {"completion_ids": [[101, 102, 103], [201, 202, 203]], "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]]} + { + "prompt_ids": [[101, 102], [201, 202]], + "completion_ids": [[103, 104, 105], [203, 204, 205]], + "logprobs": [[-0.1, -0.2, -0.3], [-0.4, -0.5, -0.6]] + } ``` """ request.images = request.images or [None] * len(request.prompts) From e164ec5aabc1334abe739db9c889915b7931aae3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 00:11:48 +0000 Subject: [PATCH 061/153] repicate all_prompt_ids --- trl/trainer/grpo_trainer.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 2e6b04c789b..8be2b27a918 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1100,7 +1100,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): ) prompt_inputs = super()._prepare_inputs(prompt_inputs) forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] @@ -1195,7 +1194,9 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): obj_list = [payload] broadcast_object_list(obj_list, from_process=0) all_prompt_ids, all_completion_ids, all_logprobs = obj_list[0] - all_completion_ids, all_logprobs = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] process_slice = slice( self.accelerator.process_index * len(prompts), From 49577adb19575ee8edbc3341cca4f7ecdddccc8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 00:17:37 +0000 Subject: [PATCH 062/153] Same for RLOO --- trl/trainer/rloo_trainer.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 4c9a3623c3d..f539881d9b5 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1090,11 +1090,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): **kwargs, ) prompt_inputs = super()._prepare_inputs(prompt_inputs) - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] if self.max_prompt_length is not None: + prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] + prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] + # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special # tokens are needed for generation. @@ -1176,19 +1177,23 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) - payload = (output["completion_ids"], output["logprobs"]) + payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None # Broadcast the completions from the main process to all processes, ensuring each process receives its corresponding slice. obj_list = [payload] broadcast_object_list(obj_list, from_process=0) - all_completion_ids, _ = obj_list[0] + all_prompt_ids, all_completion_ids, _ = obj_list[0] + + # At this point, we only get 1 copy of each prompt, so we need to repeat them num_generations times + all_prompt_ids = [ids for ids in all_prompt_ids for _ in range(self.num_generations)] process_slice = slice( self.accelerator.process_index * len(prompts), (self.accelerator.process_index + 1) * len(prompts), ) + prompt_ids = all_prompt_ids[process_slice] completion_ids = all_completion_ids[process_slice] # Generate completions using colocated vLLM instances: each device holds vLLM copy and work on their own batch of prompts @@ -1241,6 +1246,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): with profiling_context(self, "vLLM.generate"): all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] if self.vllm_tensor_parallel_size > 1: @@ -1248,8 +1254,10 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): # Each rank generates all outputs — we keep only our share. local_rank_in_group = torch.distributed.get_rank(group=self.tp_group) tp_slice = slice(local_rank_in_group * orig_size, (local_rank_in_group + 1) * orig_size) + prompt_ids = all_prompt_ids[tp_slice] completion_ids = all_completion_ids[tp_slice] else: + prompt_ids = all_prompt_ids completion_ids = all_completion_ids if self.args.vllm_enable_sleep_mode: From 5fca5b88020d182b1bd0e2c5c5ffa2793ae057e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 00:46:15 +0000 Subject: [PATCH 063/153] fix normal generation path --- trl/trainer/grpo_trainer.py | 1 + trl/trainer/rloo_trainer.py | 1 + 2 files changed, 2 insertions(+) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8be2b27a918..fe248e7bc84 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1317,6 +1317,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path + prompt_ids = prompt_inputs["input_ids"] prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f539881d9b5..24ecb0bfc8f 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1298,6 +1298,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path + prompt_ids = prompt_inputs["input_ids"] prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") From 4dce145d40f61c88c7b4cad926e564f3d11e54f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 01:09:40 +0000 Subject: [PATCH 064/153] remove vision tokens --- trl/trainer/grpo_trainer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8ae52c50218..e8c442bbbc4 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -289,10 +289,6 @@ def __init__( self.pad_token = tokenizer.pad_token self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id - self.image_token = getattr(processing_class, "image_token", None) - self.image_token_id = getattr(processing_class, "image_token_id", None) - self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None) - self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None) # Reward functions if not isinstance(reward_funcs, list): From ddfd3b58c9822658b5c61917f3296fd4756bf18c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 02:14:43 +0000 Subject: [PATCH 065/153] same for rloo --- trl/trainer/rloo_trainer.py | 90 +++++++++---------------------------- 1 file changed, 22 insertions(+), 68 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 2e7f74d6fbc..0c4e6ea16c8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -14,7 +14,6 @@ import inspect import os -import re import textwrap import warnings from collections import defaultdict, deque @@ -402,10 +401,6 @@ def decode(example, tokenizer): self.pad_token = tokenizer.pad_token self.pad_token_id = tokenizer.pad_token_id self.eos_token_id = tokenizer.eos_token_id - self.image_token = getattr(processing_class, "image_token", None) - self.image_token_id = getattr(processing_class, "image_token_id", None) - self.vision_start_token_id = getattr(model.config, "vision_start_token_id", None) - self.vision_end_token_id = getattr(model.config, "vision_end_token_id", None) # Reward functions if not isinstance(reward_funcs, list): @@ -1080,58 +1075,12 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - prompt_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - add_special_tokens=False, - **kwargs, - ) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - - if self.max_prompt_length is not None: - prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] - prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] - - # If max_prompt_length is set, we trim the prompt to keep only the last `max_prompt_length` tokens. - # Then we decode those tokens back into text. We set `skip_special_tokens=False` because some special - # tokens are needed for generation. - protected = [self.image_token_id, self.vision_start_token_id, self.vision_end_token_id] - protected = [token for token in protected if token is not None] - prompt_ids = [truncate_with_protected_tokens(ids, self.max_prompt_length, protected) for ids in prompt_ids] - - prompts_text = self.processing_class.batch_decode( - prompt_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False - ) - - # The chat template sometimes inserts a single image token into the prompt text. However, when this text is - # later tokenized, the single image token string is expanded into multiple image token IDs, depending on the - # image size. Since we're detokenizing here, we may see repeated image tokens in the decoded text. We - # collapse them back into a single token string to match the original chat template in case it originally - # applies it. Otherwise, it assumes that the chat template uses only vision_start_token_id to indicate images - # (e.g. Gemma 3) and removes all image_token instances and vision_end_token_id as well, leaving only - # the vision_start_token_id (e.g. ). - if self.image_token is not None: - escaped_img_token = re.escape(self.image_token) - # Search for the image token in the chat template - if re.search(escaped_img_token, self.processing_class.chat_template): - prompts_text = [ - re.sub(rf"({escaped_img_token})+", self.image_token, text) for text in prompts_text - ] - else: - # If the chat template doesn't use the image token, we remove all instances of it + vision_end_token_id - if self.vision_end_token_id is not None: - escaped_eoi_token = re.escape( - self.processing_class.tokenizer.decode([self.vision_end_token_id]) - ) - prompts_text = [ - re.sub(rf"({escaped_img_token})+{escaped_eoi_token}", "", text) for text in prompts_text - ] - else: - # If vision_end_token_id is None, just remove the image tokens - prompts_text = [re.sub(rf"({escaped_img_token})+", "", text) for text in prompts_text] + if images is not None: + prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1173,6 +1122,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): top_k=-1 if self.top_k is None else self.top_k, min_p=0.0 if self.min_p is None else self.min_p, max_tokens=self.max_completion_length, + truncate_prompt_tokens=self.max_prompt_length, guided_decoding_regex=self.guided_decoding_regex, generation_kwargs=self.args.generation_kwargs, ) @@ -1210,6 +1160,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding": guided_decoding, } if self.args.generation_kwargs is not None: @@ -1297,11 +1248,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - prompt_ids = prompt_inputs["input_ids"] - prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids] - prompt_mask = [torch.ones_like(ids, dtype=torch.long) for ids in prompt_ids] - prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") - prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") + self.processing_class.truncation_side = "left" + generate_inputs = self.processing_class( + text=prompts_text, + return_tensors="pt", + padding=True, + padding_side="left", + max_length=self.max_prompt_length, + truncation=True, + add_special_tokens=False, + **kwargs, + ) + generate_inputs = super()._prepare_inputs(generate_inputs) with ( profiling_context(self, "transformers.generate"), @@ -1312,15 +1270,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): FSDP.summon_full_params(self.model_wrapped, recurse=False) if self.is_fsdp_enabled else nullcontext(), ): prompt_completion_ids = unwrapped_model.generate( - input_ids=prompt_ids, - attention_mask=prompt_mask, - **forward_kwargs, - generation_config=self.generation_config, - disable_compile=True, + **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids + prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] prompt_length = prompt_ids.size(1) - prompt_ids = prompt_completion_ids[:, :prompt_length] completion_ids = prompt_completion_ids[:, prompt_length:] # Mask everything after the first EOS token From c434fa23bfbd6016b68dba47b233f7e8dbefa337 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 02:28:07 +0000 Subject: [PATCH 066/153] truncation_side=left --- trl/trainer/grpo_trainer.py | 3 +-- trl/trainer/rloo_trainer.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e8c442bbbc4..5ddd2c224ed 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -273,7 +273,7 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side = "left") # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -1267,7 +1267,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - self.processing_class.truncation_side = "left" generate_inputs = self.processing_class( text=prompts_text, return_tensors="pt", diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 0c4e6ea16c8..841f59dc617 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -385,7 +385,7 @@ def decode(example, tokenizer): # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path) + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side = "left") # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -1248,7 +1248,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): else: # Regular generation path - self.processing_class.truncation_side = "left" generate_inputs = self.processing_class( text=prompts_text, return_tensors="pt", From 377b0811c9bcad05532d7afa82bb9a90def5708b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 1 Oct 2025 02:28:38 +0000 Subject: [PATCH 067/153] rm test_training_vlm_and_prompt_truncation --- tests/test_rloo_trainer.py | 41 -------------------------------------- 1 file changed, 41 deletions(-) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index cde52de6047..149e84ac9fe 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1213,47 +1213,6 @@ def reward_func(completions, **kwargs): elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") - @require_vision - def test_training_vlm_and_prompt_truncation(self): - # If not handled properly, prompt truncation may truncate image token - dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") - - def reward_func(completions, **kwargs): - """Reward function that rewards longer completions.""" - return [float(len(completion[0]["content"])) for completion in completions] - - training_args = RLOOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=3, # reduce the batch size to reduce memory usage - num_generations=3, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=18, - report_to="none", - ) - trainer = RLOOTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs=reward_func, - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - - # Check that the params have changed - # Because of the way the tiny models are initialized, the gradient does not flow properly through the - # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ("model.visual.",) - for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue - new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - @require_vision @require_vllm @parameterized.expand( From e82db740f08d6562a070efc81fae6ade796db725 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Wed, 1 Oct 2025 09:16:42 -0600 Subject: [PATCH 068/153] =?UTF-8?q?=F0=9F=94=A3=20Fix=20test:=20replace=20?= =?UTF-8?q?`trainer.tokenizer`=20by=20`trainer.processing=5Fclass`=20(#418?= =?UTF-8?q?5)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_bco_trainer.py | 4 ++-- tests/test_kto_trainer.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index d609c3d5b90..68dbf7d9a30 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -192,7 +192,7 @@ def test_tokenize_and_process_tokens(self): tokenized_dataset = dataset.map( _tokenize, - fn_kwargs={"tokenizer": trainer.tokenizer}, + fn_kwargs={"tokenizer": trainer.processing_class}, batched=True, batch_size=2, ) @@ -207,7 +207,7 @@ def test_tokenize_and_process_tokens(self): fn_kwargs = { "prefix": "", "is_encoder_decoder": trainer.is_encoder_decoder, - "tokenizer": trainer.tokenizer, + "tokenizer": trainer.processing_class, "max_length": trainer.max_length, "truncation_mode": trainer.truncation_mode, "label_pad_token_id": trainer.label_pad_token_id, diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index 21b425fec05..fa17881544e 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -145,7 +145,7 @@ def test_tokenize_and_process_tokens(self): train_dataset = dummy_dataset["train"] tokenized_dataset = train_dataset.map( _tokenize, - fn_kwargs={"tokenizer": trainer.tokenizer}, + fn_kwargs={"tokenizer": trainer.processing_class}, batched=True, batch_size=2, ) @@ -182,7 +182,7 @@ def test_tokenize_and_process_tokens(self): fn_kwargs = { "prefix": "", "is_encoder_decoder": trainer.is_encoder_decoder, - "tokenizer": trainer.tokenizer, + "tokenizer": trainer.processing_class, "max_length": trainer.max_length, "truncation_mode": trainer.truncation_mode, "label_pad_token_id": trainer.label_pad_token_id, From 192deb3b2b69c0817d8fc11b0601df494bb55a21 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 18:01:56 +0200 Subject: [PATCH 069/153] Fix CI ImportError: FlashAttention2 and decorator order for all parameterized tests (#4176) --- tests/slow/test_grpo_slow.py | 6 +++--- tests/test_dpo_trainer.py | 2 +- tests/test_grpo_trainer.py | 4 ++-- tests/test_online_dpo_trainer.py | 4 ++-- tests/test_rloo_trainer.py | 4 ++-- tests/test_xpo_trainer.py | 2 +- 6 files changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 75dd1dc8d9e..69ba2b6ce34 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -208,14 +208,14 @@ def test_training_with_transformers_paged(self, model_name): release_memory(model, trainer) - @require_flash_attn - @require_bitsandbytes - @require_peft @parameterized.expand( [ ("HuggingFaceTB/SmolVLM-Instruct",), # Only test the smaller model to avoid OOM ] ) + @require_flash_attn + @require_bitsandbytes + @require_peft def test_vlm_training(self, model_name): """ Test VLM training with aggressive memory optimization. diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index 6a4bfd22301..bf4a5d9826b 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -1303,7 +1303,6 @@ def test_train_with_length_desensitization(self): if param.sum() != 0: # ignore 0 biases self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) - @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") @parameterized.expand( [ (0.1, "sigmoid"), @@ -1319,6 +1318,7 @@ def test_train_with_length_desensitization(self): ] ) @require_liger_kernel + @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") def test_dpo_trainer_with_liger(self, beta, loss_type): """Test DPO trainer with Liger loss enabled across supported loss types. diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index a839a654bca..03d44e50603 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1513,14 +1513,14 @@ def reward_func(completions, **kwargs): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - @require_vision - @require_vllm @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), ("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",), ] ) + @require_vision + @require_vllm @unittest.skip("We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index 336db8a089f..b9dec1135e6 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -238,8 +238,8 @@ def test_training_with_peft_model_and_peft_config(self): # Check if training loss is available self.assertIn("train_loss", trainer.state.log_history[-1]) - @require_llm_blender @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + @require_llm_blender def test_training_with_judge(self, config_name): training_args = OnlineDPOConfig( output_dir=self.tmp_dir, @@ -419,8 +419,8 @@ def test_generation_config_setup(self): self.assertEqual(trainer.generation_config.max_new_tokens, 64) self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs - @require_torch_accelerator @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + @require_torch_accelerator def test_training_with_transformers_paged(self, config_name): if Version(transformers.__version__) < Version("4.57.0"): pytest.xfail("Upstream bug in transformers (GH#40692). Fix merged; awaiting release >= 4.57.0") diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index cde52de6047..a2b1d3bf8b7 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1254,14 +1254,14 @@ def reward_func(completions, **kwargs): new_param = trainer.model.get_parameter(n) self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - @require_vision - @require_vllm @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), ("trl-internal-testing/tiny-Gemma3ForConditionalGeneration",), ] ) + @require_vision + @require_vllm @unittest.skip("We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index 9d50b542a03..9af803830cf 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -184,8 +184,8 @@ def test_training_pre_pefted_model_implicit_ref(self): self.assertIn("train_loss", trainer.state.log_history[-1]) - @require_llm_blender @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) + @require_llm_blender def test_xpo_trainer_judge_training(self, config_name): training_args = XPOConfig( output_dir=self.tmp_dir, From cf9d8e76c45e69f98dad1b6ca9d3cdf05ea3d583 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Wed, 1 Oct 2025 19:42:36 +0200 Subject: [PATCH 070/153] Hotfix wrong formatting of docstrings with blockquote tips (#4187) --- trl/models/utils.py | 1 + trl/trainer/judges.py | 1 + trl/trainer/utils.py | 2 ++ 3 files changed, 4 insertions(+) diff --git a/trl/models/utils.py b/trl/models/utils.py index efdba75fbda..1bdaad82e8c 100644 --- a/trl/models/utils.py +++ b/trl/models/utils.py @@ -90,6 +90,7 @@ def setup_chat_format( format: Optional[Literal["chatml"]] = "chatml", resize_to_multiple_of: Optional[int] = None, ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + # docstyle-ignore """ Setup chat format by adding special tokens to the tokenizer, setting the correct format, and extending the embedding layer of the model based on the new special tokens. diff --git a/trl/trainer/judges.py b/trl/trainer/judges.py index 5c8c80c3726..923e10f7259 100644 --- a/trl/trainer/judges.py +++ b/trl/trainer/judges.py @@ -185,6 +185,7 @@ def judge( class PairRMJudge(BasePairwiseJudge): + # docstyle-ignore """ LLM judge based on the PairRM model from AllenAI. diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 76691bff7fb..9b92706d5cf 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -217,6 +217,7 @@ def ensure_master_addr_port(addr: Optional[str] = None, port: Optional[int] = No @dataclass class RewardDataCollatorWithPadding: + # docstyle-ignore r""" Reward DataCollator class that pads the inputs to the maximum length of the batch. @@ -1251,6 +1252,7 @@ def empty_cache() -> None: def decode_and_strip_padding(inputs: torch.Tensor, tokenizer: PreTrainedTokenizerBase) -> list[str]: + # docstyle-ignore """ Decodes the input tensor and strips the padding tokens. From f9c3c3c72642a350b92ce81ae57d6bf2e80e59a3 Mon Sep 17 00:00:00 2001 From: YonatanGideoni Date: Wed, 1 Oct 2025 18:58:13 +0100 Subject: [PATCH 071/153] =?UTF-8?q?=F0=9F=8C=A1=EF=B8=8F=20Have=20vLLM=20r?= =?UTF-8?q?eturn=20processed=20(temperature=20scaled)=20log=20probs=20(#41?= =?UTF-8?q?63)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/vllm_integration.md | 2 +- setup.cfg | 2 +- trl/import_utils.py | 9 +++------ trl/trainer/grpo_trainer.py | 2 ++ 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index b2838215d4c..9d3f6beee11 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -3,7 +3,7 @@ This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 > [!WARNING] -> TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. Please ensure you have one of these versions installed to avoid compatibility issues. +> TRL currently only supports vLLM version `0.10.2`. Please ensure you have this version installed to avoid compatibility issues. ## 🚀 How can I use vLLM with TRL to speed up training? diff --git a/setup.cfg b/setup.cfg index f84ba1f950e..cf82100e979 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,7 +62,7 @@ test = pytest-xdist pytest vllm = - vllm>=0.10.0,<=0.10.2 + vllm==0.10.2 fastapi pydantic requests diff --git a/trl/import_utils.py b/trl/import_utils.py index 0f15a17222c..10709dc549c 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -82,13 +82,10 @@ def is_uvicorn_available() -> bool: def is_vllm_available() -> bool: - if _vllm_available and ( - version.parse(_vllm_version) < version.parse("0.10.0") - or version.parse(_vllm_version) > version.parse("0.10.2") - ): + if _vllm_available and version.parse(_vllm_version) != version.parse("0.10.2"): warnings.warn( - "TRL currently only supports vLLM versions `0.10.0`, `0.10.1`, and `0.10.2`. You have version " - f"{_vllm_version} installed. We recommend to install one of these versions to avoid compatibility issues.", + f"TRL currently only supports vLLM version `0.10.2`. You have version {_vllm_version} installed. We " + "recommend to install this version to avoid compatibility issues.", UserWarning, ) return _vllm_available diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index fe248e7bc84..bf0f519cd9f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -549,6 +549,8 @@ def __init__( max_num_batched_tokens=4096, model_impl=self.args.vllm_model_impl, enable_sleep_mode=self.args.vllm_enable_sleep_mode, + # Important so temperature scaling/logit tweaking affects the TIS log probs + logprobs_mode="processed_logprobs", ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=1) From 648947911a7b8ee7724ba373d014777b30cc7461 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Fri, 3 Oct 2025 09:08:53 +0200 Subject: [PATCH 072/153] Replace remaining trainer.tokenizer with trainer.processing_class in GRPO test (#4192) --- tests/test_grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 03d44e50603..b29e0769b89 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1882,7 +1882,7 @@ def test_update_with_inputs_different_seq_len(self): Test with inputs where the sequence lengths are different from the prepopulated buffer. """ self._prepopulate_buffer() - pad_token_id = self.trainer.tokenizer.pad_token_id + pad_token_id = self.trainer.processing_class.pad_token_id group_advantages = torch.tensor([[0.6, 0.6], [0.3, 0.45]]) # one no-variance, one variance inputs = { "group_advantages": group_advantages, From 21a67fc43ff7046a1c28515f3d40b08fd1141268 Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Fri, 3 Oct 2025 20:40:37 +0200 Subject: [PATCH 073/153] [DOCS] Lora without regret (#4181) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> Co-authored-by: sergiopaniego Co-authored-by: lewtun Co-authored-by: Kashif Rasul --- docs/source/_toctree.yml | 2 + docs/source/lora_without_regret.md | 447 +++++++++++++++++++++++++++++ 2 files changed, 449 insertions(+) create mode 100644 docs/source/lora_without_regret.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 054534c1584..566587e156f 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -55,6 +55,8 @@ title: Example Overview - local: community_tutorials title: Community Tutorials + - local: lora_without_regret + title: LoRA Without Regret - local: sentiment_tuning title: Sentiment Tuning - local: using_llama_models diff --git a/docs/source/lora_without_regret.md b/docs/source/lora_without_regret.md new file mode 100644 index 00000000000..04eed828d27 --- /dev/null +++ b/docs/source/lora_without_regret.md @@ -0,0 +1,447 @@ +# LoRA Without Regret + +Recent research from the team at [Thinking Machines Lab](https://thinkingmachines.ai/blog/lora/) (Schulman et al., 2025) shows that **LoRA can match full fine-tuning performance** when configured correctly, while using only ~67% of the compute. These findings are exciting to TRL users because they're straightforward to implement and can improve model performance on smaller budgets. + +This guide provides simple instructions to reproduce the results of the blog post in TRL. + +> [!TIP] +> It is recommended to read the blog post before following this guide, or to consult both resources in parallel for best results. + +## Benefits of LoRA over full fine-tuning + +First of all, let's remind ourselves of the benefits of [LoRA over full fine-tuning](https://huggingface.co/docs/trl/en/peft_integration). + +LoRA adds adapter layers on top of the base model, which contains significantly fewer parameters than the base model itself. This design reduces GPU memory requirements and enables more efficient training. As described in the [blog](https://thinkingmachines.ai/blog/lora/), this approach was originally thought to involve a performance trade-off, although careful configuration can overcome this trade-off and match full fine-tuning performance. + +## Examples with TRL + +Let's implement and train LoRA adapters in TRL scripts based on the core findings of the blog post. Afterwards, we'll revisit each finding in light of the TRL results. + +### Supervised Fine-Tuning (SFT) + +The blog post performs SFT on a range of models and datasets from the Hub, which we can reproduce in TRL. + +| Model | Dataset | +|-------|---------| +| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) | +| [Llama-3.2-1B-Instruct](https://huggingface.co/meta-llama/Llama-3.2-1B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) | +| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [allenai/tulu-3-sft-mixture](https://huggingface.co/datasets/allenai/tulu-3-sft-mixture) | +| [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) | [open-thoughts/OpenThoughts-114k](https://huggingface.co/datasets/open-thoughts/OpenThoughts-114k) | + + + + + +We can integrate these findings with the TRL Python API like so: + +```python + +from datasets import load_dataset +from peft import LoraConfig +from trl import SFTTrainer, SFTConfig + +dataset = load_dataset("open-thoughts/OpenThoughts-114k", split="train") + +peft_config = LoraConfig(lora_r=256, lora_alpha=16, lora_target_modules="all-linear") + +training_args = SFTConfig( + learning_rate=2e-4, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + num_train_epochs=1, + report_to=["trackio"], +) + +trainer = SFTTrainer( + model="Qwen/Qwen2.5-3B-Instruct", + train_dataset=dataset, + peft_config=peft_config, + args=training_args, +) + +trainer.train() + +``` + + + + + +```bash + +hf jobs uv run \ + --flavor a100-large \ + --timeout 8h \ + --secrets HF_TOKEN \ + "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \ + --model_name_or_path Qwen/Qwen2.5-3B-Instruct \ + --dataset_name open-thoughts/OpenThoughts-114k \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --use_peft \ + --lora_r 256 \ + --lora_alpha 16 \ + --lora_target_modules all-linear \ + --output_dir Qwen2.5-3B-OpenThoughts-LoRA \ + --report_to trackio \ + --push_to_hub + +``` + +To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details. + + + + +```bash + +uv run "https://raw.githubusercontent.com/huggingface/trl/main/trl/scripts/sft.py" \ + --model_name_or_path Qwen/Qwen2.5-3B-Instruct \ + --dataset_name open-thoughts/OpenThoughts-114k \ + --learning_rate 2.0e-5 \ + --num_train_epochs 1 \ + --packing \ + --per_device_train_batch_size 2 \ + --gradient_accumulation_steps 16 \ + --gradient_checkpointing \ + --eval_strategy no \ + --use_peft \ + --lora_r 256 \ + --lora_alpha 16 \ + --lora_target_modules all-linear \ + --output_dir Qwen2.5-3B-OpenThoughts-LoRA \ + --report_to trackio \ + --push_to_hub + +``` + +To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details. + + + + +Once training starts, you can monitor the progress in [Trackio](https://huggingface.co/trackio), which will log the URL. + +### Reinforcement Learning (GRPO) + +The blog post performs GRPO on a range of models and datasets from the Hub, and once again we can reproduce the results in TRL. + +| Model | Dataset | +|-------|---------| +| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [GSM8k](https://huggingface.co/datasets/openai/gsm8k) | +| [Llama-3.1-8B-Base](https://huggingface.co/meta-llama/Llama-3.2-1B) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) | +| [Qwen3-8b-base](https://huggingface.co/Qwen/Qwen3-8b-base) | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) | + +For reinforcement learning, the blog uses a math reasoning task that we can reproduce as a Python function. + +
+Reward function + +```python +def strip_reasoning_accuracy_reward( + completions: list[list[dict[str, str]]], solution: list[str], **kwargs +) -> list[Optional[float]]: + """Reward function that strips reasoning tags and checks mathematical accuracy. + + This function: + 1. Extracts the content from completions + 2. Removes tags (for reasoning that shouldn't be evaluated) + 3. Parses both the gold solution and the predicted answer + 4. Uses math_verify to check if they are mathematically equivalent + + Args: + completions: List of model completions, each containing a list of messages + solution: List of ground truth solutions + **kwargs: Additional arguments (ignored but required for trainer compatibility) + + Returns: + List of rewards where: + - 1.0 if the answer is correct + - 0.0 if the answer is incorrect + - None if the solution is not parseable (skips this example) + """ + contents = [completion[0]["content"] for completion in completions] + rewards = [] + + for content, sol in zip(contents, solution): + # Strip reasoning tags from completion + while "" in content and "" in content: + start = content.find("") + end = content.find("", start) + if start != -1 and end != -1: + content = content[:start] + content[end + len("") :] + else: + break + + # Parse gold solution + gold_parsed = parse( + f"${sol}$", + extraction_config=[ + LatexExtractionConfig( + boxed_match_priority=0, try_extract_without_anchor=True + ) + ], + ) + + if len(gold_parsed) != 0: + # We require the answer to be provided in correct latex (no malformed operators) + answer_parsed = parse( + content, + extraction_config=[ + LatexExtractionConfig( + boxed_match_priority=0, + normalization_config=NormalizationConfig( + basic_latex=True, + units=True, + malformed_operators=False, + nits=False, + boxed=True, + ), + try_extract_without_anchor=False, + ) + ], + extraction_mode="first_match", + ) + + # Compute binary rewards if verifiable, `None` otherwise to skip this example + try: + reward = float(verify(gold_parsed, answer_parsed)) + except Exception as e: + print( + f"verify failed: {e}, answer: {answer_parsed}, gold: {gold_parsed}" + ) + reward = None + else: + # If the gold solution is not parseable, we assign `None` to skip this example + reward = None + + rewards.append(reward) + + return rewards +``` + +
+ + + + + +We can implement these recommendations with the TRL Python API like so: + +```python + +from datasets import load_dataset +from peft import LoraConfig +from trl import GRPOConfig, GRPOTrainer + +dataset = load_dataset("HuggingFaceH4/OpenR1-Math-220k-default-verified", split="train") + +def strip_reasoning_accuracy_reward(completions, **kwargs): + """Reward function that strips reasoning and accuracy scores from the model outputs.""" + + ... + +peft_config = LoraConfig( + lora_r=1, + lora_alpha=32, + lora_target_modules="all-linear" +) + +training_args = GRPOConfig( + learning_rate=5e-5, + per_device_train_batch_size=1, + gradient_accumulation_steps=4, + num_train_epochs=1, + num_generations=8, + generation_batch_size=8, + report_to=["trackio"], +) + +trainer = GRPOTrainer( + model="Qwen/Qwen3-0.6B", + reward_funcs=strip_reasoning_accuracy_reward, + args=training_args, + train_dataset=dataset, + peft_config=peft_config, +) + +trainer.train() + +``` + +> [!WARNING] +> This snippet skips the reward function which is defined above to keep the example concise. + + + + + +```bash + +hf jobs uv run \ + --flavor a100-large \ + --timeout 4h \ + --secrets HF_TOKEN \ + --env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True \ + "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \ + --output_dir grpo-full-qwen3-0.6b \ + --learning_rate 1.0e-6 \ + --lr_scheduler_type cosine \ + --warmup_ratio 0.0 \ + --max_grad_norm 1.0 \ + --beta 0.0 \ + --max_prompt_length 1024 \ + --max_completion_length 4096 \ + --num_generations 16 \ + --generation_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --per_device_train_batch_size 1 \ + --num_train_epochs 1 \ + --lora_r 1 \ + --lora_alpha 32 \ + --lora_dropout 0.0 \ + --lora_target_modules all-linear \ + --vllm_mode colocate \ + --save_strategy steps \ + --save_steps 50 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --max_steps 200 \ + --report_to trackio +``` + +To use Hugging Face Jobs, you will need to be logged in to the Hugging Face Hub (`hf auth login`) and have a [Pro](https://hf.co/pro), [Team](https://hf.co/enterprise), or [Enterprise](https://hf.co/enterprise) plan. Check out the [Jobs documentation](https://huggingface.co/docs/huggingface_hub/en/guides/jobs) for more details. + + + + +```bash + +uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/main/grpo.py" \ + --model_name_or_path Qwen/Qwen3-0.6B \ + --dataset_name HuggingFaceH4/OpenR1-Math-220k-default-verified \ + --output_dir grpo-full-qwen3-0.6b \ + --learning_rate 1.0e-6 \ + --lr_scheduler_type cosine \ + --warmup_ratio 0.0 \ + --max_grad_norm 1.0 \ + --beta 0.0 \ + --max_prompt_length 1024 \ + --max_completion_length 4096 \ + --num_generations 16 \ + --generation_batch_size 16 \ + --gradient_accumulation_steps 8 \ + --per_device_train_batch_size 1 \ + --num_train_epochs 1 \ + --lora_r 1 \ + --lora_alpha 32 \ + --lora_dropout 0.0 \ + --lora_target_modules all-linear \ + --vllm_mode colocate \ + --save_strategy steps \ + --save_steps 50 \ + --save_total_limit 1 \ + --logging_steps 1 \ + --max_steps 200 \ + --report_to trackio +``` + +To run the script locally, you will need to have `uv` installed. Check out the [uv documentation](https://docs.astral.sh/uv/) for more details. + + + + +The reinforcement learning script with GRPO is implemented as a custom script in TRL, which uses the reward function shown above. You can review it at [`grpo.py`](https://huggingface.co/datasets/burtenshaw/lora-without-regrets/blob/main/grpo.py) - Reinforcement learning with LoRA best practices + +## Key findings in optimizing LoRA + +The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. + +We were able to reproduce the results of the blog post using TRL and the SmolLM3 model. We trained the model for 500 steps on the [Math 220k dataset](https://huggingface.co/datasets/HuggingFaceH4/OpenR1-Math-220k-default-verified) with the reward function and configuration above. As you can see in the figure below, the LoRA model's average train reward curve matches the full fine-tuning curve. + +![train reward](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/5.png) + +And most importantly, the LoRA model uses significantly less memory than the full fine-tuning model, as we can see in the figure below. + +![memory usage](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/6.png) + +Here are the parameters we used to train the above models + +| Parameter | LoRA | Full FT | +|----------------------------------|----------------------------------------------------|-------------------------------| +| `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B | +| `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified | +| `--learning_rate` | 1.0e-6 | 1.0e-5 | +| `--max_prompt_length` | 1024 | 1024 | +| `--max_completion_length` | 4096 | 4096 | +| `--lora_r` | 1 | - | +| `--lora_alpha` | 32 | - | +| `--lora_dropout` | 0.0 | - | +| `--lora_target_modules` | all-linear | - | + +Let's break down the key findings of the blog post and how we were able to reproduce them. + +### 1. *LoRA performs better when applied to all weight matrices* + +The authors recommend applying LoRA to all weight matrices rather than limiting it to attention layers, as increasing the rank does not compensate for this restriction. + +![all layers](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/1.png) + +Attention-only LoRA underperforms even when using a higher rank to match parameter count. In TRL, this can be configured using `--lora_target_modules all-linear` to apply LoRA to all weight matrices. In Python, we can do this like so: + +```python +from peft import LoraConfig + +peft_config = LoraConfig(target_modules="all-linear") +``` + +### 2. *The adapter needs sufficient capacity to learn from the dataset* + +The blog post recommends using a sufficient LoRA rank to learn from the dataset. The rank determines the number of trainable parameters in the LoRA adapter. Therefore, "For datasets that exceed LoRA capacity, LoRA underperforms FullFT". + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/3.png) + +In the TRL script, we could use `--lora_r` to set the rank and adapt it based on the task and dataset we're training on. The blog post recommends the following ranks based on the task and dataset size: + +Reinforcement learning tasks typically require lower capacity, so smaller LoRA ranks can be used. This is because policy gradient algorithms extract roughly ~1 bit of information per episode, demanding minimal parameter capacity. + +The blog post defines the ideal dataset size for LoRA to match full fine-tuning as "Post-training scale". Which we can use to determine the recommended rank for SFT and RL LoRAs as: + +| Task Type | Dataset Size | Recommended Rank | +|-----------|-------------|------------------| +| **SFT** | Post-training scale | 256 | +| **RL** | Any size | 1-32 | + +### 3. *"FullFT and high-rank LoRAs have similar learning curves"* + +Counterintuitively, the blog post recommends using similar learning rates to full fine-tuning. In the TRL script, we could use `--learning_rate` to set the learning rate. The \\( \frac{1}{r} \\) scaling in LoRA makes the optimal learning rate approximately rank-independent. + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/2.png) + +### 4. *"In some scenarios, LoRA is less tolerant of large batch sizes than full fine-tuning."* + +The blog post recommends using an effective batch size < 32 because the authors found LoRA to be less tolerant of large batch sizes. This could not be mitigated by increasing the LoRA rank. In the TRL script, we could use `--per_device_train_batch_size` and `--gradient_accumulation_steps` to set the batch size. + +![learning rate](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lora_without_regret/4.png) + +## Takeaways + +Using TRL, you can efficiently implement LoRA adapters to match full fine-tuning performance, applying the core insights (targeting all weight matrices, choosing the right rank, and managing batch size and learning rate) without the heavy compute cost of FullFT. + +## Citation + +```bibtex +@article{schulman2025lora, + title = {{LoRA Without Regret}}, + author = {John Schulman and Thinking Machines Lab}, + year = 2025, + journal = {Thinking Machines Lab: Connectionism}, + doi = {10.64434/tml.20250929}, + note = {https://thinkingmachines.ai/blog/lora/} +} +``` From c1e7ad2696bbc3c006d8a1fd53ec5c90a1946f5a Mon Sep 17 00:00:00 2001 From: burtenshaw Date: Mon, 6 Oct 2025 08:23:11 +0200 Subject: [PATCH 074/153] [DOCS/FIX] lora without regrets - fix lr (#4207) --- docs/source/lora_without_regret.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/lora_without_regret.md b/docs/source/lora_without_regret.md index 04eed828d27..3dd2061ecaa 100644 --- a/docs/source/lora_without_regret.md +++ b/docs/source/lora_without_regret.md @@ -376,7 +376,7 @@ Here are the parameters we used to train the above models |----------------------------------|----------------------------------------------------|-------------------------------| | `--model_name_or_path` | HuggingFaceTB/SmolLM3-3B | HuggingFaceTB/SmolLM3-3B | | `--dataset_name` | HuggingFaceH4/OpenR1-Math-220k-default-verified | HuggingFaceH4/OpenR1-Math-220k-default-verified | -| `--learning_rate` | 1.0e-6 | 1.0e-5 | +| `--learning_rate` | 1.0e-5 | 1.0e-6 | | `--max_prompt_length` | 1024 | 1024 | | `--max_completion_length` | 4096 | 4096 | | `--lora_r` | 1 | - | From 5d34144b6f489dae0784f4f2cfe6b07c793878c6 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 6 Oct 2025 08:31:58 +0200 Subject: [PATCH 075/153] Remove custome_container for building the docs (#4198) --- .github/workflows/build_documentation.yml | 1 - .github/workflows/build_pr_documentation.yml | 1 - 2 files changed, 2 deletions(-) diff --git a/.github/workflows/build_documentation.yml b/.github/workflows/build_documentation.yml index d66349f6c85..5570c872b16 100644 --- a/.github/workflows/build_documentation.yml +++ b/.github/workflows/build_documentation.yml @@ -14,6 +14,5 @@ jobs: commit_sha: ${{ github.sha }} package: trl version_tag_suffix: "" - custom_container: huggingface/transformers-doc-builder secrets: hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} diff --git a/.github/workflows/build_pr_documentation.yml b/.github/workflows/build_pr_documentation.yml index 53134b68500..d8febd5d87e 100644 --- a/.github/workflows/build_pr_documentation.yml +++ b/.github/workflows/build_pr_documentation.yml @@ -16,4 +16,3 @@ jobs: pr_number: ${{ github.event.number }} package: trl version_tag_suffix: "" - custom_container: huggingface/transformers-doc-builder From ae2a0e71adbb2ea2b1c7feaec9e97ab8b8447122 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 6 Oct 2025 11:04:20 +0200 Subject: [PATCH 076/153] Remove tokenizer creation from `sft` example script (#4197) --- trl/scripts/sft.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/trl/scripts/sft.py b/trl/scripts/sft.py index 742cf29d741..96d3dd901e0 100644 --- a/trl/scripts/sft.py +++ b/trl/scripts/sft.py @@ -68,7 +68,7 @@ from accelerate import logging from datasets import load_dataset -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM from transformers.models.auto.modeling_auto import MODEL_FOR_IMAGE_TEXT_TO_TEXT_MAPPING_NAMES from trl import ( @@ -93,7 +93,7 @@ def main(script_args, training_args, model_args, dataset_args): ################ - # Model init kwargs & Tokenizer + # Model init kwargs ################ model_kwargs = dict( revision=model_args.model_revision, @@ -118,11 +118,6 @@ def main(script_args, training_args, model_args, dataset_args): else: model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) - # Create tokenizer - tokenizer = AutoTokenizer.from_pretrained( - model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True - ) - # Load the dataset if dataset_args.datasets and script_args.dataset_name: logger.warning( @@ -145,7 +140,6 @@ def main(script_args, training_args, model_args, dataset_args): args=training_args, train_dataset=dataset[script_args.dataset_train_split], eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, - processing_class=tokenizer, peft_config=get_peft_config(model_args), ) From 6543f51a9dafa2427fdfbc5621b91aa0bb7ecffe Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:13:21 +0200 Subject: [PATCH 077/153] Hotfix: Exclude transformers 4.57.0 for Python 3.9 (#4209) Co-authored-by: Sergio Paniego Blanco --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index cf82100e979..a0dabfe0908 100644 --- a/setup.cfg +++ b/setup.cfg @@ -30,6 +30,7 @@ install_requires = accelerate>=1.4.0 datasets>=3.0.0 transformers>=4.56.1 + transformers!=4.57.0; python_version == "3.9" [options.packages.find] exclude = From 8319ce0b75b59c7eb87236cb1b75882713a0dead Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 6 Oct 2025 11:14:54 +0200 Subject: [PATCH 078/153] Replace unittest with pytest (#4188) --- tests/slow/test_dpo_slow.py | 20 +- tests/slow/test_grpo_slow.py | 53 ++- tests/slow/test_sft_slow.py | 23 +- tests/test_activation_offloading.py | 15 +- tests/test_bco_trainer.py | 64 ++-- tests/test_best_of_n_sampler.py | 8 +- tests/test_callbacks.py | 53 ++- tests/test_cli.py | 14 +- tests/test_cli_utils.py | 149 +++++---- tests/test_collators.py | 5 +- tests/test_core.py | 11 +- tests/test_cpo_trainer.py | 22 +- tests/test_data_utils.py | 200 +++++------- tests/test_dataset_formatting.py | 86 +++-- tests/test_dpo_trainer.py | 198 +++++------- tests/test_gkd_trainer.py | 95 +++--- tests/test_grpo_trainer.py | 295 ++++++++--------- tests/test_judges.py | 27 +- tests/test_kto_trainer.py | 81 +++-- ...test_modeling_geometric_mixture_wrapper.py | 17 +- tests/test_modeling_value_head.py | 128 ++++---- tests/test_nash_md_trainer.py | 18 +- tests/test_online_dpo_trainer.py | 126 ++++---- tests/test_orpo_trainer.py | 18 +- tests/test_peft_models.py | 63 ++-- tests/test_ppo_trainer.py | 14 +- tests/test_prm_trainer.py | 93 +++--- tests/test_reward_trainer.py | 121 ++++--- tests/test_rewards.py | 21 +- tests/test_rich_progress_callback.py | 3 +- tests/test_rloo_trainer.py | 179 +++++----- tests/test_sft_trainer.py | 284 ++++++++-------- tests/test_trainers_args.py | 188 +++++------ tests/test_utils.py | 305 +++++++++--------- tests/test_vllm_client_server.py | 130 ++++---- tests/test_xpo_trainer.py | 18 +- tests/testing_utils.py | 104 ++---- 37 files changed, 1489 insertions(+), 1760 deletions(-) diff --git a/tests/slow/test_dpo_slow.py b/tests/slow/test_dpo_slow.py index 3b76fd8ea07..26feb388c6b 100644 --- a/tests/slow/test_dpo_slow.py +++ b/tests/slow/test_dpo_slow.py @@ -21,12 +21,12 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig -from transformers.testing_utils import backend_empty_cache, require_peft, require_torch_accelerator, torch_device +from transformers.testing_utils import backend_empty_cache, require_torch_accelerator, torch_device from transformers.utils import is_peft_available from trl import DPOConfig, DPOTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft from .testing_constants import DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST @@ -37,9 +37,8 @@ @pytest.mark.slow @require_torch_accelerator @require_peft -class DPOTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestDPOTrainerSlow(TrlTestCase): + def setup_method(self): self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference") self.peft_config = LoraConfig( lora_alpha=16, @@ -50,11 +49,10 @@ def setUp(self): ) self.max_length = 128 - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(list(itertools.product(MODELS_TO_TEST, DPO_LOSS_TYPES, DPO_PRECOMPUTE_LOGITS))) def test_dpo_bare_model(self, model_id, loss_type, pre_compute_logits): @@ -151,8 +149,8 @@ def test_dpo_peft_model(self, model_id, loss_type, pre_compute_logits, gradient_ peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) - self.assertIsNone(trainer.ref_model) + assert isinstance(trainer.model, PeftModel) + assert trainer.ref_model is None # train the model trainer.train() @@ -215,8 +213,8 @@ def test_dpo_peft_model_qlora(self, model_id, loss_type, pre_compute_logits, gra peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) - self.assertIsNone(trainer.ref_model) + assert isinstance(trainer.model, PeftModel) + assert trainer.ref_model is None # train the model trainer.train() diff --git a/tests/slow/test_grpo_slow.py b/tests/slow/test_grpo_slow.py index 69ba2b6ce34..17798d10b11 100644 --- a/tests/slow/test_grpo_slow.py +++ b/tests/slow/test_grpo_slow.py @@ -35,7 +35,6 @@ backend_empty_cache, require_flash_attn, require_liger_kernel, - require_peft, require_torch_accelerator, torch_device, ) @@ -44,7 +43,7 @@ from trl import GRPOConfig, GRPOTrainer from trl.trainer.utils import get_kbit_device_map -from ..testing_utils import TrlTestCase, require_bitsandbytes, require_vllm +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft, require_vllm from .testing_constants import MODELS_TO_TEST @@ -54,18 +53,16 @@ @pytest.mark.slow @require_torch_accelerator -class GRPOTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestGRPOTrainerSlow(TrlTestCase): + def setup_method(self): self.train_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") self.eval_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="test") self.max_length = 128 - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(MODELS_TO_TEST) @require_liger_kernel @@ -103,7 +100,7 @@ def test_training_with_liger_grpo_loss(self, model_name): for n, param in previous_trainable_params.items(): new_param = model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -153,20 +150,20 @@ def test_training_with_liger_grpo_loss_and_peft(self, model_name): # Verify PEFT adapter is properly initialized from peft import PeftModel - self.assertTrue(isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT") + assert isinstance(trainer.model, PeftModel), "Model should be wrapped with PEFT" # Store adapter weights before training previous_trainable_params = { n: param.clone() for n, param in trainer.model.named_parameters() if param.requires_grad } - self.assertTrue(len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model") + assert len(previous_trainable_params) > 0, "No trainable parameters found in PEFT model" trainer.train() # Verify adapter weights have changed after training for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -199,12 +196,12 @@ def test_training_with_transformers_paged(self, model_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." release_memory(model, trainer) @@ -310,13 +307,13 @@ def reward_func(prompts, completions, **kwargs): peft_config=lora_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that LoRA parameters have changed # For VLM models, we're more permissive about which parameters can change @@ -328,7 +325,7 @@ def reward_func(prompts, completions, **kwargs): lora_params_changed = True # At least some LoRA parameters should have changed during training - self.assertTrue(lora_params_changed, "No LoRA parameters were updated during training.") + assert lora_params_changed, "No LoRA parameters were updated during training." except torch.OutOfMemoryError as e: self.skipTest(f"Skipping VLM training test due to insufficient GPU memory: {e}") @@ -378,8 +375,8 @@ def test_vlm_processor_vllm_colocate_mode(self): processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", use_fast=True, padding_side="left") # Verify processor has both required attributes for VLM detection - self.assertTrue(hasattr(processor, "tokenizer")) - self.assertTrue(hasattr(processor, "image_processor")) + assert hasattr(processor, "tokenizer") + assert hasattr(processor, "image_processor") def dummy_reward_func(completions, **kwargs): return [1.0] * len(completions) @@ -438,16 +435,14 @@ def dummy_reward_func(completions, **kwargs): ) # Should detect VLM processor correctly and allow vLLM - self.assertTrue(trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode") - self.assertEqual(trainer.vllm_mode, "colocate", "Should use colocate mode") + assert trainer.use_vllm, "vLLM should be enabled for VLM processors in colocate mode" + assert trainer.vllm_mode == "colocate", "Should use colocate mode" # Check if signature columns were set properly if trainer._signature_columns is not None: # Should include 'image' in signature columns for VLM processors - self.assertIn( - "image", - trainer._signature_columns, - "Should include 'image' in signature columns for VLM", + assert "image" in trainer._signature_columns, ( + "Should include 'image' in signature columns for VLM" ) # Should not emit any warnings about VLM incompatibility @@ -457,10 +452,8 @@ def dummy_reward_func(completions, **kwargs): if "does not support VLMs" in str(w_item.message) or "not compatible" in str(w_item.message).lower() ] - self.assertEqual( - len(incompatibility_warnings), - 0, - f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}", + assert len(incompatibility_warnings) == 0, ( + f"Should not emit VLM incompatibility warnings, but got: {incompatibility_warnings}" ) # Test passes if we get this far without exceptions @@ -525,12 +518,12 @@ def test_training_vllm(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." except Exception as e: # If vLLM fails to initialize due to hardware constraints or other issues, that's expected diff --git a/tests/slow/test_sft_slow.py b/tests/slow/test_sft_slow.py index db762df107d..b6928b697c7 100755 --- a/tests/slow/test_sft_slow.py +++ b/tests/slow/test_sft_slow.py @@ -24,7 +24,6 @@ from transformers.testing_utils import ( backend_empty_cache, require_liger_kernel, - require_peft, require_torch_accelerator, require_torch_multi_accelerator, torch_device, @@ -33,7 +32,7 @@ from trl import SFTConfig, SFTTrainer -from ..testing_utils import TrlTestCase, require_bitsandbytes +from ..testing_utils import TrlTestCase, require_bitsandbytes, require_peft from .testing_constants import DEVICE_MAP_OPTIONS, GRADIENT_CHECKPOINTING_KWARGS, MODELS_TO_TEST, PACKING_OPTIONS @@ -44,9 +43,8 @@ @pytest.mark.slow @require_torch_accelerator @require_peft -class SFTTrainerSlowTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestSFTTrainerSlow(TrlTestCase): + def setup_method(self): self.train_dataset = load_dataset("stanfordnlp/imdb", split="train[:10%]") self.eval_dataset = load_dataset("stanfordnlp/imdb", split="test[:10%]") self.max_length = 128 @@ -58,11 +56,10 @@ def setUp(self): task_type="CAUSAL_LM", ) - def tearDown(self): + def teardown_method(self): gc.collect() backend_empty_cache(torch_device) gc.collect() - super().tearDown() @parameterized.expand(list(itertools.product(MODELS_TO_TEST, PACKING_OPTIONS))) def test_sft_trainer_str(self, model_name, packing): @@ -148,7 +145,7 @@ def test_sft_trainer_peft(self, model_name, packing): peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -252,7 +249,7 @@ def test_sft_trainer_transformers_mp_gc_peft(self, model_name, packing, gradient peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -332,7 +329,7 @@ def test_sft_trainer_transformers_mp_gc_peft_qlora(self, model_name, packing, gr peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -372,7 +369,7 @@ def test_sft_trainer_with_chat_format_qlora(self, model_name, packing): peft_config=self.peft_config, ) - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) trainer.train() @@ -447,11 +444,11 @@ def test_train_offloading(self, model_name, packing): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" release_memory(trainer.model, trainer) diff --git a/tests/test_activation_offloading.py b/tests/test_activation_offloading.py index a80563005f5..d1a9ea921f5 100644 --- a/tests/test_activation_offloading.py +++ b/tests/test_activation_offloading.py @@ -16,12 +16,12 @@ import torch from torch import nn from transformers import AutoModelForCausalLM -from transformers.testing_utils import require_peft, require_torch_accelerator, torch_device +from transformers.testing_utils import require_torch_accelerator, torch_device from transformers.utils import is_peft_available from trl.models.activation_offloading import NoOpManager, OffloadActivations -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): @@ -72,9 +72,8 @@ def test_offloading_with_peft_models(self) -> None: for name_orig, grad_orig in grads_original: for name_param, param in model.named_parameters(): if name_param == name_orig and param.requires_grad and param.grad is not None: - self.assertTrue( - torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), - f"Gradient mismatch for {name_orig}", + assert torch.allclose(grad_orig, param.grad, rtol=1e-4, atol=1e-5), ( + f"Gradient mismatch for {name_orig}" ) @require_torch_accelerator @@ -105,7 +104,7 @@ def test_noop_manager_with_offloading(self): # Gradients should match as NoOpManager should have prevented offloading for g1, g2 in zip(grads1, grads2): - self.assertTrue(torch.allclose(g1, g2, rtol=1e-4, atol=1e-5)) + assert torch.allclose(g1, g2, rtol=1e-4, atol=1e-5) @require_torch_accelerator def test_min_offload_size(self): @@ -152,6 +151,6 @@ def test_real_hf_model(self): grads2 = [p.grad.clone() for p in model.parameters()] # Check outputs and gradients match - self.assertTrue(torch.allclose(out1, out2, rtol=1e-5)) + assert torch.allclose(out1, out2, rtol=1e-5) for g1, g2 in zip(grads1, grads2): - self.assertTrue(torch.allclose(g1, g2, rtol=1e-5)) + assert torch.allclose(g1, g2, rtol=1e-5) diff --git a/tests/test_bco_trainer.py b/tests/test_bco_trainer.py index 68dbf7d9a30..7b7f0414438 100644 --- a/tests/test_bco_trainer.py +++ b/tests/test_bco_trainer.py @@ -14,25 +14,25 @@ from functools import partial +import pytest import torch from accelerate import Accelerator from datasets import load_dataset from parameterized import parameterized from transformers import AutoModel, AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import BCOConfig, BCOTrainer from trl.trainer.bco_trainer import _process_tokens, _tokenize -from .testing_utils import TrlTestCase, require_no_wandb, require_sklearn +from .testing_utils import TrlTestCase, require_no_wandb, require_peft, require_sklearn if is_peft_available(): from peft import LoraConfig -class BCOTrainerTester(TrlTestCase): +class TestBCOTrainer(TrlTestCase): @parameterized.expand( [ ("standard_preference",), @@ -71,13 +71,13 @@ def test_train(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_with_precompute(self): @@ -108,13 +108,13 @@ def test_train_with_precompute(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_eval(self): @@ -158,7 +158,7 @@ def test_init_with_ref_model_is_model(self): report_to="none", ) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): BCOTrainer( model=model, ref_model=model, # ref_model can't be the same as model @@ -196,13 +196,13 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"][:], dataset["prompt"][:]) - self.assertListEqual(tokenized_dataset["completion"][:], dataset["completion"][:]) - self.assertListEqual(tokenized_dataset["label"][:], dataset["label"][:]) - self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) - self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + assert tokenized_dataset["prompt"][:] == dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == dataset["completion"][:] + assert tokenized_dataset["label"][:] == dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] fn_kwargs = { "prefix": "", @@ -214,14 +214,14 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs) - self.assertListEqual(processed_dataset["prompt"][:], dataset["prompt"][:]) - self.assertListEqual(processed_dataset["completion"][:], dataset["completion"][:]) - self.assertListEqual(processed_dataset["label"][:], dataset["label"][:]) - self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]) - self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]) + assert processed_dataset["prompt"][:] == dataset["prompt"][:] + assert processed_dataset["completion"][:] == dataset["completion"][:] + assert processed_dataset["label"][:] == dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] @require_sklearn def test_train_without_providing_ref_model(self): @@ -249,13 +249,13 @@ def test_train_without_providing_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn def test_train_udm(self): @@ -298,13 +298,13 @@ def embed_prompt(input_ids, attention_mask, model): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn @require_peft @@ -335,14 +335,14 @@ def test_train_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param.cpu(), new_param.cpu())) + assert not torch.equal(param.cpu(), new_param.cpu()) @require_sklearn @require_no_wandb @@ -362,9 +362,9 @@ def test_generate_during_eval_no_wandb(self): report_to="none", ) - with self.assertRaisesRegex( + with pytest.raises( ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." " Please install `wandb` or `comet-ml` to resolve.", ): BCOTrainer( @@ -440,4 +440,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_best_of_n_sampler.py b/tests/test_best_of_n_sampler.py index 471f75c0c7c..d52538c71d0 100644 --- a/tests/test_best_of_n_sampler.py +++ b/tests/test_best_of_n_sampler.py @@ -27,7 +27,7 @@ def queries_to_scores(list_of_strings): return [torch.rand(1).item() for _ in list_of_strings] -class BestOfNSamplerTester(TrlTestCase): +class TestBestOfNSampler(TrlTestCase): """ Tests the BestOfNSampler class """ @@ -74,8 +74,8 @@ def test_different_input_types(self): for q, expected_length in various_queries_formats: results = best_of_n.generate(q) - self.assertIsInstance(results, list) - self.assertEqual(len(results), expected_length) + assert isinstance(results, list) + assert len(results) == expected_length def test_different_sample_sizes_and_n_candidates_values(self): r""" @@ -110,4 +110,4 @@ def test_different_sample_sizes_and_n_candidates_values(self): tokenized_queries = [self.tokenizer.encode(query) for query in queries] results = best_of_n.generate(tokenized_queries) for result in results: - self.assertEqual(len(result), expected) + assert len(result) == expected diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py index 7904a4ae374..316c9b35ae1 100644 --- a/tests/test_callbacks.py +++ b/tests/test_callbacks.py @@ -18,11 +18,10 @@ from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig, Trainer, TrainingArguments -from transformers.testing_utils import require_peft, require_wandb +from transformers.testing_utils import require_wandb from transformers.trainer_utils import get_last_checkpoint from transformers.utils import is_peft_available -from tests.testing_utils import require_comet, require_mergekit from trl import ( BasePairwiseJudge, BEMACallback, @@ -34,7 +33,7 @@ ) from trl.mergekit_utils import MergeConfig -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_comet, require_mergekit, require_peft if is_peft_available(): @@ -66,9 +65,8 @@ def __init__(self, model, ref_model, args, train_dataset, eval_dataset, processi self.ref_model = ref_model -class WinRateCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestWinRateCallback(TrlTestCase): + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -119,7 +117,7 @@ def test_basic(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) def test_without_ref_model(self): # Same as before, but without the ref_model attribute. It should use the model attribute instead @@ -145,7 +143,7 @@ def test_without_ref_model(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) def test_soft_judge(self): """Test that the soft judge functionality works correctly""" @@ -188,7 +186,7 @@ def test_soft_judge(self): if "eval_avg_win_prob" in h ] for history_row, expected_row in zip(winrate_history, expected_soft_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) @require_peft def test_lora(self): @@ -222,12 +220,11 @@ def test_lora(self): trainer.train() winrate_history = [h for h in trainer.state.log_history if "eval_win_rate" in h] for history_row, expected_row in zip(winrate_history, self.expected_winrates): - self.assertTrue(all(key in history_row and history_row[key] == expected_row[key] for key in expected_row)) + assert all(key in history_row and history_row[key] == expected_row[key] for key in expected_row) -class LogCompletionsCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestLogCompletionsCallback(TrlTestCase): + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token @@ -273,12 +270,12 @@ def test_basic_wandb(self): completions = json.load(f) # Check that the columns are correct - self.assertIn("step", completions["columns"]) - self.assertIn("prompt", completions["columns"]) - self.assertIn("completion", completions["columns"]) + assert "step" in completions["columns"] + assert "prompt" in completions["columns"] + assert "completion" in completions["columns"] # Check that the prompt is in the log - self.assertIn(self.dataset["test"][0]["prompt"], completions["data"][0]) + assert self.dataset["test"][0]["prompt"] in completions["data"][0] @require_comet def test_basic_comet(self): @@ -320,9 +317,8 @@ def test_basic_comet(self): @require_mergekit -class MergeModelCallbackTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestMergeModelCallback(TrlTestCase): + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") @@ -347,7 +343,7 @@ def test_callback(self): trainer.train() last_checkpoint = get_last_checkpoint(self.tmp_dir) merged_path = os.path.join(last_checkpoint, "merged") - self.assertTrue(os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint.") + assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint." def test_every_checkpoint(self): training_args = DPOConfig( @@ -374,12 +370,11 @@ def test_every_checkpoint(self): for checkpoint in checkpoints: merged_path = os.path.join(checkpoint, "merged") - self.assertTrue(os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}.") + assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." -class BEMACallbackTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestBEMACallback(TrlTestCase): + def setup_method(self): self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.tokenizer.pad_token = self.tokenizer.eos_token @@ -409,7 +404,7 @@ def test_model_saved(self): # Check that the BEMA model was saved and can be loaded bema_path = os.path.join(self.tmp_dir, "bema") - self.assertTrue(os.path.isdir(bema_path), "BEMA directory was not created") + assert os.path.isdir(bema_path), "BEMA directory was not created" AutoModelForCausalLM.from_pretrained(bema_path) def test_update_frequency_0(self): @@ -430,7 +425,7 @@ def test_update_frequency_0(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 0 and updates every 2 steps → updates at 2, 4, 5, 8 - self.assertEqual(mock_update.call_args_list, [call(2), call(4), call(6), call(8)]) + assert mock_update.call_args_list == [call(2), call(4), call(6), call(8)] def test_update_frequency_1(self): """Test that BEMA callback respects the update frequency.""" @@ -450,7 +445,7 @@ def test_update_frequency_1(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 0 and updates every 3 steps → updates at 3, 6, 9 - self.assertEqual(mock_update.call_args_list, [call(3), call(6), call(9)]) + assert mock_update.call_args_list == [call(3), call(6), call(9)] def test_update_frequency_2(self): """Test that BEMA callback respects the update frequency.""" @@ -470,7 +465,7 @@ def test_update_frequency_2(self): # Total 9 steps (17 samples, batch size 8, 3 epochs). # BEMA starts after step 3 and updates every 2 steps → updates at 5, 7, 9 - self.assertEqual(mock_update.call_args_list, [call(5), call(7), call(9)]) + assert mock_update.call_args_list == [call(5), call(7), call(9)] def test_no_bema(self): """Test that BEMACallback works without BEMA updates.""" diff --git a/tests/test_cli.py b/tests/test_cli.py index 23b5d6bcff7..48087f5054c 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -15,18 +15,18 @@ import os import sys -import unittest from io import StringIO from unittest.mock import patch +import pytest import yaml from .testing_utils import TrlTestCase -@unittest.skipIf( +@pytest.mark.skipif( sys.version_info < (3, 10), - "Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " + reason="Transformers' generation codebase uses a Python >3.10 syntax (`str | None`), which seems to cause the CLI tests " "to fail on Python <3.10.", # let's say it's a known issue, but not expected to be fixed, because too niche ) class TestCLI(TrlTestCase): @@ -51,7 +51,7 @@ def test_env(self, mock_stdout): command = "trl env" with patch("sys.argv", command.split(" ")): main() - self.assertIn("TRL version: ", mock_stdout.getvalue().strip()) + assert "TRL version: " in mock_stdout.getvalue().strip() def test_grpo(self): from trl.cli import main @@ -112,8 +112,4 @@ def test_sft_config_file(self): main() # Verify that output directory was created - self.assertTrue(os.path.exists(output_dir)) - - -if __name__ == "__main__": - unittest.main() + assert os.path.exists(output_dir) diff --git a/tests/test_cli_utils.py b/tests/test_cli_utils.py index 271dd6f5e5b..a7cda2fddf8 100644 --- a/tests/test_cli_utils.py +++ b/tests/test_cli_utils.py @@ -13,10 +13,10 @@ # limitations under the License. import tempfile -import unittest from dataclasses import dataclass from unittest.mock import mock_open, patch +import pytest from datasets import DatasetDict, load_dataset from trl import DatasetMixtureConfig, TrlParser, get_dataset @@ -40,13 +40,12 @@ class TestTrlParser(TrlTestCase): def test_init_without_config_field(self): """Test initialization without 'config' field in the dataclasses.""" parser = TrlParser(dataclass_types=[MyDataclass]) - self.assertIsInstance(parser, TrlParser) + assert isinstance(parser, TrlParser) def test_init_with_config_field(self): """Test initialization with a 'config' field in the dataclass (should raise ValueError).""" - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="has a field named 'config'"): TrlParser(dataclass_types=[InvalidDataclass]) - self.assertTrue("has a field named 'config'" in str(context.exception)) @patch("builtins.open", mock_open(read_data="env:\n VAR1: value1\n VAR2: value2\narg1: 2")) @patch("yaml.safe_load") @@ -67,14 +66,14 @@ def test_parse_args_and_config_with_valid_config(self, mock_environ, mock_yaml_l mock_environ["VAR2"] = "value2" # Ensure that the environment variables were set correctly - self.assertEqual(mock_environ.get("VAR1"), "value1") - self.assertEqual(mock_environ.get("VAR2"), "value2") + assert mock_environ.get("VAR1") == "value1" + assert mock_environ.get("VAR2") == "value2" # Check the parsed arguments - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" @patch("builtins.open", mock_open(read_data="arg1: 2")) @patch("yaml.safe_load") @@ -90,9 +89,9 @@ def test_parse_args_and_arg_override_config(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check the parsed arguments - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 3) + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 3 @patch("builtins.open", mock_open(read_data="env: not_a_dict")) @patch("yaml.safe_load") @@ -104,11 +103,9 @@ def test_parse_args_and_config_with_invalid_env(self, mock_yaml_load): args = ["--arg1", "2", "--arg2", "value", "--config", "config.yaml"] - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="`env` field should be a dict in the YAML file."): parser.parse_args_and_config(args) - self.assertEqual(str(context.exception), "`env` field should be a dict in the YAML file.") - def test_parse_args_and_config_without_config(self): """Test parse_args_and_config without the `--config` argument.""" parser = TrlParser(dataclass_types=[MyDataclass]) @@ -119,10 +116,10 @@ def test_parse_args_and_config_without_config(self): result_args = parser.parse_args_and_config(args) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" def test_set_defaults_with_config(self): """Test set_defaults_with_config updates the defaults.""" @@ -133,9 +130,9 @@ def test_set_defaults_with_config(self): # Ensure the default value is updated result_args = parser.parse_args_and_config([]) - self.assertEqual(len(result_args), 1) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 42) + assert len(result_args) == 1 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 42 def test_parse_args_and_config_with_remaining_strings(self): parser = TrlParser(dataclass_types=[MyDataclass]) @@ -146,11 +143,11 @@ def test_parse_args_and_config_with_remaining_strings(self): result_args = parser.parse_args_and_config(args, return_remaining_strings=True) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 2) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[0].arg2, "value") - self.assertEqual(result_args[1], ["remaining"]) + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[0].arg2 == "value" + assert result_args[1] == ["remaining"] @patch("builtins.open", mock_open(read_data="remaining_string_in_config: abc")) @patch("yaml.safe_load") @@ -165,10 +162,10 @@ def test_parse_args_and_config_with_remaining_strings_in_config_and_args(self, m result_args = parser.parse_args_and_config(args, return_remaining_strings=True) # Check that the arguments are parsed as is - self.assertEqual(len(result_args), 2) - self.assertIsInstance(result_args[0], MyDataclass) - self.assertEqual(result_args[0].arg1, 2) - self.assertEqual(result_args[1], ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"]) + assert len(result_args) == 2 + assert isinstance(result_args[0], MyDataclass) + assert result_args[0].arg1 == 2 + assert result_args[1] == ["--remaining_string_in_config", "abc", "--remaining_string_in_args", "def"] @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) @patch("yaml.safe_load") @@ -190,11 +187,11 @@ def test_subparsers_with_config_defaults(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check main parser arguments - self.assertEqual(len(result_args), 1) + assert len(result_args) == 1 # Check that config values were applied to the subparser - self.assertEqual(result_args[0].arg1, 2) # Default from config - self.assertEqual(result_args[0].arg2, "config_value") # Default from config + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config @patch("builtins.open", mock_open(read_data="arg1: 2\narg2: config_value")) @patch("yaml.safe_load") @@ -216,8 +213,8 @@ def test_subparsers_with_config_defaults_and_arg_override(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Command line arguments should override config - self.assertEqual(result_args[0].arg1, 3) - self.assertEqual(result_args[0].arg2, "config_value") # Still from config + assert result_args[0].arg1 == 3 + assert result_args[0].arg2 == "config_value" # Still from config @patch("builtins.open", mock_open(read_data="arg1: 2\nthis_arg_does_not_exist: config_value")) @patch("yaml.safe_load") @@ -236,7 +233,7 @@ def test_subparsers_with_config_defaults_and_arg_override_wrong_name(self, mock_ # Test with command line arguments overriding config args = ["subcommand", "--arg1", "3", "--config", "config.yaml"] - with self.assertRaises(ValueError): + with pytest.raises(ValueError): parser.parse_args_and_config(args) parser.parse_args_and_config(args, fail_with_unknown_args=False) @@ -263,21 +260,21 @@ def test_subparsers_multiple_with_config_defaults(self, mock_yaml_load): result_args = parser.parse_args_and_config(args) # Check main parser arguments - self.assertEqual(len(result_args), 1) + assert len(result_args) == 1 # Check that config values were applied to the subparser - self.assertEqual(result_args[0].arg1, 2) # Default from config - self.assertEqual(result_args[0].arg2, "config_value") # Default from config + assert result_args[0].arg1 == 2 # Default from config + assert result_args[0].arg2 == "config_value" # Default from config -class TestGetDataset(unittest.TestCase): +class TestGetDataset: def test_single_dataset_with_config(self): mixture_config = DatasetMixtureConfig( datasets=[DatasetConfig(path="trl-internal-testing/zen", name="standard_language_modeling")] ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] def test_single_dataset_preference_config(self): mixture_config = DatasetMixtureConfig( @@ -285,7 +282,7 @@ def test_single_dataset_preference_config(self): ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_preference") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] def test_single_dataset_streaming(self): mixture_config = DatasetMixtureConfig( @@ -294,7 +291,7 @@ def test_single_dataset_streaming(self): ) result = get_dataset(mixture_config) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"].to_list(), list(result["train"])) + assert expected["train"].to_list() == list(result["train"]) def test_dataset_mixture_basic(self): dataset_config1 = DatasetConfig( @@ -305,15 +302,15 @@ def test_dataset_mixture_basic(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) + assert isinstance(result, DatasetDict) + assert "train" in result train_dataset = result["train"] - self.assertEqual(train_dataset.column_names, ["prompt"]) + assert train_dataset.column_names == ["prompt"] prompts = train_dataset["prompt"] expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") - self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] expected_second_half = load_dataset("trl-internal-testing/zen", "standard_prompt_completion", split="train") - self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] def test_dataset_mixture_with_weights(self): dataset_config1 = DatasetConfig( @@ -324,17 +321,17 @@ def test_dataset_mixture_with_weights(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) + assert isinstance(result, DatasetDict) + assert "train" in result train_dataset = result["train"] - self.assertEqual(train_dataset.column_names, ["prompt"]) + assert train_dataset.column_names == ["prompt"] prompts = train_dataset["prompt"] expected_first_half = load_dataset("trl-internal-testing/zen", "standard_preference", split="train[:50%]") - self.assertEqual(prompts[: len(prompts) // 2], expected_first_half["prompt"]) + assert prompts[: len(prompts) // 2] == expected_first_half["prompt"] expected_second_half = load_dataset( "trl-internal-testing/zen", "standard_prompt_completion", split="train[:50%]" ) - self.assertEqual(prompts[len(prompts) // 2 :], expected_second_half["prompt"]) + assert prompts[len(prompts) // 2 :] == expected_second_half["prompt"] def test_dataset_mixture_with_test_split(self): mixture_config = DatasetMixtureConfig( @@ -342,20 +339,18 @@ def test_dataset_mixture_with_test_split(self): test_split_size=2, ) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) - self.assertIn("test", result) - self.assertEqual(len(result["train"]), 15) - self.assertEqual(len(result["test"]), 2) + assert isinstance(result, DatasetDict) + assert "train" in result + assert "test" in result + assert len(result["train"]) == 15 + assert len(result["test"]) == 2 def test_empty_dataset_mixture_raises_error(self): mixture_config = DatasetMixtureConfig(datasets=[]) - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="No datasets were loaded"): get_dataset(mixture_config) - self.assertIn("No datasets were loaded", str(context.exception)) - def test_mixture_multiple_different_configs(self): dataset_config1 = DatasetConfig( path="trl-internal-testing/zen", name="conversational_preference", split="train", columns=["prompt"] @@ -365,9 +360,9 @@ def test_mixture_multiple_different_configs(self): ) mixture_config = DatasetMixtureConfig(datasets=[dataset_config1, dataset_config2]) result = get_dataset(mixture_config) - self.assertIsInstance(result, DatasetDict) - self.assertIn("train", result) - self.assertGreater(len(result["train"]), 0) + assert isinstance(result, DatasetDict) + assert "train" in result + assert len(result["train"]) > 0 def test_trlparser_parses_yaml_config_correctly(self): # Prepare YAML content exactly like your example @@ -390,24 +385,24 @@ def test_trlparser_parses_yaml_config_correctly(self): args = parser.parse_args_and_config(args=["--config", tmpfile.name])[0] # Assert that we got DatasetMixtureConfig instance - self.assertIsInstance(args, DatasetMixtureConfig) + assert isinstance(args, DatasetMixtureConfig) # Assert datasets list length - self.assertEqual(len(args.datasets), 2) + assert len(args.datasets) == 2 # Check first dataset dataset_config1 = args.datasets[0] - self.assertIsInstance(dataset_config1, DatasetConfig) - self.assertEqual(dataset_config1.path, "trl-internal-testing/zen") - self.assertEqual(dataset_config1.name, "standard_prompt_only") - self.assertIsNone(dataset_config1.columns) # No columns specified + assert isinstance(dataset_config1, DatasetConfig) + assert dataset_config1.path == "trl-internal-testing/zen" + assert dataset_config1.name == "standard_prompt_only" + assert dataset_config1.columns is None # No columns specified # Check second dataset dataset_config2 = args.datasets[1] - self.assertIsInstance(dataset_config2, DatasetConfig) - self.assertEqual(dataset_config2.path, "trl-internal-testing/zen") - self.assertEqual(dataset_config2.name, "standard_preference") - self.assertEqual(dataset_config2.columns, ["prompt"]) # Columns specified + assert isinstance(dataset_config2, DatasetConfig) + assert dataset_config2.path == "trl-internal-testing/zen" + assert dataset_config2.name == "standard_preference" + assert dataset_config2.columns == ["prompt"] # Columns specified def test_trlparser_parses_yaml_and_loads_dataset(self): # Prepare YAML content exactly like your example @@ -428,4 +423,4 @@ def test_trlparser_parses_yaml_and_loads_dataset(self): # Load the dataset using get_dataset result = get_dataset(args) expected = load_dataset("trl-internal-testing/zen", "standard_language_modeling") - self.assertEqual(expected["train"][:], result["train"][:]) + assert expected["train"][:] == result["train"][:] diff --git a/tests/test_collators.py b/tests/test_collators.py index b578758f027..3159184f558 100644 --- a/tests/test_collators.py +++ b/tests/test_collators.py @@ -21,12 +21,11 @@ class TestDataCollatorForPreference(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.collator = DataCollatorForPreference(pad_token_id=0) def assertTensorEqual(self, tensor1, tensor2): - self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") + assert torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}" def test_padding_behavior(self): examples = [ diff --git a/tests/test_core.py b/tests/test_core.py index bab69ca9da2..85d99615be9 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -20,22 +20,21 @@ from .testing_utils import TrlTestCase -class CoreTester(TrlTestCase): +class TestCore(TrlTestCase): """ A wrapper class for testing core utils functions """ - def setUp(self): - super().setUp() + def setup_method(self): self.test_input = torch.Tensor([1, 2, 3, 4]) self.test_mask = torch.Tensor([0, 1, 1, 0]) self.test_input_unmasked = self.test_input[1:3] def test_masked_mean(self): - self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) + assert torch.mean(self.test_input_unmasked) == masked_mean(self.test_input, self.test_mask) def test_masked_var(self): - self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) + assert torch.var(self.test_input_unmasked) == masked_var(self.test_input, self.test_mask) def test_masked_whiten(self): def whiten(values: torch.Tensor) -> torch.Tensor: @@ -45,4 +44,4 @@ def whiten(values: torch.Tensor) -> torch.Tensor: whiten_unmasked = whiten(self.test_input_unmasked) whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] diffs = (whiten_unmasked - whiten_masked).sum() - self.assertLess(abs(diffs.item()), 0.00001) + assert abs(diffs.item()) < 0.00001 diff --git a/tests/test_cpo_trainer.py b/tests/test_cpo_trainer.py index cc3e394846d..56792f608dc 100644 --- a/tests/test_cpo_trainer.py +++ b/tests/test_cpo_trainer.py @@ -17,17 +17,15 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft from trl import CPOConfig, CPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft -class CPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestCPOTrainer(TrlTestCase): + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) @@ -87,13 +85,13 @@ def test_cpo_trainer(self, name, loss_type, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @parameterized.expand( [ @@ -143,14 +141,14 @@ def test_cpo_trainer_with_lora(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") @@ -180,7 +178,7 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 def test_alphapo_trainer(self): training_args = CPOConfig( @@ -212,9 +210,9 @@ def test_alphapo_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index 0a9eba7f7bb..8fe8a24bd50 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -15,7 +15,6 @@ import copy import itertools import textwrap -import unittest from time import strftime from datasets import Dataset, DatasetDict @@ -40,7 +39,7 @@ from .testing_utils import TrlTestCase -class PrepareMultimodalMessagesTester(unittest.TestCase): +class TestPrepareMultimodalMessages: def test_basic_user_assistant_conversation(self): """Test basic conversation with user and assistant messages.""" messages = [ @@ -55,7 +54,7 @@ def test_basic_user_assistant_conversation(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_first_user_message_gets_image(self): """Test that only the first user message gets an image placeholder.""" @@ -73,7 +72,7 @@ def test_first_user_message_gets_image(self): {"role": "user", "content": [{"type": "text", "text": "How about the grass?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_multiple_images(self): """Test that multiple images are added to the first user message.""" @@ -97,7 +96,7 @@ def test_multiple_images(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_system_message_transformation(self): """Test that system messages are properly transformed.""" @@ -113,7 +112,7 @@ def test_system_message_transformation(self): {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected def test_already_prepared_messages_unchanged(self): """Test that messages with list content are not modified.""" @@ -126,7 +125,7 @@ def test_already_prepared_messages_unchanged(self): original = copy.deepcopy(messages) prepare_multimodal_messages(messages, num_images=1) - self.assertEqual(messages, original) + assert messages == original def test_mixed_prepared_and_unprepared_messages(self): """Test handling of mixed prepared and unprepared messages.""" @@ -144,10 +143,10 @@ def test_mixed_prepared_and_unprepared_messages(self): {"role": "user", "content": [{"type": "text", "text": "What about the grass?"}]}, ] - self.assertEqual(messages, expected) + assert messages == expected -class IsConversationalTester(TrlTestCase): +class TestIsConversational(TrlTestCase): conversational_examples = [ { # Language modeling "messages": [ @@ -250,14 +249,14 @@ class IsConversationalTester(TrlTestCase): @parameterized.expand(itertools.product(conversational_examples)) def test_conversational(self, example): - self.assertTrue(is_conversational(example)) + assert is_conversational(example) @parameterized.expand(itertools.product(non_conversational_examples)) def test_non_conversational(self, example): - self.assertFalse(is_conversational(example)) + assert not is_conversational(example) -class IsConversationalFromValueTester(TrlTestCase): +class TestIsConversationalFromValue(TrlTestCase): def test_positive_1(self): example = { "conversations": [ @@ -265,7 +264,7 @@ def test_positive_1(self): {"from": "assistant", "value": "It is blue."}, ], } - self.assertTrue(is_conversational_from_value(example)) + assert is_conversational_from_value(example) def test_negative_1(self): example = { @@ -274,14 +273,14 @@ def test_negative_1(self): {"role": "assistant", "content": "It is blue."}, ], } - self.assertFalse(is_conversational_from_value(example)) + assert not is_conversational_from_value(example) def test_negative_2(self): example = {"text": "The sky is blue."} - self.assertFalse(is_conversational_from_value(example)) + assert not is_conversational_from_value(example) -class ApplyChatTemplateTester(TrlTestCase): +class TestApplyChatTemplate(TrlTestCase): tokenizers = [ "trl-internal-testing/tiny-CohereForCausalLM", "trl-internal-testing/tiny-DbrxForCausalLM", @@ -352,24 +351,24 @@ def test_apply_chat_template(self, tokenizer_id, example): result = apply_chat_template(example, tokenizer) # Checking if the result is a dictionary - self.assertIsInstance(result, dict) + assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: - self.assertIn(key, result) - self.assertIsInstance(result[key], str) + assert key in result + assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: - self.assertIn("text", result) - self.assertIsInstance(result["text"], str) + assert "text" in result + assert isinstance(result["text"], str) # The label should be kept if "label" in example: - self.assertIn("label", result) - self.assertIsInstance(result["label"], bool) - self.assertEqual(result["label"], example["label"]) + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] # both conversational and non-conversational examples @parameterized.expand(itertools.product(tokenizers, conversational_examples + non_conversational_examples)) @@ -378,24 +377,24 @@ def test_maybe_apply_chat_template(self, tokenizer_id, example): result = maybe_apply_chat_template(example, tokenizer) # Checking if the result is a dictionary - self.assertIsInstance(result, dict) + assert isinstance(result, dict) # The chat template should be applied to the following keys for key in ["prompt", "chosen", "rejected", "completion"]: if key in example: - self.assertIn(key, result) - self.assertIsInstance(result[key], str) + assert key in result + assert isinstance(result[key], str) # Exception for messages, the key is "text" once the chat template is applied if "messages" in example: - self.assertIn("text", result) - self.assertIsInstance(result["text"], str) + assert "text" in result + assert isinstance(result["text"], str) # The label should be kept if "label" in example: - self.assertIn("label", result) - self.assertIsInstance(result["label"], bool) - self.assertEqual(result["label"], example["label"]) + assert "label" in result + assert isinstance(result["label"], bool) + assert result["label"] == example["label"] def test_apply_chat_template_with_tools(self): tokenizer = AutoProcessor.from_pretrained("trl-internal-testing/tiny-LlamaForCausalLM-3.2") @@ -420,16 +419,16 @@ def get_current_temperature(location: str): result_with_tools = apply_chat_template(test_case, tokenizer, tools=[get_current_temperature]) # Verify tools are included in the output - self.assertIn("get_current_temperature", result_with_tools["prompt"]) + assert "get_current_temperature" in result_with_tools["prompt"] # Test without tools result_without_tools = apply_chat_template(test_case, tokenizer, tools=None) # Verify tools are not included in the output - self.assertNotIn("get_current_temperature", result_without_tools["prompt"]) + assert "get_current_temperature" not in result_without_tools["prompt"] -class ApplyChatTemplateHarmonyTester(TrlTestCase): +class TestApplyChatTemplateHarmony(TrlTestCase): def test_language_modeling(self): messages = { "messages": [ @@ -459,7 +458,7 @@ def test_language_modeling(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>""") - self.assertEqual(output["text"], expected) + assert output["text"] == expected def test_prompt_only(self): messages = { @@ -489,7 +488,7 @@ def test_prompt_only(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") - self.assertEqual(output["prompt"], expected) + assert output["prompt"] == expected def test_prompt_completion(self): messages = { @@ -523,8 +522,8 @@ def test_prompt_completion(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["completion"], expected_completion) + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion def test_preference(self): messages = { @@ -562,9 +561,9 @@ def test_preference(self): expected_chosen = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" expected_rejected = "<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["chosen"], expected_chosen) - self.assertEqual(output["rejected"], expected_rejected) + assert output["prompt"] == expected_prompt + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected def test_preference_with_implicit_prompt(self): messages = { @@ -614,8 +613,8 @@ def test_preference_with_implicit_prompt(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant<|channel|>analysis<|message|>The user asks the color of the tree...<|end|><|start|>assistant<|channel|>final<|message|>It is green.<|return|>""") - self.assertEqual(output["chosen"], expected_chosen) - self.assertEqual(output["rejected"], expected_rejected) + assert output["chosen"] == expected_chosen + assert output["rejected"] == expected_rejected def test_unpaired_preference(self): messages = { @@ -650,12 +649,12 @@ def test_unpaired_preference(self): <|end|><|start|>user<|message|>What color is the sky?<|end|><|start|>assistant""") expected_completion = "<|channel|>analysis<|message|>The user asks the color of the sky...<|end|><|start|>assistant<|channel|>final<|message|>It is blue.<|return|>" - self.assertEqual(output["prompt"], expected_prompt) - self.assertEqual(output["completion"], expected_completion) - self.assertTrue(output["label"]) + assert output["prompt"] == expected_prompt + assert output["completion"] == expected_completion + assert output["label"] -class UnpairPreferenceDatasetTester(TrlTestCase): +class TestUnpairPreferenceDataset(TrlTestCase): paired_dataset = Dataset.from_dict( { "prompt": ["The sky is", "The sun is"], @@ -675,61 +674,49 @@ class UnpairPreferenceDatasetTester(TrlTestCase): def test_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired unpaired_dataset = unpair_preference_dataset(self.paired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." ) def test_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = unpair_preference_dataset(paired_dataset_dict) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset(self): # Test that a paired dataset is correctly converted to unpaired with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.paired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset_dict(self): # Test that a paired dataset dict is correctly converted to unpaired with maybe_unpair_preference_dataset paired_dataset_dict = DatasetDict({"abc": self.paired_dataset}) unpaired_dataset_dict = maybe_unpair_preference_dataset(paired_dataset_dict) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The paired dataset should be converted to unpaired.", + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The paired dataset should be converted to unpaired." ) def test_maybe_unpair_preference_dataset_already_paired(self): # Test that a paired dataset remains unchanged with maybe_unpair_preference_dataset unpaired_dataset = maybe_unpair_preference_dataset(self.unpaired_dataset) - self.assertEqual( - unpaired_dataset.to_dict(), - self.unpaired_dataset.to_dict(), - "The unpaired dataset should remain unchanged.", + assert unpaired_dataset.to_dict() == self.unpaired_dataset.to_dict(), ( + "The unpaired dataset should remain unchanged." ) def test_maybe_unpair_preference_dataset_dict_already_paired(self): # Test that a paired dataset dict remains unchanged with maybe_unpair_preference_dataset unpaired_dataset_dict = maybe_unpair_preference_dataset(DatasetDict({"abc": self.unpaired_dataset})) - self.assertEqual( - unpaired_dataset_dict["abc"].to_dict(), - self.unpaired_dataset.to_dict(), - "The unpaired dataset should remain unchanged.", + assert unpaired_dataset_dict["abc"].to_dict() == self.unpaired_dataset.to_dict(), ( + "The unpaired dataset should remain unchanged." ) -class ExtractPromptTester(TrlTestCase): +class TestExtractPrompt(TrlTestCase): example_implicit_prompt_conversational = { "chosen": [ {"role": "user", "content": "What color is the sky?"}, @@ -767,56 +754,42 @@ class ExtractPromptTester(TrlTestCase): def test_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt is not correctly extracted from the dataset.", + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_conversational(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt is not correctly extracted from the dataset.", + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_conversational_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_conversational) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_conversational, - "The prompt should remain unchanged.", + assert example_extracted_prompt == self.example_explicit_prompt_conversational, ( + "The prompt should remain unchanged." ) def test_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset example_extracted_prompt = extract_prompt(self.example_implicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt is not correctly extracted from the dataset.", + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( + "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_standard(self): # Test that the prompt is correctly extracted from the dataset with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_implicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt is not correctly extracted from the dataset.", + assert example_extracted_prompt == self.example_explicit_prompt_standard, ( + "The prompt is not correctly extracted from the dataset." ) def test_maybe_extract_prompt_standard_already_explicit(self): # Test that the prompt remains unchanged with maybe_extract_prompt example_extracted_prompt = maybe_extract_prompt(self.example_explicit_prompt_standard) - self.assertEqual( - example_extracted_prompt, - self.example_explicit_prompt_standard, - "The prompt should remain unchanged.", - ) + assert example_extracted_prompt == self.example_explicit_prompt_standard, "The prompt should remain unchanged." class TestPackDatasetWrapped(TrlTestCase): @@ -832,7 +805,7 @@ def test_with_dataset(self): "attention_mask": [[0, 1, 1], [0, 0, 1], [1, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -847,7 +820,7 @@ def test_with_iterable_dataset(self): } dataset = pack_dataset(dataset, seq_length, strategy="wrapped") num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output class TestPackDatasetBfd(TrlTestCase): @@ -864,7 +837,7 @@ def test_simple(self): "seq_lengths": [[4], [3, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -880,7 +853,7 @@ def test_with_iterable_dataset(self): } dataset = pack_dataset(dataset, seq_length, strategy="bfd") num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output def test_with_truncation(self): examples = { @@ -895,7 +868,7 @@ def test_with_truncation(self): "seq_lengths": [[4], [4], [2, 1]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_non_power_of_2(self): examples = { @@ -910,7 +883,7 @@ def test_with_non_power_of_2(self): "seq_lengths": [[5], [4, 1], [3]], } dataset = pack_dataset(dataset, seq_length, strategy="bfd") - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output class TestTruncateExamples(TrlTestCase): @@ -926,7 +899,7 @@ def test_with_dataset(self): "attention_mask": [[0, 1], [0, 0], [1]], } dataset = truncate_dataset(dataset, max_length) - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output def test_with_iterable_dataset(self): examples = { @@ -941,7 +914,7 @@ def test_with_iterable_dataset(self): } dataset = truncate_dataset(dataset, max_length) num_examples = len(examples[next(iter(examples))]) - self.assertEqual(next(iter(dataset.batch(batch_size=num_examples))), expected_output) + assert next(iter(dataset.batch(batch_size=num_examples))) == expected_output def test_with_extra_column(self): examples = { @@ -957,7 +930,7 @@ def test_with_extra_column(self): "my_column": ["a", "b", "c"], } dataset = truncate_dataset(dataset, max_length) - self.assertEqual(dataset.to_dict(), expected_output) + assert dataset.to_dict() == expected_output class TestMaybeConvertToChatML(TrlTestCase): @@ -975,7 +948,7 @@ def test_with_conversations_key(self): {"role": "assistant", "content": "It is blue."}, ] } - self.assertEqual(maybe_convert_to_chatml(example), expected_output) + assert maybe_convert_to_chatml(example) == expected_output def test_without_conversations_key(self): # Same as before, but we don't rename the keys @@ -987,12 +960,12 @@ def test_without_conversations_key(self): "prompt": [{"role": "user", "content": "What color is the sky?"}], "completion": [{"role": "assistant", "content": "It is blue."}], } - self.assertEqual(maybe_convert_to_chatml(example), expected_output) + assert maybe_convert_to_chatml(example) == expected_output def test_not_conversional(self): # When not needed, the example should remain unchanged example = {"text": "The sky is blue."} - self.assertEqual(maybe_convert_to_chatml(example), example) + assert maybe_convert_to_chatml(example) == example def test_already_chatml(self): # When the example is already in ChatML format, it should remain unchanged @@ -1002,9 +975,4 @@ def test_already_chatml(self): {"role": "assistant", "content": "It is blue."}, ] } - self.assertEqual(maybe_convert_to_chatml(example), example) - - -# Run the tests -if __name__ == "__main__": - unittest.main() + assert maybe_convert_to_chatml(example) == example diff --git a/tests/test_dataset_formatting.py b/tests/test_dataset_formatting.py index c85845e34c3..80f65f964de 100644 --- a/tests/test_dataset_formatting.py +++ b/tests/test_dataset_formatting.py @@ -23,9 +23,8 @@ from .testing_utils import TrlTestCase -class DatasetFormattingTestCase(TrlTestCase): - def setUp(self): - super().setUp() +class TestDatasetFormatting(TrlTestCase): + def setup_method(self): self.llama_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-MistralForCausalLM-0.1") self.chatml_tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -44,20 +43,20 @@ def test_get_formatting_func_from_dataset_with_chatml_messages(self): # Llama tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsInstance(formatting_func, Callable) + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] # ChatML tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) formatted_text = formatting_func(dataset[0]) expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] def test_get_formatting_func_from_dataset_with_chatml_conversations(self): dataset = Dataset.from_dict( @@ -73,53 +72,52 @@ def test_get_formatting_func_from_dataset_with_chatml_conversations(self): ) # Llama tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsInstance(formatting_func, Callable) + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) expected = " [INST] You are helpful\n\nHello [/INST] Hi, how can I help you?" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] # ChatML tokenizer formatting_func = get_formatting_func_from_dataset(dataset, self.chatml_tokenizer) formatted_text = formatting_func(dataset[0]) expected = "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" - self.assertEqual(formatted_text, expected) + assert formatted_text == expected formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [expected]) + assert formatted_text == [expected] def test_get_formatting_func_from_dataset_with_instruction(self): dataset = Dataset.from_list( [{"prompt": "What is 2+2?", "completion": "4"}, {"prompt": "What is 3+3?", "completion": "6"}] ) formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsNotNone(formatting_func) - self.assertIsInstance(formatting_func, Callable) + assert formatting_func is not None + assert isinstance(formatting_func, Callable) formatted_text = formatting_func(dataset[0]) - self.assertEqual(formatted_text, " [INST] What is 2+2? [/INST] 4") + assert formatted_text == " [INST] What is 2+2? [/INST] 4" formatted_text = formatting_func(dataset[0:1]) - self.assertListEqual(formatted_text, [" [INST] What is 2+2? [/INST] 4"]) + assert formatted_text == [" [INST] What is 2+2? [/INST] 4"] def test_get_formatting_func_from_dataset_from_hub(self): ds_1 = load_dataset("philschmid/trl-test-instruction", split="train") ds_2 = load_dataset("philschmid/dolly-15k-oai-style", split="train") for ds in [ds_1, ds_2]: formatting_func = get_formatting_func_from_dataset(ds, self.llama_tokenizer) - self.assertIsNotNone(formatting_func) - self.assertIsInstance(formatting_func, Callable) + assert formatting_func is not None + assert isinstance(formatting_func, Callable) ds_3 = load_dataset("philschmid/guanaco-sharegpt-style", split="train") formatting_func = get_formatting_func_from_dataset(ds_3, self.llama_tokenizer) - self.assertIsNone(formatting_func) + assert formatting_func is None def test_get_formatting_func_from_dataset_with_unknown_format(self): dataset = Dataset.from_dict({"text": "test"}) formatting_func = get_formatting_func_from_dataset(dataset, self.llama_tokenizer) - self.assertIsNone(formatting_func) + assert formatting_func is None -class SetupChatFormatTestCase(TrlTestCase): - def setUp(self): - super().setUp() +class TestSetupChatFormat(TrlTestCase): + def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") # remove built-in chat_template to simulate a model having no chat_template @@ -132,13 +130,13 @@ def test_setup_chat_format(self): _chatml = ChatMlSpecialTokens() # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") - self.assertEqual(modified_tokenizer.pad_token, "<|im_end|>") - self.assertEqual(modified_tokenizer.bos_token, "<|im_start|>") - self.assertEqual(modified_tokenizer.eos_token, _chatml.eos_token) - self.assertEqual(modified_tokenizer.pad_token, _chatml.pad_token) - self.assertEqual(modified_tokenizer.bos_token, _chatml.bos_token) - self.assertEqual((modified_model.vocab_size % 123), 0) + assert modified_tokenizer.eos_token == "<|im_end|>" + assert modified_tokenizer.pad_token == "<|im_end|>" + assert modified_tokenizer.bos_token == "<|im_start|>" + assert modified_tokenizer.eos_token == _chatml.eos_token + assert modified_tokenizer.pad_token == _chatml.pad_token + assert modified_tokenizer.bos_token == _chatml.bos_token + assert (modified_model.vocab_size % 123) == 0 def test_example_with_setup_model(self): modified_model, modified_tokenizer = setup_chat_format( @@ -152,13 +150,13 @@ def test_example_with_setup_model(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - self.assertEqual( - prompt, - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n", + assert ( + prompt + == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\nHi, how can I help you?<|im_end|>\n" ) -class CloneChatTemplateTestCase(TrlTestCase): +class TestCloneChatTemplate(TrlTestCase): def test_clone(self): # This tokenizer doesn't have a chat_template by default tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-BloomForCausalLM") @@ -168,7 +166,7 @@ def test_clone(self): _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") + assert modified_tokenizer.eos_token == "<|im_end|>" def test_clone_with_resize(self): # This tokenizer doesn't have a chat_template by default @@ -181,9 +179,9 @@ def test_clone_with_resize(self): ) # Check that the input embeddings have been resized to a multiple of 123 - self.assertEqual((modified_model.vocab_size % 123), 0) + assert (modified_model.vocab_size % 123) == 0 # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) + assert model.vocab_size == len(modified_tokenizer.vocab) def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): # This tokenizer doesn't have a chat_template by default @@ -201,9 +199,9 @@ def test_clone_with_resize_and_extra_tokens_already_in_vocab(self): ) # Check that the input embeddings have been resized to a multiple of 123 - self.assertEqual((modified_model.vocab_size % 124), 0) + assert (modified_model.vocab_size % 124) == 0 # Check that the input embeddings size matches the tokenizer vocabulary size - self.assertEqual(model.vocab_size, len(modified_tokenizer.vocab)) + assert model.vocab_size == len(modified_tokenizer.vocab) def test_apply_new_chat_template(self): # This tokenizer doesn't have a chat_template by default @@ -219,9 +217,9 @@ def test_apply_new_chat_template(self): ] prompt = modified_tokenizer.apply_chat_template(messages, tokenize=False) - self.assertEqual( - prompt, - "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n", + assert ( + prompt + == "<|im_start|>system\nYou are helpful<|im_end|>\n<|im_start|>user\nHello<|im_end|>\n<|im_start|>assistant\n\n\n\n\nHi, how can I help you?<|im_end|>\n" ) def test_clone_with_sequence_classification_model(self): @@ -235,4 +233,4 @@ def test_clone_with_sequence_classification_model(self): _, modified_tokenizer, _ = clone_chat_template(model, tokenizer, source) # Check if special tokens are correctly set - self.assertEqual(modified_tokenizer.eos_token, "<|im_end|>") + assert modified_tokenizer.eos_token == "<|im_end|>" diff --git a/tests/test_dpo_trainer.py b/tests/test_dpo_trainer.py index bf4a5d9826b..1d1e94fcf99 100644 --- a/tests/test_dpo_trainer.py +++ b/tests/test_dpo_trainer.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import sys -import unittest from unittest.mock import MagicMock import numpy as np +import pytest import torch from datasets import Dataset, features, load_dataset from parameterized import parameterized @@ -32,14 +33,12 @@ from transformers.testing_utils import ( get_device_properties, require_liger_kernel, - require_peft, require_torch_gpu_if_bnb_not_multi_backend_enabled, - require_vision, ) from trl import DPOConfig, DPOTrainer, FDivergenceType -from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb +from .testing_utils import TrlTestCase, require_bitsandbytes, require_no_wandb, require_peft, require_vision if is_vision_available(): @@ -47,8 +46,7 @@ class TestTokenizeRow(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the mock tokenizer with specific behaviors self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) self.tokenizer.bos_token_id = 0 @@ -84,14 +82,11 @@ def test_tokenize_row_no_truncation_no_special_tokens(self): ) # Assert the correct output without truncation or special tokens - self.assertEqual( - result, - { - "prompt_input_ids": [464, 6766, 318], - "chosen_input_ids": [4171, 2], # eos_token added - "rejected_input_ids": [4077, 2], # eos_token added - }, - ) + assert result == { + "prompt_input_ids": [464, 6766, 318], + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + } def test_tokenize_row_with_truncation(self): # Define the input features @@ -107,14 +102,11 @@ def test_tokenize_row_with_truncation(self): ) # Assert the correct output with truncation applied - self.assertEqual( - result, - { - "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens - "chosen_input_ids": [4171], # truncated to 1 token - "rejected_input_ids": [4077], # truncated to 1 token - }, - ) + assert result == { + "prompt_input_ids": [6766, 318], # truncated to the last 2 tokens + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + } def test_tokenize_row_with_special_tokens(self): # Define the input features @@ -130,14 +122,11 @@ def test_tokenize_row_with_special_tokens(self): ) # Assert the correct output with special tokens added - self.assertEqual( - result, - { - "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added - "chosen_input_ids": [4171, 2], # eos_token added - "rejected_input_ids": [4077, 2], # eos_token added - }, - ) + assert result == { + "prompt_input_ids": [0, 464, 6766, 318, 2], # bos_token and eos_token added + "chosen_input_ids": [4171, 2], # eos_token added + "rejected_input_ids": [4077, 2], # eos_token added + } def test_tokenize_row_with_truncation_and_special_tokens(self): # Define the input features @@ -153,19 +142,15 @@ def test_tokenize_row_with_truncation_and_special_tokens(self): ) # Assert the correct output with both truncation and special tokens - self.assertEqual( - result, - { - "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token - "chosen_input_ids": [4171], # truncated to 1 token - "rejected_input_ids": [4077], # truncated to 1 token - }, - ) + assert result == { + "prompt_input_ids": [464, 6766, 318, 2], # truncated to 4 tokens with bos_token and eos_token + "chosen_input_ids": [4171], # truncated to 1 token + "rejected_input_ids": [4077], # truncated to 1 token + } -class DPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestDPOTrainer(TrlTestCase): + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -193,13 +178,13 @@ def test_train(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @parameterized.expand( [ @@ -241,13 +226,13 @@ def test_train_loss_types(self, loss_type): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_liger_kernel def test_train_encoder_decoder_liger(self): @@ -274,13 +259,13 @@ def test_train_encoder_decoder_liger(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_dpo_trainer_with_weighting(self): dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") @@ -304,13 +289,13 @@ def test_dpo_trainer_with_weighting(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_train_with_multiple_loss_types(self): """ @@ -338,22 +323,21 @@ def test_train_with_multiple_loss_types(self): # Test that training works trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Verify SFT loss is computed in the first test too with torch.no_grad(): batch = next(iter(trainer.get_train_dataloader())) loss, metrics = trainer.get_batch_loss_metrics(trainer.model, batch) - self.assertIn("nll_loss", metrics) # SFT loss should be computed + assert "nll_loss" in metrics # SFT loss should be computed def test_wrong_loss_weights_length(self): - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="Length of loss_weights list"): DPOConfig( output_dir=self.tmp_dir, loss_type=["sigmoid", "bco_pair"], loss_weights=[1.0, 0.5, 0.1], # Wrong length ) - self.assertIn("Length of loss_weights list", str(context.exception)) @parameterized.expand([(None,), (0.5,)]) def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha): @@ -386,13 +370,13 @@ def test_dpo_trainer_without_providing_ref_model(self, rpo_alpha): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_dpo_trainer_with_ref_model_is_model(self): training_args = DPOConfig( @@ -404,7 +388,7 @@ def test_dpo_trainer_with_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): DPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -437,13 +421,13 @@ def test_precompute_ref_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_peft def test_dpo_trainer_without_providing_ref_model_with_lora(self): @@ -486,14 +470,14 @@ def test_dpo_trainer_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_dpo_trainer_w_dataset_num_proc(self): training_args = DPOConfig( @@ -555,13 +539,13 @@ def test_tr_dpo_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.ref_model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_no_wandb def test_dpo_trainer_generate_during_eval_no_wandb(self): @@ -580,9 +564,9 @@ def test_dpo_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaisesRegex( + with pytest.raises( ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." + match="`generate_during_eval=True` requires Weights and Biases, MLFlow or Comet to be installed." " Please install `wandb`, `mlflow` or `comet-ml` to resolve.", ): DPOTrainer( @@ -645,7 +629,7 @@ def test_dpo_lora_save(self): try: AutoModelForCausalLM.from_pretrained(self.tmp_dir) except OSError: - self.fail("Loading the saved peft adapter failed") + pytest.fail("Loading the saved peft adapter failed") @require_peft @require_torch_gpu_if_bnb_not_multi_backend_enabled @@ -729,9 +713,9 @@ def test_dpo_lora_bf16_autocast_llama(self): ) @require_bitsandbytes @require_peft - @unittest.skipIf( + @pytest.mark.skipif( get_device_properties()[0] == "cuda" and get_device_properties()[1] < 8, - "Skipping because bf16 not supported on CUDA GPU with capability < 8.0", + reason="Skipping because bf16 not supported on CUDA GPU with capability < 8.0", ) def test_dpo_lora_bf16_autocast(self, loss_type, pre_compute, gen_during_eval): from peft import LoraConfig @@ -826,7 +810,7 @@ def test_dpo_lora_tags(self): ) for tag in ["dpo", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_dpo_tags(self): @@ -861,7 +845,7 @@ def test_dpo_tags(self): ) for tag in ["dpo", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_dpo_lora_force_use_ref(self): @@ -895,7 +879,7 @@ def test_dpo_lora_force_use_ref(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): # passing a peft_model as model and ref_model should error out, # unless you pass `force_use_ref_model` trainer = DPOTrainer( @@ -953,8 +937,8 @@ def test_dpo_trainer_dtype(self): args=training_args, train_dataset=dummy_dataset["train"], ) - self.assertEqual(trainer.model.config.dtype, torch.float16) - self.assertEqual(trainer.ref_model.config.dtype, torch.float16) + assert trainer.model.config.dtype == torch.float16 + assert trainer.ref_model.config.dtype == torch.float16 # Now test when `dtype` is provided but is wrong to either the model or the ref_model training_args = DPOConfig( @@ -965,7 +949,12 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with self.assertRaises(ValueError) as context: + with pytest.raises( + ValueError, + match=re.escape( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1." + ), + ): _ = DPOTrainer( model=self.model_id, processing_class=self.tokenizer, @@ -973,12 +962,6 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - self.assertIn( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1.", - str(context.exception), - ) - training_args = DPOConfig( output_dir=self.tmp_dir, per_device_train_batch_size=2, @@ -987,7 +970,12 @@ def test_dpo_trainer_dtype(self): report_to="none", ) - with self.assertRaises(ValueError) as context: + with pytest.raises( + ValueError, + match=re.escape( + "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid `torch.dtype` (e.g., 'float32'), but got -1." + ), + ): _ = DPOTrainer( model=self.model_id, ref_model=self.model_id, @@ -996,12 +984,6 @@ def test_dpo_trainer_dtype(self): train_dataset=dummy_dataset["train"], ) - self.assertIn( - "Invalid `dtype` passed to the config. Expected either 'auto' or a string representing a valid " - "`torch.dtype` (e.g., 'float32'), but got -1.", - str(context.exception), - ) - def test_dpo_loss_alpha_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -1041,7 +1023,7 @@ def test_dpo_loss_alpha_div_f(self): losses, _, _ = trainer.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + assert torch.isfinite(losses).cpu().numpy().all() def test_dpo_loss_js_div_f(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1083,7 +1065,7 @@ def test_dpo_loss_js_div_f(self): losses, _, _ = trainer.dpo_loss( policy_chosen_logps, policy_rejected_logps, reference_chosen_logps, reference_rejected_logps ) - self.assertTrue(torch.isfinite(losses).cpu().numpy().all()) + assert torch.isfinite(losses).cpu().numpy().all() def test_dpo_trainer_use_logits_to_keep(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" @@ -1199,7 +1181,7 @@ def get_current_temperature(location: str): # We don't run the training, but at this stage, the dataset is supposed to be pre-processed. When # pre-processing, we expect the available tools to be explicitly mentioned in the system prompt. That's # what we're checking here - self.assertIn("get_current_temperature", tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0])) + assert "get_current_temperature" in tokenizer.decode(trainer.train_dataset["prompt_input_ids"][0]) def test_padding_free(self): model_id = "trl-internal-testing/tiny-LlamaForCausalLM-3.2" @@ -1235,7 +1217,7 @@ def test_padding_free(self): for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -1270,7 +1252,7 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 def test_train_with_length_desensitization(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1295,13 +1277,13 @@ def test_train_with_length_desensitization(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @parameterized.expand( [ @@ -1318,7 +1300,7 @@ def test_train_with_length_desensitization(self): ] ) @require_liger_kernel - @unittest.skipUnless(sys.version_info >= (3, 10), "Liger kernel is not supported on Python 3.9") + @pytest.mark.skipif(not (sys.version_info >= (3, 10)), reason="Liger kernel is not supported on Python 3.9") def test_dpo_trainer_with_liger(self, beta, loss_type): """Test DPO trainer with Liger loss enabled across supported loss types. @@ -1359,20 +1341,20 @@ def test_dpo_trainer_with_liger(self, beta, loss_type): train_output = trainer.train() # Verify training completed successfully - self.assertIsNotNone(train_output) - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert train_output is not None + assert trainer.state.log_history[-1]["train_loss"] is not None # Verify loss is finite - self.assertTrue(np.isfinite(trainer.state.log_history[-1]["train_loss"])) + assert np.isfinite(trainer.state.log_history[-1]["train_loss"]) # Check parameters have been updated for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # Only check non-zero parameters if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) # Verify new parameters are finite - self.assertTrue(torch.isfinite(new_param).all()) + assert torch.isfinite(new_param).all() # Verify model can still do forward pass after training dummy_batch = next(iter(trainer.get_train_dataloader())) @@ -1382,8 +1364,8 @@ def test_dpo_trainer_with_liger(self, beta, loss_type): } with torch.no_grad(): output = trainer.model(**model_inputs) - self.assertIsNotNone(output) - self.assertFalse("loss" in output.keys()) + assert output is not None + assert "loss" not in output.keys() def test_train_with_iterable_dataset(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -1411,17 +1393,17 @@ def test_train_with_iterable_dataset(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_vision -class DPOVisionTrainerTester(TrlTestCase): +class TestDPOVisionTrainer(TrlTestCase): @parameterized.expand( [ # ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), device issue from transformers, see https://github.com/huggingface/transformers/pull/39975 @@ -1494,7 +1476,7 @@ def test_vdpo_trainer(self, model_id): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the trainable params have changed for n, param in previous_trainable_params.items(): @@ -1510,7 +1492,7 @@ def test_vdpo_trainer(self, model_id): # For some reason, these params are not updated. This is probably not related to TRL, but to # the model itself. We should investigate this further, but for now we just skip these params. continue - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" class TestDPOConfig(TrlTestCase): @@ -1529,7 +1511,3 @@ def test_f_divergence_type(self, f_divergence_type, as_string: bool): # Serialization: TrainingArguments.to_dict should yield the enum's string value configparser_dict = training_args.to_dict() assert configparser_dict["f_divergence_type"] == f_divergence_type.value - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/test_gkd_trainer.py b/tests/test_gkd_trainer.py index 4a0d458440c..b311ce2b0b6 100644 --- a/tests/test_gkd_trainer.py +++ b/tests/test_gkd_trainer.py @@ -27,9 +27,9 @@ from .testing_utils import TrlTestCase -class TestGKDTrainer(TrlTestCase): +class TestGKDTrainerGenerateOnPolicy(TrlTestCase): @classmethod - def setUpClass(cls): + def setup_class(cls): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" cls.tokenizer = AutoTokenizer.from_pretrained(model_id) cls.tokenizer.pad_token = cls.tokenizer.eos_token @@ -70,9 +70,8 @@ def test_generate_on_policy_outputs_deterministic(self): # Check if the generated texts start with the original prompts for prompt, generated_text in zip(prompts, generated_texts): - self.assertTrue( - generated_text.startswith(prompt), - f"Generated text '{generated_text}' does not start with prompt '{prompt}'", + assert generated_text.startswith(prompt), ( + f"Generated text '{generated_text}' does not start with prompt '{prompt}'" ) # Run the generation twice and check if the outputs are identical @@ -83,15 +82,11 @@ def test_generate_on_policy_outputs_deterministic(self): new_input_ids2, new_attention_mask2, new_labels2 = outputs2 # Check if the two generations are identical - self.assertTrue(torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical") - self.assertTrue( - torch.all(new_attention_mask.eq(new_attention_mask2)), - "Attention masks for deterministic generations are not identical", - ) - self.assertTrue( - torch.all(new_labels.eq(new_labels2)), - "Labels for deterministic generations are not identical", + assert torch.all(new_input_ids.eq(new_input_ids2)), "Deterministic generations are not identical" + assert torch.all(new_attention_mask.eq(new_attention_mask2)), ( + "Attention masks for deterministic generations are not identical" ) + assert torch.all(new_labels.eq(new_labels2)), "Labels for deterministic generations are not identical" def test_generate_on_policy_outputs(self): prompts = ["Hello, how are you?", "What's the weather like today?"] @@ -107,30 +102,29 @@ def test_generate_on_policy_outputs(self): ) # Check that outputs is a tuple of three tensors - self.assertIsInstance(outputs, tuple) - self.assertEqual(len(outputs), 3) + assert isinstance(outputs, tuple) + assert len(outputs) == 3 new_input_ids, new_attention_mask, new_labels = outputs # Check shapes batch_size = len(prompts) - self.assertEqual(new_input_ids.shape[0], batch_size) - self.assertEqual(new_attention_mask.shape[0], batch_size) - self.assertEqual(new_labels.shape[0], batch_size) + assert new_input_ids.shape[0] == batch_size + assert new_attention_mask.shape[0] == batch_size + assert new_labels.shape[0] == batch_size # Check types - self.assertIsInstance(new_input_ids, torch.Tensor) - self.assertIsInstance(new_attention_mask, torch.Tensor) - self.assertIsInstance(new_labels, torch.Tensor) + assert isinstance(new_input_ids, torch.Tensor) + assert isinstance(new_attention_mask, torch.Tensor) + assert isinstance(new_labels, torch.Tensor) # Check that new_input_ids and new_attention_mask have the same shape - self.assertEqual(new_input_ids.shape, new_attention_mask.shape) - self.assertEqual(new_labels.shape, new_attention_mask.shape) + assert new_input_ids.shape == new_attention_mask.shape + assert new_labels.shape == new_attention_mask.shape class TestGeneralizedJSDLoss(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.batch_size = 2 self.seq_length = 3 self.vocab_size = 5 @@ -140,7 +134,7 @@ def setUp(self): def test_uniform_distribution(self): logits = torch.ones(1, 1, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(logits, logits) - self.assertAlmostEqual(loss.item(), 0, places=5) + assert round(abs(loss.item() - 0), 5) == 0 def test_generalized_jsd_loss_edge_cases(self): # Setup @@ -152,29 +146,29 @@ def test_generalized_jsd_loss_edge_cases(self): expected_loss_beta_1 = F.kl_div( F.log_softmax(teacher_logits, dim=-1), F.softmax(student_logits, dim=-1), reduction="batchmean" ) - self.assertAlmostEqual(loss_beta_1.item(), expected_loss_beta_1.item(), places=5) + assert round(abs(loss_beta_1.item() - expected_loss_beta_1.item()), 5) == 0 # Case 2: beta = 0 (should be equivalent to KL(teacher || student)) loss_beta_0 = GKDTrainer.generalized_jsd_loss(student_logits, teacher_logits, beta=0) expected_loss_beta_0 = F.kl_div( F.log_softmax(student_logits, dim=-1), F.softmax(teacher_logits, dim=-1), reduction="batchmean" ) - self.assertAlmostEqual(loss_beta_0.item(), expected_loss_beta_0.item(), places=5) + assert round(abs(loss_beta_0.item() - expected_loss_beta_0.item()), 5) == 0 def test_output_shape(self): loss = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits) - self.assertTrue(torch.is_tensor(loss)) - self.assertEqual(loss.shape, torch.Size([])) + assert torch.is_tensor(loss) + assert loss.shape == torch.Size([]) def test_beta_values(self): loss_beta_0 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0) loss_beta_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=1) - self.assertNotEqual(loss_beta_0, loss_beta_1) + assert loss_beta_0 != loss_beta_1 def test_temperature_scaling(self): loss_temp_1 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=1) loss_temp_2 = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, temperature=2) - self.assertNotEqual(loss_temp_1, loss_temp_2) + assert loss_temp_1 != loss_temp_2 def test_reduction_methods(self): loss_batchmean = GKDTrainer.generalized_jsd_loss( @@ -184,29 +178,28 @@ def test_reduction_methods(self): loss_mean = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="mean") loss_none = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, reduction="none") - self.assertEqual(loss_batchmean.shape, torch.Size([])) - self.assertEqual(loss_sum.shape, torch.Size([])) - self.assertEqual(loss_mean.shape, torch.Size([])) - self.assertEqual(loss_none.shape, self.student_logits.shape) + assert loss_batchmean.shape == torch.Size([]) + assert loss_sum.shape == torch.Size([]) + assert loss_mean.shape == torch.Size([]) + assert loss_none.shape == self.student_logits.shape def test_symmetry(self): student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.1) teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.1) - self.assertNotEqual(student_teacher, teacher_student) + assert student_teacher != teacher_student student_teacher = GKDTrainer.generalized_jsd_loss(self.student_logits, self.teacher_logits, beta=0.5) teacher_student = GKDTrainer.generalized_jsd_loss(self.teacher_logits, self.student_logits, beta=0.5) - self.assertEqual(student_teacher, teacher_student) + assert student_teacher == teacher_student def test_zero_loss_for_identical_inputs(self): identical_logits = torch.randn(self.batch_size, self.seq_length, self.vocab_size) loss = GKDTrainer.generalized_jsd_loss(identical_logits, identical_logits) - self.assertAlmostEqual(loss.item(), 0, places=6) + assert round(abs(loss.item() - 0), 6) == 0 -class GKDTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestGKDTrainer(TrlTestCase): + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.teacher_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -242,9 +235,9 @@ def test_gkd_trainer(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) - self.assertIn("model.safetensors", os.listdir(self.tmp_dir + "/checkpoint-2")) + assert trainer.state.log_history[(-1)]["train_loss"] is not None + assert trainer.state.log_history[0]["eval_loss"] is not None + assert "model.safetensors" in os.listdir(self.tmp_dir + "/checkpoint-2") @require_liger_kernel @pytest.mark.xfail(reason="Computing the Liger loss spikes GPU memory usage, causing the test to run OOM.") @@ -271,7 +264,7 @@ def test_gkd_trainer_with_liger(self): trainer.train() # Check we logged a train loss - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None def test_generation_config_init(self): training_args = GKDConfig(output_dir=self.tmp_dir) @@ -286,8 +279,8 @@ def test_generation_config_init(self): processing_class=self.tokenizer, ) - self.assertEqual(trainer.generation_config.pad_token_id, self.tokenizer.eos_token_id) - self.assertEqual(trainer.generation_config.eos_token_id, self.model.generation_config.eos_token_id) - self.assertEqual(trainer.generation_config.max_new_tokens, training_args.max_new_tokens) - self.assertEqual(trainer.generation_config.temperature, training_args.temperature) - self.assertEqual(trainer.generation_config.top_k, 0) + assert trainer.generation_config.pad_token_id == self.tokenizer.eos_token_id + assert trainer.generation_config.eos_token_id == self.model.generation_config.eos_token_id + assert trainer.generation_config.max_new_tokens == training_args.max_new_tokens + assert trainer.generation_config.temperature == training_args.temperature + assert trainer.generation_config.top_k == 0 diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b29e0769b89..17766ef0b98 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import patch import pytest @@ -25,7 +24,7 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_liger_kernel, require_peft, require_vision +from transformers.testing_utils import require_liger_kernel from transformers.utils import is_peft_available from trl import GRPOConfig, GRPOTrainer @@ -36,14 +35,14 @@ ) from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer -from .testing_utils import TrlTestCase, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm if is_peft_available(): from peft import LoraConfig, PeftModel -class GetHighEntropyMaskTester(TrlTestCase): +class TestGetHighEntropyMask(TrlTestCase): def get_high_entropy_mask(self, entropies, mask, threshold): """Helper method to test the get_high_entropy_mask functionality.""" # Create a mock trainer with minimal setup @@ -115,7 +114,7 @@ def test_compute_entropy_all_masked(self): torch.testing.assert_close(entropy_mask, expected_mask) -class GRPOTrainerTester(TrlTestCase): +class TestGRPOTrainer(TrlTestCase): def test_init_minimal(self): # Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -148,12 +147,12 @@ def test_training(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand([("bnpo",), ("dr_grpo",), ("dapo",)]) def test_training_loss_types(self, loss_type): @@ -180,12 +179,12 @@ def test_training_loss_types(self, loss_type): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") @@ -233,12 +232,12 @@ def test_training_multiple_iterations(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft(self): @@ -266,15 +265,15 @@ def test_training_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft_with_gradient_checkpointing(self): @@ -308,22 +307,22 @@ def test_training_peft_with_gradient_checkpointing(self): ) # Verify gradient checkpointing is enabled - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Store initial parameters to check which ones change previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that only LoRA parameters have changed, base model parameters remain unchanged for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "lora" in n.lower(): # LoRA parameters should change - self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed." else: # Base model parameters should not change - self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") + assert torch.equal(param, new_param), f"Base parameter {n} has changed." def test_training_different_reward_model(self): # Use a reward model different from the model: different chat template, tokenization, etc. @@ -357,12 +356,12 @@ def test_training_different_reward_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_standard(self): # Test if trainer can handle reward function with standard format @@ -391,12 +390,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_conversational(self): # Test if trainer can handle reward function with conversational format @@ -426,12 +425,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs(self): # Test that GRPOTrainer can be instantiated with multiple reward functions @@ -464,12 +463,12 @@ def reward_func2(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_None_output(self): """Test that a valid math reward function is processed correctly while the code reward function returns None.""" @@ -508,12 +507,12 @@ def non_applicable_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_weights(self): """Test that GRPOTrainer can handle multiple reward functions with weights.""" @@ -548,16 +547,16 @@ def reward_func2(completions, **kwargs): trainer.train() # Check that training logs contain both reward metrics - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert "rewards/reward_func1/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func1/std" in trainer.state.log_history[-1] + assert "rewards/reward_func2/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func2/std" in trainer.state.log_history[-1] # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models @@ -586,12 +585,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_additional_column(self): # Test if trainer can handle reward function that rely on additional columns in the dataset @@ -624,12 +623,12 @@ def reward_func(completions, some_values, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_sync_ref_model(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -655,12 +654,12 @@ def test_training_with_sync_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_beta_non_zero(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -684,12 +683,12 @@ def test_training_beta_non_zero(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_entropy_filter(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -713,16 +712,16 @@ def test_training_with_entropy_filter(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @unittest.skip("We should add a mock for the vLLM server.") @require_peft @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_and_peft(self): """Test that training works with vLLM for generation.""" model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM @@ -755,19 +754,19 @@ def test_training_vllm_and_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n and "original_module" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_guided_decoding(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -793,15 +792,15 @@ def test_training_vllm_guided_decoding(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_importance_sampling_correction(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -828,12 +827,12 @@ def test_training_vllm_importance_sampling_correction(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_additional_generation_kwargs(self): """Test that training works with additional generation kwargs.""" @@ -863,15 +862,15 @@ def test_training_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_with_additional_generation_kwargs(self): """Test that training works with vLLM and additional generation kwargs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -901,12 +900,12 @@ def test_training_vllm_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand([(False,), ("group",), ("batch",), (True,), ("none",)]) def test_training_scale_rewards(self, scale_rewards): @@ -932,12 +931,12 @@ def test_training_scale_rewards(self, scale_rewards): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @patch("transformers.generation.utils.GenerationMixin.generate") def test_training_with_mask_truncated_completions(self, mock_generate): @@ -982,12 +981,12 @@ def fake_generate(input_ids, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_mask_truncated_completions_all_masked(self): """ @@ -1020,14 +1019,14 @@ def test_training_with_mask_truncated_completions_all_masked(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") + assert torch.equal(param, new_param), f"Parameter {n} has changed." - def test_warning_raised_all_rewards_none(self): + def test_warning_raised_all_rewards_none(self, caplog): """Test that a proper warning is raised when all rewards are None.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1050,11 +1049,11 @@ def always_none_reward_func(completions, **kwargs): train_dataset=dataset, ) - with self.assertLogs("trl.trainer.grpo_trainer", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.grpo_trainer"): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in caplog.text def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1079,12 +1078,12 @@ def test_training_num_generations_larger_than_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_delta_clipping(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1109,12 +1108,12 @@ def test_training_delta_clipping(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_dataloader_workers(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1139,12 +1138,12 @@ def test_training_multiple_dataloader_workers(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_generation_kwargs(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1169,12 +1168,12 @@ def test_training_with_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_reward_func_accessing_trainer_state(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1246,7 +1245,7 @@ def test_prepare_input_called_with_correct_data(self): with patch.object(GRPOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare: trainer.train() # 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation - self.assertEqual(mock_prepare.call_count, 48) + assert mock_prepare.call_count == 48 for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations) assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch for i in range(8, 16): @@ -1289,7 +1288,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1305,7 +1304,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_beta_non_zero(self): @@ -1335,7 +1334,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1345,7 +1344,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_peft @@ -1380,15 +1379,15 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_importance_sampling(self): @@ -1418,7 +1417,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1428,7 +1427,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_liger_kernel @@ -1460,7 +1459,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1470,7 +1469,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_prompt_truncation(self): @@ -1501,7 +1500,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1511,7 +1510,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand( [ @@ -1521,7 +1520,7 @@ def reward_func(completions, **kwargs): ) @require_vision @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") @@ -1551,11 +1550,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_multi_image(self): @@ -1588,14 +1587,14 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_sequence_importance_sampling(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1621,12 +1620,12 @@ def test_training_sequence_importance_sampling(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" @@ -1645,7 +1644,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = GRPOConfig(output_dir=self.tmp_dir, report_to="none") - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="must match"): GRPOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1654,8 +1653,6 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - self.assertIn("must match", str(context.exception)) - def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1685,7 +1682,7 @@ def test_correct_reward_processing_classes_list(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), len(reward_models)) + assert len(trainer.reward_processing_classes) == len(reward_models) def test_single_reward_model_with_single_processing_class(self): """Test that single reward model with single processing class works.""" @@ -1709,13 +1706,13 @@ def test_single_reward_model_with_single_processing_class(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), 1) - self.assertEqual(trainer.reward_processing_classes[0], single_processing_class) + assert len(trainer.reward_processing_classes) == 1 + assert trainer.reward_processing_classes[0] == single_processing_class @pytest.mark.low_priority -class TestReplayBuffer(unittest.TestCase): - def setUp(self): +class TestReplayBuffer: + def setup_method(self): self.replay_buffer = ReplayBuffer(max_size=5) def test_add(self): @@ -1731,12 +1728,12 @@ def test_add(self): self.replay_buffer.add(scores, data) # Check if the buffer contains the correct number of elements - self.assertEqual(len(self.replay_buffer.heap), 5) + assert len(self.replay_buffer.heap) == 5 # Check if the buffer maintains the min-heap property heap_scores = [item[0] for item in self.replay_buffer.heap] - self.assertEqual(heap_scores[0], min(heap_scores)) - self.assertEqual(heap_scores[0], 0.3) + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.3 def test_add_more_than_maxlen(self): # Add elements to the replay buffer @@ -1753,12 +1750,12 @@ def test_add_more_than_maxlen(self): self.replay_buffer.add(scores, data) # Check if the buffer contains the correct number of elements - self.assertEqual(len(self.replay_buffer.heap), 5) + assert len(self.replay_buffer.heap) == 5 # Check if the buffer maintains the min-heap property heap_scores = [item[0] for item in self.replay_buffer.heap] - self.assertEqual(heap_scores[0], min(heap_scores)) - self.assertEqual(heap_scores[0], 0.5) # 0.3 and 0.4 should be removed + assert heap_scores[0] == min(heap_scores) + assert heap_scores[0] == 0.5 # 0.3 and 0.4 should be removed def test_sample(self): # Add elements to the replay buffer @@ -1776,14 +1773,14 @@ def test_sample(self): sampled = self.replay_buffer.sample(num_samples=3) # Check if the sampled elements are from the buffer - self.assertEqual(len(sampled), 3) + assert len(sampled) == 3 for item in sampled: - self.assertIn(item, [entry[1] for entry in self.replay_buffer.heap]) + assert item in [entry[1] for entry in self.replay_buffer.heap] @pytest.mark.low_priority -class TestUpdateWithReplayBuffer(unittest.TestCase): - def setUp(self): +class TestUpdateWithReplayBuffer: + def setup_method(self): config = GRPOWithReplayBufferConfig( replay_buffer_size=5, ) @@ -1841,12 +1838,12 @@ def test_update_with_replay_buffer_no_variance(self): outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertIsNotNone(outputs) - self.assertIn("pixel_values", outputs) - self.assertIn("old_per_token_logps", outputs) - self.assertEqual(len(self.trainer.replay_buffer.heap), 2) + assert outputs is not None + assert "pixel_values" in outputs + assert "old_per_token_logps" in outputs + assert len(self.trainer.replay_buffer.heap) == 2 for pid in outputs["prompt_ids"]: - self.assertNotIn(pid.tolist(), original_prompt_ids.tolist()) + assert pid.tolist() not in original_prompt_ids.tolist() def test_update_with_replay_buffer_with_variance(self): self._prepopulate_buffer() @@ -1855,8 +1852,8 @@ def test_update_with_replay_buffer_with_variance(self): sampled = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertEqual(len(self.trainer.replay_buffer.heap), 4) # grew - self.assertIsNone(sampled) + assert len(self.trainer.replay_buffer.heap) == 4 # grew + assert sampled is None def test_update_with_mixed_variance(self): self._prepopulate_buffer() @@ -1866,16 +1863,16 @@ def test_update_with_mixed_variance(self): outputs = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) - self.assertEqual(len(self.trainer.replay_buffer.heap), 3) # grew by 1 + assert len(self.trainer.replay_buffer.heap) == 3 # grew by 1 output_prompt_ids = outputs["prompt_ids"].view(-1, self.trainer.num_generations, 2).tolist() buffer_ids = [item[1]["prompt_ids"].tolist() for item in self.trainer.replay_buffer.heap] found_from_buffer = any(pid in buffer_ids for pid in output_prompt_ids) found_from_original = any(pid in original_prompt_ids for pid in output_prompt_ids) - self.assertTrue(found_from_buffer) - self.assertTrue(found_from_original) - self.assertNotIn([[1, 2], [3, 4]], output_prompt_ids) # excluded no-variance group + assert found_from_buffer + assert found_from_original + assert [[1, 2], [3, 4]] not in output_prompt_ids # excluded no-variance group def test_update_with_inputs_different_seq_len(self): """ @@ -1910,8 +1907,8 @@ def test_update_with_inputs_different_seq_len(self): outputs_after_sampling = self.trainer.update_with_replay_buffer(**inputs, num_items_in_batch=4) # Seq length of current batch should be preserved - self.assertEqual(outputs_after_sampling["prompt_ids"].shape[-1], 3) - self.assertEqual(len(self.trainer.replay_buffer.heap), 3) + assert outputs_after_sampling["prompt_ids"].shape[-1] == 3 + assert len(self.trainer.replay_buffer.heap) == 3 output_prompt_ids = outputs_after_sampling["prompt_ids"].view(-1, self.trainer.num_generations, 3).tolist() buffered_prompt_completion_ids = [ @@ -1921,24 +1918,20 @@ def test_update_with_inputs_different_seq_len(self): buffered_prompt_ids, buffered_completion_ids = zip(*buffered_prompt_completion_ids) # Check for new entry with seq len 3 in buffer - self.assertIn([[3, 4, 5], [3, 4, 5]], buffered_prompt_ids) # excluded no-variance group - self.assertIn( - [[1013, 1014, pad_token_id], [1015, 1016, 1017]], buffered_completion_ids - ) # excluded no-variance group + assert [[3, 4, 5], [3, 4, 5]] in buffered_prompt_ids # excluded no-variance group + assert [ + [1013, 1014, pad_token_id], + [1015, 1016, 1017], + ] in buffered_completion_ids # excluded no-variance group # Check that sampled outputs contain one group with prompt_ids starting with a pad token - self.assertTrue( - [ - [pad_token_id, 101, 102], - [pad_token_id, 102, 103], - ] - in output_prompt_ids - or [ - [pad_token_id, 104, 105], - [pad_token_id, 106, 107], - ] - in output_prompt_ids - ) + assert [ + [pad_token_id, 101, 102], + [pad_token_id, 102, 103], + ] in output_prompt_ids or [ + [pad_token_id, 104, 105], + [pad_token_id, 106, 107], + ] in output_prompt_ids @pytest.mark.low_priority @@ -1973,15 +1966,15 @@ def custom_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." -class GSPOTokenTrainerTester(TrlTestCase): +class TestGSPOTokenTrainer(TrlTestCase): def test_training(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -2006,13 +1999,9 @@ def test_training(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") - - -if __name__ == "__main__": - unittest.main() + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." diff --git a/tests/test_judges.py b/tests/test_judges.py index cce8f961a5f..bba1ffffbbb 100644 --- a/tests/test_judges.py +++ b/tests/test_judges.py @@ -13,7 +13,8 @@ # limitations under the License. import time -import unittest + +import pytest from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge @@ -35,17 +36,17 @@ def test_all_true_judge(self): judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) prompts, completions = self._get_prompts_and_single_completions() judgements = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(judgements), 2) - self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) + assert len(judgements) == 2 + assert all(judgement in {0, 1, -1} for judgement in judgements) - @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") + @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.") def test_hugging_face_judge(self): judge = HfPairwiseJudge() prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(ranks), 2) - self.assertTrue(all(isinstance(rank, int) for rank in ranks)) - self.assertEqual(ranks, [0, 1]) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] def load_pair_rm_judge(self): # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. @@ -62,15 +63,15 @@ def test_pair_rm_judge(self): judge = self.load_pair_rm_judge() prompts, completions = self._get_prompts_and_pairwise_completions() ranks = judge.judge(prompts=prompts, completions=completions) - self.assertEqual(len(ranks), 2) - self.assertTrue(all(isinstance(rank, int) for rank in ranks)) - self.assertEqual(ranks, [0, 1]) + assert len(ranks) == 2 + assert all(isinstance(rank, int) for rank in ranks) + assert ranks == [0, 1] @require_llm_blender def test_pair_rm_judge_return_scores(self): judge = self.load_pair_rm_judge() prompts, completions = self._get_prompts_and_pairwise_completions() probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) - self.assertEqual(len(probs), 2) - self.assertTrue(all(isinstance(prob, float) for prob in probs)) - self.assertTrue(all(0 <= prob <= 1 for prob in probs)) + assert len(probs) == 2 + assert all(isinstance(prob, float) for prob in probs) + assert all(0 <= prob <= 1 for prob in probs) diff --git a/tests/test_kto_trainer.py b/tests/test_kto_trainer.py index fa17881544e..e2c325149f2 100644 --- a/tests/test_kto_trainer.py +++ b/tests/test_kto_trainer.py @@ -13,21 +13,21 @@ # limitations under the License. +import pytest import torch from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_liger_kernel, require_peft +from transformers.testing_utils import require_liger_kernel from trl import KTOConfig, KTOTrainer from trl.trainer.kto_trainer import _get_kl_dataset, _process_tokens, _tokenize -from .testing_utils import TrlTestCase, require_no_wandb +from .testing_utils import TrlTestCase, require_no_wandb, require_peft -class KTOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestKTOTrainer(TrlTestCase): + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -91,13 +91,13 @@ def test_kto_trainer(self, name, config_name, loss_type, pre_compute, eval_datas trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_kto_trainer_with_ref_model_is_model(self): training_args = KTOConfig( @@ -109,7 +109,7 @@ def test_kto_trainer_with_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): KTOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -149,13 +149,13 @@ def test_tokenize_and_process_tokens(self): batched=True, batch_size=2, ) - self.assertListEqual(tokenized_dataset["prompt"][:], train_dataset["prompt"][:]) - self.assertListEqual(tokenized_dataset["completion"][:], train_dataset["completion"][:]) - self.assertListEqual(tokenized_dataset["label"][:], train_dataset["label"][:]) - self.assertListEqual(tokenized_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(tokenized_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(tokenized_dataset["answer_input_ids"][0], [27261, 13]) - self.assertListEqual(tokenized_dataset["answer_attention_mask"][0], [1, 1]) + assert tokenized_dataset["prompt"][:] == train_dataset["prompt"][:] + assert tokenized_dataset["completion"][:] == train_dataset["completion"][:] + assert tokenized_dataset["label"][:] == train_dataset["label"][:] + assert tokenized_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert tokenized_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert tokenized_dataset["answer_input_ids"][0] == [27261, 13] + assert tokenized_dataset["answer_attention_mask"][0] == [1, 1] # Test corruption of (prompt, completion) pairs for KL dataset for batch_size in [2, 3]: @@ -166,18 +166,11 @@ def test_tokenize_and_process_tokens(self): # the last batch remains unaltered. This is a rare scenario that does not impact the training # process, so we exclude it from testing by iterating only up to len - 1. for i in range(len(tokenized_kl_dataset["answer_input_ids"]) - 1): - self.assertListEqual( - tokenized_dataset["prompt_input_ids"][i], - tokenized_kl_dataset["prompt_input_ids"][i], - ) - self.assertListEqual( - tokenized_dataset["prompt_attention_mask"][i], - tokenized_kl_dataset["prompt_attention_mask"][i], - ) - self.assertNotEqual( - tokenized_dataset["answer_input_ids"][i], - tokenized_kl_dataset["answer_input_ids"][i], + assert tokenized_dataset["prompt_input_ids"][i] == tokenized_kl_dataset["prompt_input_ids"][i] + assert ( + tokenized_dataset["prompt_attention_mask"][i] == tokenized_kl_dataset["prompt_attention_mask"][i] ) + assert tokenized_dataset["answer_input_ids"][i] != tokenized_kl_dataset["answer_input_ids"][i] fn_kwargs = { "prefix": "", @@ -189,14 +182,14 @@ def test_tokenize_and_process_tokens(self): "max_prompt_length": trainer.max_prompt_length, } processed_dataset = tokenized_dataset.map(_process_tokens, fn_kwargs=fn_kwargs, num_proc=2) - self.assertListEqual(processed_dataset["prompt"][:], train_dataset["prompt"][:]) - self.assertListEqual(processed_dataset["completion"][:], train_dataset["completion"][:]) - self.assertListEqual(processed_dataset["label"][:], train_dataset["label"][:]) - self.assertListEqual(processed_dataset["prompt_input_ids"][0], [46518, 374, 2664, 1091]) - self.assertListEqual(processed_dataset["prompt_attention_mask"][0], [1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_input_ids"][0], [46518, 374, 2664, 1091, 27261, 13, 151645]) - self.assertListEqual(processed_dataset["completion_attention_mask"][0], [1, 1, 1, 1, 1, 1, 1]) - self.assertListEqual(processed_dataset["completion_labels"][0], [-100, -100, -100, -100, 27261, 13, 151645]) + assert processed_dataset["prompt"][:] == train_dataset["prompt"][:] + assert processed_dataset["completion"][:] == train_dataset["completion"][:] + assert processed_dataset["label"][:] == train_dataset["label"][:] + assert processed_dataset["prompt_input_ids"][0] == [46518, 374, 2664, 1091] + assert processed_dataset["prompt_attention_mask"][0] == [1, 1, 1, 1] + assert processed_dataset["completion_input_ids"][0] == [46518, 374, 2664, 1091, 27261, 13, 151645] + assert processed_dataset["completion_attention_mask"][0] == [1, 1, 1, 1, 1, 1, 1] + assert processed_dataset["completion_labels"][0] == [-100, -100, -100, -100, 27261, 13, 151645] def test_kto_trainer_without_providing_ref_model(self): training_args = KTOConfig( @@ -226,13 +219,13 @@ def test_kto_trainer_without_providing_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_peft def test_kto_trainer_without_providing_ref_model_with_lora(self): @@ -274,14 +267,14 @@ def test_kto_trainer_without_providing_ref_model_with_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @require_no_wandb def test_kto_trainer_generate_during_eval_no_wandb(self): @@ -300,9 +293,9 @@ def test_kto_trainer_generate_during_eval_no_wandb(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_unpaired_preference") - with self.assertRaisesRegex( + with pytest.raises( ValueError, - expected_regex="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." + match="`generate_during_eval=True` requires Weights and Biases or Comet to be installed." " Please install `wandb` or `comet-ml` to resolve.", ): KTOTrainer( @@ -365,7 +358,7 @@ def test_kto_lora_save(self): try: AutoModelForCausalLM.from_pretrained(self.tmp_dir) except OSError: - self.fail("Loading the saved peft adapter failed") + pytest.fail("Loading the saved peft adapter failed") @require_liger_kernel def test_kto_trainer_with_liger(self): @@ -389,14 +382,14 @@ def test_kto_trainer_with_liger(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) # check the params have changed - ignore 0 biases if param.sum() != 0: - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -432,4 +425,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_modeling_geometric_mixture_wrapper.py b/tests/test_modeling_geometric_mixture_wrapper.py index ae6f5010821..7dcd89f757e 100644 --- a/tests/test_modeling_geometric_mixture_wrapper.py +++ b/tests/test_modeling_geometric_mixture_wrapper.py @@ -22,8 +22,7 @@ class TestGeometricMixtureWrapper(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model = AutoModelForCausalLM.from_pretrained(model_id).to(self.device) @@ -40,9 +39,9 @@ def test_forward(self): output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) - self.assertIsNotNone(output) - self.assertTrue(hasattr(output, "logits")) - self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) + assert output is not None + assert hasattr(output, "logits") + assert output.logits.shape == (1, 5, self.model.config.vocab_size) def test_mixture_coefficient(self): input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) @@ -57,7 +56,7 @@ def test_mixture_coefficient(self): self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 ) - self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) + assert torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5) def test_prepare_inputs_for_generation(self): input_ids = torch.tensor([[1, 2, 3, 4, 5]], device=self.device) @@ -65,6 +64,6 @@ def test_prepare_inputs_for_generation(self): inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) - self.assertIn("input_ids", inputs) - self.assertIn("attention_mask", inputs) - self.assertFalse(inputs.get("use_cache", False)) + assert "input_ids" in inputs + assert "attention_mask" in inputs + assert not inputs.get("use_cache", False) diff --git a/tests/test_modeling_value_head.py b/tests/test_modeling_value_head.py index b0a75211175..a2fde5a12c4 100644 --- a/tests/test_modeling_value_head.py +++ b/tests/test_modeling_value_head.py @@ -13,8 +13,8 @@ # limitations under the License. import gc -import unittest +import pytest import torch from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, GenerationConfig @@ -55,8 +55,7 @@ class VHeadModelTester(TrlTestCase): trl_model_class = None transformers_model_class = None - def setUp(self): - super().setUp() + def setup_method(self): self.device = "cuda" if torch.cuda.is_available() else "cpu" def test_value_head(self): @@ -65,7 +64,7 @@ def test_value_head(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertTrue(hasattr(model, "v_head")) + assert hasattr(model, "v_head") def test_value_head_shape(self): r""" @@ -73,7 +72,7 @@ def test_value_head_shape(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertEqual(model.v_head.summary.weight.shape[0], 1) + assert model.v_head.summary.weight.shape[0] == 1 def test_value_head_init_random(self): r""" @@ -82,9 +81,7 @@ def test_value_head_init_random(self): """ for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) - self.assertFalse( - torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) - ) + assert not torch.allclose(model.v_head.summary.bias, torch.zeros_like(model.v_head.summary.bias)) def test_value_head_not_str(self): r""" @@ -94,7 +91,7 @@ def test_value_head_not_str(self): for model_name in self.all_model_names: pretrained_model = self.transformers_model_class.from_pretrained(model_name) model = self.trl_model_class.from_pretrained(pretrained_model) - self.assertTrue(hasattr(model, "v_head")) + assert hasattr(model, "v_head") def test_from_save_trl(self): """ @@ -110,7 +107,7 @@ def test_from_save_trl(self): # Check if the weights are the same for key in model_from_save.state_dict(): - self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]) def test_from_save_trl_sharded(self): """ @@ -125,7 +122,7 @@ def test_from_save_trl_sharded(self): # Check if the weights are the same for key in model_from_save.state_dict(): - self.assertTrue(torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key])) + assert torch.allclose(model_from_save.state_dict()[key], model.state_dict()[key]) def test_from_save_transformers_sharded(self): """ @@ -143,10 +140,8 @@ def test_from_save_transformers_sharded(self): # Check if the weights are the same for key in transformers_model.state_dict(): - self.assertTrue( - torch.allclose( - transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] - ) + assert torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] ) def test_from_save_transformers(self): @@ -166,30 +161,25 @@ def test_from_save_transformers(self): # Check if the weights are the same for key in transformers_model.state_dict(): - self.assertTrue( - torch.allclose( - transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] - ) + assert torch.allclose( + transformers_model_from_save.state_dict()[key], transformers_model.state_dict()[key] ) # Check if the trl model has the same keys as the transformers model # except the v_head for key in trl_model.state_dict(): if "v_head" not in key: - self.assertIn(key, transformers_model.state_dict()) + assert key in transformers_model.state_dict() # check if the weights are the same - self.assertTrue( - torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) - ) + assert torch.allclose(trl_model.state_dict()[key], transformers_model.state_dict()[key]) # check if they have the same modules - self.assertEqual( - set(transformers_model_from_save.state_dict().keys()), - set(transformers_model.state_dict().keys()), + assert set(transformers_model_from_save.state_dict().keys()) == set( + transformers_model.state_dict().keys() ) -class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): +class TestCausalLMValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): """ Testing suite for v-head models. """ @@ -198,10 +188,9 @@ class CausalLMValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): trl_model_class = AutoModelForCausalLMWithValueHead transformers_model_class = AutoModelForCausalLM - def tearDown(self): + def teardown_method(self): # free memory gc.collect() - super().tearDown() def test_inference(self): r""" @@ -217,7 +206,7 @@ def test_inference(self): # Check if the outputs are of the right size - here # we always output 3 values - logits, loss, and value states - self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + assert len(outputs) == EXPECTED_OUTPUT_SIZE def test_dropout_config(self): r""" @@ -229,7 +218,7 @@ def test_dropout_config(self): model = self.trl_model_class.from_pretrained(pretrained_model) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob def test_dropout_kwargs(self): r""" @@ -241,12 +230,12 @@ def test_dropout_kwargs(self): model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 @parameterized.expand(ALL_CAUSAL_LM_MODELS) def test_generate(self, model_name): @@ -271,21 +260,20 @@ def test_transformers_bf16_kwargs(self): lm_head_namings = ["lm_head", "embed_out", "output_layer"] - self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), - "Can't test the model because it doesn't have any of the expected lm_head namings", + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings), ( + "Can't test the model because it doesn't have any of the expected lm_head namings" ) for lm_head_naming in lm_head_namings: if hasattr(trl_model.pretrained_model, lm_head_naming): - self.assertEqual(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype, torch.bfloat16) + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) # check dummy forward pass works in half precision _ = trl_model(dummy_input) - @unittest.skip("This test needs to be run manually due to HF token issue.") + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") def test_push_to_hub(self): for model_name in self.all_model_names: model = AutoModelForCausalLMWithValueHead.from_pretrained(model_name) @@ -296,16 +284,15 @@ def test_push_to_hub(self): model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(model_name + "-ppo") # check all keys - self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - self.assertTrue( - torch.allclose(param, model_from_pretrained.state_dict()[name]), - f"Parameter {name} is not the same after push_to_hub and from_pretrained", + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), ( + f"Parameter {name} is not the same after push_to_hub and from_pretrained" ) -class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): +class TestSeq2SeqValueHeadModel(BaseTester.VHeadModelTester, TrlTestCase): """ Testing suite for v-head models. """ @@ -314,10 +301,9 @@ class Seq2SeqValueHeadModelTester(BaseTester.VHeadModelTester, TrlTestCase): trl_model_class = AutoModelForSeq2SeqLMWithValueHead transformers_model_class = AutoModelForSeq2SeqLM - def tearDown(self): + def teardown_method(self): # free memory gc.collect() - super().tearDown() def test_inference(self): r""" @@ -334,7 +320,7 @@ def test_inference(self): # Check if the outputs are of the right size - here # we always output 3 values - logits, loss, and value states - self.assertEqual(len(outputs), EXPECTED_OUTPUT_SIZE) + assert len(outputs) == EXPECTED_OUTPUT_SIZE def test_dropout_config(self): r""" @@ -346,7 +332,7 @@ def test_dropout_config(self): model = self.trl_model_class.from_pretrained(pretrained_model) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, pretrained_model.config.summary_dropout_prob) + assert model.v_head.dropout.p == pretrained_model.config.summary_dropout_prob def test_dropout_kwargs(self): r""" @@ -358,12 +344,12 @@ def test_dropout_kwargs(self): model = self.trl_model_class.from_pretrained(model_name, **v_head_kwargs) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 model = self.trl_model_class.from_pretrained(model_name, summary_dropout_prob=0.5) # Check if v head of the model has the same dropout as the config - self.assertEqual(model.v_head.dropout.p, 0.5) + assert model.v_head.dropout.p == 0.5 @parameterized.expand(ALL_SEQ2SEQ_MODELS) def test_generate(self, model_name): @@ -378,7 +364,7 @@ def test_generate(self, model_name): # Just check if the generation works _ = model.generate(input_ids, decoder_input_ids=decoder_input_ids, generation_config=generation_config) - @unittest.skip("This test needs to be run manually due to HF token issue.") + @pytest.mark.skip(reason="This test needs to be run manually due to HF token issue.") def test_push_to_hub(self): for model_name in self.all_model_names: model = self.trl_model_class.from_pretrained(model_name) @@ -389,12 +375,11 @@ def test_push_to_hub(self): model_from_pretrained = self.trl_model_class.from_pretrained(model_name + "-ppo") # check all keys - self.assertEqual(model.state_dict().keys(), model_from_pretrained.state_dict().keys()) + assert model.state_dict().keys() == model_from_pretrained.state_dict().keys() for name, param in model.state_dict().items(): - self.assertTrue( - torch.allclose(param, model_from_pretrained.state_dict()[name]), - f"Parameter {name} is not the same after push_to_hub and from_pretrained", + assert torch.allclose(param, model_from_pretrained.state_dict()[name]), ( + f"Parameter {name} is not the same after push_to_hub and from_pretrained" ) def test_transformers_bf16_kwargs(self): @@ -408,13 +393,11 @@ def test_transformers_bf16_kwargs(self): lm_head_namings = self.trl_model_class.lm_head_namings - self.assertTrue( - any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) - ) + assert any(hasattr(trl_model.pretrained_model, lm_head_naming) for lm_head_naming in lm_head_namings) for lm_head_naming in lm_head_namings: if hasattr(trl_model.pretrained_model, lm_head_naming): - self.assertTrue(getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16) + assert getattr(trl_model.pretrained_model, lm_head_naming).weight.dtype == torch.bfloat16 dummy_input = torch.LongTensor([[0, 1, 0, 1]]).to(self.device) @@ -422,9 +405,8 @@ def test_transformers_bf16_kwargs(self): _ = trl_model(input_ids=dummy_input, decoder_input_ids=dummy_input) -class ReferenceModelTest(TrlTestCase): - def setUp(self): - super().setUp() +class TestReferenceModel(TrlTestCase): + def setup_method(self): self.model = AutoModelForCausalLMWithValueHead.from_pretrained("trl-internal-testing/tiny-GPT2LMHeadModel") self.test_input = torch.tensor([[0, 1, 2, 3]]) self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=1) @@ -453,16 +435,16 @@ def test_independent_reference(self): last_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() # before optimization ref and model are identical - self.assertTrue((first_layer_before == first_ref_layer_before).all()) - self.assertTrue((last_layer_before == last_ref_layer_before).all()) + assert (first_layer_before == first_ref_layer_before).all() + assert (last_layer_before == last_ref_layer_before).all() # ref model stays identical after optimization - self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) - self.assertTrue((last_ref_layer_before == last_ref_layer_after).all()) + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (last_ref_layer_before == last_ref_layer_after).all() # optimized model changes - self.assertFalse((first_layer_before == first_layer_after).all()) - self.assertFalse((last_layer_before == last_layer_after).all()) + assert not (first_layer_before == first_layer_after).all() + assert not (last_layer_before == last_layer_after).all() def test_shared_layers(self): layer_0 = self.layer_format.format(layer=0) @@ -487,15 +469,15 @@ def test_shared_layers(self): second_ref_layer_after = ref_model.get_parameter(layer_1).data.clone() # before optimization ref and model are identical - self.assertTrue((first_layer_before == first_ref_layer_before).all()) - self.assertTrue((second_layer_before == second_ref_layer_before).all()) + assert (first_layer_before == first_ref_layer_before).all() + assert (second_layer_before == second_ref_layer_before).all() # ref model stays identical after optimization - self.assertTrue((first_ref_layer_before == first_ref_layer_after).all()) - self.assertTrue((second_ref_layer_before == second_ref_layer_after).all()) + assert (first_ref_layer_before == first_ref_layer_after).all() + assert (second_ref_layer_before == second_ref_layer_after).all() # first layer of optimized model stays the same - self.assertTrue((first_layer_before == first_layer_after).all()) + assert (first_layer_before == first_layer_after).all() # other layers in optimized model change - self.assertFalse((second_layer_before == second_layer_after).all()) + assert not (second_layer_before == second_layer_after).all() diff --git a/tests/test_nash_md_trainer.py b/tests/test_nash_md_trainer.py index 4550c35e1d1..d6026e73443 100644 --- a/tests/test_nash_md_trainer.py +++ b/tests/test_nash_md_trainer.py @@ -16,12 +16,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import NashMDConfig, NashMDTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender +from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft if is_peft_available(): @@ -29,8 +28,7 @@ class TestNashMDTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -65,7 +63,7 @@ def test_nash_md_trainer_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft(self): @@ -93,7 +91,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -122,7 +120,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -153,7 +151,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): @@ -184,7 +182,7 @@ def test_training_pre_pefted_model_implicit_ref_with_reward_model(self): trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_llm_blender @@ -215,4 +213,4 @@ def test_nash_md_trainer_judge_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] diff --git a/tests/test_online_dpo_trainer.py b/tests/test_online_dpo_trainer.py index b9dec1135e6..f8706770371 100644 --- a/tests/test_online_dpo_trainer.py +++ b/tests/test_online_dpo_trainer.py @@ -18,12 +18,19 @@ from packaging.version import Version from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft, require_torch_accelerator, require_vision +from transformers.testing_utils import require_torch_accelerator from transformers.utils import is_peft_available, is_vision_available from trl import OnlineDPOConfig, OnlineDPOTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_vllm +from .testing_utils import ( + RandomPairwiseJudge, + TrlTestCase, + require_llm_blender, + require_peft, + require_vision, + require_vllm, +) if is_peft_available(): @@ -36,8 +43,7 @@ class TestOnlineDPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -73,7 +79,7 @@ def test_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_training_model_str(self): training_args = OnlineDPOConfig( @@ -98,7 +104,7 @@ def test_training_model_str(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_training_with_ref_model(self): training_args = OnlineDPOConfig( @@ -124,7 +130,7 @@ def test_training_with_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_ref_model_is_model(self): training_args = OnlineDPOConfig( @@ -136,7 +142,7 @@ def test_ref_model_is_model(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") - with self.assertRaises(ValueError): + with pytest.raises(ValueError): OnlineDPOTrainer( model=self.model, ref_model=self.model, # ref_model can't be the same as model @@ -174,7 +180,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -204,7 +210,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -236,7 +242,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_llm_blender @@ -262,7 +268,7 @@ def test_training_with_judge(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_torch_accelerator @@ -293,7 +299,7 @@ def test_training_with_vllm(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_vllm def test_training_with_vllm_colocate(self): @@ -330,57 +336,57 @@ def test_training_with_vllm_colocate(self): ) # Verify vLLM setup - self.assertTrue(trainer.use_vllm) - self.assertEqual(trainer.vllm_mode, "colocate") - self.assertIsNotNone(trainer.llm) + assert trainer.use_vllm + assert trainer.vllm_mode == "colocate" + assert trainer.llm is not None # self.assertIsNone(trainer.vllm_client) # self.assertEqual(trainer.vllm_gpu_memory_utilization, 0.2) # Verify generation parameters - self.assertEqual(trainer.temperature, 0.9) - self.assertEqual(trainer.top_p, 0.95) - self.assertEqual(trainer.top_k, 50) - self.assertEqual(trainer.repetition_penalty, 1.1) + assert trainer.temperature == 0.9 + assert trainer.top_p == 0.95 + assert trainer.top_k == 50 + assert trainer.repetition_penalty == 1.1 # Verify generation config - self.assertIsNotNone(trainer.generation_config) - self.assertEqual(trainer.generation_config.temperature, 0.9) - self.assertEqual(trainer.generation_config.top_p, 0.95) - self.assertEqual(trainer.generation_config.top_k, 50) - self.assertEqual(trainer.generation_config.repetition_penalty, 1.1) - self.assertEqual(trainer.generation_config.max_tokens, 32) + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.9 + assert trainer.generation_config.top_p == 0.95 + assert trainer.generation_config.top_k == 50 + assert trainer.generation_config.repetition_penalty == 1.1 + assert trainer.generation_config.max_tokens == 32 trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] def test_vllm_config_validation(self): """Test vLLM configuration validation""" # Test valid vllm_mode values config = OnlineDPOConfig(use_vllm=True, vllm_mode="server") - self.assertEqual(config.vllm_mode, "server") + assert config.vllm_mode == "server" config = OnlineDPOConfig(use_vllm=True, vllm_mode="colocate") - self.assertEqual(config.vllm_mode, "colocate") + assert config.vllm_mode == "colocate" # Test default values config = OnlineDPOConfig() - self.assertEqual(config.vllm_mode, "server") - self.assertIsNone(config.vllm_server_base_url) - self.assertEqual(config.vllm_server_host, "0.0.0.0") - self.assertEqual(config.vllm_server_port, 8000) - self.assertEqual(config.vllm_server_timeout, 240.0) - self.assertEqual(config.vllm_gpu_memory_utilization, 0.55) + assert config.vllm_mode == "server" + assert config.vllm_server_base_url is None + assert config.vllm_server_host == "0.0.0.0" + assert config.vllm_server_port == 8000 + assert config.vllm_server_timeout == 240.0 + assert config.vllm_gpu_memory_utilization == 0.55 # Test generation parameters - self.assertEqual(config.top_p, 1.0) - self.assertIsNone(config.top_k) - self.assertIsNone(config.min_p) - self.assertEqual(config.repetition_penalty, 1.0) - self.assertFalse(config.use_transformers_paged) - self.assertIsNone(config.cache_implementation) - self.assertIsNone(config.generation_kwargs) + assert config.top_p == 1.0 + assert config.top_k is None + assert config.min_p is None + assert config.repetition_penalty == 1.0 + assert not config.use_transformers_paged + assert config.cache_implementation is None + assert config.generation_kwargs is None def test_generation_config_setup(self): """Test that generation configuration is properly set up for both vLLM and transformers""" @@ -407,17 +413,17 @@ def test_generation_config_setup(self): ) # Verify transformers generation config - self.assertFalse(trainer.use_vllm) + assert not trainer.use_vllm # When not using vLLM, these attributes should not be set - self.assertFalse(hasattr(trainer, "llm") and trainer.llm is not None) - self.assertFalse(hasattr(trainer, "vllm_client") and trainer.vllm_client is not None) - self.assertIsNotNone(trainer.generation_config) - self.assertEqual(trainer.generation_config.temperature, 0.8) - self.assertEqual(trainer.generation_config.top_p, 0.9) - self.assertEqual(trainer.generation_config.top_k, 40) - self.assertEqual(trainer.generation_config.repetition_penalty, 1.2) - self.assertEqual(trainer.generation_config.max_new_tokens, 64) - self.assertFalse(trainer.generation_config.do_sample) # From generation_kwargs + assert not (hasattr(trainer, "llm") and trainer.llm is not None) + assert not (hasattr(trainer, "vllm_client") and trainer.vllm_client is not None) + assert trainer.generation_config is not None + assert trainer.generation_config.temperature == 0.8 + assert trainer.generation_config.top_p == 0.9 + assert trainer.generation_config.top_k == 40 + assert trainer.generation_config.repetition_penalty == 1.2 + assert trainer.generation_config.max_new_tokens == 64 + assert not trainer.generation_config.do_sample # From generation_kwargs @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_torch_accelerator @@ -447,7 +453,7 @@ def test_training_with_transformers_paged(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) def test_training_with_reward_funcs(self, config_name): @@ -475,15 +481,15 @@ def simple_reward_func(prompts, completions, completion_ids, **kwargs): ) trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) - self.assertEqual(len(trainer.reward_funcs), 2) - self.assertIsNotNone(trainer.reward_weights) - self.assertAlmostEqual(trainer.reward_weights[0].item(), 0.7, places=5) - self.assertAlmostEqual(trainer.reward_weights[1].item(), 0.3, places=5) + assert "train_loss" in trainer.state.log_history[-1] + assert len(trainer.reward_funcs) == 2 + assert trainer.reward_weights is not None + assert round(abs(trainer.reward_weights[0].item() - 0.7), 5) == 0 + assert round(abs(trainer.reward_weights[1].item() - 0.3), 5) == 0 @require_vision -class OnlineDPOVisionTrainerTester(TrlTestCase): +class TestOnlineDPOVisionTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Idefics2ForConditionalGeneration",), @@ -531,4 +537,4 @@ def test_online_dpo_vlm_trainer(self, model_id): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None diff --git a/tests/test_orpo_trainer.py b/tests/test_orpo_trainer.py index 5898ac8d7dd..dedfc4c36c9 100644 --- a/tests/test_orpo_trainer.py +++ b/tests/test_orpo_trainer.py @@ -17,17 +17,15 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer -from transformers.testing_utils import require_peft from trl import ORPOConfig, ORPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft -class ORPOTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestORPOTrainer(TrlTestCase): + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.tokenizer = AutoTokenizer.from_pretrained(self.model_id) @@ -82,13 +80,13 @@ def test_orpo_trainer(self, name, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) @parameterized.expand( [ @@ -137,14 +135,14 @@ def test_orpo_trainer_with_lora(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): if "lora" in n: new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.equal(param, new_param)) + assert not torch.equal(param, new_param) def test_compute_metrics(self): model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") @@ -178,4 +176,4 @@ def dummy_compute_metrics(*args, **kwargs): trainer.train() - self.assertEqual(trainer.state.log_history[-2]["eval_test"], 0.0) + assert trainer.state.log_history[-2]["eval_test"] == 0.0 diff --git a/tests/test_peft_models.py b/tests/test_peft_models.py index 0543ee31c3c..508ad175565 100644 --- a/tests/test_peft_models.py +++ b/tests/test_peft_models.py @@ -16,15 +16,12 @@ import torch from transformers import AutoModelForCausalLM -from transformers.testing_utils import ( - require_peft, - require_torch_gpu_if_bnb_not_multi_backend_enabled, -) +from transformers.testing_utils import require_torch_gpu_if_bnb_not_multi_backend_enabled from transformers.utils import is_peft_available from trl import AutoModelForCausalLMWithValueHead -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): @@ -32,9 +29,8 @@ @require_peft -class PeftModelTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestPeftModel(TrlTestCase): + def setup_method(self): self.causal_lm_model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.lora_config = LoraConfig( r=16, @@ -63,7 +59,7 @@ def test_peft_requires_grad(self): model = AutoModelForCausalLMWithValueHead.from_pretrained(pretrained_model) # Check that the value head has requires_grad=True - self.assertTrue(model.v_head.summary.weight.requires_grad) + assert model.v_head.summary.weight.requires_grad def test_check_peft_model_nb_trainable_params(self): r""" @@ -76,12 +72,12 @@ def test_check_peft_model_nb_trainable_params(self): # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 # Check that the number of trainable param for the non-peft model is correct non_peft_model = AutoModelForCausalLMWithValueHead.from_pretrained(self.causal_lm_model_id) nb_trainable_params = sum(p.numel() for p in non_peft_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 2428641) + assert nb_trainable_params == 2428641 def test_create_peft_model_from_config(self): r""" @@ -92,13 +88,13 @@ def test_create_peft_model_from_config(self): ) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 causal_lm_model = AutoModelForCausalLM.from_pretrained(self.causal_lm_model_id) trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 @require_torch_gpu_if_bnb_not_multi_backend_enabled def test_create_bnb_peft_model_from_config(self): @@ -112,8 +108,8 @@ def test_create_bnb_peft_model_from_config(self): ) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) - self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) causal_lm_model = AutoModelForCausalLM.from_pretrained( self.causal_lm_model_id, load_in_8bit=True, device_map="auto" @@ -121,8 +117,8 @@ def test_create_bnb_peft_model_from_config(self): trl_model = AutoModelForCausalLMWithValueHead.from_pretrained(causal_lm_model, peft_config=self.lora_config) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in trl_model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) - self.assertIsInstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) + assert nb_trainable_params == 905 + assert isinstance(trl_model.pretrained_model.model.model.layers[0].mlp.gate_proj, Linear8bitLt) def test_save_pretrained_peft(self): r""" @@ -136,31 +132,27 @@ def test_save_pretrained_peft(self): model.save_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - self.assertTrue( - os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), - f"{self.tmp_dir}/adapter_model.safetensors does not exist", + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( + f"{self.tmp_dir}/adapter_model.safetensors does not exist" ) - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" ) # check also for `pytorch_model.bin` and make sure it only contains `v_head` weights - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" - ) + assert os.path.exists(f"{self.tmp_dir}/pytorch_model.bin"), f"{self.tmp_dir}/pytorch_model.bin does not exist" # check that only keys that starts with `v_head` are in the dict maybe_v_head = torch.load(f"{self.tmp_dir}/pytorch_model.bin", weights_only=True) - self.assertTrue( - all(k.startswith("v_head") for k in maybe_v_head.keys()), - f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`", + assert all(k.startswith("v_head") for k in maybe_v_head.keys()), ( + f"keys in {self.tmp_dir}/pytorch_model.bin do not start with `v_head`" ) model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) # check all the weights are the same for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): - self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" def test_load_pretrained_peft(self): r""" @@ -175,18 +167,17 @@ def test_load_pretrained_peft(self): model_from_pretrained = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir) # check that the files `adapter_model.safetensors` and `adapter_config.json` are in the directory - self.assertTrue( - os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), - f"{self.tmp_dir}/adapter_model.safetensors does not exist", + assert os.path.isfile(f"{self.tmp_dir}/adapter_model.safetensors"), ( + f"{self.tmp_dir}/adapter_model.safetensors does not exist" ) - self.assertTrue( - os.path.exists(f"{self.tmp_dir}/adapter_config.json"), f"{self.tmp_dir}/adapter_config.json does not exist" + assert os.path.exists(f"{self.tmp_dir}/adapter_config.json"), ( + f"{self.tmp_dir}/adapter_config.json does not exist" ) # check all the weights are the same for p1, p2 in zip(model.named_parameters(), model_from_pretrained.named_parameters()): if p1[0] not in ["v_head.summary.weight", "v_head.summary.bias"]: - self.assertTrue(torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}") + assert torch.allclose(p1[1], p2[1]), f"{p1[0]} != {p2[0]}" def test_continue_training_peft_model(self): r""" @@ -200,4 +191,4 @@ def test_continue_training_peft_model(self): model = AutoModelForCausalLMWithValueHead.from_pretrained(self.tmp_dir, is_trainable=True) # Check that the number of trainable parameters is correct nb_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - self.assertEqual(nb_trainable_params, 905) + assert nb_trainable_params == 905 diff --git a/tests/test_ppo_trainer.py b/tests/test_ppo_trainer.py index f8b95a8e5ff..6e62e742115 100644 --- a/tests/test_ppo_trainer.py +++ b/tests/test_ppo_trainer.py @@ -16,13 +16,12 @@ import torch from datasets import load_dataset from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import PPOConfig, PPOTrainer from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): @@ -30,8 +29,7 @@ class TestPPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the models and tokenizer using the test model self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -107,8 +105,8 @@ def test_basic_training(self): policy_weights_updated = True break - self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") - self.assertTrue(policy_weights_updated, "Policy weights were not updated during training") + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy weights were not updated during training" @require_peft def test_peft_training(self): @@ -171,5 +169,5 @@ def test_peft_training(self): policy_weights_updated = True break - self.assertTrue(critic_weights_updated, "Critic weights were not updated during training") - self.assertTrue(policy_weights_updated, "Policy LoRA weights were not updated during training") + assert critic_weights_updated, "Critic weights were not updated during training" + assert policy_weights_updated, "Policy LoRA weights were not updated during training" diff --git a/tests/test_prm_trainer.py b/tests/test_prm_trainer.py index e26428c5203..16876c6df62 100644 --- a/tests/test_prm_trainer.py +++ b/tests/test_prm_trainer.py @@ -18,12 +18,11 @@ from datasets import Dataset, load_dataset from parameterized import parameterized from transformers import AutoModelForTokenClassification, AutoTokenizer, PreTrainedTokenizerBase -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import PRMConfig, PRMTrainer -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): @@ -31,8 +30,7 @@ class TestTokenizeRow(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Set up the mock tokenizer with specific behaviors self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase) self.tokenizer.bos_token_id = 0 @@ -75,13 +73,10 @@ def test_tokenize_row_no_truncation(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - }, - ) + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + } def test_tokenize_row_train_on_last_step_only(self): # Define the input features @@ -102,13 +97,10 @@ def test_tokenize_row_train_on_last_step_only(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], - }, - ) + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100, 0], + } def test_tokenize_row_prompt_truncation(self): # Define the input features @@ -130,13 +122,10 @@ def test_tokenize_row_prompt_truncation(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], - "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], - }, - ) + assert result == { + "input_ids": [6766, 318, 298, 4, 322, 12, 1030, 4995, 11, 22, 1030], + "labels": [-100, -100, -100, -100, -100, -100, 1, -100, -100, -100, 0], + } def test_tokenize_row_completion_truncation(self): # Define the input features @@ -158,13 +147,10 @@ def test_tokenize_row_completion_truncation(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], - }, - ) + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 4995, 11], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100], + } def test_tokenize_row_prompt_completion_truncation(self): # Define the input features @@ -186,13 +172,10 @@ def test_tokenize_row_prompt_completion_truncation(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], - }, - ) + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, 1], + } def test_tokenize_row_multi_token_separator(self): # Define the input features @@ -214,18 +197,14 @@ def test_tokenize_row_multi_token_separator(self): is_eval=False, ) - self.assertEqual( - result, - { - "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], - "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], - }, - ) + assert result == { + "input_ids": [0, 465, 6766, 318, 298, 4, 322, 12, 1030, 1030, 4995, 11, 22, 1030, 1030], + "labels": [-100, -100, -100, -100, -100, -100, -100, -100, -100, 1, -100, -100, -100, -100, 0], + } -class PRMTrainerTester(TrlTestCase): - def setUp(self): - super().setUp() +class TestPRMTrainer(TrlTestCase): + def setup_method(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForTokenClassification.from_pretrained(model_id) self.tokenizer = AutoTokenizer.from_pretrained(model_id) @@ -244,12 +223,12 @@ def test_train_full(self, train_on_last_step_only): previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) def test_train_full_pretokenized(self): dummy_dataset = Dataset.from_dict( @@ -297,12 +276,12 @@ def test_train_full_pretokenized(self): previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if param.sum() != 0: # ignore 0 biases - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12)) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12) @require_peft def test_train_lora(self): @@ -337,17 +316,17 @@ def test_train_lora(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[(-1)]["train_loss"]) + assert trainer.state.log_history[(-1)]["train_loss"] is not None # Check that the parameters have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + assert not torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) # Check that the non trainable parameters have not changed for n, param in previous_non_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.allclose(param, new_param, atol=1e-12, rtol=1e-12)) + assert torch.allclose(param, new_param, atol=1e-12, rtol=1e-12) def test_tags(self): dummy_dataset = load_dataset("trl-internal-testing/zen", "standard_stepwise_supervision", split="train") @@ -355,4 +334,4 @@ def test_tags(self): trainer = PRMTrainer( model=self.model, args=training_args, processing_class=self.tokenizer, train_dataset=dummy_dataset ) - self.assertEqual(trainer.model.model_tags, trainer._tag_names) + assert trainer.model.model_tags == trainer._tag_names diff --git a/tests/test_reward_trainer.py b/tests/test_reward_trainer.py index b4d53e16941..ab6d6656e99 100644 --- a/tests/test_reward_trainer.py +++ b/tests/test_reward_trainer.py @@ -13,19 +13,18 @@ # limitations under the License. import pathlib -import unittest +import pytest import torch from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import RewardConfig, RewardTrainer from trl.trainer.reward_trainer import DataCollatorForPreference -from .testing_utils import TrlTestCase +from .testing_utils import TrlTestCase, require_peft if is_peft_available(): @@ -108,7 +107,7 @@ def test_collate_with_margin(self): torch.testing.assert_close(result["margin"], torch.tensor([0.1, 0.2])) -class RewardTrainerTester(TrlTestCase): +class TestRewardTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",), @@ -131,12 +130,12 @@ def test_train(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @parameterized.expand( [ @@ -165,12 +164,12 @@ def test_train_dataset_types(self, config_name): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model(self): # Instantiate the model @@ -192,12 +191,12 @@ def test_train_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_from_causal_lm(self): # Get the dataset @@ -216,12 +215,12 @@ def test_train_from_causal_lm(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model_dtype(self): # Get the dataset @@ -247,7 +246,7 @@ def test_train_model_dtype(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -257,8 +256,8 @@ def test_train_model_dtype(self): continue new_param = trainer.model.get_parameter(n) # Check the torch dtype - self.assertEqual(new_param.dtype, torch.float16) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config(self): @@ -287,15 +286,15 @@ def test_train_dense_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config(self): @@ -324,15 +323,15 @@ def test_train_moe_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_peft_model(self): @@ -361,15 +360,15 @@ def test_train_peft_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config_and_gradient_checkpointing(self): @@ -398,15 +397,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config_and_gradient_checkpointing(self): @@ -435,15 +434,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_with_peft_model_and_gradient_checkpointing(self): @@ -462,7 +461,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer = RewardTrainer(model=model, args=training_args, train_dataset=dataset) # Verify model is a PeftModel - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -471,15 +470,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_pretokenized_data(self): # Get the dataset @@ -507,12 +506,12 @@ def tokenize_example(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_iterable_dataset(self): # Get the dataset @@ -535,12 +534,12 @@ def test_train_with_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_chat_template_kwargs(self): # Get the dataset @@ -569,12 +568,12 @@ def test_train_with_chat_template_kwargs(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_model(self): # Get the dataset @@ -596,7 +595,7 @@ def test_train_with_set_chat_template_from_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -606,7 +605,7 @@ def test_train_with_set_chat_template_from_model(self): # this parameter. if n == "gpt_neox.final_layer_norm.bias": continue - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_path(self): # Get the dataset @@ -632,7 +631,7 @@ def test_train_with_set_chat_template_from_path(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -642,21 +641,19 @@ def test_train_with_set_chat_template_from_path(self): # this parameter. if n == "gpt_neox.final_layer_norm.bias": continue - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Check that the template saved in the output directory is the same as the one used for training template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" - self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + assert template_path.exists(), f"Chat template not found at {template_path}" with open(template_path) as f: template_content = f.read() with open(training_args.chat_template_path) as f: original_template_content = f.read() - self.assertEqual( - template_content, original_template_content, "Chat template content does not match the original" - ) + assert template_content == original_template_content, "Chat template content does not match the original" - @unittest.skip("Skipping until we have a dataset with tool calls") + @pytest.mark.skip(reason="Skipping until we have a dataset with tool calls") def test_train_toolcall_data(self): # Get the dataset dataset = load_dataset("trl-internal-testing/toolcall", split="train") @@ -676,12 +673,12 @@ def test_train_toolcall_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_eval(self): # Get the dataset @@ -700,7 +697,7 @@ def test_train_with_eval(self): trainer.train() # Check that the eval loss is not None - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + assert trainer.state.log_history[0]["eval_loss"] is not None def test_train_with_multiple_eval_dataset(self): # Get the dataset @@ -718,8 +715,8 @@ def test_train_with_multiple_eval_dataset(self): trainer.train() # Check that the eval losses are not None - self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) - self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None def test_train_with_gradient_checkpointing(self): # Get the dataset @@ -740,12 +737,12 @@ def test_train_with_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_tag_added(self): # Get the dataset @@ -758,7 +755,7 @@ def test_tag_added(self): ) for tag in ["reward-trainer", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_tag_added_peft(self): @@ -773,7 +770,7 @@ def test_tag_added_peft(self): ) for tag in ["reward-trainer", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags def test_train_with_margin(self): # Get the dataset @@ -800,12 +797,12 @@ def add_margin(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_center_rewards_coefficient(self): # Get the dataset @@ -826,9 +823,9 @@ def test_train_with_center_rewards_coefficient(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" diff --git a/tests/test_rewards.py b/tests/test_rewards.py index 8b20a0ff7e9..0764ce5d9ea 100644 --- a/tests/test_rewards.py +++ b/tests/test_rewards.py @@ -12,14 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from trl.rewards import get_soft_overlong_punishment, think_format_reward from .testing_utils import TrlTestCase -class ThinkFormatRewardTester(TrlTestCase): +class TestThinkFormatReward(TrlTestCase): def test_valid_format(self): completions = [ "This is my reasoning.This is my answer.", # Simple, one-line reasoning @@ -31,7 +30,7 @@ def test_valid_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards def test_invalid_format(self): completions = [ @@ -48,7 +47,7 @@ def test_invalid_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards def test_mixed_format(self): completions = [ @@ -60,17 +59,17 @@ def test_mixed_format(self): completions = [[{"content": completion}] for completion in completions] expected_rewards = [1.0, 1.0, 0.0, 0.0] rewards = think_format_reward(completions) - self.assertEqual(rewards, expected_rewards) + assert rewards == expected_rewards -class SoftOverlongPunishmentRewardTester(unittest.TestCase): +class TestSoftOverlongPunishmentReward: def test_soft_overlong_punishment_short_completion(self): """Test soft overlong punishment reward function with a short completion.""" # length 50, with max=100 and soft cache=20, reward should be 0. reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 50] # 50 <= 80 rewards = reward_fn(completion_ids=completion_ids) - self.assertEqual(rewards, [0]) + assert rewards == [0] def test_soft_overlong_punishment_long_completion(self): """Test soft overlong punishment reward function with a longer than max completion.""" @@ -78,15 +77,11 @@ def test_soft_overlong_punishment_long_completion(self): reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 110] rewards = reward_fn(completion_ids) - self.assertEqual(rewards, [-1]) + assert rewards == [-1] def test_soft_overlong_punishment_intermediate_completion(self): """Test soft overlong punishment reward function for intermediate length completion.""" reward_fn = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) completion_ids = [[1] * 90] # 90 is between 80 and 100 rewards = reward_fn(completion_ids) - self.assertAlmostEqual(rewards[0], -0.5, places=4) - - -if __name__ == "__main__": - unittest.main() + assert round(abs(rewards[0] - -0.5), 4) == 0 diff --git a/tests/test_rich_progress_callback.py b/tests/test_rich_progress_callback.py index d9069481263..d246b694b72 100644 --- a/tests/test_rich_progress_callback.py +++ b/tests/test_rich_progress_callback.py @@ -34,8 +34,7 @@ def forward(self, x): @require_rich class TestRichProgressCallback(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.dummy_model = DummyModel() self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5) self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index a2b1d3bf8b7..1de4eca479e 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest from unittest.mock import patch +import pytest import torch from datasets import load_dataset from parameterized import parameterized @@ -24,19 +24,18 @@ AutoModelForSequenceClassification, AutoTokenizer, ) -from transformers.testing_utils import require_peft, require_vision from transformers.utils import is_peft_available from trl import RLOOConfig, RLOOTrainer -from .testing_utils import TrlTestCase, require_vllm +from .testing_utils import TrlTestCase, require_peft, require_vision, require_vllm if is_peft_available(): from peft import LoraConfig, PeftModel -class RLOOTrainerTester(TrlTestCase): +class TestRLOOTrainer(TrlTestCase): def test_init_minimal(self): # Test that RLOOTrainer can be instantiated with only model, reward_model and train_dataset dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -69,12 +68,12 @@ def test_training(self, config_name): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_eval(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only") @@ -122,12 +121,12 @@ def test_training_multiple_iterations(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft(self): @@ -155,15 +154,15 @@ def test_training_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_peft def test_training_peft_with_gradient_checkpointing(self): @@ -197,22 +196,22 @@ def test_training_peft_with_gradient_checkpointing(self): ) # Verify gradient checkpointing is enabled - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Store initial parameters to check which ones change previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that only LoRA parameters have changed, base model parameters remain unchanged for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "lora" in n.lower(): # LoRA parameters should change - self.assertFalse(torch.equal(param, new_param), f"LoRA parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"LoRA parameter {n} has not changed." else: # Base model parameters should not change - self.assertTrue(torch.equal(param, new_param), f"Base parameter {n} has changed.") + assert torch.equal(param, new_param), f"Base parameter {n} has changed." def test_training_different_reward_model(self): # Use a reward model different from the model: different chat template, tokenization, etc. @@ -246,12 +245,12 @@ def test_training_different_reward_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_standard(self): # Test if trainer can handle reward function with standard format @@ -280,12 +279,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_conversational(self): # Test if trainer can handle reward function with conversational format @@ -315,12 +314,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs(self): # Test that RLOOTrainer can be instantiated with multiple reward functions @@ -353,12 +352,12 @@ def reward_func2(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_None_output(self): """Test that a valid math reward function is processed correctly while the code reward function returns None.""" @@ -397,12 +396,12 @@ def non_applicable_reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_reward_funcs_with_weights(self): """Test that RLOOTrainer can handle multiple reward functions with weights.""" @@ -437,16 +436,16 @@ def reward_func2(completions, **kwargs): trainer.train() # Check that training logs contain both reward metrics - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIn("rewards/reward_func1/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func1/std", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/mean", trainer.state.log_history[-1]) - self.assertIn("rewards/reward_func2/std", trainer.state.log_history[-1]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert "rewards/reward_func1/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func1/std" in trainer.state.log_history[-1] + assert "rewards/reward_func2/mean" in trainer.state.log_history[-1] + assert "rewards/reward_func2/std" in trainer.state.log_history[-1] # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_mixed_reward_funcs(self): # Test if the trainer can handle a mix of reward functions and reward models @@ -475,12 +474,12 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_reward_func_additional_column(self): # Test if trainer can handle reward function that rely on additional columns in the dataset @@ -513,12 +512,12 @@ def reward_func(completions, some_values, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_sync_ref_model(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -544,12 +543,12 @@ def test_training_with_sync_ref_model(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_beta_zero(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -573,16 +572,16 @@ def test_training_beta_zero(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @unittest.skip("We should add a mock for the vLLM server.") @require_peft @require_vllm + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_and_peft(self): """Test that training works with vLLM for generation.""" model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B-Instruct") # tiny model is too small for vLLM @@ -615,19 +614,19 @@ def test_training_vllm_and_peft(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n and "original_module" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_guided_decoding(self): """Test that training works with vLLM for generation with guided decoding.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -653,12 +652,12 @@ def test_training_vllm_guided_decoding(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_additional_generation_kwargs(self): """Test that training works with additional generation kwargs.""" @@ -688,15 +687,15 @@ def test_training_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vllm_with_additional_generation_kwargs(self): """Test that training works with vLLM and additional generation kwargs.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -726,12 +725,12 @@ def test_training_vllm_with_additional_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_normalized_advantages(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -756,12 +755,12 @@ def test_training_with_normalized_advantages(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_clipped_rewards(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -786,12 +785,12 @@ def test_training_with_clipped_rewards(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @patch("transformers.generation.utils.GenerationMixin.generate") def test_training_with_mask_truncated_completions(self, mock_generate): @@ -836,12 +835,12 @@ def fake_generate(input_ids, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_mask_truncated_completions_all_masked(self): """ @@ -874,14 +873,14 @@ def test_training_with_mask_truncated_completions_all_masked(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertTrue(torch.equal(param, new_param), f"Parameter {n} has changed.") + assert torch.equal(param, new_param), f"Parameter {n} has changed." - def test_warning_raised_all_rewards_none(self): + def test_warning_raised_all_rewards_none(self, caplog): """Test that a proper warning is raised when all rewards are None.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -904,11 +903,11 @@ def always_none_reward_func(completions, **kwargs): train_dataset=dataset, ) - with self.assertLogs("trl.trainer.rloo_trainer", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.rloo_trainer"): trainer.train() expected_warning = "All reward functions returned None for the following kwargs:" - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in caplog.text def test_training_num_generations_larger_than_batch_size(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -933,12 +932,12 @@ def test_training_num_generations_larger_than_batch_size(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_multiple_dataloader_workers(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -963,12 +962,12 @@ def test_training_multiple_dataloader_workers(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_generation_kwargs(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -993,12 +992,12 @@ def test_training_with_generation_kwargs(self): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_training_with_reward_func_accessing_trainer_state(self): dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1070,7 +1069,7 @@ def test_prepare_input_called_with_correct_data(self): with patch.object(RLOOTrainer, "training_step", wraps=trainer.training_step) as mock_prepare: trainer.train() # 3 epochs * 2 iterations * 2 generation batches to cover the dataset * 4 steps_per_generation - self.assertEqual(mock_prepare.call_count, 48) + assert mock_prepare.call_count == 48 for i in range(0, 8): # Generation batch repeated 8 times (steps_per_generation*num_iterations) assert mock_prepare.call_args_list[i].args[1] == expected_first_generation_batch for i in range(8, 16): @@ -1113,7 +1112,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1128,7 +1127,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_beta_non_zero(self): @@ -1158,7 +1157,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1168,7 +1167,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision @require_peft @@ -1203,15 +1202,15 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model params to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed.") + assert torch.allclose(param, new_param), f"Parameter {n} has changed." elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed.") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_and_prompt_truncation(self): @@ -1242,7 +1241,7 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check that the params have changed # Because of the way the tiny models are initialized, the gradient does not flow properly through the @@ -1252,7 +1251,7 @@ def reward_func(completions, **kwargs): if n.startswith(params_to_skip): continue new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @parameterized.expand( [ @@ -1262,7 +1261,7 @@ def reward_func(completions, **kwargs): ) @require_vision @require_vllm - @unittest.skip("We should add a mock for the vLLM server.") + @pytest.mark.skip(reason="We should add a mock for the vLLM server.") def test_training_vlm_and_vllm(self, model_id) -> None: dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") @@ -1292,11 +1291,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." @require_vision def test_training_vlm_multi_image(self): @@ -1329,11 +1328,11 @@ def reward_func(completions, **kwargs): trainer.train() - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.") + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" @@ -1352,7 +1351,7 @@ def test_mismatched_reward_processing_classes_length(self): training_args = RLOOConfig(output_dir=self.tmp_dir, report_to="none") - with self.assertRaises(ValueError) as context: + with pytest.raises(ValueError, match="must match"): RLOOTrainer( model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", reward_funcs=reward_models, @@ -1361,8 +1360,6 @@ def test_mismatched_reward_processing_classes_length(self): train_dataset=dataset, ) - self.assertIn("must match", str(context.exception)) - def test_correct_reward_processing_classes_list(self): """Test that correct list of reward_processing_classes works properly.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") @@ -1392,7 +1389,7 @@ def test_correct_reward_processing_classes_list(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), len(reward_models)) + assert len(trainer.reward_processing_classes) == len(reward_models) def test_single_reward_model_with_single_processing_class(self): """Test that single reward model with single processing class works.""" @@ -1416,9 +1413,5 @@ def test_single_reward_model_with_single_processing_class(self): train_dataset=dataset, ) - self.assertEqual(len(trainer.reward_processing_classes), 1) - self.assertEqual(trainer.reward_processing_classes[0], single_processing_class) - - -if __name__ == "__main__": - unittest.main() + assert len(trainer.reward_processing_classes) == 1 + assert trainer.reward_processing_classes[0] == single_processing_class diff --git a/tests/test_sft_trainer.py b/tests/test_sft_trainer.py index 19d5a1b5d70..5d1aacf2876 100644 --- a/tests/test_sft_trainer.py +++ b/tests/test_sft_trainer.py @@ -20,20 +20,20 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer -from transformers.testing_utils import require_flash_attn, require_liger_kernel, require_peft, require_vision +from transformers.testing_utils import require_flash_attn, require_liger_kernel from transformers.utils import is_peft_available from trl import SFTConfig, SFTTrainer from trl.trainer.sft_trainer import DataCollatorForLanguageModeling, dft_loss -from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes +from .testing_utils import TrlTestCase, ignore_warnings, require_bitsandbytes, require_peft, require_vision if is_peft_available(): from peft import LoraConfig, PeftModel, PromptEncoderConfig, TaskType, get_peft_model -class DFTLossTester(TrlTestCase): +class TestDFTLoss(TrlTestCase): def test_dft_loss(self): batch_size = 2 seq_len = 3 @@ -64,7 +64,7 @@ def test_basic_padding(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -79,7 +79,7 @@ def test_completion_mask(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3], [-100, 5, -100]])) @@ -95,7 +95,7 @@ def test_completion_only_loss_disabled(self): result = collator(examples) # Labels should not be masked when completion_only_loss=False - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -107,7 +107,7 @@ def test_padding_free_mode(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5]])) @@ -122,7 +122,7 @@ def test_padding_free_with_completion_mask(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, -100, 3, -100, 5]])) @@ -139,7 +139,7 @@ def test_packing(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 2, 3, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, 6, -100, 8, 9, 10, -100]])) @@ -151,7 +151,7 @@ def test_pad_to_multiple_of(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 0], [1, 1, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, -100], [4, 5, -100, -100]])) @@ -163,7 +163,7 @@ def test_pad_to_multiple_of_and_padding_free(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "position_ids", "labels"}) + assert set(result.keys()) == {"input_ids", "position_ids", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0]])) torch.testing.assert_close(result["position_ids"], torch.tensor([[0, 1, 2, 0, 1, 0, 0, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[-100, 2, 3, -100, 5, -100, -100, -100]])) @@ -175,7 +175,7 @@ def test_custom_position_ids_but_no_padding_free(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 0]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -187,7 +187,7 @@ def test_single_example(self): result = self.collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3, 4]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1, 1]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3, 4]])) @@ -199,7 +199,7 @@ def test_different_pad_token_id(self): result = collator(examples) - self.assertEqual(set(result.keys()), {"input_ids", "attention_mask", "labels"}) + assert set(result.keys()) == {"input_ids", "attention_mask", "labels"} torch.testing.assert_close(result["input_ids"], torch.tensor([[1, 2, 3], [4, 5, 999]])) torch.testing.assert_close(result["attention_mask"], torch.tensor([[1, 1, 1], [1, 1, 0]])) torch.testing.assert_close(result["labels"], torch.tensor([[1, 2, 3], [4, 5, -100]])) @@ -221,25 +221,25 @@ def test_assistant_masks(self): def test_single_example_single_doc(self): batch_seq_lengths = [[5]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 1) - self.assertTrue(torch.equal(result[0], torch.arange(5))) + assert len(result) == 1 + assert torch.equal(result[0], torch.arange(5)) def test_single_example_multiple_docs(self): batch_seq_lengths = [[3, 2]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 1) + assert len(result) == 1 # First sequence: 0, 1, 2; second sequence: 0, 1 - self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1]))) + assert torch.equal(result[0], torch.tensor([0, 1, 2, 0, 1])) def test_multiple_examples(self): batch_seq_lengths = [[2, 2], [3]] result = DataCollatorForLanguageModeling.get_position_ids_from_packed_seq_lengths(batch_seq_lengths) - self.assertEqual(len(result), 2) - self.assertTrue(torch.equal(result[0], torch.tensor([0, 1, 0, 1]))) - self.assertTrue(torch.equal(result[1], torch.arange(3))) + assert len(result) == 2 + assert torch.equal(result[0], torch.tensor([0, 1, 0, 1])) + assert torch.equal(result[1], torch.arange(3)) -class SFTTrainerTester(TrlTestCase): +class TestSFTTrainer(TrlTestCase): @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",), @@ -262,12 +262,12 @@ def test_train(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Special case for harmony def test_train_gpt_oss(self): @@ -287,12 +287,12 @@ def test_train_gpt_oss(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model(self): # Instantiate the model @@ -312,12 +312,12 @@ def test_train_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_dft_loss(self): # Get the dataset @@ -348,12 +348,12 @@ def test_train_dft_loss(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_moe_model_with_aux_loss(self): # Get the dataset @@ -375,13 +375,13 @@ def test_train_moe_model_with_aux_loss(self): trainer.train() # Check that the training loss and aux loss are not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["aux_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["aux_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_formatting_func(self): # Dummy formatting function @@ -408,12 +408,12 @@ def formatting_prompts_func(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_model_dtype(self): # Get the dataset @@ -437,7 +437,7 @@ def test_train_model_dtype(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -447,8 +447,8 @@ def test_train_model_dtype(self): continue new_param = trainer.model.get_parameter(n) # Check the torch dtype - self.assertEqual(new_param.dtype, torch.float16) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert new_param.dtype == torch.float16 + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config(self): @@ -477,15 +477,15 @@ def test_train_dense_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config(self): @@ -514,15 +514,15 @@ def test_train_moe_with_peft_config(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_peft_model(self): @@ -551,15 +551,15 @@ def test_train_peft_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_dense_with_peft_config_and_gradient_checkpointing(self): @@ -588,15 +588,15 @@ def test_train_dense_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_moe_with_peft_config_and_gradient_checkpointing(self): @@ -625,15 +625,15 @@ def test_train_moe_with_peft_config_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_peft def test_train_with_peft_model_and_gradient_checkpointing(self): @@ -652,7 +652,7 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer = SFTTrainer(model=model, args=training_args, train_dataset=dataset) # Verify model is a PeftModel - self.assertIsInstance(trainer.model, PeftModel) + assert isinstance(trainer.model, PeftModel) # Save the initial parameters to compare them later previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} @@ -661,15 +661,15 @@ def test_train_with_peft_model_and_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n in base_param_names: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "base_layer" not in n: # We expect the peft parameters to be different (except for the base layer) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_liger_kernel def test_train_with_liger(self): @@ -689,12 +689,12 @@ def test_train_with_liger(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_non_chatml_conversational_data(self): # Get the dataset @@ -719,12 +719,12 @@ def rename_fields(example: list[dict]): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_pretokenized_data(self): # Get the dataset @@ -749,12 +749,12 @@ def tokenize_example(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_iterable_dataset(self): # Get the dataset @@ -773,12 +773,12 @@ def test_train_with_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @require_flash_attn def test_train_padding_free(self): @@ -804,12 +804,12 @@ def test_train_padding_free(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @parameterized.expand([("bfd",), ("wrapped",)]) @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @@ -833,12 +833,12 @@ def test_train_packing(self, packing_strategy): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) @@ -863,16 +863,16 @@ def test_eval_packing(self): # Check the number of sequences in train and eval datasets num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) num_eval_seqs = sum(len(x) for x in trainer.eval_dataset["seq_lengths"]) - self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs - self.assertEqual(num_eval_seqs, 2) # we should still have 2 seqs + assert num_train_seqs == 17 # we should still have 17 seqs + assert num_eval_seqs == 2 # we should still have 2 seqs # Check that all sequences are shorter than the max length - self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])) - self.assertTrue(all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"])) + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) + assert all(sum(x) <= 64 for x in trainer.eval_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets - self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs - self.assertEqual(len(trainer.eval_dataset["input_ids"]), 1) # w/ this dataset, we end up with 6 seqs + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 1 # w/ this dataset, we end up with 6 seqs @ignore_warnings(message="You are using packing, but the attention implementation is not.*", category=UserWarning) @ignore_warnings(message="Padding-free training is enabled, but the attention.*", category=UserWarning) @@ -897,17 +897,17 @@ def test_only_train_packing(self): # Check the number of sequences in train dataset num_train_seqs = sum(len(x) for x in trainer.train_dataset["seq_lengths"]) - self.assertEqual(num_train_seqs, 17) # we should still have 17 seqs + assert num_train_seqs == 17 # we should still have 17 seqs # We expect eval dataset not having "seq_lengths" as eval_packing is False - self.assertNotIn("seq_lengths", trainer.eval_dataset) + assert "seq_lengths" not in trainer.eval_dataset # Check that all sequences are shorter than the max length - self.assertTrue(all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"])) + assert all(sum(x) <= 64 for x in trainer.train_dataset["seq_lengths"]) # Check the number of sequences in train and eval datasets - self.assertEqual(len(trainer.train_dataset["input_ids"]), 3) # w/ this dataset, we end up with 46 seqs - self.assertEqual(len(trainer.eval_dataset["input_ids"]), 2) # w/ this dataset, we end up with 6 seqs + assert len(trainer.train_dataset["input_ids"]) == 3 # w/ this dataset, we end up with 46 seqs + assert len(trainer.eval_dataset["input_ids"]) == 2 # w/ this dataset, we end up with 6 seqs def test_train_with_chat_template_kwargs(self): # Get the dataset @@ -934,12 +934,12 @@ def test_train_with_chat_template_kwargs(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only(self): # Get the dataset @@ -958,12 +958,12 @@ def test_train_assistant_only(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only(self): # Get the dataset @@ -982,12 +982,12 @@ def test_train_completion_only(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_completion_only_harmony(self): # Get the dataset @@ -1006,12 +1006,12 @@ def test_train_completion_only_harmony(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_and_completion_only(self): # Get the dataset @@ -1040,12 +1040,12 @@ def add_to_completion(example): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_assistant_only_iterable_dataset(self): # Get the dataset @@ -1066,12 +1066,12 @@ def test_train_assistant_only_iterable_dataset(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_model(self): # Get the dataset @@ -1091,12 +1091,12 @@ def test_train_with_set_chat_template_from_model(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_set_chat_template_from_path(self): # Get the dataset @@ -1120,24 +1120,22 @@ def test_train_with_set_chat_template_from_path(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" # Check that the template saved in the output directory is the same as the one used for training template_path = pathlib.Path(self.tmp_dir) / "checkpoint-9" / "chat_template.jinja" - self.assertTrue(template_path.exists(), f"Chat template not found at {template_path}") + assert template_path.exists(), f"Chat template not found at {template_path}" with open(template_path) as f: template_content = f.read() with open(training_args.chat_template_path) as f: original_template_content = f.read() - self.assertEqual( - template_content, original_template_content, "Chat template content does not match the original" - ) + assert template_content == original_template_content, "Chat template content does not match the original" def test_train_toolcall_data(self): # Get the dataset @@ -1156,12 +1154,12 @@ def test_train_toolcall_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_train_with_eval(self): # Get the dataset @@ -1180,7 +1178,7 @@ def test_train_with_eval(self): trainer.train() # Check that the eval loss is not None - self.assertIsNotNone(trainer.state.log_history[0]["eval_loss"]) + assert trainer.state.log_history[0]["eval_loss"] is not None def test_train_with_multiple_eval_dataset(self): # Get the dataset @@ -1198,8 +1196,8 @@ def test_train_with_multiple_eval_dataset(self): trainer.train() # Check that the eval losses are not None - self.assertIsNotNone(trainer.state.log_history[-3]["eval_data1_loss"]) - self.assertIsNotNone(trainer.state.log_history[-2]["eval_data2_loss"]) + assert trainer.state.log_history[-3]["eval_data1_loss"] is not None + assert trainer.state.log_history[-2]["eval_data2_loss"] is not None def test_train_with_gradient_checkpointing(self): # Get the dataset @@ -1218,12 +1216,12 @@ def test_train_with_gradient_checkpointing(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" def test_tag_added(self): # Get the dataset @@ -1236,7 +1234,7 @@ def test_tag_added(self): ) for tag in ["sft", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @require_peft def test_tag_added_peft(self): @@ -1251,7 +1249,7 @@ def test_tag_added_peft(self): ) for tag in ["sft", "trl"]: - self.assertIn(tag, trainer.model.model_tags) + assert tag in trainer.model.model_tags @parameterized.expand( [ @@ -1285,7 +1283,7 @@ def test_train_vlm(self, model_id): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -1302,9 +1300,7 @@ def test_train_vlm(self, model_id): ): # fmt: on continue - self.assertFalse( - torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" - ) + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_vision def test_train_vlm_prompt_completion(self): @@ -1330,12 +1326,12 @@ def test_train_vlm_prompt_completion(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" # Gemma 3n uses a timm encoder, making it difficult to create a smaller variant for testing. # To ensure coverage, we run tests on the full model but mark them as slow to exclude from default runs. @@ -1363,7 +1359,7 @@ def test_train_vlm_gemma_3n(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): @@ -1371,7 +1367,7 @@ def test_train_vlm_gemma_3n(self): if "model.vision_tower" in n: # The vision tower is not updated, not sure why at this point. continue - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_vision def test_train_vlm_text_only_data(self): @@ -1393,15 +1389,15 @@ def test_train_vlm_text_only_data(self): trainer.train() # Check that the training loss is not None - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) + assert trainer.state.log_history[-1]["train_loss"] is not None # Check the params have changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if n.startswith("model.visual"): - self.assertTrue(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated") + assert torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is updated" else: - self.assertFalse(torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated") + assert not torch.allclose(param, new_param, rtol=1e-12, atol=1e-12), f"Param {n} is not updated" @require_peft def test_prompt_tuning(self): @@ -1422,16 +1418,16 @@ def test_prompt_tuning(self): trainer.train() # Check that training completed successfully - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") @@ -1455,7 +1451,7 @@ def test_peft_model_with_quantization(self): # Verify that this triggers the is_qlora condition is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False) - self.assertTrue(is_qlora, "Model should be detected as QLoRA (quantized)") + assert is_qlora, "Model should be detected as QLoRA (quantized)" # Create LoRA configuration suitable for QLoRA lora_config = LoraConfig( @@ -1470,7 +1466,7 @@ def test_peft_model_with_quantization(self): peft_model = get_peft_model(model, lora_config) # Verify the quantization attributes are preserved on the PeftModel - self.assertTrue(getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag") + assert getattr(peft_model, "is_loaded_in_4bit", False), "PeftModel should preserve quantization flag" # Get the dataset dataset = load_dataset("trl-internal-testing/zen", "standard_language_modeling", split="train") @@ -1489,9 +1485,9 @@ def test_peft_model_with_quantization(self): base_params_before.append(name) # Ensure we have the expected parameter distribution for QLoRA - self.assertTrue(len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially") - self.assertTrue(len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters") - self.assertEqual(len(base_params_before), 0, "Base model parameters should already be frozen in PeftModel") + assert len(trainable_params_before) > 0, "PeftModel should have trainable parameters initially" + assert len(lora_params_before) > 0, "PeftModel should have trainable LoRA parameters" + assert len(base_params_before) == 0, "Base model parameters should already be frozen in PeftModel" # Initialize the trainer with the already configured PeftModel training_args = SFTConfig(output_dir=self.tmp_dir, report_to="none", max_steps=1) @@ -1508,31 +1504,25 @@ def test_peft_model_with_quantization(self): lora_params_after.append(name) # LoRA parameters should remain trainable - self.assertTrue( - len(trainable_params_after) > 0, + assert len(trainable_params_after) > 0, ( f"PeftModel should still have trainable parameters after SFTTrainer initialization. " f"Found {len(trainable_params_after)} trainable params. " - f"This test fails without the fix for issue #3926.", + f"This test fails without the fix for issue #3926." ) - self.assertTrue( - len(lora_params_after) > 0, + assert len(lora_params_after) > 0, ( f"LoRA adapter parameters should remain trainable. " - f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original.", + f"Found {len(lora_params_after)} trainable LoRA params out of {len(lora_params_before)} original." ) # Ensure the parameter counts are preserved (no additional freezing occurred) - self.assertEqual( - len(trainable_params_before), - len(trainable_params_after), - "Number of trainable parameters should not change after SFTTrainer initialization", + assert len(trainable_params_before) == len(trainable_params_after), ( + "Number of trainable parameters should not change after SFTTrainer initialization" ) # Verify that all original LoRA parameters are still trainable - self.assertEqual( - set(lora_params_before), - set(lora_params_after), - "All original LoRA parameters should remain trainable after SFTTrainer initialization", + assert set(lora_params_before) == set(lora_params_after), ( + "All original LoRA parameters should remain trainable after SFTTrainer initialization" ) @require_peft @@ -1552,15 +1542,15 @@ def test_prompt_tuning_peft_model(self): trainer.train() # Check that training completed successfully - self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"]) - self.assertIsNotNone(trainer.state.log_history[-1]["mean_token_accuracy"]) + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["mean_token_accuracy"] is not None # Check the peft params have changed and the base model params have not changed for n, param in previous_trainable_params.items(): new_param = trainer.model.get_parameter(n) if "base_model" in n: # We expect the base model parameters to be the same - self.assertTrue(torch.allclose(param, new_param), f"Parameter {n} has changed") + assert torch.allclose(param, new_param), f"Parameter {n} has changed" elif "prompt_encoder" in n: # We expect the peft parameters to be different - self.assertFalse(torch.allclose(param, new_param), f"Parameter {n} has not changed") + assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" else: raise ValueError(f"Unexpected parameter {n} in model: {trainer.model}") diff --git a/tests/test_trainers_args.py b/tests/test_trainers_args.py index 61ab72130f2..b76110d5f17 100644 --- a/tests/test_trainers_args.py +++ b/tests/test_trainers_args.py @@ -44,7 +44,7 @@ from .testing_utils import TrlTestCase, require_sklearn -class TrainerArgTester(TrlTestCase): +class TestTrainerArg(TrlTestCase): @require_sklearn def test_bco(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -76,22 +76,22 @@ def test_bco(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.prompt_sample_size, 512) - self.assertEqual(trainer.args.min_density_ratio, 0.2) - self.assertEqual(trainer.args.max_density_ratio, 20.0) + assert trainer.args.is_encoder_decoder + assert trainer.args.precompute_ref_log_probs + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.prompt_sample_size == 512 + assert trainer.args.min_density_ratio == 0.2 + assert trainer.args.max_density_ratio == 20.0 def test_cpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -117,22 +117,22 @@ def test_cpo(self): dataset_num_proc=4, ) trainer = CPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_smoothing, 0.5) - self.assertEqual(trainer.args.loss_type, "hinge") - self.assertEqual(trainer.args.disable_dropout, False) - self.assertEqual(trainer.args.cpo_alpha, 0.5) - self.assertEqual(trainer.args.simpo_gamma, 0.2) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.label_smoothing == 0.5 + assert trainer.args.loss_type == "hinge" + assert not trainer.args.disable_dropout + assert trainer.args.cpo_alpha == 0.5 + assert trainer.args.simpo_gamma == 0.2 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) + assert trainer.args.is_encoder_decoder + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 def test_dpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -174,32 +174,32 @@ def test_dpo(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.label_smoothing, 0.5) - self.assertEqual(trainer.args.loss_type, "hinge") - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.pad_token, ".") - self.assertEqual(trainer.args.truncation_mode, "keep_start") - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.disable_dropout, False) + assert trainer.args.beta == 0.5 + assert trainer.args.label_smoothing == 0.5 + assert trainer.args.loss_type == "hinge" + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.pad_token == "." + assert trainer.args.truncation_mode == "keep_start" + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert not trainer.args.disable_dropout # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.model_adapter_name, "dummy_adapter") - self.assertEqual(trainer.args.ref_adapter_name, "dummy_adapter") - self.assertEqual(trainer.args.reference_free, True) - self.assertEqual(trainer.args.force_use_ref_model, True) - self.assertEqual(trainer.args.f_divergence_type, FDivergenceType.JS_DIVERGENCE) - self.assertEqual(trainer.args.f_alpha_divergence_coef, 0.5) + assert trainer.args.precompute_ref_log_probs + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.model_adapter_name == "dummy_adapter" + assert trainer.args.ref_adapter_name == "dummy_adapter" + assert trainer.args.reference_free + assert trainer.args.force_use_ref_model + assert trainer.args.f_divergence_type == FDivergenceType.JS_DIVERGENCE + assert trainer.args.f_alpha_divergence_coef == 0.5 # self.assertEqual(trainer.args.sync_ref_model, True) - self.assertEqual(trainer.args.ref_model_mixup_alpha, 0.5) - self.assertEqual(trainer.args.ref_model_sync_steps, 32) - self.assertEqual(trainer.args.rpo_alpha, 0.5) - self.assertEqual(trainer.args.discopop_tau, 0.1) + assert trainer.args.ref_model_mixup_alpha == 0.5 + assert trainer.args.ref_model_sync_steps == 32 + assert trainer.args.rpo_alpha == 0.5 + assert trainer.args.discopop_tau == 0.1 def test_kto(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -230,21 +230,21 @@ def test_kto(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.desirable_weight, 0.5) - self.assertEqual(trainer.args.undesirable_weight, 0.5) - self.assertEqual(trainer.args.label_pad_token_id, -99) - self.assertEqual(trainer.args.padding_value, -99) - self.assertEqual(trainer.args.truncation_mode, "keep_start") + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert trainer.args.desirable_weight == 0.5 + assert trainer.args.undesirable_weight == 0.5 + assert trainer.args.label_pad_token_id == -99 + assert trainer.args.padding_value == -99 + assert trainer.args.truncation_mode == "keep_start" # self.assertEqual(trainer.args.generate_during_eval, True) - self.assertEqual(trainer.args.is_encoder_decoder, True) - self.assertEqual(trainer.args.precompute_ref_log_probs, True) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.ref_model_init_kwargs, {"trust_remote_code": True}) - self.assertEqual(trainer.args.dataset_num_proc, 4) + assert trainer.args.is_encoder_decoder + assert trainer.args.precompute_ref_log_probs + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.ref_model_init_kwargs == {"trust_remote_code": True} + assert trainer.args.dataset_num_proc == 4 @parameterized.expand([(False,), (True,)]) def test_nash_md(self, mixtures_coef_list): @@ -266,7 +266,7 @@ def test_nash_md(self, mixtures_coef_list): reward_funcs=reward_model, train_dataset=dataset, ) - self.assertEqual(trainer.args.mixture_coef, 0.5 if not mixtures_coef_list else [0.5, 0.6]) + assert trainer.args.mixture_coef == (0.5 if not mixtures_coef_list else [0.5, 0.6]) @parameterized.expand([(False,), (True,)]) def test_online_dpo(self, beta_list): @@ -293,11 +293,11 @@ def test_online_dpo(self, beta_list): processing_class=tokenizer, reward_processing_classes=tokenizer, ) - self.assertEqual(trainer.args.max_new_tokens, 42) - self.assertEqual(trainer.args.temperature, 0.5) - self.assertEqual(trainer.args.missing_eos_penalty, 0.33) - self.assertEqual(trainer.args.beta, 0.6 if not beta_list else [0.6, 0.7]) - self.assertEqual(trainer.args.loss_type, "hinge") + assert trainer.args.max_new_tokens == 42 + assert trainer.args.temperature == 0.5 + assert trainer.args.missing_eos_penalty == 0.33 + assert trainer.args.beta == (0.6 if not beta_list else [0.6, 0.7]) + assert trainer.args.loss_type == "hinge" def test_orpo(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -319,12 +319,12 @@ def test_orpo(self): dataset_num_proc=4, ) trainer = ORPOTrainer(model=model_id, args=training_args, train_dataset=dataset, processing_class=tokenizer) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.max_prompt_length, 64) - self.assertEqual(trainer.args.max_completion_length, 64) - self.assertEqual(trainer.args.beta, 0.5) - self.assertEqual(trainer.args.disable_dropout, False) - self.assertEqual(trainer.args.label_pad_token_id, -99) + assert trainer.args.max_length == 256 + assert trainer.args.max_prompt_length == 64 + assert trainer.args.max_completion_length == 64 + assert trainer.args.beta == 0.5 + assert not trainer.args.disable_dropout + assert trainer.args.label_pad_token_id == -99 def test_reward(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -343,9 +343,9 @@ def test_reward(self): train_dataset=dataset, processing_class=tokenizer, ) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.center_rewards_coefficient, 0.1) + assert trainer.args.max_length == 256 + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.center_rewards_coefficient == 0.1 def test_sft(self): model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" @@ -362,15 +362,15 @@ def test_sft(self): eval_packing=True, ) trainer = SFTTrainer(model_id, args=training_args, train_dataset=dataset) - self.assertEqual(trainer.args.dataset_text_field, "dummy_text_field") - self.assertEqual(trainer.args.packing, True) - self.assertEqual(trainer.args.max_length, 256) - self.assertEqual(trainer.args.dataset_num_proc, 4) - self.assertEqual(trainer.args.neftune_noise_alpha, 0.1) - self.assertEqual(trainer.args.model_init_kwargs, {"trust_remote_code": True}) - self.assertIn("append_concat_token", trainer.args.dataset_kwargs) - self.assertEqual(trainer.args.dataset_kwargs["append_concat_token"], True) - self.assertEqual(trainer.args.eval_packing, True) + assert trainer.args.dataset_text_field == "dummy_text_field" + assert trainer.args.packing + assert trainer.args.max_length == 256 + assert trainer.args.dataset_num_proc == 4 + assert trainer.args.neftune_noise_alpha == 0.1 + assert trainer.args.model_init_kwargs == {"trust_remote_code": True} + assert "append_concat_token" in trainer.args.dataset_kwargs + assert trainer.args.dataset_kwargs["append_concat_token"] + assert trainer.args.eval_packing @parameterized.expand([(False,), (True,)]) def test_xpo(self, alpha_list): @@ -392,4 +392,4 @@ def test_xpo(self, alpha_list): reward_funcs=reward_model, train_dataset=dataset, ) - self.assertEqual(trainer.args.alpha, 0.5 if not alpha_list else [0.5, 0.6]) + assert trainer.args.alpha == (0.5 if not alpha_list else [0.5, 0.6]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 4758109e08b..14b3fb570af 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -17,11 +17,11 @@ from unittest.mock import patch import numpy as np +import pytest import torch from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import ModelConfig @@ -46,7 +46,7 @@ unsplit_pixel_values_by_grid, ) -from .testing_utils import TrlTestCase, require_rich +from .testing_utils import TrlTestCase, require_peft, require_rich if is_peft_available(): @@ -59,14 +59,14 @@ def test_pad_1_dim_left(self): y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="left") expected = torch.tensor([[1, 2, 3], [0, 4, 5]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_1_dim_right(self): x = torch.tensor([1, 2, 3]) y = torch.tensor([4, 5]) output = pad((x, y), padding_value=0, padding_side="right") expected = torch.tensor([[1, 2, 3], [4, 5, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_left(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -78,7 +78,7 @@ def test_pad_2_dim_left(self): [[0, 0], [5, 6]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_right(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -90,7 +90,7 @@ def test_pad_2_dim_right(self): [[5, 6], [0, 0]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_2_dim_right_multidim(self): x = torch.tensor([[1, 2], [3, 4]]) @@ -102,7 +102,7 @@ def test_pad_2_dim_right_multidim(self): [[5, 0], [0, 0]], ] ) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_1(self): x = torch.tensor([1, 2, 3]) @@ -110,7 +110,7 @@ def test_pad_to_multiple_of_1(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 0], [4, 5, 0, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_2(self): x = torch.tensor([1, 2, 3, 4, 5]) @@ -118,7 +118,7 @@ def test_pad_to_multiple_of_2(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="right", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4, 5, 0, 0, 0], [6, 7, 8, 0, 0, 0, 0, 0]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_side_left(self): x = torch.tensor([1, 2, 3, 4, 5]) @@ -126,7 +126,7 @@ def test_pad_to_multiple_of_side_left(self): # Max length is 3, pad to multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[0, 0, 0, 1, 2, 3, 4, 5], [0, 0, 0, 0, 0, 6, 7, 8]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) def test_pad_to_multiple_of_no_extra_padding(self): x = torch.tensor([1, 2, 3, 4]) @@ -134,7 +134,7 @@ def test_pad_to_multiple_of_no_extra_padding(self): # Already multiple of 4 output = pad((x, y), padding_value=0, padding_side="left", pad_to_multiple_of=4) expected = torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]]) - self.assertTrue(torch.equal(output, expected)) + assert torch.equal(output, expected) @require_peft @@ -143,7 +143,7 @@ def test_create_peft_config_use_peft_false(self): """Test that when use_peft is False, the function returns None.""" model_args = ModelConfig(use_peft=False) peft_config = get_peft_config(model_args) - self.assertIsNone(peft_config) + assert peft_config is None def test_create_peft_config_use_peft_true(self): """Test that when use_peft is True, the function returns a LoraConfig object.""" @@ -159,7 +159,7 @@ def test_create_peft_config_use_peft_true(self): } model_args = ModelConfig(use_peft=True, **peft_kwargs) peft_config = get_peft_config(model_args) - self.assertTrue(isinstance(peft_config, LoraConfig)) + assert isinstance(peft_config, LoraConfig) for arg, value in peft_kwargs.items(): # Test that lists of modules are converted to sets if arg == "lora_target_modules": @@ -168,23 +168,22 @@ def test_create_peft_config_use_peft_true(self): if arg in ["lora_r", "lora_task_type", "lora_target_modules", "lora_modules_to_save"]: arg = arg[len("lora_") :] if arg.startswith("lora_") else arg - self.assertEqual(getattr(peft_config, arg), value) + assert getattr(peft_config, arg) == value class TestDecodeAndStripPadding(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") def test_example_with_padding(self): inputs = self.tokenizer(["Hello world", "Hello"], padding=True, return_tensors="pt") decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) - self.assertEqual(decoded, ["Hello world", "Hello"]) + assert decoded == ["Hello world", "Hello"] def test_example_without_padding(self): inputs = self.tokenizer(["Hello", "Hello"], padding=False, return_tensors="pt") decoded = decode_and_strip_padding(inputs["input_ids"], self.tokenizer) - self.assertEqual(decoded, ["Hello", "Hello"]) + assert decoded == ["Hello", "Hello"] class TestGenerateModelCard(TrlTestCase): @@ -203,15 +202,15 @@ def test_full(self): paper_id="1234.56789", ) card_text = str(model_card) - self.assertIn("[username/my_base_model](https://huggingface.co/username/my_base_model)", card_text) - self.assertIn("my_model", card_text) - self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) - self.assertIn("datasets: username/my_dataset", card_text) - self.assertIn("](https://wandb.ai/username/project_id/runs/abcd1234)", card_text) - self.assertIn("](https://www.comet.com/username/project_id/experiment_id", card_text) - self.assertIn("My Trainer", card_text) - self.assertIn("```bibtex\n@article{my_trainer, ...}\n```", card_text) - self.assertIn("[My Paper](https://huggingface.co/papers/1234.56789)", card_text) + assert "[username/my_base_model](https://huggingface.co/username/my_base_model)" in card_text + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "datasets: username/my_dataset" in card_text + assert "](https://wandb.ai/username/project_id/runs/abcd1234)" in card_text + assert "](https://www.comet.com/username/project_id/experiment_id" in card_text + assert "My Trainer" in card_text + assert "```bibtex\n@article{my_trainer, ...}\n```" in card_text + assert "[My Paper](https://huggingface.co/papers/1234.56789)" in card_text def test_val_none(self): model_card = generate_model_card( @@ -228,14 +227,13 @@ def test_val_none(self): paper_id=None, ) card_text = str(model_card) - self.assertIn("my_model", card_text) - self.assertIn('pipeline("text-generation", model="username/my_hub_model", device="cuda")', card_text) - self.assertIn("My Trainer", card_text) + assert "my_model" in card_text + assert 'pipeline("text-generation", model="username/my_hub_model", device="cuda")' in card_text + assert "My Trainer" in card_text class TestDataCollatorForChatML(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Initialize the tokenizer self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") if self.tokenizer.pad_token is None: @@ -265,11 +263,11 @@ def test_data_collator_for_chatml(self): data = self.collator(self.examples) # Verify basic shapes and types - self.assertIn("input_ids", data) - self.assertIn("attention_mask", data) - self.assertIn("labels", data) - self.assertIn("prompts", data) - self.assertIn("prompt_attention_mask", data) + assert "input_ids" in data + assert "attention_mask" in data + assert "labels" in data + assert "prompts" in data + assert "prompt_attention_mask" in data # Decode input_ids and labels for verification input_ids = data["input_ids"][0].tolist() @@ -278,22 +276,21 @@ def test_data_collator_for_chatml(self): # Get the last assistant's response for comparison last_message = self.examples[0][self.messages_key][-1] - self.assertEqual(last_message["role"], "assistant", "Last message should be from assistant") + assert last_message["role"] == "assistant", "Last message should be from assistant" last_assistant_response = last_message["content"] # Verify that input_ids contain both prompt and response decoded_input = self.tokenizer.decode(input_ids) - self.assertIn(last_assistant_response, decoded_input, "Input should contain assistant's response") + assert last_assistant_response in decoded_input, "Input should contain assistant's response" # Verify that prompts only contain the conversation up to the last response decoded_prompt = self.tokenizer.decode(prompt_only) - self.assertNotIn(last_assistant_response, decoded_prompt, "Prompt should not contain assistant's response") + assert last_assistant_response not in decoded_prompt, "Prompt should not contain assistant's response" # Verify labels are -100 for non-assistant parts prompt_length = len(prompt_only) - self.assertTrue( - all(label == self.ignore_index for label in labels[:prompt_length]), - "Labels should be ignore_index for prompt tokens", + assert all(label == self.ignore_index for label in labels[:prompt_length]), ( + "Labels should be ignore_index for prompt tokens" ) # Verify labels match assistant response after prompt @@ -310,29 +307,19 @@ def test_data_collator_for_chatml(self): response_labels.append(label) if label == self.tokenizer.convert_tokens_to_ids("<|im_end|>"): break - self.assertEqual( - response_labels, - last_assistant_response_tokens, - "Labels should match assistant response tokens", - ) + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" # Verify there isn't a generation prompt at the end generation_prompt = "<|im_start|>assistant" - self.assertFalse( - decoded_input.strip().endswith(generation_prompt), - f"Input should not end with generation prompt '{generation_prompt}'", + assert not decoded_input.strip().endswith(generation_prompt), ( + f"Input should not end with generation prompt '{generation_prompt}'" ) - self.assertEqual( - response_labels, - last_assistant_response_tokens, - "Labels should match assistant response tokens", - ) + assert response_labels == last_assistant_response_tokens, "Labels should match assistant response tokens" class TestBatchGeneration(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): # Initialize the tokenizer self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.device = "cuda" if torch.cuda.is_available() else "cpu" @@ -367,9 +354,9 @@ def test_mini_batch_generation(self): max_length_query = query_responses.shape[1] max_length_logits = max_length_query - context_length - self.assertGreater(max_length_query, context_length) - self.assertEqual(query_responses.shape, (bs, max_length_query)) - self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) def test_single_batch_generation(self): batch = [ @@ -386,9 +373,9 @@ def test_single_batch_generation(self): max_length_query = query_responses.shape[1] max_length_logits = max_length_query - context_length - self.assertGreater(max_length_query, context_length) - self.assertEqual(query_responses.shape, (bs, max_length_query)) - self.assertEqual(logits.shape, (bs, max_length_logits, self.model.config.vocab_size)) + assert max_length_query > context_length + assert query_responses.shape == (bs, max_length_query) + assert logits.shape == (bs, max_length_logits, self.model.config.vocab_size) class TestComputeAccuracy(TrlTestCase): @@ -404,7 +391,7 @@ def test_token_classification_task(self): ) expected_accuracy = 0.5 # 2 matches, 2 mismatches result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_0(self): eval_pred = ( @@ -418,7 +405,7 @@ def test_token_classification_task_with_ignored_tokens_0(self): ) expected_accuracy = 1.0 # All non-ignored tokens match result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 def test_token_classification_task_with_ignored_tokens_1(self): eval_pred = ( @@ -432,9 +419,9 @@ def test_token_classification_task_with_ignored_tokens_1(self): ) expected_accuracy = 1 / 3 # 1 match, 2 mismatch, 1 ignored result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 - def test_rewards_comparison_task(self): + def test_rewards_comparison_task(self, caplog): eval_pred = ( np.array( [ @@ -447,15 +434,15 @@ def test_rewards_comparison_task(self): ) expected_accuracy = 0.5 # 1 match, 1 mismatch, 1 equal (ignored) - with self.assertLogs("trl.trainer.utils", level="WARNING") as cm: + with caplog.at_level("WARNING", logger="trl.trainer.utils"): result = compute_accuracy(eval_pred) - self.assertAlmostEqual(result["accuracy"], expected_accuracy) + assert round(abs(result["accuracy"] - expected_accuracy), 7) == 0 expected_warning = ( "There are 1 out of 3 instances where the predictions for both options are equal. " "These instances are ignored in the accuracy computation." ) - self.assertIn(expected_warning, cm.output[0]) + assert expected_warning in caplog.text class TestFlushLeft(TrlTestCase): @@ -469,9 +456,9 @@ def test_basic_case(self): expected_tensor1 = torch.tensor([[2, 3, 4], [5, 6, 0]]) expected_tensor2 = torch.tensor([[7, 8, 9], [10, 11, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) - self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) + assert torch.equal(new_tensor2, expected_tensor2) def test_single_row(self): mask = torch.tensor([[0, 0, 1, 1]]) @@ -481,8 +468,8 @@ def test_single_row(self): expected_mask = torch.tensor([[1, 1]]) expected_tensor1 = torch.tensor([[2, 3]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_shift_needed(self): mask = torch.tensor([[1, 1, 0, 0], [1, 0, 0, 0]]) @@ -492,14 +479,14 @@ def test_no_shift_needed(self): expected_mask = torch.tensor([[1, 1], [1, 0]]) expected_tensor1 = torch.tensor([[5, 6], [7, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_tensors(self): mask = torch.tensor([[0, 0, 1, 1, 1], [0, 1, 1, 0, 0]]) new_mask = flush_left(mask) expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) + assert torch.equal(new_mask, expected_mask) class TestFlushRight(TrlTestCase): @@ -513,9 +500,9 @@ def test_basic_case(self): expected_tensor1 = torch.tensor([[2, 3, 4], [0, 5, 6]]) expected_tensor2 = torch.tensor([[7, 8, 9], [0, 10, 11]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) - self.assertTrue(torch.equal(new_tensor2, expected_tensor2)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) + assert torch.equal(new_tensor2, expected_tensor2) def test_single_row(self): mask = torch.tensor([[1, 1, 0, 0]]) @@ -525,8 +512,8 @@ def test_single_row(self): expected_mask = torch.tensor([[1, 1]]) expected_tensor1 = torch.tensor([[2, 3]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_shift_needed(self): mask = torch.tensor([[0, 0, 1, 1], [0, 0, 0, 1]]) @@ -536,17 +523,17 @@ def test_no_shift_needed(self): expected_mask = torch.tensor([[1, 1], [0, 1]]) expected_tensor1 = torch.tensor([[5, 6], [0, 7]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) - self.assertTrue(torch.equal(new_tensor1, expected_tensor1)) + assert torch.equal(new_mask, expected_mask) + assert torch.equal(new_tensor1, expected_tensor1) def test_no_tensors(self): mask = torch.tensor([[1, 1, 1, 0, 0], [0, 0, 1, 1, 0]]) new_mask = flush_right(mask) expected_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) - self.assertTrue(torch.equal(new_mask, expected_mask)) + assert torch.equal(new_mask, expected_mask) -class RepeatRandomSamplerTester(TrlTestCase): +class TestRepeatRandomSampler(TrlTestCase): def test_sampler(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] sampler = RepeatSampler(dataset, mini_repeat_count=2) @@ -564,7 +551,7 @@ def test_sampler_no_shuffle(self): sampler = RepeatSampler(dataset, mini_repeat_count=2, shuffle=False) sampled = list(sampler) expected = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6] - self.assertEqual(sampled, expected) + assert sampled == expected def test_sampler_no_repeat(self): dataset = ["a", "b", "c", "d", "e", "f", "g"] @@ -706,7 +693,7 @@ def test_print_output(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_num_samples(self, mock_stdout): @@ -741,7 +728,7 @@ def test_num_samples(self, mock_stdout): ╰─────────────────────────────────────────────╯ """), ] - self.assertIn(output, possible_outputs) + assert output in possible_outputs @patch("sys.stdout", new_callable=StringIO) def test_print_messages(self, mock_stdout): @@ -790,7 +777,7 @@ def test_print_messages(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output @patch("sys.stdout", new_callable=StringIO) def test_print_messages_with_tools(self, mock_stdout): @@ -829,7 +816,7 @@ def test_print_messages_with_tools(self, mock_stdout): ╰──────────────────────────────────────────────────────────────────────────────╯ """) - self.assertEqual(output, expected_output) + assert output == expected_output class TestSelectiveLogSoftmax(TrlTestCase): @@ -848,12 +835,12 @@ def test_selective_log_softmax(self, dtype): if dtype in [torch.float16, torch.bfloat16]: # half-precision dtypes fall back to an exact method - self.assertTrue(torch.equal(actual_output, expected_output)) + assert torch.equal(actual_output, expected_output) else: torch.testing.assert_close(actual_output, expected_output, rtol=1e-5, atol=1e-5) -class ShuffleSequenceDictTester(TrlTestCase): +class TestShuffleSequenceDict(TrlTestCase): def test_shuffle_preserves_shape(self): x = torch.arange(6).reshape(3, 2) y = torch.arange(3).reshape(3, 1) @@ -861,8 +848,8 @@ def test_shuffle_preserves_shape(self): shuffled = shuffle_sequence_dict(tensor_dict) - self.assertEqual(shuffled["x"].shape, x.shape) - self.assertEqual(shuffled["y"].shape, y.shape) + assert shuffled["x"].shape == x.shape + assert shuffled["y"].shape == y.shape def test_shuffle_consistent_across_tensors(self): # Use known patterns to check alignment @@ -878,13 +865,13 @@ def test_shuffle_consistent_across_tensors(self): y_val = shuffled["y"][i].item() if torch.equal(x_row, torch.tensor([10, 11])): - self.assertEqual(y_val, 1) + assert y_val == 1 elif torch.equal(x_row, torch.tensor([20, 21])): - self.assertEqual(y_val, 2) + assert y_val == 2 elif torch.equal(x_row, torch.tensor([30, 31])): - self.assertEqual(y_val, 3) + assert y_val == 3 else: - self.fail("Unexpected x row in shuffled output.") + pytest.fail("Unexpected x row in shuffled output.") def test_none_tensor_remains_none(self): x = torch.arange(6).reshape(3, 2) @@ -892,8 +879,8 @@ def test_none_tensor_remains_none(self): shuffled = shuffle_sequence_dict(tensor_dict) - self.assertIsNone(shuffled["y"]) - self.assertEqual(shuffled["x"].shape, x.shape) + assert shuffled["y"] is None + assert shuffled["x"].shape == x.shape def test_shuffle_with_list(self): x = torch.tensor([[10, 11], [20, 21], [30, 31]]) @@ -909,16 +896,16 @@ def test_shuffle_with_list(self): y_val = shuffled["y"][i] if torch.equal(x_row, torch.tensor([10, 11])): - self.assertEqual(y_val, "a") + assert y_val == "a" elif torch.equal(x_row, torch.tensor([20, 21])): - self.assertEqual(y_val, "b") + assert y_val == "b" elif torch.equal(x_row, torch.tensor([30, 31])): - self.assertEqual(y_val, "c") + assert y_val == "c" else: - self.fail("Unexpected x row in shuffled output.") + pytest.fail("Unexpected x row in shuffled output.") -class SplitTensorDictTester(TrlTestCase): +class TestSplitTensorDict(TrlTestCase): def test_split_equal_chunks(self): x = torch.arange(12).reshape(6, 2) y = torch.arange(6).reshape(6, 1) @@ -928,10 +915,10 @@ def test_split_equal_chunks(self): expected_x_chunks = torch.chunk(x, 3, dim=0) expected_y_chunks = torch.chunk(y, 3, dim=0) - self.assertEqual(len(result), 3) + assert len(result) == 3 for i in range(3): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertTrue(torch.equal(result[i]["y"], expected_y_chunks[i])) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert torch.equal(result[i]["y"], expected_y_chunks[i]) def test_with_none_tensor(self): x = torch.arange(12).reshape(6, 2) @@ -940,10 +927,10 @@ def test_with_none_tensor(self): result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) - self.assertEqual(len(result), 2) + assert len(result) == 2 for i in range(2): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertIsNone(result[i]["y"]) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert result[i]["y"] is None def test_with_scalar(self): x = torch.arange(12).reshape(6, 2) @@ -952,13 +939,13 @@ def test_with_scalar(self): result = split_tensor_dict(tensor_dict, 2) expected_x_chunks = torch.chunk(x, 2, dim=0) - self.assertEqual(len(result), 2) + assert len(result) == 2 for i in range(2): - self.assertTrue(torch.equal(result[i]["x"], expected_x_chunks[i])) - self.assertTrue(torch.equal(result[i]["y"], torch.tensor(1))) + assert torch.equal(result[i]["x"], expected_x_chunks[i]) + assert torch.equal(result[i]["y"], torch.tensor(1)) -class SplitPixelValuesByGridTester(TrlTestCase): +class TestSplitPixelValuesByGrid(TrlTestCase): def test_split_correctly_0(self): batch = { "image_grid_thw": torch.tensor([[1, 2, 2], [1, 2, 2]]), @@ -966,14 +953,14 @@ def test_split_correctly_0(self): "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2]])) def test_split_correctly_1(self): batch = { @@ -982,19 +969,19 @@ def test_split_correctly_1(self): "pixel_values": torch.arange(12 * 3).reshape(12, 3), # Shape: [12, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:4])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:4]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][4:12]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 2, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 4]])) def test_missing_keys(self): batch = {"pixel_values": torch.tensor([1.0])} result = split_pixel_values_by_grid(batch) - self.assertEqual(result, batch) + assert result == batch def test_mismatched_length(self): batch = { @@ -1002,7 +989,7 @@ def test_mismatched_length(self): "num_images": [1, 1], "pixel_values": torch.randn(3, 5), # Only 3 rows } - with self.assertRaises(ValueError): + with pytest.raises(ValueError): split_pixel_values_by_grid(batch) def test_multi_images(self): @@ -1012,17 +999,17 @@ def test_multi_images(self): "pixel_values": torch.arange(8 * 3).reshape(8, 3), # Shape: [8, 3] } result = split_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], list) - self.assertEqual(len(result["pixel_values"]), 2) - self.assertTrue(torch.equal(result["pixel_values"][0], batch["pixel_values"][:2])) - self.assertTrue(torch.equal(result["pixel_values"][1], batch["pixel_values"][2:])) - self.assertIsInstance(result["image_grid_thw"], list) - self.assertEqual(len(result["image_grid_thw"]), 2) - self.assertTrue(torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]]))) - self.assertTrue(torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]]))) + assert isinstance(result["pixel_values"], list) + assert len(result["pixel_values"]) == 2 + assert torch.equal(result["pixel_values"][0], batch["pixel_values"][:2]) + assert torch.equal(result["pixel_values"][1], batch["pixel_values"][2:]) + assert isinstance(result["image_grid_thw"], list) + assert len(result["image_grid_thw"]) == 2 + assert torch.equal(result["image_grid_thw"][0], torch.tensor([[1, 1, 2]])) + assert torch.equal(result["image_grid_thw"][1], torch.tensor([[1, 2, 2], [1, 2, 1]])) -class TruncateWithProtectedTokensTester(TrlTestCase): +class TestTruncateWithProtectedTokens(TrlTestCase): def test_basic_example(self): """Test the basic example from the problem description.""" prompt_ids = [1, 2, 3, 4, 5] @@ -1032,7 +1019,7 @@ def test_basic_example(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [2, 3, 5] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_no_truncation_needed(self): """Test when target length equals current length.""" @@ -1042,7 +1029,7 @@ def test_no_truncation_needed(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertEqual(new_ids, prompt_ids) + assert new_ids == prompt_ids def test_no_protected_tokens(self): """Test truncation with no protected tokens (normal right truncation).""" @@ -1053,7 +1040,7 @@ def test_no_protected_tokens(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [3, 4, 5] # Last 3 tokens - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_all_tokens_protected(self): """Test when all remaining tokens are protected.""" @@ -1064,7 +1051,7 @@ def test_all_tokens_protected(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) expected_ids = [3, 4, 5] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids def test_too_many_protected_tokens(self): """Test error when too many protected tokens for target length.""" @@ -1072,7 +1059,7 @@ def test_too_many_protected_tokens(self): protected_tokens = [1, 2, 3, 4] target_length = 3 - with self.assertRaises(ValueError): + with pytest.raises(ValueError): truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) def test_single_batch_single_token(self): @@ -1083,7 +1070,7 @@ def test_single_batch_single_token(self): new_ids = truncate_with_protected_tokens(prompt_ids, target_length, protected_tokens) - self.assertEqual(new_ids, prompt_ids) + assert new_ids == prompt_ids def test_order_preservation(self): """Test that relative order is preserved.""" @@ -1097,10 +1084,10 @@ def test_order_preservation(self): # Order should be: 2, 3, 30, 40 (maintaining original relative positions) expected_ids = [2, 3, 30, 40] - self.assertEqual(new_ids, expected_ids) + assert new_ids == expected_ids -class UnsplitPixelValuesByGridTester(TrlTestCase): +class TestUnsplitPixelValuesByGrid(TrlTestCase): def test_unsplit_correctly(self): pixel_values = [torch.randn(4, 5), torch.randn(2, 5)] pixel_values_merged = torch.cat(pixel_values, dim=0) @@ -1108,14 +1095,14 @@ def test_unsplit_correctly(self): image_grid_thw_merged = torch.cat(image_grid_thw, dim=0) batch = {"pixel_values": pixel_values, "image_grid_thw": image_grid_thw, "other_key": torch.tensor([1])} result = unsplit_pixel_values_by_grid(batch) - self.assertIsInstance(result["pixel_values"], torch.Tensor) - self.assertTrue(torch.allclose(result["pixel_values"], pixel_values_merged)) - self.assertIsInstance(result["image_grid_thw"], torch.Tensor) - self.assertTrue(torch.equal(result["image_grid_thw"], image_grid_thw_merged)) - self.assertIn("other_key", result) + assert isinstance(result["pixel_values"], torch.Tensor) + assert torch.allclose(result["pixel_values"], pixel_values_merged) + assert isinstance(result["image_grid_thw"], torch.Tensor) + assert torch.equal(result["image_grid_thw"], image_grid_thw_merged) + assert "other_key" in result def test_no_op_if_not_list(self): original = torch.randn(5, 3) batch = {"pixel_values": original} result = unsplit_pixel_values_by_grid(batch) - self.assertTrue(torch.equal(result["pixel_values"], original)) + assert torch.equal(result["pixel_values"], original) diff --git a/tests/test_vllm_client_server.py b/tests/test_vllm_client_server.py index 6171845c1de..ad93022aff4 100644 --- a/tests/test_vllm_client_server.py +++ b/tests/test_vllm_client_server.py @@ -27,28 +27,28 @@ class TestChunkList(TrlTestCase): def test_even_split(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 2), [[1, 2, 3], [4, 5, 6]]) + assert chunk_list([1, 2, 3, 4, 5, 6], 2) == [[1, 2, 3], [4, 5, 6]] def test_uneven_split(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 4), [[1, 2], [3, 4], [5], [6]]) + assert chunk_list([1, 2, 3, 4, 5, 6], 4) == [[1, 2], [3, 4], [5], [6]] def test_more_chunks_than_elements(self): - self.assertEqual(chunk_list([1, 2, 3, 4, 5, 6], 8), [[1], [2], [3], [4], [5], [6], [], []]) + assert chunk_list([1, 2, 3, 4, 5, 6], 8) == [[1], [2], [3], [4], [5], [6], [], []] def test_n_equals_len(self): - self.assertEqual(chunk_list([1, 2, 3], 3), [[1], [2], [3]]) + assert chunk_list([1, 2, 3], 3) == [[1], [2], [3]] def test_n_is_1(self): - self.assertEqual(chunk_list([1, 2, 3], 1), [[1, 2, 3]]) + assert chunk_list([1, 2, 3], 1) == [[1, 2, 3]] def test_single_element_list(self): - self.assertEqual(chunk_list([42], 2), [[42], []]) + assert chunk_list([42], 2) == [[42], []] def test_any_dtype(self): - self.assertEqual( - chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2), - [[1, "two", 3.0], [{"four": 4}, ["f", "i", "v", "e"]]], - ) + assert chunk_list([1, "two", 3.0, {"four": 4}, ["f", "i", "v", "e"]], 2) == [ + [1, "two", 3.0], + [{"four": 4}, ["f", "i", "v", "e"]], + ] @pytest.mark.slow @@ -57,7 +57,7 @@ class TestVLLMClientServer(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -79,18 +79,18 @@ def test_generate(self): completion_ids = outputs["completion_ids"] # Check that the outputs are lists - self.assertIsInstance(prompt_ids, list) - self.assertIsInstance(completion_ids, list) + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) # Check that the number of sequences are equal to the number of prompts - self.assertEqual(len(prompt_ids), len(prompts)) - self.assertEqual(len(completion_ids), len(prompts)) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) # Check that the sequences are lists of integers for seq in prompt_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] @@ -99,18 +99,18 @@ def test_generate_with_params(self): ] # Check that the output is a list - self.assertIsInstance(completion_ids, list) + assert isinstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(completion_ids), 2 * len(prompts)) + assert len(completion_ids) == 2 * len(prompts) # Check that the generated sequences are lists of integers for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 for seq in completion_ids: - self.assertLessEqual(len(seq), 32) + assert len(seq) <= 32 def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -121,9 +121,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -139,7 +137,7 @@ class TestVLLMClientServerBaseURL(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -161,18 +159,18 @@ def test_generate(self): completion_ids = outputs["completion_ids"] # Check that the outputs are lists - self.assertIsInstance(prompt_ids, list) - self.assertIsInstance(completion_ids, list) + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) # Check that the number of sequences are equal to the number of prompts - self.assertEqual(len(prompt_ids), len(prompts)) - self.assertEqual(len(completion_ids), len(prompts)) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) # Check that the sequences are lists of integers for seq in prompt_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_generate_with_params(self): prompts = ["Hello, AI!", "Tell me a joke"] @@ -181,18 +179,18 @@ def test_generate_with_params(self): ] # Check that the output is a list - self.assertIsInstance(completion_ids, list) + assert isinstance(completion_ids, list) # Check that the number of generated sequences is 2 times the number of prompts - self.assertEqual(len(completion_ids), 2 * len(prompts)) + assert len(completion_ids) == 2 * len(prompts) # Check that the generated sequences are lists of integers for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) # Check that the length of the generated sequences is less than or equal to 32 for seq in completion_ids: - self.assertLessEqual(len(seq), 32) + assert len(seq) <= 32 def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -203,9 +201,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -220,7 +216,7 @@ class TestVLLMClientServerTP(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -245,18 +241,18 @@ def test_generate(self): completion_ids = outputs["completion_ids"] # Check that the outputs are lists - self.assertIsInstance(prompt_ids, list) - self.assertIsInstance(completion_ids, list) + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) # Check that the number of sequences are equal to the number of prompts - self.assertEqual(len(prompt_ids), len(prompts)) - self.assertEqual(len(completion_ids), len(prompts)) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) # Check that the sequences are lists of integers for seq in prompt_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -267,9 +263,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -284,7 +278,7 @@ class TestVLLMClientServerDP(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1 and 2, so we set VISIBLE_DEVICES to "1,2" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -309,18 +303,18 @@ def test_generate(self): completion_ids = outputs["completion_ids"] # Check that the outputs are lists - self.assertIsInstance(prompt_ids, list) - self.assertIsInstance(completion_ids, list) + assert isinstance(prompt_ids, list) + assert isinstance(completion_ids, list) # Check that the number of sequences are equal to the number of prompts - self.assertEqual(len(prompt_ids), len(prompts)) - self.assertEqual(len(completion_ids), len(prompts)) + assert len(prompt_ids) == len(prompts) + assert len(completion_ids) == len(prompts) # Check that the sequences are lists of integers for seq in prompt_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) for seq in completion_ids: - self.assertTrue(all(isinstance(tok, int) for tok in seq)) + assert all(isinstance(tok, int) for tok in seq) def test_update_model_params(self): model = AutoModelForCausalLM.from_pretrained(self.model_id, device_map=torch_device) @@ -331,9 +325,7 @@ def test_reset_prefix_cache(self): self.client.reset_prefix_cache() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # Close the client cls.client.close_communicator() @@ -350,7 +342,7 @@ class TestVLLMClientServerDeviceParameter(TrlTestCase): model_id = "Qwen/Qwen2.5-1.5B" @classmethod - def setUpClass(cls): + def setup_class(cls): # We want the server to run on accelerator 1, so we set VISIBLE_DEVICES to "1" env = os.environ.copy() VISIBLE_DEVICES = "ZE_AFFINITY_MASK" if torch_device == "xpu" else "CUDA_VISIBLE_DEVICES" @@ -371,10 +363,10 @@ def test_init_communicator_with_device_int(self): outputs = client.generate(prompts) prompt_ids = outputs["prompt_ids"] completion_ids = outputs["completion_ids"] - self.assertIsInstance(prompt_ids, list) - self.assertEqual(len(prompt_ids), len(prompts)) - self.assertIsInstance(completion_ids, list) - self.assertEqual(len(completion_ids), len(prompts)) + assert isinstance(prompt_ids, list) + assert len(prompt_ids) == len(prompts) + assert isinstance(completion_ids, list) + assert len(completion_ids) == len(prompts) client.close_communicator() @@ -386,8 +378,8 @@ def test_init_communicator_with_device_string(self): # Test basic functionality prompts = ["Hello, AI!"] outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) client.close_communicator() @@ -402,15 +394,13 @@ def test_init_communicator_with_torch_device(self): # Test basic functionality prompts = ["Hello, AI!"] outputs = client.generate(prompts)["completion_ids"] - self.assertIsInstance(outputs, list) - self.assertEqual(len(outputs), len(prompts)) + assert isinstance(outputs, list) + assert len(outputs) == len(prompts) client.close_communicator() @classmethod - def tearDownClass(cls): - super().tearDownClass() - + def teardown_class(cls): # vLLM x pytest (or Popen) seems not to handle process termination well. To avoid zombie processes, we need to # kill the server process and its children explicitly. kill_process(cls.server_process) diff --git a/tests/test_xpo_trainer.py b/tests/test_xpo_trainer.py index 9af803830cf..4d41471187c 100644 --- a/tests/test_xpo_trainer.py +++ b/tests/test_xpo_trainer.py @@ -16,12 +16,11 @@ from datasets import load_dataset from parameterized import parameterized from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer -from transformers.testing_utils import require_peft from transformers.utils import is_peft_available from trl import XPOConfig, XPOTrainer -from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender +from .testing_utils import RandomPairwiseJudge, TrlTestCase, require_llm_blender, require_peft if is_peft_available(): @@ -29,8 +28,7 @@ class TestXPOTrainer(TrlTestCase): - def setUp(self): - super().setUp() + def setup_method(self): self.model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" self.model = AutoModelForCausalLM.from_pretrained(self.model_id) self.ref_model = AutoModelForCausalLM.from_pretrained(self.model_id) @@ -65,7 +63,7 @@ def test_xpo_trainer_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft(self): @@ -93,7 +91,7 @@ def test_training_with_peft(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_and_ref_model(self): @@ -122,7 +120,7 @@ def test_training_with_peft_and_ref_model(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_with_peft_model_and_peft_config(self): @@ -153,7 +151,7 @@ def test_training_with_peft_model_and_peft_config(self): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @require_peft def test_training_pre_pefted_model_implicit_ref(self): @@ -182,7 +180,7 @@ def test_training_pre_pefted_model_implicit_ref(self): trainer.train() - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] @parameterized.expand([("standard_prompt_only",), ("conversational_prompt_only",)]) @require_llm_blender @@ -213,4 +211,4 @@ def test_xpo_trainer_judge_training(self, config_name): trainer.train() # Check if training loss is available - self.assertIn("train_loss", trainer.state.log_history[-1]) + assert "train_loss" in trainer.state.log_history[-1] diff --git a/tests/testing_utils.py b/tests/testing_utils.py index 85026a53947..cbe677255b5 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -14,88 +14,36 @@ import functools import random -import shutil import signal -import tempfile -import unittest import warnings import psutil +import pytest import torch from transformers import is_bitsandbytes_available, is_comet_available, is_sklearn_available, is_wandb_available from transformers.testing_utils import torch_device -from transformers.utils import is_rich_available +from transformers.utils import is_peft_available, is_rich_available, is_vision_available from trl import BaseBinaryJudge, BasePairwiseJudge from trl.import_utils import is_joblib_available, is_llm_blender_available, is_mergekit_available, is_vllm_available -# transformers.testing_utils contains a require_bitsandbytes function, but relies on pytest markers which we don't use -# in our test suite. We therefore need to implement our own version of this function. -def require_bitsandbytes(test_case): - """ - Decorator marking a test that requires bitsandbytes. Skips the test if bitsandbytes is not available. - """ - return unittest.skipUnless(is_bitsandbytes_available(), "test requires bitsandbytes")(test_case) - - -def require_comet(test_case): - """ - Decorator marking a test that requires Comet. Skips the test if Comet is not available. - """ - return unittest.skipUnless(is_comet_available(), "test requires comet_ml")(test_case) - - -def require_llm_blender(test_case): - """ - Decorator marking a test that requires llm-blender. Skips the test if llm-blender is not available. - """ - return unittest.skipUnless(is_llm_blender_available(), "test requires llm-blender")(test_case) - - -def require_mergekit(test_case): - """ - Decorator marking a test that requires mergekit. Skips the test if mergekit is not available. - """ - return unittest.skipUnless(is_mergekit_available(), "test requires mergekit")(test_case) - - -def require_rich(test_case): - """ - Decorator marking a test that requires rich. Skips the test if rich is not available. - """ - return unittest.skipUnless(is_rich_available(), "test requires rich")(test_case) - - -def require_sklearn(test_case): - """ - Decorator marking a test that requires sklearn. Skips the test if sklearn is not available. - """ - return unittest.skipUnless(is_sklearn_available() and is_joblib_available(), "test requires sklearn")(test_case) - - -def require_vllm(test_case): - """ - Decorator marking a test that requires vllm. Skips the test if vllm is not available. - """ - return unittest.skipUnless(is_vllm_available(), "test requires vllm")(test_case) - - -def require_no_wandb(test_case): - """ - Decorator marking a test that requires no wandb. Skips the test if wandb is available. - """ - return unittest.skipUnless(not is_wandb_available(), "test requires no wandb")(test_case) - - -def require_3_accelerators(test_case): - """ - Decorator marking a test that requires at least 3 accelerators. Skips the test if 3 accelerators are not available. - """ - torch_accelerator_module = getattr(torch, torch_device, torch.cuda) - return unittest.skipUnless( - torch_accelerator_module.device_count() >= 3, f"test requires at least 3 {torch_device}s" - )(test_case) +require_bitsandbytes = pytest.mark.skipif(not is_bitsandbytes_available(), reason="test requires bitsandbytes") +require_comet = pytest.mark.skipif(not is_comet_available(), reason="test requires comet_ml") +require_llm_blender = pytest.mark.skipif(not is_llm_blender_available(), reason="test requires llm-blender") +require_mergekit = pytest.mark.skipif(not is_mergekit_available(), reason="test requires mergekit") +require_peft = pytest.mark.skipif(not is_peft_available(), reason="test requires peft") +require_rich = pytest.mark.skipif(not is_rich_available(), reason="test requires rich") +require_sklearn = pytest.mark.skipif( + not (is_sklearn_available() and is_joblib_available()), reason="test requires sklearn" +) +require_vision = pytest.mark.skipif(not is_vision_available(), reason="test requires vision") +require_vllm = pytest.mark.skipif(not is_vllm_available(), reason="test requires vllm") +require_no_wandb = pytest.mark.skipif(is_wandb_available(), reason="test requires no wandb") +require_3_accelerators = pytest.mark.skipif( + not (getattr(torch, torch_device, torch.cuda).device_count() >= 3), + reason=f"test requires at least 3 {torch_device}s", +) class RandomBinaryJudge(BaseBinaryJudge): @@ -119,18 +67,10 @@ def judge(self, prompts, completions, shuffle_order=True, return_scores=False): return [random.random() for _ in range(len(prompts))] -class TrlTestCase(unittest.TestCase): - """ - Base test case for TRL tests. Sets up a temporary directory for testing. - """ - - def setUp(self): - super().setUp() - self.tmp_dir = tempfile.mkdtemp() - - def tearDown(self): - shutil.rmtree(self.tmp_dir) - super().tearDown() +class TrlTestCase: + @pytest.fixture(autouse=True) + def set_tmp_dir(self, tmp_path): + self.tmp_dir = str(tmp_path) def ignore_warnings(message: str = None, category: type[Warning] = Warning) -> callable: From 4fdaa4c67290f2a46b995baaa2f33d640492f8b5 Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 6 Oct 2025 15:57:17 +0200 Subject: [PATCH 079/153] Updated vLLM integration guide (#4162) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com> --- docs/source/vllm_integration.md | 392 ++++++++++++++++++++++++++++---- 1 file changed, 348 insertions(+), 44 deletions(-) diff --git a/docs/source/vllm_integration.md b/docs/source/vllm_integration.md index 9d3f6beee11..9cfea500a0a 100644 --- a/docs/source/vllm_integration.md +++ b/docs/source/vllm_integration.md @@ -1,16 +1,27 @@ # vLLM Integration -This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. Let's go! 🔥 +This document will guide you through the process of using vLLM with TRL for faster generation in online methods like GRPO and Online DPO. We first summarize a tl;dr on how to use vLLM with TRL, and then we will go into the details of how it works under the hood. > [!WARNING] > TRL currently only supports vLLM version `0.10.2`. Please ensure you have this version installed to avoid compatibility issues. +> [!TIP] +> The following trainers currently support generation with vLLM: +> +> - [`GRPOTrainer`] +> - [`OnlineDPOTrainer`] +> - [`NashMDTrainer`] +> - [`XPOTrainer`] +> - [`RLOOTrainer`] + + ## 🚀 How can I use vLLM with TRL to speed up training? 💡 **Note**: Resources required for this specific example: a single node with 8 GPUs. > [!WARNING] -> vLLM server and TRL trainer must use different CUDA devices to avoid conflicts. +> When using vLLM with TRL, the **vLLM server** and the **trainer** must run on **separate CUDA devices** to prevent conflicts. +> For guidance on configuring this properly, see [Modes of using vLLM during training](#modes-of-using-vllm-during-training). First, install vLLM using the following command: @@ -24,12 +35,15 @@ Then run the server on specific GPUs (e.g., GPUs 0-3): CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model Qwen/Qwen2.5-7B --tensor-parallel-size 2 --data-parallel-size 2 ``` -Once the server is running, you can use it to generate completions for training. In the example below, we are using the `GRPOTrainer` to train a model using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs. +Once the server is running, you can use it to generate completions for training. In the example below, we are using the different supported trainers using the vLLM server for generation. The `--tensor-parallel-size` and `--data-parallel-size` arguments control how the model and data are sharded across GPUs. In this example, we are sharding two copies of the model across 4 GPUs. Increasing data parallelism increases throughput, while increasing tensor parallelism allows for serving larger models. Then, run the training script on different GPUs (e.g., GPUs 4-7) by passing `use_vllm=True` in the training arguments as follows: Sample of a simple `train.py` script: + + + ```python from datasets import load_dataset from trl import GRPOTrainer, GRPOConfig @@ -57,21 +71,148 @@ trainer = GRPOTrainer( trainer.train() ``` + + + +```python +from datasets import load_dataset +from trl import OnlineDPOTrainer, OnlineDPOConfig + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +training_args = OnlineDPOConfig( + output_dir="my_test", + use_vllm=True, + bf16=True, + gradient_checkpointing=True, +) + +trainer = OnlineDPOTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl import NashMDTrainer, NashMDConfig + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +training_args = NashMDConfig( + output_dir="my_test", + use_vllm=True, + bf16=True, + gradient_checkpointing=True, +) + +trainer = NashMDTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl import XPOTrainer, XPOConfig + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +training_args = XPOConfig( + output_dir="my_test", + use_vllm=True, + bf16=True, + gradient_checkpointing=True, +) + +trainer = XPOTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) + +trainer.train() +``` + + + + +```python +from datasets import load_dataset +from trl import RLOOTrainer, RLOOConfig + +dataset = load_dataset("trl-lib/tldr", split="train") + +# Dummy reward function: count the number of unique characters in the completions +def reward_num_unique_chars(completions, **kwargs): + return [len(set(c)) for c in completions] + +training_args = RLOOConfig( + output_dir="my_test", + use_vllm=True, + bf16=True, + gradient_checkpointing=True, +) + +trainer = RLOOTrainer( + model="Qwen/Qwen2.5-7B", + args=training_args, + reward_funcs=reward_num_unique_chars, + train_dataset=dataset, +) + +trainer.train() +``` + + + + And the train command on separate GPUs from the server: ```sh CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py ``` -## 🎬 Flashback: Why do we need to use vLLM in online methods? +## Why using vLLM? + +### 🎬 Flashback: Why do we need to use vLLM in online methods? Online methods like GRPO or Online DPO require the model to generate completions during training, which are then used to compute reward signals. However, generation can be extremely time-consuming, especially with large or reasoning models. In the default setup (without vLLM), completions are generated using the [(unwrapped) model's `generate` method](https://github.com/huggingface/trl/blob/f3e8c2304428ef16e9ae5de9e5741ed84d533b7b/trl/trainer/grpo_trainer.py#L965C39-L965C66). This approach quickly becomes a major bottleneck — generation is slow and inefficient, particularly for large batches or models. As a result, training times increase significantly, and overall efficiency drops. To address this, we turn to vLLM, which enables much faster and more scalable generation, helping eliminate this bottleneck in online methods. -## 🤔 How does vLLM solve the slow generation issue? +### 🤔 How does vLLM solve the slow generation issue? If you've ever done autoregressive decoder training, you know all the input tokens to the LLM produce their attention key and value tensors, and these tensors are kept in GPU memory to later generate subsequent tokens based on them. These cached key and value tensors are often referred to as the KV cache. However, storing the KV cache occupies a lot of memory, so vLLM uses a technique called **PagedAttention** to solve this problem. PagedAttention, which is inspired by the OS’s virtual memory concept, stores continuous keys and values in **non-contiguous memory space**, which is much more efficient. The details of this are beyond the scope of this document, but in short, it allows the model to store the keys and values in a more efficient way, reducing the memory footprint and speeding up the generation process. If you are interested, make sure to check out the [vLLM PagedAttention](https://blog.vllm.ai/2023/06/20/vllm.html) for more details. -## 🤔 What exactly happens when you run `trl vllm-serve --model `? +## How vLLM Works (Under the Hood) 🔍 + +### 🤔 What exactly happens when you run `trl vllm-serve --model `? When you run for example @@ -92,7 +233,7 @@ Each worker operates independently and processes a chunk of the incoming request This GPU-to-GPU communication is managed efficiently by NVIDIA’s NCCL library. The communication mainly ensures that each GPU gets its correct portion of the incoming requests — it’s lightweight and doesn’t interfere with generation itself. Separately, the number of completions to generate per prompt is controlled by the `num_generations` setting in the GRPO config. For instance, if you set `num_generations=2` (like in the picture above), each prompt will have 2 completions. So, with 8 prompts and `num_generations=2`, you would end up with 16 completions total — regardless of the number of GPUs or parallelism settings. -## 🥸 More detail on what happens under the hood when running the server +### 🥸 More detail on what happens under the hood when running the server * The vLLM server starts by running the command: `trl vllm-serve --model Qwen/Qwen2.5-7B`. * Once the server is running, it generates completions based on requests from the client (trainer) using `vllm_client.generate` [here](https://github.com/huggingface/trl/blob/cc044e35b285be7dc062764b3364e1e684db4c7c/trl/trainer/grpo_trainer.py#L1025-L1035). @@ -114,19 +255,21 @@ For example, if you want to use GPUs 4–7 for training while the server runs on CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py ``` -## 🍷 More customization options with vLLM? +## Advanced usage + +### 🍷 More customization options with vLLM? You can customize the server configuration by passing additional arguments. -``` +```txt $ trl vllm-serve --help -usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] - [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] [--port PORT] - [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN] - [--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager ENFORCE_EAGER] [--log_level LOG_LEVEL] +usage: trl vllm-serve [-h] --model MODEL [--revision REVISION] [--tensor_parallel_size TENSOR_PARALLEL_SIZE] [--data_parallel_size DATA_PARALLEL_SIZE] [--host HOST] + [--port PORT] [--gpu_memory_utilization GPU_MEMORY_UTILIZATION] [--dtype DTYPE] [--max_model_len MAX_MODEL_LEN] + [--enable_prefix_caching ENABLE_PREFIX_CACHING] [--enforce_eager [ENFORCE_EAGER]] [--kv_cache_dtype KV_CACHE_DTYPE] + [--trust_remote_code [TRUST_REMOTE_CODE]] [--log_level LOG_LEVEL] [--vllm_model_impl VLLM_MODEL_IMPL] options: - -h, --help Show this help message and exit + -h, --help show this help message and exit --model MODEL Model name or path to load the model from. (default: None) --revision REVISION Revision to use for the model. If not specified, the default branch will be used. (default: None) --tensor_parallel_size TENSOR_PARALLEL_SIZE, --tensor-parallel-size TENSOR_PARALLEL_SIZE @@ -136,39 +279,33 @@ options: --host HOST Host address to run the server on. (default: 0.0.0.0) --port PORT Port to run the server on. (default: 8000) --gpu_memory_utilization GPU_MEMORY_UTILIZATION, --gpu-memory-utilization GPU_MEMORY_UTILIZATION - Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device - dedicated to generation powered by vLLM. Higher values will increase the KV cache size and thus improve the - model's throughput. However, if the value is too high, it may cause out-of-memory (OOM) errors during - initialization. (default: 0.9) - --dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on - the model configuration. Find the supported values in the vLLM documentation. (default: auto) + Ratio (between 0 and 1) of GPU memory to reserve for the model weights, activations, and KV cache on the device dedicated to generation + powered by vLLM. Higher values will increase the KV cache size and thus improve the model's throughput. However, if the value is too high, + it may cause out-of-memory (OOM) errors during initialization. (default: 0.9) + --dtype DTYPE Data type to use for vLLM generation. If set to 'auto', the data type will be automatically determined based on the model configuration. + Find the supported values in the vLLM documentation. (default: auto) --max_model_len MAX_MODEL_LEN, --max-model-len MAX_MODEL_LEN - If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced - `vllm_gpu_memory_utilization`, leading to a reduced KV cache size. If not set, vLLM will use the model context - size, which might be much larger than the KV cache, leading to inefficiencies. (default: None) + If set, the `max_model_len` to use for vLLM. This can be useful when running with reduced `vllm_gpu_memory_utilization`, leading to a + reduced KV cache size. If not set, vLLM will use the model context size, which might be much larger than the KV cache, leading to + inefficiencies. (default: None) --enable_prefix_caching ENABLE_PREFIX_CACHING, --enable-prefix-caching ENABLE_PREFIX_CACHING - Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this - feature. (default: None) - --enforce_eager ENFORCE_EAGER, --enforce-eager ENFORCE_EAGER - Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model - in eager mode. If `False` (default behavior), we will use CUDA graph and eager execution in hybrid. (default: - None) + Whether to enable prefix caching in vLLM. If set to `True`, ensure that the model and the hardware support this feature. (default: None) + --enforce_eager [ENFORCE_EAGER], --enforce-eager [ENFORCE_EAGER] + Whether to enforce eager execution. If set to `True`, we will disable CUDA graph and always execute the model in eager mode. If `False` + (default behavior), we will use CUDA graph and eager execution in hybrid. (default: False) + --kv_cache_dtype KV_CACHE_DTYPE, --kv-cache-dtype KV_CACHE_DTYPE + Data type to use for KV cache. If set to 'auto', the dtype will default to the model data type. (default: auto) + --trust_remote_code [TRUST_REMOTE_CODE], --trust-remote-code [TRUST_REMOTE_CODE] + Whether to trust remote code when loading models. Set to True to allow executing code from model repositories. This is required for some + custom models but introduces security risks. (default: False) --log_level LOG_LEVEL, --log-level LOG_LEVEL - Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: - info) + Log level for uvicorn. Possible choices: 'critical', 'error', 'warning', 'info', 'debug', 'trace'. (default: info) + --vllm_model_impl VLLM_MODEL_IMPL, --vllm-model-impl VLLM_MODEL_IMPL + Model implementation to use for vLLM. Must be one of `transformers` or `vllm`. `transformers`: Use the `transformers` backend for model + implementation. `vllm`: Use the `vllm` library for model implementation. (default: vllm) ``` -## 🥳 Okay, now that we have the server running, how can we use it to generate completions? - -Run the training script and pass `use_vllm=True` in the training arguments: - -```python -from trl import GRPOConfig - -training_args = GRPOConfig(..., use_vllm=True) -``` - -## 💆🏻‍♀️ What's the best distributed setup? +### 💆🏻‍♀️ What's the best distributed setup? ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_8_gpus.png) ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/tp_dp_throughput_4_gpus.png) @@ -188,11 +325,178 @@ Given these factors, our experiments on the Qwen model family (3B, 7B, 14B, 32B) * For reasonable-sized models (3B–14B) and a moderate context window (`max_len < 8k`), using full capacity for data parallelism gives better throughput. The setup `(tp=1, dp=8)` yields the best results. * For larger models (32B) and longer context windows (`max_len > 8k`), a smaller DP size combined with some model-side parallelism performs better. For example, `(tp=2, dp=4)` is a good setup for 32B models with a larger context window. -## vLLM with Transformers Backend +### vLLM with Transformers Backend + +vLLM can use the **Transformers backend** for model implementations, which works for both LLMs and VLMs. +To enable this, set `vllm_model_impl="transformers"` in your configuration or pass it via the command-line argument. + +For more details, check out [vLLM Transformers Backend](https://blog.vllm.ai/2025/04/11/transformers-backend.html). -vLLM now supports transformers backend for model implementations. Simply passing in `transformers` in `vllm_model_impl` in configurations or through argument parser will set use transformers backend. This works for both LLMs and VLMs. See an example below, you can get more information [here](https://blog.vllm.ai/2025/04/11/transformers-backend.html). +Example: ``` CUDA_DEVICE_ORDER=PCI_BUS_ID CUDA_VISIBLE_DEVICES=0 trl vllm-serve --model Qwen/Qwen 2.5-VL-3B-Instruct --tensor-parallel-size 1 --port 8000 --enforce_eager --vllm_model_impl transformers ``` + +### Modes of Using vLLM During Training + +TRL supports **two modes** for integrating vLLM during training: **server mode** and **colocate mode**. + +#### Server Mode + +In **server mode**, vLLM runs as a separate process on dedicated GPUs and communicates with the trainer via HTTP. +This setup is ideal if you have GPUs dedicated to inference. + +Example configuration: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl import NashMDConfig + +training_args = NashMDConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl import XPOConfig + +training_args = XPOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="server", # default value, can be omitted +) +``` + + + + +#### Colocate Mode + +In **colocate mode**, vLLM runs inside the trainer process and shares GPU memory with the training model. +This avoids launching a separate server and can improve GPU utilization, but may lead to memory contention on the training GPUs. + +Example configuration: + + + + +```python +from trl import GRPOConfig + +training_args = GRPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl import OnlineDPOConfig + +training_args = OnlineDPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl import NashMDConfig + +training_args = NashMDConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl import XPOConfig + +training_args = XPOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +```python +from trl import RLOOConfig + +training_args = RLOOConfig( + ..., + use_vllm=True, + vllm_mode="colocate", +) +``` + + + + +> [!WARNING] +> Check the documentation of the trainer you are using for specific details on vLLM usage and parameters. + + +> [!WARNING] +> To reduce GPU memory usage when running vLLM, consider [enabling vLLM sleep mode](reducing_memory_usage#vllm-sleep-mode). + From d258e36e4557e61bbd82ca5c06948592009d5dca Mon Sep 17 00:00:00 2001 From: Sergio Paniego Blanco Date: Mon, 6 Oct 2025 16:04:06 +0200 Subject: [PATCH 080/153] Remove `Optional` from `processing_class` in `PPOTrainer` (#4212) --- trl/trainer/ppo_trainer.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trl/trainer/ppo_trainer.py b/trl/trainer/ppo_trainer.py index 70912fb336c..5e941a181e8 100644 --- a/trl/trainer/ppo_trainer.py +++ b/trl/trainer/ppo_trainer.py @@ -146,9 +146,7 @@ class PPOTrainer(BaseTrainer): def __init__( self, args: PPOConfig, - processing_class: Optional[ - Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin] - ], + processing_class: Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin], model: nn.Module, ref_model: Optional[nn.Module], reward_model: nn.Module, From 7f5b4995b659d02f0c52dc41691a4271b5b49914 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 6 Oct 2025 17:45:44 +0200 Subject: [PATCH 081/153] Replace setup with pyproject and fix packaging unintended modules (#4194) --- MANIFEST.in | 5 +- pyproject.toml | 126 +++++++++++++++++++++++++++++++++++++++++++++++++ setup.cfg | 93 ------------------------------------ setup.py | 18 ------- 4 files changed, 129 insertions(+), 113 deletions(-) delete mode 100644 setup.cfg delete mode 100644 setup.py diff --git a/MANIFEST.in b/MANIFEST.in index 8855af1a5ae..595e7a15053 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,6 +1,7 @@ include LICENSE include CONTRIBUTING.md include README.md -recursive-exclude * __pycache__ +include trl/accelerate_configs/*.yaml include trl/templates/*.md -include trl/accelerate_configs/*.yaml \ No newline at end of file +recursive-exclude * __pycache__ +prune tests diff --git a/pyproject.toml b/pyproject.toml index 8674324a4a8..a0b6c1072ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,129 @@ +[build-system] +requires = ["setuptools >= 77.0.3"] +build-backend = "setuptools.build_meta" + +[project] +name = "trl" +description = "Train transformer language models with reinforcement learning." +authors = [ + { name = "Leandro von Werra", email = "leandro.vonwerra@gmail.com" } +] +readme = { file = "README.md", content-type = "text/markdown" } +license = "Apache-2.0" +license-files = ["LICENSE"] +keywords = [ + "transformers", "huggingface", "language modeling", "post-training", "rlhf", "sft", "dpo", "grpo" +] +classifiers = [ + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "Natural Language :: English", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13" +] +requires-python = ">=3.9" +dependencies = [ + "accelerate>=1.4.0", + "datasets>=3.0.0", + "transformers>=4.56.1", + "transformers!=4.57.0; python_version == '3.9'" +] +dynamic = ["version"] + +[project.urls] +Homepage = "https://github.com/huggingface/trl" + +[project.scripts] +trl = "trl.cli:main" + +[project.optional-dependencies] +bco = [ + "scikit-learn", + "joblib" +] +deepspeed = [ + "deepspeed>=0.14.4" +] +judges = [ + "openai>=1.23.2", + "llm-blender>=0.0.2" +] +liger = [ + "liger-kernel>=0.6.2" +] +peft = [ + "peft>=0.8.0" +] +quality = [ + "pre-commit", + "hf-doc-builder" +] +quantization = [ + "bitsandbytes" +] +scikit = [ + "scikit-learn" +] +test = [ + "parameterized", + "pytest-cov", + "pytest-rerunfailures==15.1", + "pytest-xdist", + "pytest" +] +vllm = [ + "vllm==0.10.2", + "fastapi", + "pydantic", + "requests", + "uvicorn" +] +vlm = [ + "Pillow", + "torchvision", + "num2words==0.5.14" +] +dev = [ + "scikit-learn", + "joblib", + "deepspeed>=0.14.4", + "openai>=1.23.2", + "llm-blender>=0.0.2", + "liger-kernel>=0.6.2", + "peft>=0.8.0", + "pre-commit", + "hf-doc-builder", + "bitsandbytes", + "parameterized", + "pytest-cov", + "pytest-rerunfailures==15.1", + "pytest-xdist", + "pytest", + "vllm==0.10.2", + "fastapi", + "pydantic", + "requests", + "uvicorn", + "Pillow", + "torchvision", + "num2words==0.5.14" +] + +[tool.setuptools] +package-dir = {"trl" = "trl"} + +[tool.setuptools.dynamic] +version = { file = "VERSION" } + +[tool.coverage.run] +branch = true + [tool.ruff] target-version = "py39" line-length = 119 diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index a0dabfe0908..00000000000 --- a/setup.cfg +++ /dev/null @@ -1,93 +0,0 @@ -[metadata] -name = trl -version = file: VERSION -description = Train transformer language models with reinforcement learning. -long_description = file: README.md -long_description_content_type = text/markdown -author = Leandro von Werra -author_email = leandro.vonwerra@gmail.com -url = https://github.com/huggingface/trl -keywords = transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo -license_file = LICENSE -classifiers = - Development Status :: 2 - Pre-Alpha - Intended Audience :: Developers - Intended Audience :: Science/Research - Natural Language :: English - Operating System :: OS Independent - Programming Language :: Python :: 3 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: Python :: 3.12 - Programming Language :: Python :: 3.13 - -[options] -packages = find_namespace: -python_requires = >=3.9 -include_package_data = True -install_requires = - accelerate>=1.4.0 - datasets>=3.0.0 - transformers>=4.56.1 - transformers!=4.57.0; python_version == "3.9" - -[options.packages.find] -exclude = - tests* - -[options.extras_require] -bco = - scikit-learn - joblib -deepspeed = - deepspeed>=0.14.4 -judges = - openai>=1.23.2 - llm-blender>=0.0.2 -liger = - liger-kernel>=0.6.2 -peft = - peft>=0.8.0 -quality = - pre-commit - hf-doc-builder -quantization = - bitsandbytes -scikit = - scikit-learn -test = - parameterized - pytest-cov - pytest-rerunfailures==15.1 - pytest-xdist - pytest -vllm = - vllm==0.10.2 - fastapi - pydantic - requests - uvicorn - -vlm = - Pillow - torchvision - num2words==0.5.14 -dev = - %(bco)s - %(deepspeed)s - %(judges)s - %(liger)s - %(peft)s - %(quality)s - %(quantization)s - %(scikit)s - %(test)s - %(vlm)s - -[options.entry_points] -console_scripts = - trl = trl.cli:main - -[coverage:run] -branch = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 26f52a2806c..00000000000 --- a/setup.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright 2020-2025 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from setuptools import setup - - -setup() From a84325c73b45b9ef27d2a78bbd2cc97e0c002c29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Oct 2025 22:35:42 +0000 Subject: [PATCH 082/153] style --- trl/trainer/grpo_trainer.py | 2 +- trl/trainer/rloo_trainer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 56cdb447eed..2305546b797 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -273,7 +273,7 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side = "left") + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 841f59dc617..d30de80e2ac 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -385,7 +385,7 @@ def decode(example, tokenizer): # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side = "left") + processing_class = AutoProcessor.from_pretrained(model.config._name_or_path, truncation_side="left") # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): From 2ce6c1ff412c3494189a004a16a435e29eaea5ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Oct 2025 23:53:53 +0000 Subject: [PATCH 083/153] token_type_ids and RLOO --- trl/trainer/grpo_trainer.py | 14 ++++++++------ trl/trainer/rloo_trainer.py | 38 ++++++++++++++++++++----------------- 2 files changed, 29 insertions(+), 23 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d4efbcf8ab6..3144af816c6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1085,6 +1085,7 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): prompts_text = [ maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] + # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1382,12 +1383,6 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - # If token_type_ids are used, extend them with zeros for the completion part - if "token_type_ids" in forward_kwargs: - token_type_ids = forward_kwargs["token_type_ids"] - forward_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 - ) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size @@ -1405,6 +1400,13 @@ def _generate_and_score_completions( else: forward_kwargs = {} + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) + with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index d30de80e2ac..c0372afcc7a 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1075,13 +1075,6 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts ] - if images is not None: - prompt_inputs = self.processing_class(text=prompts_text, padding=True, return_tensors="pt", **kwargs) - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: @@ -1285,13 +1278,13 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): prompt_ids = [p[m].tolist() for p, m in zip(prompt_ids, prompt_mask.bool())] completion_ids = [c[m].tolist() for c, m in zip(completion_ids, completion_mask.bool())] - return prompt_ids, completion_ids, forward_kwargs + return prompt_ids, completion_ids def _generate(self, prompts: list[str], images: Optional[list]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids, forward_kwargs = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids = self._generate_single_turn(prompts, images) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1324,7 +1317,7 @@ def _generate(self, prompts: list[str], images: Optional[list]): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, forward_kwargs + return prompt_ids, completion_ids def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1341,7 +1334,7 @@ def _generate_and_score_completions( else: images = None - prompt_ids_list, completion_ids_list, forward_kwargs = self._generate(prompts, images) + prompt_ids_list, completion_ids_list = self._generate(prompts, images) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1362,18 +1355,29 @@ def _generate_and_score_completions( # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - # If token_type_ids are used, extend them with zeros for the completion part - if "token_type_ids" in forward_kwargs: - token_type_ids = forward_kwargs["token_type_ids"] - forward_kwargs["token_type_ids"] = torch.cat( - [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 - ) logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size num_images = [len(img_list) for img_list in images] if images is not None else None + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + + # If token_type_ids are used, extend them with zeros for the completion part + if "token_type_ids" in forward_kwargs: + token_type_ids = forward_kwargs["token_type_ids"] + forward_kwargs["token_type_ids"] = torch.cat( + [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 + ) with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies( From ddf3405c6cbfec95712b15aec36b12392e56c8dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Oct 2025 23:59:08 +0000 Subject: [PATCH 084/153] gfpo --- trl/experimental/gfpo/gfpo_trainer.py | 34 +++++++++++++++++---------- 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/trl/experimental/gfpo/gfpo_trainer.py b/trl/experimental/gfpo/gfpo_trainer.py index 58ca39a6f45..20a28520f31 100644 --- a/trl/experimental/gfpo/gfpo_trainer.py +++ b/trl/experimental/gfpo/gfpo_trainer.py @@ -18,7 +18,7 @@ import torch from accelerate.utils import gather_object -from ...data_utils import is_conversational +from ...data_utils import apply_chat_template, is_conversational from ...trainer.grpo_trainer import GRPOTrainer as _GRPOTrainer from ...trainer.utils import nanmax, nanmin, nanstd, pad @@ -77,13 +77,9 @@ def _generate_and_score_completions(self, inputs): else: images = None - ( - prompt_ids_list, - completion_ids_list, - num_items_in_batch, - sampling_per_token_logps_list, - forward_kwargs, - ) = self._generate(prompts, images) + prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list = self._generate( + prompts, images + ) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -109,6 +105,23 @@ def _generate_and_score_completions(self, inputs): # Concatenate prompt_mask with completion_mask for logit computation prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) + + logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens + batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size + + num_images = [len(img_list) for img_list in images] if images is not None else None + + # Get forward_kwargs for models with multimodal inputs + if images is not None: + prompts_text = [ + apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts + ] + prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") + prompt_inputs = super()._prepare_inputs(prompt_inputs) + forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} + else: + forward_kwargs = {} + # If token_type_ids are used, extend them with zeros for the completion part if "token_type_ids" in forward_kwargs: token_type_ids = forward_kwargs["token_type_ids"] @@ -116,11 +129,6 @@ def _generate_and_score_completions(self, inputs): [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - with torch.no_grad(): # If the generation and optimization steps are misaligned—i.e., if generation does not occur at the end of # a full optimizer step (when gradient_accumulation_steps is not a multiple of generate_every)—then the From e3c679c9c71179a07bd676d5d4eb77513ee4ad4f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 6 Oct 2025 23:59:17 +0000 Subject: [PATCH 085/153] style --- trl/trainer/rloo_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index c0372afcc7a..ed51f04c0f8 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1378,6 +1378,7 @@ def _generate_and_score_completions( forward_kwargs["token_type_ids"] = torch.cat( [token_type_ids, token_type_ids.new_zeros(completion_ids.shape)], dim=1 ) + with torch.no_grad(): # Compute the per-token log probabilities for the current model old_per_token_logps, _ = self._get_per_token_logps_and_entropies( From ee03478a14a0f474051edcef081dbf915993d697 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 7 Oct 2025 00:32:37 +0000 Subject: [PATCH 086/153] remove test case for prompt truncation --- tests/test_grpo_trainer.py | 41 -------------------------------------- tests/test_rloo_trainer.py | 41 -------------------------------------- 2 files changed, 82 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 17766ef0b98..5202c6a0927 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1471,47 +1471,6 @@ def reward_func(completions, **kwargs): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @require_vision - def test_training_vlm_and_prompt_truncation(self): - # If not handled properly, prompt truncation may truncate image token - dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") - - def reward_func(completions, **kwargs): - """Reward function that rewards longer completions.""" - return [float(len(completion[0]["content"])) for completion in completions] - - training_args = GRPOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=3, # reduce the batch size to reduce memory usage - num_generations=3, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=18, - report_to="none", - ) - trainer = GRPOTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs=reward_func, - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check that the params have changed - # Because of the way the tiny models are initialized, the gradient does not flow properly through the - # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ("model.visual.",) - for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue - new_param = trainer.model.get_parameter(n) - assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 1de4eca479e..7e0135d5410 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1212,47 +1212,6 @@ def reward_func(completions, **kwargs): elif "base_layer" not in n: # We expect the peft params to be different (except for the base layer) assert not torch.allclose(param, new_param), f"Parameter {n} has not changed." - @require_vision - def test_training_vlm_and_prompt_truncation(self): - # If not handled properly, prompt truncation may truncate image token - dataset = load_dataset("trl-internal-testing/zen-image", "conversational_prompt_only", split="train") - - def reward_func(completions, **kwargs): - """Reward function that rewards longer completions.""" - return [float(len(completion[0]["content"])) for completion in completions] - - training_args = RLOOConfig( - output_dir=self.tmp_dir, - learning_rate=0.1, # increase the learning rate to speed up the test - per_device_train_batch_size=3, # reduce the batch size to reduce memory usage - num_generations=3, # reduce the number of generations to reduce memory usage - max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=18, - report_to="none", - ) - trainer = RLOOTrainer( - model="trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration", - reward_funcs=reward_func, - args=training_args, - train_dataset=dataset, - ) - - previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} - - trainer.train() - - assert trainer.state.log_history[-1]["train_loss"] is not None - - # Check that the params have changed - # Because of the way the tiny models are initialized, the gradient does not flow properly through the - # vision parts of the model, so we skip them. Ideally, we should fix the init of these models. - params_to_skip = ("model.visual.",) - for n, param in previous_trainable_params.items(): - if n.startswith(params_to_skip): - continue - new_param = trainer.model.get_parameter(n) - assert not torch.equal(param, new_param), f"Parameter {n} has not changed." - @parameterized.expand( [ ("trl-internal-testing/tiny-Qwen2_5_VLForConditionalGeneration",), From fe11512100384733bf7518c81240f8b5fc79ecfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 00:02:48 +0000 Subject: [PATCH 087/153] dedup and some fixes --- tests/test_data_utils.py | 2 +- trl/data_utils.py | 6 ++--- trl/trainer/grpo_trainer.py | 50 ++++++++++++------------------------- 3 files changed, 20 insertions(+), 38 deletions(-) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index d2a68c70e8a..a7ad80bb2bf 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -150,7 +150,7 @@ def test_already_prepared_messages_unchanged(self): image = Image.new("RGB", (32, 32), color="red") messages = prepare_multimodal_messages(messages, images=[image]) - + expected = [ { "role": "system", diff --git a/trl/data_utils.py b/trl/data_utils.py index eaeb5fcff0e..e0f7a9c3ddf 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -35,9 +35,9 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> Args: messages (`list[dict[str, Any]]`): - Messages with `"role"` and `"content"`. Content may be a raw string before transformation. - List of messages a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing - either a string or a list of structured blocks if already prepared. + Messages with `"role"` and `"content"`. Content may be a raw string before transformation. List of messages + a `"role"` key (`"system"`, `"user"`, or `"assistant"`) and a `"content"` key containing either a string or + a list of structured blocks if already prepared. images (`list`): List of image objects to insert. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c7861dc2db1..c9da96d60e9 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -44,7 +44,7 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available from ..data_utils import ( apply_chat_template, @@ -1154,7 +1154,6 @@ def _generate_single_turn(self, prompts: list): "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, "truncate_prompt_tokens": self.max_prompt_length, - "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding": guided_decoding, "logprobs": 0, # only return the logprob of the generated token } @@ -1203,15 +1202,20 @@ def _generate_single_turn(self, prompts: list): self.llm.sleep(level=1) elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + processor_kwargs = { + "max_length": self.max_prompt_length, + "truncation": True, + "return_dict": True, + "add_special_tokens": False, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True + ) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) + generate_inputs["inputs"] = generate_inputs.pop("input_ids") + with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1227,13 +1231,12 @@ def _generate_single_turn(self, prompts: list): unwrapped_model.to(torch.float16) with torch.inference_mode(): all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + **generate_inputs, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids + prompt_ids = generate_inputs["inputs"] # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn logprobs = None # not used in this case else: @@ -1265,11 +1268,9 @@ def _generate_single_turn(self, prompts: list): ): prompt_completion_ids = unwrapped_model.generate( **generate_inputs, generation_config=self.generation_config, disable_compile=True - **generate_inputs, generation_config=self.generation_config, disable_compile=True ) # Compute prompt length and extract completion ids prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] - prompt_ids, prompt_mask = generate_inputs["input_ids"], generate_inputs["attention_mask"] prompt_length = prompt_ids.size(1) completion_ids = prompt_completion_ids[:, prompt_length:] @@ -1284,7 +1285,6 @@ def _generate_single_turn(self, prompts: list): logprobs = None # not used in this case return prompt_ids, completion_ids, logprobs - return prompt_ids, completion_ids, logprobs def _generate(self, prompts: list[str]): device = self.accelerator.device @@ -1323,7 +1323,6 @@ def _generate(self, prompts: list[str]): self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) return prompt_ids, completion_ids, total_completion_tokens, logprobs - return prompt_ids, completion_ids, total_completion_tokens, logprobs def _generate_and_score_completions( self, inputs: list[dict[str, Union[torch.Tensor, Any]]] @@ -1378,23 +1377,6 @@ def _generate_and_score_completions( prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) # (B, P+C) attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens - batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size - - num_images = [len(img_list) for img_list in images] if images is not None else None - - # Get forward_kwargs for models with multimodal inputs - if images is not None: - prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] - prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") - prompt_inputs = super()._prepare_inputs(prompt_inputs) - forward_kwargs = {k: v for k, v in prompt_inputs.items() if k not in ["input_ids", "attention_mask"]} - else: - forward_kwargs = {} - - logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens batch_size = self.args.per_device_train_batch_size if mode == "train" else self.args.per_device_eval_batch_size From c0c88071a3af2043e8c5c57721854dc8c0415741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 00:08:25 +0000 Subject: [PATCH 088/153] fix style --- trl/data_utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index e0f7a9c3ddf..ff6a907b817 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -30,6 +30,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> list[dict[str, Any]]: + # docstyle-ignore # because is not parsable in the code block """ Convert messages into a structured multimodal format and inject the provided images into the message contents. @@ -63,7 +64,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> # Output, one image provided [ - {"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}, + {"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}, {"role": "assistant", "content": [{"type": "text", "text": "It looks like a cat."}]}, ] ``` @@ -109,6 +110,7 @@ def prepare_multimodal_messages(messages: list[dict[str, Any]], images: list) -> def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dict[str, Any]]: + # docstyle-ignore # because is not parsable in the code block """ Convert structured multimodal messages into a format compatible with vLLM. Replaces `"type": "image"` blocks with `"type": "image_pil"` blocks, and `"image": Image` with `"image_pil": Image`. @@ -124,10 +126,10 @@ def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dic Example: ```python # Input - [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}] + [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What's in this image?"}]}] # Output - [{"role": "user", "content": [{"type": "image_pil", "image_pil": }, {"type": "text", "text": "What's in this image?"}]}] + [{"role": "user", "content": [{"type": "image_pil", "image_pil": }, {"type": "text", "text": "What's in this image?"}]}] ``` """ messages = copy.deepcopy(messages) # avoid modifying the original messages From ba8b93831f027e1d19c4a6b5eea106c1a27dc761 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 00:37:20 +0000 Subject: [PATCH 089/153] rloo --- trl/trainer/grpo_trainer.py | 1 - trl/trainer/rloo_trainer.py | 165 +++++++++++++++++------------------- 2 files changed, 78 insertions(+), 88 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c9da96d60e9..885d112972f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1236,7 +1236,6 @@ def _generate_single_turn(self, prompts: list): unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] prompt_ids = generate_inputs["inputs"] - # Restore the original attention implementation, training mode logprobs = None # not used in this case else: diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index 59fc9c285e0..f50e4775167 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -45,9 +45,15 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_flash_attn_2_available, is_peft_available, is_rich_available - -from ..data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template, prepare_multimodal_messages +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available + +from ..data_utils import ( + apply_chat_template, + is_conversational, + maybe_apply_chat_template, + prepare_multimodal_messages, + prepare_multimodal_messages_vllm, +) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient from ..import_utils import is_vllm_available @@ -1065,22 +1071,9 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): rewards_per_func = gather(rewards_per_func) return rewards_per_func - def _generate_single_turn(self, prompts: list[str], images: Optional[list]): + def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # If the prompts are conversational and the inputs contain images, we need to convert the prompts from - # [{"role": "user", "content": "What color is the sky?"}] to - # [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "What color is the sky?"}]}] - kwargs = {} - if images is not None: - kwargs = {"images": images} - for prompt, image_list in zip(prompts, images): - if isinstance(prompt, list): # i.e., when using conversational data - prepare_multimodal_messages(prompt, num_images=len(image_list)) - - prompts_text = [ - maybe_apply_chat_template({"prompt": prompt}, self.processing_class)["prompt"] for prompt in prompts - ] # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1094,38 +1087,35 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": - all_prompts_text = gather_object(prompts_text) - if images is not None: - all_images = gather_object(images) + all_prompts = gather_object(prompts) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate # prompt individually. - ordered_set_of_prompts = all_prompts_text[:: self.num_generations] - - if images is not None: - ordered_set_of_images = all_images[:: self.num_generations] - else: - ordered_set_of_images = None - + ordered_set_of_prompts = all_prompts[:: self.num_generations] + + sampling_params = { + "n": self.num_generations, + "repetition_penalty": self.repetition_penalty, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": -1 if self.top_k is None else self.top_k, + "min_p": 0.0 if self.min_p is None else self.min_p, + "max_tokens": self.max_completion_length, + "truncate_prompt_tokens": self.max_prompt_length, + "guided_decoding_regex": self.guided_decoding_regex, + "generation_kwargs": self.args.generation_kwargs, + } with profiling_context(self, "vLLM.generate"): - output = self.vllm_client.generate( - prompts=ordered_set_of_prompts, - images=ordered_set_of_images, - n=self.num_generations, - repetition_penalty=self.repetition_penalty, - temperature=self.temperature, - top_p=self.top_p, - top_k=-1 if self.top_k is None else self.top_k, - min_p=0.0 if self.min_p is None else self.min_p, - max_tokens=self.max_completion_length, - truncate_prompt_tokens=self.max_prompt_length, - guided_decoding_regex=self.guided_decoding_regex, - generation_kwargs=self.args.generation_kwargs, - ) + if is_conversational({"prompt": ordered_set_of_prompts[0]}): + output = self.vllm_client.chat(prompts=ordered_set_of_prompts, **sampling_params) + else: + output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) payload = (output["prompt_ids"], output["completion_ids"], output["logprobs"]) else: payload = None @@ -1170,31 +1160,18 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): if self.vllm_tensor_parallel_size > 1: # Gather prompts from all ranks in the TP group and flatten. # Each rank starts with its own prompts; after gathering, all ranks see the full group set. - orig_size = len(prompts_text) + orig_size = len(prompts) gathered_prompts = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_prompts, prompts_text, group=self.tp_group) - all_prompts_text = [p for sublist in gathered_prompts for p in sublist] - - if images is not None: - gathered_images = [None for _ in range(self.vllm_tensor_parallel_size)] - torch.distributed.all_gather_object(gathered_images, images, group=self.tp_group) - all_images = [img for sublist in gathered_images for img in sublist] - else: - all_images = None + torch.distributed.all_gather_object(gathered_prompts, prompts, group=self.tp_group) + all_prompts = [p for sublist in gathered_prompts for p in sublist] else: - all_prompts_text = prompts_text - all_images = images - - if images is not None and all_images: - vllm_inputs = [] - for prompt, image_list in zip(all_prompts_text, all_images): - vllm_inputs.append({"prompt": prompt, "multi_modal_data": {"image": image_list}}) - - else: - vllm_inputs = all_prompts_text + all_prompts = prompts with profiling_context(self, "vLLM.generate"): - all_outputs = self.llm.generate(vllm_inputs, sampling_params=sampling_params, use_tqdm=False) + if is_conversational({"prompt": prompts[0]}): + all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + else: + all_outputs = self.llm.generate(all_prompts, sampling_params=sampling_params, use_tqdm=False) all_prompt_ids = [output.prompt_token_ids for output in all_outputs] all_completion_ids = [output.token_ids for outputs in all_outputs for output in outputs.outputs] @@ -1214,15 +1191,20 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): self.llm.sleep(level=1) elif self.use_transformers_paged: - # Re-process inputs for paged generation if needed - # Note: images are already validated and preprocessed above - paged_prompt_inputs = self.processing_class(text=prompts_text, **kwargs) - previous_attn = self.model_wrapped.config._attn_implementation - - if is_flash_attn_2_available(): - self.model_wrapped.config._attn_implementation = "paged_attention" + processor_kwargs = { + "max_length": self.max_prompt_length, + "truncation": True, + "return_dict": True, + "add_special_tokens": False, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True + ) else: - self.model_wrapped.config._attn_implementation = "sdpa_paged" + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) + generate_inputs["inputs"] = generate_inputs.pop("input_ids") + with ( profiling_context(self, "transformers.generate_batch"), unwrap_model_for_generation( @@ -1238,26 +1220,29 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): unwrapped_model.to(torch.float16) with torch.inference_mode(): all_outputs = unwrapped_model.generate_batch( - paged_prompt_inputs.input_ids, generation_config=self.generation_config, progress_bar=False + **generate_inputs, generation_config=self.generation_config, progress_bar=False ) unwrapped_model.train() # restore training mode, as generate_batch forces eval mode completion_ids = [output.generated_tokens for output in all_outputs.values()] - prompt_ids = paged_prompt_inputs.input_ids - # Restore the original attention implementation, training mode - self.model_wrapped.config._attn_implementation = previous_attn + prompt_ids = generate_inputs["inputs"] else: # Regular generation path - generate_inputs = self.processing_class( - text=prompts_text, - return_tensors="pt", - padding=True, - padding_side="left", - max_length=self.max_prompt_length, - truncation=True, - add_special_tokens=False, - **kwargs, - ) + processor_kwargs = { + "return_tensors": "pt", + "padding": True, + "padding_side": "left", + "max_length": self.max_prompt_length, + "truncation": True, + "return_dict": True, + "add_special_tokens": False, + } + if is_conversational({"prompt": prompts[0]}): + generate_inputs = self.processing_class.apply_chat_template( + conversation=prompts, **processor_kwargs, tokenize=True + ) + else: + generate_inputs = self.processing_class(text=prompts, **processor_kwargs) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1287,11 +1272,11 @@ def _generate_single_turn(self, prompts: list[str], images: Optional[list]): return prompt_ids, completion_ids - def _generate(self, prompts: list[str], images: Optional[list]): + def _generate(self, prompts: list[str]): device = self.accelerator.device mode = "train" if self.model.training else "eval" - prompt_ids, completion_ids = self._generate_single_turn(prompts, images) + prompt_ids, completion_ids = self._generate_single_turn(prompts) # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1344,7 +1329,13 @@ def _generate_and_score_completions( if images is not None and all(img_list == [] for img_list in images): images = None - prompt_ids_list, completion_ids_list = self._generate(prompts, images) + # If the prompts are conversational and the inputs contain images, we need to convert the prompts from + # [{"role": "user", "content": "What color is the sky?"}] to + # [{"role": "user", "content": [{"type": "image", "image": }, {"type": "text", "text": "What color is the sky?"}]}] + if images is not None: + prompts = [prepare_multimodal_messages(prompt, image_list) for prompt, image_list in zip(prompts, images)] + + prompt_ids_list, completion_ids_list = self._generate(prompts) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] From 7a2936e0a2d0350ad66136c9d1edabf0bbb9bdd8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 00:38:17 +0000 Subject: [PATCH 090/153] style --- trl/trainer/rloo_trainer.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f50e4775167..f5017b6faf1 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -50,7 +50,6 @@ from ..data_utils import ( apply_chat_template, is_conversational, - maybe_apply_chat_template, prepare_multimodal_messages, prepare_multimodal_messages_vllm, ) @@ -1074,7 +1073,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # Generate completions using either vLLM or regular generation if self.use_vllm: if self.vllm_mode == "colocate" and self.args.vllm_enable_sleep_mode: From 1a6f04000b8278f132dadd4ca579feffd62260e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 00:53:54 +0000 Subject: [PATCH 091/153] test --- tests/test_data_utils.py | 105 ++++++++++++++++++++++++++++++++++++--- 1 file changed, 99 insertions(+), 6 deletions(-) diff --git a/tests/test_data_utils.py b/tests/test_data_utils.py index a7ad80bb2bf..25a2ed323b1 100644 --- a/tests/test_data_utils.py +++ b/tests/test_data_utils.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import itertools import textwrap from time import strftime @@ -32,6 +33,7 @@ maybe_unpair_preference_dataset, pack_dataset, prepare_multimodal_messages, + prepare_multimodal_messages_vllm, truncate_dataset, unpair_preference_dataset, ) @@ -46,7 +48,7 @@ def test_basic_user_assistant_conversation(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - image = Image.new("RGB", (32, 32), color="red") + image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ @@ -70,7 +72,7 @@ def test_first_user_message_gets_image(self): {"role": "user", "content": "How about the grass?"}, ] - image = Image.new("RGB", (32, 32), color="red") + image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ @@ -96,7 +98,7 @@ def test_multiple_images(self): {"role": "user", "content": "What color is the sky?"}, {"role": "assistant", "content": "It is blue."}, ] - images = [Image.new("RGB", (32, 32), color=color) for color in ["red", "green", "blue"]] + images = [Image.new("RGB", (10, 10), color=color) for color in ["red", "green", "blue"]] messages = prepare_multimodal_messages(messages, images=images) expected = [ @@ -124,7 +126,7 @@ def test_system_message_transformation(self): {"role": "user", "content": "What color is the sky?"}, ] - image = Image.new("RGB", (32, 32), color="red") + image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ @@ -148,7 +150,7 @@ def test_already_prepared_messages_unchanged(self): {"role": "assistant", "content": [{"type": "text", "text": "It is blue."}]}, ] - image = Image.new("RGB", (32, 32), color="red") + image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ @@ -176,7 +178,7 @@ def test_mixed_prepared_and_unprepared_messages(self): {"role": "user", "content": "What about the grass?"}, ] - image = Image.new("RGB", (32, 32), color="red") + image = Image.new("RGB", (10, 10), color="blue") messages = prepare_multimodal_messages(messages, images=[image]) expected = [ @@ -197,6 +199,97 @@ def test_mixed_prepared_and_unprepared_messages(self): assert messages == expected +class TestPrepareMultimodalMessagesVLLM: + def test_single_image_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + {"type": "text", "text": "What color is the sky?"}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # Original should remain unchanged (deepcopy test) + assert messages[0]["content"][0]["type"] == "image" + + # Converted version should have correct structure + assert result[0]["content"][0]["type"] == "image_pil" + assert "image_pil" in result[0]["content"][0] + assert "image" not in result[0]["content"][0] + assert isinstance(result[0]["content"][0]["image_pil"], Image.Image) + assert result[0]["content"][1]["type"] == "text" + + def test_mixed_content_conversion(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + } + ] + + result = prepare_multimodal_messages_vllm(messages) + + # The image part should be converted, text should be unchanged + assert result[0]["content"][0]["type"] == "text" + assert result[0]["content"][1]["type"] == "image_pil" + + def test_no_images(self): + messages = [{"role": "user", "content": [{"type": "text", "text": "What color is the sky?"}]}] + + result = prepare_multimodal_messages_vllm(messages) + + # Should be identical since there are no images + assert result == messages + # And a deepcopy — not the same object + assert result is not messages + assert result[0] is not messages[0] + + def test_multiple_messages(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + { + "role": "assistant", + "content": [{"type": "text", "text": "It is blue."}], + }, + ] + + result = prepare_multimodal_messages_vllm(messages) + + assert result[0]["content"][1]["type"] == "image_pil" + assert result[1]["content"][0]["type"] == "text" + assert result[1]["content"][0]["text"] == "It is blue." + + def test_deepcopy_integrity(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What color is the sky?"}, + {"type": "image", "image": Image.new("RGB", (10, 10), color="blue")}, + ], + }, + ] + original = copy.deepcopy(messages) + + _ = prepare_multimodal_messages_vllm(messages) + + # Original should not be mutated + assert messages == original + + class TestIsConversational(TrlTestCase): conversational_examples = [ { # Language modeling From 26ffb043db0306c179580b34c63f119bcb033b4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 01:14:31 +0000 Subject: [PATCH 092/153] style --- trl/trainer/grpo_trainer.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 94a5a5bedde..077114df70c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -15,6 +15,7 @@ import inspect import json import os +import re import textwrap import traceback from collections import defaultdict, deque @@ -94,7 +95,6 @@ if is_wandb_available(): import wandb -import re logger = logging.get_logger(__name__) @@ -102,24 +102,25 @@ # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] + def extract_tool_calls(text: str) -> list[dict[str, Any]]: """ - Extract JSON objects from ... blocks in `text` - and return them in the format: - {"type": "function", "function": {...}} + Extract JSON objects from ... blocks in `text` and return them in the format: `[{"type": + "function", "function": {...}}, ...]` """ # Find every block between and - blocks = re.findall(r'\s*(\{.*?\})\s*', text, flags=re.DOTALL) - + blocks = re.findall(r"\s*(\{.*?\})\s*", text, flags=re.DOTALL) + result = [] for block in blocks: try: parsed = json.loads(block) - except json.JSONDecodeError as e: + except json.JSONDecodeError: continue result.append({"type": "function", "function": parsed}) return result or None + class GRPOTrainer(BaseTrainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -228,16 +229,14 @@ def reward_func(completions, **kwargs): "title": "DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", "id": "2402.03300", # docstyle-ignore - "citation": textwrap.dedent( - """\ + "citation": textwrap.dedent("""\ @article{shao2024deepseekmath, title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, year = 2024, eprint = {arXiv:2402.03300}, } - """ - ), + """), } def __init__( @@ -1354,8 +1353,10 @@ def _generate(self, prompts: list[str]): # Truncate post-tool completion so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length for i in range(len(post_tool_ids)): - excess_length = len(prompt_completion_tool_ids[i]) + len(post_tool_ids[i]) - ( - self.max_prompt_length + self.max_completion_length + excess_length = ( + len(prompt_completion_tool_ids[i]) + + len(post_tool_ids[i]) + - (self.max_prompt_length + self.max_completion_length) ) if excess_length > 0: post_tool_ids[i] = post_tool_ids[i][:-excess_length] @@ -1373,7 +1374,7 @@ def _generate(self, prompts: list[str]): # parsed_completions.append(None) # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] tool_calls = [extract_tool_calls(content) for content in cc] - completion_contents =[None] * len(completion_contents) + completion_contents = [None] * len(completion_contents) for i, content in zip(idxs_with_tool, cc): completion_contents[i] = content idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls) if tc] From ced5450e0dc2db74376501a0b1cf3dae7ae841f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 01:17:17 +0000 Subject: [PATCH 093/153] safe prepare_multimodal_messages_vllm --- trl/data_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index ff6a907b817..cd4b5908f49 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -134,10 +134,11 @@ def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dic """ messages = copy.deepcopy(messages) # avoid modifying the original messages for message in messages: - for part in message["content"]: - if part["type"] == "image": - part["type"] = "image_pil" # vLLM expects 'image_pil' key for images - part["image_pil"] = part.pop("image") + if not isinstance(message["content"], list): + for part in message["content"]: + if part["type"] == "image": + part["type"] = "image_pil" # vLLM expects 'image_pil' key for images + part["image_pil"] = part.pop("image") return messages From 23d13f9ae9476bdbad63e7d0381f443338b15b45 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 01:18:04 +0000 Subject: [PATCH 094/153] oops --- trl/data_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/data_utils.py b/trl/data_utils.py index cd4b5908f49..5aec3eaf501 100644 --- a/trl/data_utils.py +++ b/trl/data_utils.py @@ -134,7 +134,7 @@ def prepare_multimodal_messages_vllm(messages: list[dict[str, Any]]) -> list[dic """ messages = copy.deepcopy(messages) # avoid modifying the original messages for message in messages: - if not isinstance(message["content"], list): + if isinstance(message["content"], list): for part in message["content"]: if part["type"] == "image": part["type"] = "image_pil" # vLLM expects 'image_pil' key for images From 5f87ee989d055a882a89f6d485332891cb65a9dd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 18 Oct 2025 03:55:42 +0000 Subject: [PATCH 095/153] fix return-dict --- trl/trainer/grpo_trainer.py | 12 +++--------- trl/trainer/rloo_trainer.py | 15 ++++++--------- 2 files changed, 9 insertions(+), 18 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 885d112972f..61d94476a11 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1202,15 +1202,10 @@ def _generate_single_turn(self, prompts: list): self.llm.sleep(level=1) elif self.use_transformers_paged: - processor_kwargs = { - "max_length": self.max_prompt_length, - "truncation": True, - "return_dict": True, - "add_special_tokens": False, - } + processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False} if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1246,12 +1241,11 @@ def _generate_single_turn(self, prompts: list): "padding_side": "left", "max_length": self.max_prompt_length, "truncation": True, - "return_dict": True, "add_special_tokens": False, } if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) diff --git a/trl/trainer/rloo_trainer.py b/trl/trainer/rloo_trainer.py index f5017b6faf1..4f37003924e 100644 --- a/trl/trainer/rloo_trainer.py +++ b/trl/trainer/rloo_trainer.py @@ -1189,15 +1189,13 @@ def _generate_single_turn(self, prompts: list): self.llm.sleep(level=1) elif self.use_transformers_paged: - processor_kwargs = { - "max_length": self.max_prompt_length, - "truncation": True, - "return_dict": True, - "add_special_tokens": False, - } + processor_kwargs = {"max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False} if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True + conversation=prompts, + **processor_kwargs, + tokenize=True, + return_dict=True, ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1232,12 +1230,11 @@ def _generate_single_turn(self, prompts: list): "padding_side": "left", "max_length": self.max_prompt_length, "truncation": True, - "return_dict": True, "add_special_tokens": False, } if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( - conversation=prompts, **processor_kwargs, tokenize=True + conversation=prompts, **processor_kwargs, tokenize=True, return_dict=True ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) From 2d945f2d38901a09b0743e0e081e51349b9642d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Fri, 14 Nov 2025 22:22:02 +0000 Subject: [PATCH 096/153] move extraction to util + doc --- trl/trainer/grpo_trainer.py | 59 +++++++++++++++++-------------------- trl/trainer/utils.py | 19 ++++++++++++ 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0fb6d802f2d..f6a05dbe953 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -13,12 +13,10 @@ # limitations under the License. import inspect -import json import os -import re import textwrap -import traceback import time +import traceback import warnings from collections import defaultdict, deque from collections.abc import Callable @@ -74,6 +72,7 @@ disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, + extract_tool_calls, get_config_model_id, identity, nanmax, @@ -120,24 +119,6 @@ RolloutFunc = Callable[[list[str], "GRPOTrainer"], dict[str, Any]] -def extract_tool_calls(text: str) -> list[dict[str, Any]]: - """ - Extract JSON objects from ... blocks in `text` and return them in the format: `[{"type": - "function", "function": {...}}, ...]` - """ - # Find every block between and - blocks = re.findall(r"\s*(\{.*?\})\s*", text, flags=re.DOTALL) - - result = [] - for block in blocks: - try: - parsed = json.loads(block) - except json.JSONDecodeError: - continue - result.append({"type": "function", "function": parsed}) - return result or None - - class GRPOTrainer(BaseTrainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -238,6 +219,13 @@ def reward_func(completions, **kwargs): model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. peft_config ([`~peft.PeftConfig`], *optional*): PEFT configuration used to wrap the model. If `None`, the model is not wrapped. + tools (list of `Callable`, *optional*): + A list of callable tool functions that the model can invoke during generation. Each tool should be a + standard Python function with properly type-hinted arguments and return values, and a Google-style + docstring describing its purpose, arguments, and return value. For more details, see: + https://huggingface.co/docs/transformers/en/chat_extras#passing-tools. The model uses the function's name, + type hints, and docstring to determine how to call it. Ensure that the model's chat template supports tool + use and that it has been fine-tuned for tool calling. rollout_func (`RolloutFunc`, *optional*): Function to use for generating completions. It receives the list of prompts allocated to the current process and the trainer instance. It must return a dict with `"prompt_ids"`, `"completion_ids"`, and @@ -273,11 +261,9 @@ def __init__( callbacks: list[TrainerCallback] | None = None, optimizers: tuple[torch.optim.Optimizer | None, torch.optim.lr_scheduler.LambdaLR | None] = (None, None), peft_config: "PeftConfig | None" = None, - tools=None, + tools: list[Callable] | None = None, rollout_func: RolloutFunc | None = None, ): - self.tools = tools or [] - self._tool_dict = {tool.__name__: tool for tool in self.tools} # Args if args is None: model_name = model if isinstance(model, str) else get_config_model_id(model.config) @@ -405,6 +391,10 @@ def __init__( ) self.rollout_func = rollout_func + # Tools + self.tools = tools or [] + self._tool_dict = {tool.__name__: tool for tool in self.tools} + # Training arguments self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -620,7 +610,7 @@ def cast_outputs_to_original_dtype(module, args, output): ensure_master_addr_port() if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length + 512 + max_model_len = self.max_prompt_length + self.max_completion_length else: max_model_len = None @@ -1424,7 +1414,6 @@ def _generate_single_turn(self, prompts: list): tools=self.tools, return_dict=True, **self.chat_template_kwargs, - ) else: generate_inputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1464,6 +1453,9 @@ def _generate(self, prompts: list): mode = "train" if self.model.training else "eval" prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + + # Tool execution loop: check for tool calls and execute them, then regenerate completions with tool results + # appended to the prompt completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) # parsed_completions = [] # for content in completion_contents: @@ -1479,7 +1471,9 @@ def _generate(self, prompts: list): while idxs_with_tool: prompts_for_generation = [prompts[i] for i in idxs_with_tool] - for idx, tool_call_list, prompt_for_generation in zip(idxs_with_tool, tool_calls, prompts_for_generation): + for idx, tool_call_list, prompt_for_generation in zip( + idxs_with_tool, tool_calls, prompts_for_generation, strict=True + ): prompt_for_generation.append({"role": "assistant", "content": completion_contents[idx]}) for tool_call in tool_call_list: if tool_call["type"] == "function": @@ -1507,7 +1501,7 @@ def _generate(self, prompts: list): if excess_length > 0: post_tool_ids[i] = post_tool_ids[i][:-excess_length] - for idx, pct, post_tool in zip(idxs_with_tool, prompt_completion_tool_ids, post_tool_ids): + for idx, pct, post_tool in zip(idxs_with_tool, prompt_completion_tool_ids, post_tool_ids, strict=True): completion_ids[idx] = pct[len(prompt_ids[idx]) :] + post_tool cc = self.processing_class.batch_decode(post_tool_ids, skip_special_tokens=True) @@ -1521,9 +1515,9 @@ def _generate(self, prompts: list): # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] tool_calls = [extract_tool_calls(content) for content in cc] completion_contents = [None] * len(completion_contents) - for i, content in zip(idxs_with_tool, cc): + for i, content in zip(idxs_with_tool, cc, strict=True): completion_contents[i] = content - idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls) if tc] + idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls, strict=True) if tc] tool_calls = [tc for tc in tool_calls if tc] # Get completion length per sequence, used for logging @@ -1622,8 +1616,9 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs)["prompt"] - + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] for prompt in prompts ] prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 02d4cc78073..69d0adc8d16 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -17,6 +17,7 @@ import json import os import random +import re import socket import warnings from collections.abc import Mapping, Sequence, Sized @@ -2027,3 +2028,21 @@ def get_config_model_id(config: PretrainedConfig) -> str: """ # Fall back to `config.text_config._name_or_path` if `config._name_or_path` is missing: Qwen2-VL and Qwen2.5-VL. See GH-4323 return getattr(config, "_name_or_path", "") or getattr(getattr(config, "text_config", None), "_name_or_path", "") + + +def extract_tool_calls(text: str) -> list[dict[str, Any]]: + """ + Extract JSON objects from ... blocks in `text` and return them in the format: `[{"type": + "function", "function": {...}}, ...]` + """ + # Find every block between and + blocks = re.findall(r"\s*(\{.*?\})\s*", text, flags=re.DOTALL) + + result = [] + for block in blocks: + try: + parsed = json.loads(block) + except json.JSONDecodeError: + continue + result.append({"type": "function", "function": parsed}) + return result or None From 65ad930fa0a332609dc2decfc31e3240bfc0153a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 15 Nov 2025 02:44:32 +0000 Subject: [PATCH 097/153] using response parser --- trl/trainer/grpo_trainer.py | 91 +++++++++------- trl/trainer/utils.py | 207 +++++++++++++++++++++++++++++++++--- 2 files changed, 243 insertions(+), 55 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index f6a05dbe953..9f5c4ab9c9f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import copy import inspect import os import textwrap @@ -22,6 +23,7 @@ from collections.abc import Callable from contextlib import nullcontext from functools import partial +from itertools import takewhile from pathlib import Path from typing import Any @@ -69,10 +71,10 @@ from .grpo_config import GRPOConfig from .utils import ( RepeatSampler, + add_response_schema, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, - extract_tool_calls, get_config_model_id, identity, nanmax, @@ -394,6 +396,11 @@ def __init__( # Tools self.tools = tools or [] self._tool_dict = {tool.__name__: tool for tool in self.tools} + # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. + # While waiting for broader adoption, we provide this utility function to manually set the response schema for + # known chat templates. + if tools and not processing_class.response_schema: + processing_class = add_response_schema(processing_class) # Training arguments self.max_prompt_length = args.max_prompt_length @@ -1452,29 +1459,31 @@ def _generate(self, prompts: list): device = self.accelerator.device mode = "train" if self.model.training else "eval" + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) # Tool execution loop: check for tool calls and execute them, then regenerate completions with tool results # appended to the prompt - completion_contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - # parsed_completions = [] - # for content in completion_contents: - # try: - # parsed_completions.append(self.processing_class.parse_response(content)) - # except Exception as e: - # logger.warning(f"Failed to parse model output: {content}\nError: {e}") - # parsed_completions.append(None) - # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] - tool_calls = [extract_tool_calls(content) for content in completion_contents] + completions = self.processing_class.parse_response(completion_ids) + # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here + for completion in completions: + completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) + completions = [[completion] for completion in completions] # format as list of messages + + # Check for tool calls + tool_calls = [completion[-1].get("tool_calls") for completion in completions] idxs_with_tool = [i for i, t in enumerate(tool_calls) if t] # find indices that actually have a tool call tool_calls = [tool_calls[i] for i in idxs_with_tool] while idxs_with_tool: - prompts_for_generation = [prompts[i] for i in idxs_with_tool] + prompts_for_generation = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls for idx, tool_call_list, prompt_for_generation in zip( idxs_with_tool, tool_calls, prompts_for_generation, strict=True ): - prompt_for_generation.append({"role": "assistant", "content": completion_contents[idx]}) + # Call the tools, and build the new prompt for generation + prompt_for_generation.append(completions[idx][-1]) for tool_call in tool_call_list: if tool_call["type"] == "function": function = tool_call["function"] @@ -1488,6 +1497,7 @@ def _generate(self, prompts: list): tool_call["result"] = result tool_message = {"role": "tool", "name": function["name"], "content": str(result)} prompt_for_generation.append(tool_message) + completions[idx].append(tool_message) prompt_completion_tool_ids, post_tool_ids, _, _ = self._generate_single_turn(prompts_for_generation) @@ -1501,22 +1511,28 @@ def _generate(self, prompts: list): if excess_length > 0: post_tool_ids[i] = post_tool_ids[i][:-excess_length] + # Qwen3 inserts \n\n tokens only for the latest user message which can cause discrepancies + # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the + # common prefix between the two. In most cases, this is a no-op. + for idx, pct in zip(idxs_with_tool, prompt_completion_tool_ids, strict=True): + prompt_ids[idx] = [ + tok for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx], pct, strict=False)) + ] + + # Update completion_ids with the new completions after tool execution for idx, pct, post_tool in zip(idxs_with_tool, prompt_completion_tool_ids, post_tool_ids, strict=True): completion_ids[idx] = pct[len(prompt_ids[idx]) :] + post_tool - cc = self.processing_class.batch_decode(post_tool_ids, skip_special_tokens=True) - # parsed_completions = [] - # for content in cc: - # try: - # parsed_completions.append(self.processing_class.parse_response(content)) - # except Exception as e: - # logger.warning(f"Failed to parse model output: {content}\nError: {e}") - # parsed_completions.append(None) - # tool_calls = [completion.get("tool_calls") if completion is not None else None for completion in parsed_completions] - tool_calls = [extract_tool_calls(content) for content in cc] - completion_contents = [None] * len(completion_contents) - for i, content in zip(idxs_with_tool, cc, strict=True): - completion_contents[i] = content + post_tool_completions = self.processing_class.parse_response(post_tool_ids) + for completion in post_tool_completions: + completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) + + # Add post-tool completions to the existing completions + for idx in range(len(idxs_with_tool)): + completions[idxs_with_tool[idx]].append(post_tool_completions[idx]) + + # Check for further tool calls + tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls, strict=True) if tc] tool_calls = [tc for tc in tool_calls if tc] @@ -1550,7 +1566,7 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields + return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields, completions def _generate_and_score_completions( self, inputs: list[dict[str, torch.Tensor | Any]] @@ -1579,9 +1595,14 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( - self._generate(prompts) - ) + ( + prompt_ids_list, + completion_ids_list, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + completions, + ) = self._generate(prompts) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -1694,16 +1715,6 @@ def _generate_and_score_completions( # Decode prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text, strict=True): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}] - assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text" - bootstrap = bootstrap[0]["text"] - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text # Merge extra_fields from rollout_func into inputs for reward functions if extra_fields: diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 69d0adc8d16..6714dc46070 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -17,7 +17,6 @@ import json import os import random -import re import socket import warnings from collections.abc import Mapping, Sequence, Sized @@ -39,12 +38,15 @@ from torch.utils.data import Sampler from transformers import ( AutoConfig, + AutoTokenizer, BitsAndBytesConfig, EvalPrediction, GenerationConfig, PretrainedConfig, PreTrainedModel, + PreTrainedTokenizer, PreTrainedTokenizerBase, + ProcessorMixin, TrainerState, TrainingArguments, is_comet_available, @@ -2030,19 +2032,194 @@ def get_config_model_id(config: PretrainedConfig) -> str: return getattr(config, "_name_or_path", "") or getattr(getattr(config, "text_config", None), "_name_or_path", "") -def extract_tool_calls(text: str) -> list[dict[str, Any]]: - """ - Extract JSON objects from ... blocks in `text` and return them in the format: `[{"type": - "function", "function": {...}}, ...]` +# These schemas are copy-pasted from https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py +cohere_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"}, + "tool_calls": { + "x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)", + "x-parser": "json", + "x-parser-args": { + "transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}" + }, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +ernie_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": "\n(.*?)\n?"}, + "thinking": {"type": "string", "x-regex": r"(?:^|\s*)(.*?)\s*<\/think>"}, + "tool_calls": { + "x-regex-iterator": "(.*?)", + "type": "array", + "items": { + "type": "object", + "x-parser": "json", + "x-parser-args": {"transform": "{type: 'function', function: @}"}, + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +gpt_oss_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, + "tool_calls": { + "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, + "arguments": { + "type": "object", + "x-regex": r"<\|message\|>(.*)", + "x-parser": "json", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +smollm_schema = { + "x-regex": r"(?:\n?(?P.+?)\n?)?\s*(?:(?P.+?))?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-parser": "json", + "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +qwen3_schema = { + "x-regex": r"^(?:(?:)?\s*(?P.+?)\s*)?\s*(?:(?P.*?)\s*)?\s*(?P.+?)?\s*$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r""}, + "arguments": { + "type": "object", + "x-regex-key-value": r"\w+)>\n(?P.*?)\n", + "additionalProperties": { + "x-parser": "json", + "x-parser-args": {"allow_non_json": True}, + }, + }, + }, + }, + }, + }, + }, + }, +} + + +TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) + + +def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: """ - # Find every block between and - blocks = re.findall(r"\s*(\{.*?\})\s*", text, flags=re.DOTALL) + Adds the appropriate response schema to the given tokenizer or processor based on its chat template. - result = [] - for block in blocks: - try: - parsed = json.loads(block) - except json.JSONDecodeError: - continue - result.append({"type": "function", "function": parsed}) - return result or None + At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While + waiting for broader adoption, we provide this utility function to manually set the response schema for known chat + templates. + + Args: + processor (`TokenizerOrProcessor`): + Tokenizer or processor to which the response schema will be added. + + Returns: + `TokenizerOrProcessor`: + Tokenizer or processor with the added response schema. + """ + qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template + if processor.chat_template == qwen3_chat_template: + # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See + # https://github.com/huggingface/transformers/issues/42220 + processor.response_schema = smollm_schema + return processor + raise ValueError( + "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " + "tokenizer or processor." + ) From 67e8f29c8c661cf33ff00589a60a673d09b4d72b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 15 Nov 2025 03:15:21 +0000 Subject: [PATCH 098/153] backward compat --- trl/trainer/grpo_trainer.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 9f5c4ab9c9f..2c87fe56a36 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -35,6 +35,7 @@ from accelerate import logging from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset +from packaging import version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler @@ -394,6 +395,11 @@ def __init__( self.rollout_func = rollout_func # Tools + if tools and not version.parse(transformers.__version__) >= version.parse("5.0.0.dev0"): + raise ImportError( + "Using tools with GRPOTrainer requires transformers version 5.0.0.dev0 or higher. Please upgrade " + "transformers to use this feature." + ) self.tools = tools or [] self._tool_dict = {tool.__name__: tool for tool in self.tools} # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. @@ -1466,7 +1472,12 @@ def _generate(self, prompts: list): # Tool execution loop: check for tool calls and execute them, then regenerate completions with tool results # appended to the prompt - completions = self.processing_class.parse_response(completion_ids) + if version.parse(transformers.__version__) >= version.parse("5.0.0.dev0"): + completions = self.processing_class.parse_response(completion_ids) + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [{"role": "assistant", "content": content} for content in contents] + # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here for completion in completions: completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) From a4eac3c61bebac81eff1c48dfade53bd654aefef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 15 Nov 2025 06:13:24 +0000 Subject: [PATCH 099/153] fixes --- trl/trainer/grpo_trainer.py | 37 +++++++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 14 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 2c87fe56a36..0fd1861eec8 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1470,23 +1470,32 @@ def _generate(self, prompts: list): prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + if is_conversational({"prompt": prompts[0]}): + if ( + version.parse(transformers.__version__) >= version.parse("5.0.0.dev0") # parse_response added in v5 + and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors + and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + ): + completions = self.processing_class.parse_response(completion_ids) + # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here + for completion in completions: + completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) + completions = [[completion] for completion in completions] # format as list of messages + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + # Tool execution loop: check for tool calls and execute them, then regenerate completions with tool results # appended to the prompt - if version.parse(transformers.__version__) >= version.parse("5.0.0.dev0"): - completions = self.processing_class.parse_response(completion_ids) + if self.tools: + # Check for tool calls + tool_calls = [completion[0].get("tool_calls") for completion in completions] + idxs_with_tool = [i for i, t in enumerate(tool_calls) if t] # find indices that actually have a tool call + tool_calls = [tool_calls[i] for i in idxs_with_tool] else: - contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [{"role": "assistant", "content": content} for content in contents] - - # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here - for completion in completions: - completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) - completions = [[completion] for completion in completions] # format as list of messages - - # Check for tool calls - tool_calls = [completion[-1].get("tool_calls") for completion in completions] - idxs_with_tool = [i for i, t in enumerate(tool_calls) if t] # find indices that actually have a tool call - tool_calls = [tool_calls[i] for i in idxs_with_tool] + idxs_with_tool = [] while idxs_with_tool: prompts_for_generation = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls From 1e32b0afa4dbfcccb064b8f1f7012ec7a617f19b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Sat, 15 Nov 2025 06:53:49 +0000 Subject: [PATCH 100/153] don't truncate prompt --- trl/trainer/grpo_trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0fd1861eec8..0da6c92e700 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1223,7 +1223,7 @@ def _generate_single_turn(self, prompts: list): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "truncate_prompt_tokens": self.max_prompt_length, + # "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding_regex": self.guided_decoding_regex, "generation_kwargs": self.args.generation_kwargs, } @@ -1309,7 +1309,7 @@ def _generate_single_turn(self, prompts: list): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - "truncate_prompt_tokens": self.max_prompt_length, + # "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding": guided_decoding, "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only } @@ -1366,7 +1366,7 @@ def _generate_single_turn(self, prompts: list): elif self.use_transformers_paged: processor_kwargs = { - "max_length": self.max_prompt_length, + # "max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False, } @@ -1414,7 +1414,7 @@ def _generate_single_turn(self, prompts: list): "return_tensors": "pt", "padding": True, "padding_side": "left", - "max_length": self.max_prompt_length, + # "max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False, } From e816ef49677d43f24285c3e72b02578967ae5e4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 17 Nov 2025 08:01:04 +0000 Subject: [PATCH 101/153] remove max_length --- docs/source/grpo_trainer.md | 10 -- docs/source/lora_without_regret.md | 2 - docs/source/paper_index.md | 1 - docs/source/rapidfire_integration.md | 1 - docs/source/rloo_trainer.md | 9 -- examples/scripts/grpo_vlm.py | 2 - examples/scripts/gspo.py | 1 - examples/scripts/gspo_vlm.py | 1 - tests/test_grpo_trainer.py | 4 - trl/trainer/grpo_config.py | 8 -- trl/trainer/grpo_trainer.py | 173 +++++++++++++++++---------- 11 files changed, 113 insertions(+), 99 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index bdc132e4115..6e6baeff8de 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -553,7 +553,6 @@ accelerate launch \ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ @@ -564,15 +563,6 @@ accelerate launch \ ### Configuration Tips -> [!TIP] -> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_prompt_length=None` in the [`GRPOConfig`]. This allows the model to process the full sequence length without truncating image tokens. -> -> ```python -> GRPOConfig(max_prompt_length=None, ...) -> ``` -> -> Only use `max_prompt_length` when you've verified that truncation won't remove image tokens for the entire dataset. - - Use LoRA on vision-language projection layers - Enable 4-bit quantization to reduce memory usage - VLMs are memory-intensive — start with smaller batch sizes diff --git a/docs/source/lora_without_regret.md b/docs/source/lora_without_regret.md index f56e4ef9200..0b400f85e8b 100644 --- a/docs/source/lora_without_regret.md +++ b/docs/source/lora_without_regret.md @@ -291,7 +291,6 @@ hf jobs uv run \ --warmup_ratio 0.0 \ --max_grad_norm 1.0 \ --beta 0.0 \ - --max_prompt_length 1024 \ --max_completion_length 4096 \ --num_generations 16 \ --generation_batch_size 16 \ @@ -326,7 +325,6 @@ uv run "https://huggingface.co/datasets/burtenshaw/lora-without-regrets/resolve/ --warmup_ratio 0.0 \ --max_grad_norm 1.0 \ --beta 0.0 \ - --max_prompt_length 1024 \ --max_completion_length 4096 \ --num_generations 16 \ --generation_batch_size 16 \ diff --git a/docs/source/paper_index.md b/docs/source/paper_index.md index bdc41263013..671e7e174a6 100644 --- a/docs/source/paper_index.md +++ b/docs/source/paper_index.md @@ -111,7 +111,6 @@ training_args = GRPOConfig( loss_type="dr_grpo", per_device_train_batch_size=1, # train_batch_size_per_device in the Training section of the repository num_generations=8, # num_samples in the Training section of the repository - max_prompt_length=1024, # prompt_max_length in the Training section of the repository max_completion_length=3000, # generate_max_length in the Training section of the repository beta=0.0, # beta in the Training section of the repository ) diff --git a/docs/source/rapidfire_integration.md b/docs/source/rapidfire_integration.md index 6305495826a..7530c01847f 100644 --- a/docs/source/rapidfire_integration.md +++ b/docs/source/rapidfire_integration.md @@ -226,7 +226,6 @@ from rapidfireai.automl import RFGRPOConfig training_args = RFGRPOConfig( learning_rate=5e-6, num_generations=8, - max_prompt_length=256, max_completion_length=256, # ... all other GRPOConfig parameters supported ) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index 1b8089337a9..05ad4a7be09 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -546,15 +546,6 @@ accelerate launch \ ### Configuration Tips -> [!TIP] -> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_prompt_length=None` in the [`RLOOConfig`]. This allows the model to process the full sequence length without truncating image tokens. -> -> ```python -> RLOOConfig(max_prompt_length=None, ...) -> ``` -> -> Only use `max_prompt_length` when you've verified that truncation won't remove image tokens for the entire dataset. - - Use LoRA on vision-language projection layers - Enable 4-bit quantization to reduce memory usage - VLMs are memory-intensive — start with smaller batch sizes diff --git a/examples/scripts/grpo_vlm.py b/examples/scripts/grpo_vlm.py index 62ddb975d5c..626945082ac 100644 --- a/examples/scripts/grpo_vlm.py +++ b/examples/scripts/grpo_vlm.py @@ -37,7 +37,6 @@ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ @@ -55,7 +54,6 @@ --output_dir grpo-SmolVLM2-2.2B-Instruct \ --learning_rate 1e-5 \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ diff --git a/examples/scripts/gspo.py b/examples/scripts/gspo.py index 3c587fdae6e..8523037b2c4 100644 --- a/examples/scripts/gspo.py +++ b/examples/scripts/gspo.py @@ -36,7 +36,6 @@ --output_dir gspo-Qwen3-0.6B \ --learning_rate 1e-5 \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ diff --git a/examples/scripts/gspo_vlm.py b/examples/scripts/gspo_vlm.py index cff9e241ca5..8836c9facb3 100644 --- a/examples/scripts/gspo_vlm.py +++ b/examples/scripts/gspo_vlm.py @@ -36,7 +36,6 @@ --output_dir gspo-Qwen2.5-VL-3B-Instruct \ --learning_rate 1e-5 \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index b3844a399c1..52abd51f96b 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1325,7 +1325,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = GRPOTrainer( @@ -1567,7 +1566,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, num_generations=3, max_completion_length=8, - max_prompt_length=18, report_to="none", use_vllm=True, vllm_mode="server", @@ -1609,7 +1607,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = GRPOTrainer( @@ -2059,7 +2056,6 @@ def reward_func(prompts, completions, **kwargs): gradient_accumulation_steps=2, # Maintain effective batch size num_generations=2, max_completion_length=8, # Much shorter completions - max_prompt_length=None, # Don't limit prompt length for VLM bf16=True, # Use bfloat16 precision max_steps=1, # Only do 1 training step to save time and memory report_to="none", diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 2d97d67bd8e..ab668d9f7a7 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -50,8 +50,6 @@ class GRPOConfig(TrainingArguments): remove_unused_columns (`bool`, *optional*, defaults to `False`): Whether to only keep the column `"prompt"` in the dataset. If you use a custom reward function that requires any column other than `"prompts"` and `"completions"`, you should keep this to `False`. - max_prompt_length (`int` or `None`, *optional*, defaults to `512`): - Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left. num_generations (`int` or `None`, *optional*, defaults to `8`): Number of generations per prompt to sample. The effective batch size (num_processes * per_device_batch_size * gradient_accumulation_steps) must be evenly divisible by this value. @@ -347,12 +345,6 @@ class GRPOConfig(TrainingArguments): "that requires any column other than 'prompts' and 'completions', you should keep this to `False`." }, ) - max_prompt_length: int | None = field( - default=512, - metadata={ - "help": "Maximum length of the prompt. If the prompt is longer than this value, it will be truncated left." - }, - ) num_generations: int | None = field( default=8, metadata={ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0da6c92e700..c2d5e38a476 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -14,6 +14,7 @@ import copy import inspect +import json import os import textwrap import time @@ -122,6 +123,20 @@ RolloutFunc = Callable[[list[str], "GRPOTrainer"], dict[str, Any]] +def parse_response(processing_class, ids): + outputs = [] + for seq in ids: + try: + parsed = processing_class.parse_response(seq) + # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here + parsed["content"] = parsed["content"].removesuffix(processing_class.eos_token) + except Exception: + content = processing_class.decode(seq, skip_special_tokens=True) + parsed = {"role": "assistant", "content": content} + outputs.append(parsed) + return outputs + + class GRPOTrainer(BaseTrainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -314,7 +329,9 @@ def __init__( # Processing class if processing_class is None: - processing_class = AutoProcessor.from_pretrained(get_config_model_id(model.config), truncation_side="left") + processing_class = AutoProcessor.from_pretrained( + get_config_model_id(model.config), truncation_side="left", padding_side="left" + ) # Handle pad token for processors or tokenizers if isinstance(processing_class, ProcessorMixin): @@ -409,7 +426,6 @@ def __init__( processing_class = add_response_schema(processing_class) # Training arguments - self.max_prompt_length = args.max_prompt_length self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper self.num_generations = args.num_generations # = G in the GRPO paper self.chat_template_kwargs = args.chat_template_kwargs or {} @@ -622,11 +638,6 @@ def cast_outputs_to_original_dtype(module, args, output): # Ensure distributed rendezvous variables are set without colliding across concurrent runs ensure_master_addr_port() - if self.max_prompt_length is not None and self.max_completion_length is not None: - max_model_len = self.max_prompt_length + self.max_completion_length - else: - max_model_len = None - vllm_quantization = None if is_bitsandbytes_available(): for _, module in model.named_modules(): @@ -642,7 +653,6 @@ def cast_outputs_to_original_dtype(module, args, output): max_num_seqs=self.args.per_device_train_batch_size * self.vllm_tensor_parallel_size * self.args.steps_per_generation, - max_model_len=max_model_len, distributed_executor_backend="external_launcher", # Feed identical seed for tp groups to ensure sampling results are the same across workers seed=self.accelerator.process_index // self.vllm_tensor_parallel_size, @@ -1223,7 +1233,6 @@ def _generate_single_turn(self, prompts: list): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - # "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding_regex": self.guided_decoding_regex, "generation_kwargs": self.args.generation_kwargs, } @@ -1309,7 +1318,6 @@ def _generate_single_turn(self, prompts: list): "top_k": -1 if self.top_k is None else self.top_k, "min_p": 0.0 if self.min_p is None else self.min_p, "max_tokens": self.max_completion_length, - # "truncate_prompt_tokens": self.max_prompt_length, "guided_decoding": guided_decoding, "logprobs": 0, # enable returning log probabilities; 0 means for the sampled tokens only } @@ -1332,7 +1340,21 @@ def _generate_single_turn(self, prompts: list): with profiling_context(self, "vLLM.generate"): if is_conversational({"prompt": prompts[0]}): - all_outputs = self.llm.chat(all_prompts, sampling_params=sampling_params, use_tqdm=False) + all_prompts = copy.deepcopy(all_prompts) + for conv in all_prompts: # iterate over each conversation + for msg in conv: # iterate over each message + if "tool_calls" in msg: # check if message has tool calls + for call in msg["tool_calls"]: + args = call["function"]["arguments"] + if isinstance(args, dict): # only convert dict → JSON string + call["function"]["arguments"] = json.dumps(args) + all_outputs = self.llm.chat( + all_prompts, + sampling_params=sampling_params, + use_tqdm=False, + chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, + ) else: all_outputs = self.llm.generate( all_prompts, sampling_params=sampling_params, use_tqdm=False @@ -1365,11 +1387,7 @@ def _generate_single_turn(self, prompts: list): self.llm.sleep(level=2) elif self.use_transformers_paged: - processor_kwargs = { - # "max_length": self.max_prompt_length, - "truncation": True, - "add_special_tokens": False, - } + processor_kwargs = {"truncation": True, "add_special_tokens": False} if is_conversational({"prompt": prompts[0]}): processor_outputs = self.processing_class.apply_chat_template( conversation=prompts, @@ -1414,7 +1432,6 @@ def _generate_single_turn(self, prompts: list): "return_tensors": "pt", "padding": True, "padding_side": "left", - # "max_length": self.max_prompt_length, "truncation": True, "add_special_tokens": False, } @@ -1469,17 +1486,16 @@ def _generate(self, prompts: list): prompts = copy.deepcopy(prompts) prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + completion_mask = [[1] * len(ids) for ids in completion_ids] + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): if ( version.parse(transformers.__version__) >= version.parse("5.0.0.dev0") # parse_response added in v5 and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors and self.processing_class.response_schema is not None # only works if the tokenizer has a schema ): - completions = self.processing_class.parse_response(completion_ids) - # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here - for completion in completions: - completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) + completions = parse_response(self.processing_class, completion_ids) completions = [[completion] for completion in completions] # format as list of messages else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) @@ -1487,74 +1503,102 @@ def _generate(self, prompts: list): else: completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - # Tool execution loop: check for tool calls and execute them, then regenerate completions with tool results - # appended to the prompt + # Extract tool calls from the completions if self.tools: - # Check for tool calls tool_calls = [completion[0].get("tool_calls") for completion in completions] - idxs_with_tool = [i for i, t in enumerate(tool_calls) if t] # find indices that actually have a tool call - tool_calls = [tool_calls[i] for i in idxs_with_tool] + idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] + tool_calls = [tool_calls[idx] for idx in idxs_with_tool] else: idxs_with_tool = [] + # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt while idxs_with_tool: - prompts_for_generation = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls - for idx, tool_call_list, prompt_for_generation in zip( - idxs_with_tool, tool_calls, prompts_for_generation, strict=True - ): - # Call the tools, and build the new prompt for generation - prompt_for_generation.append(completions[idx][-1]) + prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls + + # Call the tools, and build the new prompt for generation + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + tool_call_list = tool_calls[idx] + prompt_completion_tool = prompt_completion_tools[idx] + prompt_completion_tool.append(completions[idx_with_tool][-1]) for tool_call in tool_call_list: if tool_call["type"] == "function": function = tool_call["function"] try: result = self._tool_dict[function["name"]](**function["arguments"]) except Exception as e: - # store the full traceback as a string in the result + # Store the full traceback as a string in the result result = {"error": str(e), "traceback": traceback.format_exc()} else: result = {"error": f"Unsupported tool call type: {tool_call['type']}"} tool_call["result"] = result tool_message = {"role": "tool", "name": function["name"], "content": str(result)} - prompt_for_generation.append(tool_message) - completions[idx].append(tool_message) + prompt_completion_tool.append(tool_message) + completions[idx_with_tool].append(tool_message) - prompt_completion_tool_ids, post_tool_ids, _, _ = self._generate_single_turn(prompts_for_generation) - - # Truncate post-tool completion so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length - for i in range(len(post_tool_ids)): - excess_length = ( - len(prompt_completion_tool_ids[i]) - + len(post_tool_ids[i]) - - (self.max_prompt_length + self.max_completion_length) - ) - if excess_length > 0: - post_tool_ids[i] = post_tool_ids[i][:-excess_length] + # Generate new completions after tool execution + prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn( + prompt_completion_tools + ) # Qwen3 inserts \n\n tokens only for the latest user message which can cause discrepancies # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the # common prefix between the two. In most cases, this is a no-op. - for idx, pct in zip(idxs_with_tool, prompt_completion_tool_ids, strict=True): - prompt_ids[idx] = [ - tok for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx], pct, strict=False)) + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + prompt_ids[idx_with_tool] = [ + tok + for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx_with_tool], pct, strict=False)) ] - # Update completion_ids with the new completions after tool execution - for idx, pct, post_tool in zip(idxs_with_tool, prompt_completion_tool_ids, post_tool_ids, strict=True): - completion_ids[idx] = pct[len(prompt_ids[idx]) :] + post_tool + # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_len = len(prompt_ids[idx_with_tool]) + completion_tool_ids = prompt_completion_tool_ids[idx][prompt_len:] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If exceeding max length, truncate post_tool_ids + post_tool_ids[idx] = post_tool_ids[idx][:-excess_length] + if logprobs is not None: + post_tool_logprobs[idx] = post_tool_logprobs[idx][:-excess_length] + excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length + if excess_length > 0: + # If still exceeding max length, truncate completion_tool_ids as well + prompt_completion_tool_ids[idx] = completion_tool_ids[:-excess_length] + + # Update completion_mask: the tool result should be 0 and the post-tool 1 + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) + prompt_length = len(prompt_ids[idx_with_tool]) + completion_length = len(completion_ids[idx_with_tool]) + post_tool_length = len(post_tool_ids[idx]) + tool_length = prompt_completion_tool_length - prompt_length - completion_length + completion_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] + + # Update completion_ids with the new completions (after tool execution) + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + prompt_length = len(prompt_ids[idx_with_tool]) + pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool + completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] - post_tool_completions = self.processing_class.parse_response(post_tool_ids) - for completion in post_tool_completions: - completion["content"] = completion["content"].removesuffix(self.processing_class.eos_token) + # Decode post-tool completions + post_tool_completions = parse_response(self.processing_class, post_tool_ids) # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): - completions[idxs_with_tool[idx]].append(post_tool_completions[idx]) + idx_with_tool = idxs_with_tool[idx] + completions[idx_with_tool].append(post_tool_completions[idx]) # Check for further tool calls tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] - idxs_with_tool = [idx for idx, tc in zip(idxs_with_tool, tool_calls, strict=True) if tc] - tool_calls = [tc for tc in tool_calls if tc] + idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] + tool_calls = [tool_call for tool_call in tool_calls if tool_call] # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) @@ -1586,7 +1630,15 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) - return prompt_ids, completion_ids, total_completion_tokens, logprobs, extra_fields, completions + return ( + prompt_ids, + completion_ids, + completion_mask, + total_completion_tokens, + logprobs, + extra_fields, + completions, + ) def _generate_and_score_completions( self, inputs: list[dict[str, torch.Tensor | Any]] @@ -1618,6 +1670,7 @@ def _generate_and_score_completions( ( prompt_ids_list, completion_ids_list, + completion_mask_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields, @@ -1630,7 +1683,7 @@ def _generate_and_score_completions( prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] + completion_mask = [torch.tensor(ids, device=device) for ids in completion_mask_list] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") completion_mask = pad(completion_mask, padding_value=0, padding_side="right") if sampling_per_token_logps_list is not None: From 400bee41fc4720105982bfc73ac008a4c2820850 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 17 Nov 2025 17:37:38 +0000 Subject: [PATCH 102/153] move to chat template utils --- chat_template_utils.py | 391 +++++++++++++++++++++++++++++++++++++++++ trl/trainer/utils.py | 196 --------------------- 2 files changed, 391 insertions(+), 196 deletions(-) create mode 100644 chat_template_utils.py diff --git a/chat_template_utils.py b/chat_template_utils.py new file mode 100644 index 00000000000..26ec4329366 --- /dev/null +++ b/chat_template_utils.py @@ -0,0 +1,391 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functools +from typing import TypeVar + +from transformers import AutoTokenizer, PreTrainedTokenizer, ProcessorMixin + + +# These schemas are copy-pasted from https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py +cohere_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"}, + "tool_calls": { + "x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)", + "x-parser": "json", + "x-parser-args": { + "transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}" + }, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +ernie_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": "\n(.*?)\n?"}, + "thinking": {"type": "string", "x-regex": r"(?:^|\s*)(.*?)\s*<\/think>"}, + "tool_calls": { + "x-regex-iterator": "(.*?)", + "type": "array", + "items": { + "type": "object", + "x-parser": "json", + "x-parser-args": {"transform": "{type: 'function', function: @}"}, + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +gpt_oss_schema = { + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, + "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, + "tool_calls": { + "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, + "arguments": { + "type": "object", + "x-regex": r"<\|message\|>(.*)", + "x-parser": "json", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +smollm_schema = { + "x-regex": r"(?:\n?(?P.+?)\n?)?\s*(?:(?P.+?))?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-parser": "json", + "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "arguments": { + "type": "object", + "additionalProperties": {}, + }, + }, + }, + }, + }, + }, + }, +} + +qwen3_schema = { + "x-regex": r"^(?:(?:)?\s*(?P.+?)\s*)?\s*(?:(?P.*?)\s*)?\s*(?P.+?)?\s*$", + "type": "object", + "properties": { + "role": {"const": "assistant"}, + "content": {"type": "string"}, + "thinking": {"type": "string"}, + "tool_calls": { + "x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list + "type": "array", + "items": { + "type": "object", + "properties": { + "type": {"const": "function"}, + "function": { + "type": "object", + "properties": { + "name": {"type": "string", "x-regex": r""}, + "arguments": { + "type": "object", + "x-regex-key-value": r"\w+)>\n(?P.*?)\n", + "additionalProperties": { + "x-parser": "json", + "x-parser-args": {"allow_non_json": True}, + }, + }, + }, + }, + }, + }, + }, + }, +} + + +TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) + + +def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: + """ + Adds the appropriate response schema to the given tokenizer or processor based on its chat template. + + At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While + waiting for broader adoption, we provide this utility function to manually set the response schema for known chat + templates. + + Args: + processor (`TokenizerOrProcessor`): + Tokenizer or processor to which the response schema will be added. + + Returns: + `TokenizerOrProcessor`: + Tokenizer or processor with the added response schema. + """ + qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template + if processor.chat_template == qwen3_chat_template: + # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See + # https://github.com/huggingface/transformers/issues/42220 + processor.response_schema = smollm_schema + return processor + raise ValueError( + "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " + "tokenizer or processor." + ) + + +def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: + """ + Check whether the chat template preserves prefixes when applied. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer instance to check. + + Returns: + `bool`: + `True` if the chat template preserves prefixes, `False` otherwise. + """ + messages1 = [ + {"role": "user", "content": "What color is the sky?"}, + ] + messages2 = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + messages3 = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "And at night?"}, + ] + + text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False) + text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + + return text2.startswith(text1) and text3.startswith(text2) + + +qwen3_training_chat_template = r""" +{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} +{%- endif %} +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %} +""" + + +def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: + """ + Wrap `tokenizer.apply_chat_template()` to use a training-compatible chat template if needed. + + During training, we need a *prefix-preserving* template where each message strictly appends to previous ones. For + example: + + ```python + turn0 = {"role": "user", "content": "Hello!"} + turn1 = {"role": "assistant", "content": "Hi!"} + text0 = tokenizer.apply_chat_template([turn0], add_generation_prompt=True) + text1 = tokenizer.apply_chat_template([turn0, turn1]) + assert text1.startswith(text0) + ``` + + Tokenizers typically use inference-ready templates that may differ from the template used in training. The + inference template may not satisfy the prefix-preservation requirement. For example, Qwen3 and OpenAI GPT OSS drop + thinking blocks from non-final turns. + + This function first checks if the template is prefix-preserving. If it is, no patching is needed. If not, it + patches the `apply_chat_template()` method to temporarily swap in a training-compatible template during calls, then + restore the original afterward. This ensures the chat template complies with training needs while preserving the + original template for later inference. + + Currently supported: Qwen3 models only. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer instance to patch. + + Returns: + `PreTrainedTokenizer`: + The same tokenizer with `apply_chat_template()` patched (if needed and supported). + """ + # First check if patching is needed + if is_chat_template_prefix_preserving(tokenizer): + return tokenizer # No patching needed + + original_method = tokenizer.apply_chat_template + original_chat_template = tokenizer.chat_template + + qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template + if tokenizer.chat_template == qwen3_chat_template: + chat_template_for_training = qwen3_training_chat_template + else: + raise ValueError( + "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " + "Please manually modify the tokenizer's chat template for training." + ) + + @functools.wraps(original_method) + def wrapper(self, *args, **kwargs): + tokenizer.chat_template = chat_template_for_training + try: + result = original_method(self, *args, **kwargs) + finally: + tokenizer.chat_template = original_chat_template + return result + + tokenizer.apply_chat_template = wrapper + return tokenizer diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 6714dc46070..02d4cc78073 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -38,15 +38,12 @@ from torch.utils.data import Sampler from transformers import ( AutoConfig, - AutoTokenizer, BitsAndBytesConfig, EvalPrediction, GenerationConfig, PretrainedConfig, PreTrainedModel, - PreTrainedTokenizer, PreTrainedTokenizerBase, - ProcessorMixin, TrainerState, TrainingArguments, is_comet_available, @@ -2030,196 +2027,3 @@ def get_config_model_id(config: PretrainedConfig) -> str: """ # Fall back to `config.text_config._name_or_path` if `config._name_or_path` is missing: Qwen2-VL and Qwen2.5-VL. See GH-4323 return getattr(config, "_name_or_path", "") or getattr(getattr(config, "text_config", None), "_name_or_path", "") - - -# These schemas are copy-pasted from https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py -cohere_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"}, - "thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"}, - "tool_calls": { - "x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)", - "x-parser": "json", - "x-parser-args": { - "transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}" - }, - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -ernie_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": "\n(.*?)\n?"}, - "thinking": {"type": "string", "x-regex": r"(?:^|\s*)(.*?)\s*<\/think>"}, - "tool_calls": { - "x-regex-iterator": "(.*?)", - "type": "array", - "items": { - "type": "object", - "x-parser": "json", - "x-parser-args": {"transform": "{type: 'function', function: @}"}, - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -gpt_oss_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, - "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, - "tool_calls": { - "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, - "arguments": { - "type": "object", - "x-regex": r"<\|message\|>(.*)", - "x-parser": "json", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -smollm_schema = { - "x-regex": r"(?:\n?(?P.+?)\n?)?\s*(?:(?P.+?))?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string"}, - "thinking": {"type": "string"}, - "tool_calls": { - "x-parser": "json", - "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -qwen3_schema = { - "x-regex": r"^(?:(?:)?\s*(?P.+?)\s*)?\s*(?:(?P.*?)\s*)?\s*(?P.+?)?\s*$", - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string"}, - "thinking": {"type": "string"}, - "tool_calls": { - "x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string", "x-regex": r""}, - "arguments": { - "type": "object", - "x-regex-key-value": r"\w+)>\n(?P.*?)\n", - "additionalProperties": { - "x-parser": "json", - "x-parser-args": {"allow_non_json": True}, - }, - }, - }, - }, - }, - }, - }, - }, -} - - -TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) - - -def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: - """ - Adds the appropriate response schema to the given tokenizer or processor based on its chat template. - - At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While - waiting for broader adoption, we provide this utility function to manually set the response schema for known chat - templates. - - Args: - processor (`TokenizerOrProcessor`): - Tokenizer or processor to which the response schema will be added. - - Returns: - `TokenizerOrProcessor`: - Tokenizer or processor with the added response schema. - """ - qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template - if processor.chat_template == qwen3_chat_template: - # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See - # https://github.com/huggingface/transformers/issues/42220 - processor.response_schema = smollm_schema - return processor - raise ValueError( - "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " - "tokenizer or processor." - ) From b86483c0d2eec4340be6b191c117b9d16e5a7c64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 17 Nov 2025 19:23:19 +0000 Subject: [PATCH 103/153] tool mask --- .../chat_template_utils.py | 39 +++++++- trl/extras/vllm_client.py | 6 ++ trl/trainer/grpo_trainer.py | 90 ++++++++++--------- 3 files changed, 89 insertions(+), 46 deletions(-) rename chat_template_utils.py => trl/chat_template_utils.py (90%) diff --git a/chat_template_utils.py b/trl/chat_template_utils.py similarity index 90% rename from chat_template_utils.py rename to trl/chat_template_utils.py index 26ec4329366..4a429b6b3a3 100644 --- a/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -371,7 +371,7 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template if tokenizer.chat_template == qwen3_chat_template: - chat_template_for_training = qwen3_training_chat_template + tokenizer._training_chat_template = qwen3_training_chat_template else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " @@ -380,7 +380,7 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain @functools.wraps(original_method) def wrapper(self, *args, **kwargs): - tokenizer.chat_template = chat_template_for_training + tokenizer.chat_template = tokenizer._training_chat_template try: result = original_method(self, *args, **kwargs) finally: @@ -389,3 +389,38 @@ def wrapper(self, *args, **kwargs): tokenizer.apply_chat_template = wrapper return tokenizer + + +def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list[dict]: + """ + Parse token sequences into structured response dictionaries with fallback handling. + + Attempts to parse each sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool + calls like `{"type":"function"`), falls back to decoding as plain text. + + Also removes incorrectly appended EOS tokens from tool call content when present. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer with a `parse_response()` method. + ids (`list[list[int]]`): + List of token sequences. + + Returns: + `list[dict]`: + List of response dictionaries. + """ + + outputs = [] + for seq in ids: + try: + parsed = tokenizer.parse_response(seq) + # Hotfix: remove incorrectly appended EOS token from tool calls + # See https://github.com/huggingface/transformers/issues/42249 + parsed["content"] = parsed["content"].removesuffix(tokenizer.eos_token) + except Exception: + # Fallback: decode as plain text if parsing fails. This happens if the model outputs malformed tool calls. + content = tokenizer.decode(seq, skip_special_tokens=True) + parsed = {"role": "assistant", "content": content} + outputs.append(parsed) + return outputs diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index 5ddc6150a59..b38bbf16fee 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -281,6 +281,7 @@ def chat( guided_decoding_regex: str | None = None, generation_kwargs: dict | None = None, chat_template_kwargs: dict | None = None, + tools: list | None = None, ) -> dict[str, list[list[int]]]: """ Generates model completions for the provided chat messages. @@ -315,6 +316,8 @@ def chat( will override them. chat_template_kwargs (`dict`, *optional*): Additional keyword arguments to customize the chat template used by the model. + tools (`list`, *optional*): + List of tool functions available for tool calling during chat generation. Returns: `dict` with keys: @@ -325,6 +328,9 @@ def chat( - `logprobs` (`list[list[float]]`): List of lists of log probabilities for each generated token. """ + if tools is not None: + raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().") + url = f"{self.base_url}/chat/" # Convert PIL images to base64 strings diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c2d5e38a476..d93098c4910 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -24,7 +24,6 @@ from collections.abc import Callable from contextlib import nullcontext from functools import partial -from itertools import takewhile from pathlib import Path from typing import Any @@ -57,6 +56,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available, is_rich_available +from ..chat_template_utils import add_response_schema, parse_response, patch_chat_template_for_training from ..data_utils import ( apply_chat_template, is_conversational, @@ -73,7 +73,6 @@ from .grpo_config import GRPOConfig from .utils import ( RepeatSampler, - add_response_schema, disable_dropout_in_model, ensure_master_addr_port, entropy_from_logits, @@ -123,20 +122,6 @@ RolloutFunc = Callable[[list[str], "GRPOTrainer"], dict[str, Any]] -def parse_response(processing_class, ids): - outputs = [] - for seq in ids: - try: - parsed = processing_class.parse_response(seq) - # Hotfix: when there is a tool call, the content wrongly includes the EOS token, so we remove it here - parsed["content"] = parsed["content"].removesuffix(processing_class.eos_token) - except Exception: - content = processing_class.decode(seq, skip_special_tokens=True) - parsed = {"role": "assistant", "content": content} - outputs.append(parsed) - return outputs - - class GRPOTrainer(BaseTrainer): """ Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the @@ -424,6 +409,7 @@ def __init__( # known chat templates. if tools and not processing_class.response_schema: processing_class = add_response_schema(processing_class) + processing_class = patch_chat_template_for_training(processing_class) # Training arguments self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -663,6 +649,7 @@ def cast_outputs_to_original_dtype(module, args, output): # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", quantization=vllm_quantization, + enforce_eager=True, ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) @@ -1199,6 +1186,7 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): def _generate_single_turn(self, prompts: list): device = self.accelerator.device + # all_prompts = copy.deepcopy(all_prompts) # to avoid modifying the input list # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1215,6 +1203,16 @@ def _generate_single_turn(self, prompts: list): if is_conversational({"prompt": prompts[0]}): prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] + # In vLLM, tool call arguments must be JSON strings. + # See https://github.com/vllm-project/vllm/pull/28820 + for prompt in prompts: # iterate over each conversation + for message in prompt: # iterate over each message + if "tool_calls" in message: # check if message has tool calls + for call in message["tool_calls"]: + args = call["function"]["arguments"] + if isinstance(args, dict): # only convert dict → JSON string + call["function"]["arguments"] = json.dumps(args) + # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": all_prompts = gather_object(prompts) @@ -1253,6 +1251,7 @@ def _generate_single_turn(self, prompts: list): messages=ordered_set_of_prompts, **sampling_params, chat_template_kwargs=self.chat_template_kwargs, + tools=self.tools, ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) @@ -1340,20 +1339,13 @@ def _generate_single_turn(self, prompts: list): with profiling_context(self, "vLLM.generate"): if is_conversational({"prompt": prompts[0]}): - all_prompts = copy.deepcopy(all_prompts) - for conv in all_prompts: # iterate over each conversation - for msg in conv: # iterate over each message - if "tool_calls" in msg: # check if message has tool calls - for call in msg["tool_calls"]: - args = call["function"]["arguments"] - if isinstance(args, dict): # only convert dict → JSON string - call["function"]["arguments"] = json.dumps(args) all_outputs = self.llm.chat( all_prompts, sampling_params=sampling_params, use_tqdm=False, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, + chat_template=self.processing_class._training_chat_template, ) else: all_outputs = self.llm.generate( @@ -1486,7 +1478,6 @@ def _generate(self, prompts: list): prompts = copy.deepcopy(prompts) prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) - completion_mask = [[1] * len(ids) for ids in completion_ids] # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): @@ -1508,8 +1499,10 @@ def _generate(self, prompts: list): tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] + tool_mask = [[0] * len(ids) for ids in completion_ids] else: idxs_with_tool = [] + tool_mask = None # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt while idxs_with_tool: @@ -1547,10 +1540,12 @@ def _generate(self, prompts: list): for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool - prompt_ids[idx_with_tool] = [ - tok - for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx_with_tool], pct, strict=False)) - ] + # prompt_ids[idx_with_tool] = [ + # tok + # for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx_with_tool], pct, strict=False)) + # ] + # sanity check + assert prompt_ids[idx_with_tool] == pct[: len(prompt_ids[idx_with_tool])] # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length for idx in range(len(idxs_with_tool)): @@ -1568,7 +1563,7 @@ def _generate(self, prompts: list): # If still exceeding max length, truncate completion_tool_ids as well prompt_completion_tool_ids[idx] = completion_tool_ids[:-excess_length] - # Update completion_mask: the tool result should be 0 and the post-tool 1 + # Update tool_mask: the tool result should be 1 and the post-tool 0 for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) @@ -1576,7 +1571,7 @@ def _generate(self, prompts: list): completion_length = len(completion_ids[idx_with_tool]) post_tool_length = len(post_tool_ids[idx]) tool_length = prompt_completion_tool_length - prompt_length - completion_length - completion_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length + tool_mask[idx_with_tool] += [1] * tool_length + [0] * post_tool_length if logprobs is not None: logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] @@ -1633,11 +1628,11 @@ def _generate(self, prompts: list): return ( prompt_ids, completion_ids, - completion_mask, + tool_mask, + completions, total_completion_tokens, logprobs, extra_fields, - completions, ) def _generate_and_score_completions( @@ -1670,11 +1665,11 @@ def _generate_and_score_completions( ( prompt_ids_list, completion_ids_list, - completion_mask_list, + tool_mask_list, + completions, num_items_in_batch, sampling_per_token_logps_list, extra_fields, - completions, ) = self._generate(prompts) # Convert lists of token IDs to padded tensors @@ -1683,7 +1678,7 @@ def _generate_and_score_completions( prompt_ids = pad(prompt_ids, padding_value=self.pad_token_id, padding_side="left") prompt_mask = pad(prompt_mask, padding_value=0, padding_side="left") completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids_list] - completion_mask = [torch.tensor(ids, device=device) for ids in completion_mask_list] + completion_mask = [torch.ones_like(ids, dtype=torch.long) for ids in completion_ids] completion_ids = pad(completion_ids, padding_value=self.pad_token_id, padding_side="right") completion_mask = pad(completion_mask, padding_value=0, padding_side="right") if sampling_per_token_logps_list is not None: @@ -1691,6 +1686,9 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None + if self.tools: + tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=0, padding_side="right") # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -1859,7 +1857,8 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - delta = delta[completion_mask.bool()] + mask = completion_mask.bool() if self.tools is None else (completion_mask * (1 - tool_mask)).bool() + delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( @@ -1869,7 +1868,7 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + flat_is_ratio = importance_sampling_ratio[mask] min_importance_sampling_ratio = ( torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) ) @@ -1915,6 +1914,8 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images + if self.tools is not None: + output["tool_mask"] = tool_mask return output def compute_liger_loss(self, unwrapped_model, inputs): @@ -2059,31 +2060,32 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl + mask = completion_mask if self.tools is None else completion_mask * (1 - inputs["tool_mask"]) if self.loss_type == "grpo": - loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean() + loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "bnpo": - loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0) + loss = (per_token_loss * mask).sum() / mask.sum().clamp(min=1.0) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type == "dr_grpo": - loss = (per_token_loss * completion_mask).sum() / (per_token_loss.size(0) * self.max_completion_length) + loss = (per_token_loss * mask).sum() / (per_token_loss.size(0) * self.max_completion_length) loss = loss / self.current_gradient_accumulation_steps elif self.loss_type in ["cispo", "dapo"]: normalizer = inputs["num_items_in_batch"] / self.accelerator.num_processes - loss = (per_token_loss * completion_mask).sum() / normalizer + loss = (per_token_loss * mask).sum() / normalizer else: raise ValueError(f"Unknown loss type: {self.loss_type}") # Log the metrics mode = "train" if self.model.training else "eval" - completion_token_count = completion_mask.sum().clamp(min=1.0) + completion_token_count = mask.sum().clamp(min=1.0) def masked_batch_mean(x): if x.shape[1] == 1: # when importance_sampling_level == "sequence" return x.mean() else: - return (x * completion_mask).sum() / completion_token_count + return (x * mask).sum() / completion_token_count if self.beta != 0.0: mean_kl = masked_batch_mean(per_token_kl) From 93c79992c3f1a6295977b8cc62eff43f3e42f643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Mon, 17 Nov 2025 19:33:43 +0000 Subject: [PATCH 104/153] hard coded chat template --- trl/chat_template_utils.py | 100 ++++++++++++++++++++++++++++++++++--- 1 file changed, 94 insertions(+), 6 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 4a429b6b3a3..74ee6f9509a 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -243,8 +243,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) and text3.startswith(text2) -qwen3_training_chat_template = r""" -{%- if tools %} +qwen3_chat_template = r"""{%- if tools %} {{- '<|im_start|>system\n' }} {%- if messages[0].role == 'system' %} {{- messages[0].content + '\n\n' }} @@ -286,7 +285,15 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content) or (not loop.first) %} @@ -324,8 +331,89 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- if enable_thinking is defined and enable_thinking is false %} {{- '\n\n\n\n' }} {%- endif %} +{%- endif %}""" + +qwen3_training_chat_template = r"""{%- if tools %} + {{- '<|im_start|>system\n' }} + {%- if messages[0].role == 'system' %} + {{- messages[0].content + '\n\n' }} + {%- endif %} + {{- "# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n" }} + {%- for tool in tools %} + {{- "\n" }} + {{- tool | tojson }} + {%- endfor %} + {{- "\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{\"name\": , \"arguments\": }\n<|im_end|>\n" }} +{%- else %} + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} {%- endif %} -""" +{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %} +{%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if ns.multi_step_tool and message.role == "user" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %} + {%- set ns.multi_step_tool = false %} + {%- set ns.last_query_index = index %} + {%- endif %} +{%- endfor %} +{%- for message in messages %} + {%- if message.content is string %} + {%- set content = message.content %} + {%- else %} + {%- set content = '' %} + {%- endif %} + {%- if (message.role == "user") or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- if message.tool_calls %} + {%- for tool_call in message.tool_calls %} + {%- if (loop.first and content) or (not loop.first) %} + {{- '\n' }} + {%- endif %} + {%- if tool_call.function %} + {%- set tool_call = tool_call.function %} + {%- endif %} + {{- '\n{"name": "' }} + {{- tool_call.name }} + {{- '", "arguments": ' }} + {%- if tool_call.arguments is string %} + {{- tool_call.arguments }} + {%- else %} + {{- tool_call.arguments | tojson }} + {%- endif %} + {{- '}\n' }} + {%- endfor %} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- elif message.role == "tool" %} + {%- if loop.first or (messages[loop.index0 - 1].role != "tool") %} + {{- '<|im_start|>user' }} + {%- endif %} + {{- '\n\n' }} + {{- content }} + {{- '\n' }} + {%- if loop.last or (messages[loop.index0 + 1].role != "tool") %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} +{%- endif %}""" def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: @@ -350,7 +438,8 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain This function first checks if the template is prefix-preserving. If it is, no patching is needed. If not, it patches the `apply_chat_template()` method to temporarily swap in a training-compatible template during calls, then restore the original afterward. This ensures the chat template complies with training needs while preserving the - original template for later inference. + original template for later inference. It also stores the training template in a `_training_chat_template` + attribute, which is useful when you need to access it—for example, when using vLLM inference. Currently supported: Qwen3 models only. @@ -369,7 +458,6 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain original_method = tokenizer.apply_chat_template original_chat_template = tokenizer.chat_template - qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template if tokenizer.chat_template == qwen3_chat_template: tokenizer._training_chat_template = qwen3_training_chat_template else: From 24ea4a455d827ce3a9b25d5cbb601f277955df60 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 06:13:52 +0000 Subject: [PATCH 105/153] almost done!! --- trl/chat_template_utils.py | 156 ++++++++++++++++++------------------ trl/trainer/grpo_trainer.py | 62 +++++++++++--- 2 files changed, 127 insertions(+), 91 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 74ee6f9509a..d31b91994da 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -15,7 +15,7 @@ import functools from typing import TypeVar -from transformers import AutoTokenizer, PreTrainedTokenizer, ProcessorMixin +from transformers import PreTrainedTokenizer, ProcessorMixin # These schemas are copy-pasted from https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py @@ -145,15 +145,17 @@ }, } + qwen3_schema = { - "x-regex": r"^(?:(?:)?\s*(?P.+?)\s*)?\s*(?:(?P.*?)\s*)?\s*(?P.+?)?\s*$", + "x-regex": r"^(?:\n?(?P.+?)\n?\s*)?(?P.*?)(?=(?:|<\|im_end\|>|$))(?:(?P.+?))?\s*(?:<\|im_end\|>|$)", "type": "object", "properties": { "role": {"const": "assistant"}, "content": {"type": "string"}, "thinking": {"type": "string"}, "tool_calls": { - "x-regex-iterator": r"^(.*)$", # We have already extracted tool calls and there can only be one, so just make it a list + "x-parser": "json", + "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, "type": "array", "items": { "type": "object", @@ -162,14 +164,10 @@ "function": { "type": "object", "properties": { - "name": {"type": "string", "x-regex": r""}, + "name": {"type": "string"}, "arguments": { "type": "object", - "x-regex-key-value": r"\w+)>\n(?P.*?)\n", - "additionalProperties": { - "x-parser": "json", - "x-parser-args": {"allow_non_json": True}, - }, + "additionalProperties": {}, }, }, }, @@ -179,70 +177,7 @@ }, } - -TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) - - -def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: - """ - Adds the appropriate response schema to the given tokenizer or processor based on its chat template. - - At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While - waiting for broader adoption, we provide this utility function to manually set the response schema for known chat - templates. - - Args: - processor (`TokenizerOrProcessor`): - Tokenizer or processor to which the response schema will be added. - - Returns: - `TokenizerOrProcessor`: - Tokenizer or processor with the added response schema. - """ - qwen3_chat_template = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B").chat_template - if processor.chat_template == qwen3_chat_template: - # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See - # https://github.com/huggingface/transformers/issues/42220 - processor.response_schema = smollm_schema - return processor - raise ValueError( - "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " - "tokenizer or processor." - ) - - -def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: - """ - Check whether the chat template preserves prefixes when applied. - - Args: - tokenizer (`PreTrainedTokenizer`): - Tokenizer instance to check. - - Returns: - `bool`: - `True` if the chat template preserves prefixes, `False` otherwise. - """ - messages1 = [ - {"role": "user", "content": "What color is the sky?"}, - ] - messages2 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - ] - messages3 = [ - {"role": "user", "content": "What color is the sky?"}, - {"role": "assistant", "content": "It is blue."}, - {"role": "user", "content": "And at night?"}, - ] - - text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) - text2 = tokenizer.apply_chat_template(messages2, tokenize=False) - text3 = tokenizer.apply_chat_template(messages3, tokenize=False) - - return text2.startswith(text1) and text3.startswith(text2) - - +# docstyle-ignore qwen3_chat_template = r"""{%- if tools %} {{- '<|im_start|>system\n' }} {%- if messages[0].role == 'system' %} @@ -333,6 +268,69 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %} {%- endif %}""" +TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) + + +def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: + """ + Adds the appropriate response schema to the given tokenizer or processor based on its chat template. + + At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While + waiting for broader adoption, we provide this utility function to manually set the response schema for known chat + templates. + + Args: + processor (`TokenizerOrProcessor`): + Tokenizer or processor to which the response schema will be added. + + Returns: + `TokenizerOrProcessor`: + Tokenizer or processor with the added response schema. + """ + if processor.chat_template == qwen3_chat_template: + # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See + # https://github.com/huggingface/transformers/issues/42220 + processor.response_schema = qwen3_schema + return processor + raise ValueError( + "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " + "tokenizer or processor." + ) + + +def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: + """ + Check whether the chat template preserves prefixes when applied. + + Args: + tokenizer (`PreTrainedTokenizer`): + Tokenizer instance to check. + + Returns: + `bool`: + `True` if the chat template preserves prefixes, `False` otherwise. + """ + messages1 = [ + {"role": "user", "content": "What color is the sky?"}, + ] + messages2 = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + ] + messages3 = [ + {"role": "user", "content": "What color is the sky?"}, + {"role": "assistant", "content": "It is blue."}, + {"role": "user", "content": "And at night?"}, + ] + + text1 = tokenizer.apply_chat_template(messages1, tokenize=False, add_generation_prompt=True) + text2 = tokenizer.apply_chat_template(messages2, tokenize=False) + text3 = tokenizer.apply_chat_template(messages3, tokenize=False) + + return text2.startswith(text1) and text3.startswith(text2) + + +# docstyle-ignore qwen3_training_chat_template = r"""{%- if tools %} {{- '<|im_start|>system\n' }} {%- if messages[0].role == 'system' %} @@ -420,7 +418,7 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain """ Wrap `tokenizer.apply_chat_template()` to use a training-compatible chat template if needed. - During training, we need a *prefix-preserving* template where each message strictly appends to previous ones. For + During training, we need a *prefix-preserving* template where each message strictly appends to previous ones. For example: ```python @@ -431,8 +429,8 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain assert text1.startswith(text0) ``` - Tokenizers typically use inference-ready templates that may differ from the template used in training. The - inference template may not satisfy the prefix-preservation requirement. For example, Qwen3 and OpenAI GPT OSS drop + Tokenizers typically use inference-ready templates that may differ from the template used in training. The + inference template may not satisfy the prefix-preservation requirement. For example, Qwen3 and OpenAI GPT OSS drop thinking blocks from non-final turns. This function first checks if the template is prefix-preserving. If it is, no patching is needed. If not, it @@ -467,10 +465,10 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain ) @functools.wraps(original_method) - def wrapper(self, *args, **kwargs): + def wrapper(*args, **kwargs): tokenizer.chat_template = tokenizer._training_chat_template try: - result = original_method(self, *args, **kwargs) + result = original_method(*args, **kwargs) finally: tokenizer.chat_template = original_chat_template return result @@ -483,7 +481,7 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list """ Parse token sequences into structured response dictionaries with fallback handling. - Attempts to parse each sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool + Attempts to parse each sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool calls like `{"type":"function"`), falls back to decoding as plain text. Also removes incorrectly appended EOS tokens from tool call content when present. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d93098c4910..f130405f10c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1186,7 +1186,6 @@ def _calculate_rewards(self, inputs, prompts, completions, completion_ids_list): def _generate_single_turn(self, prompts: list): device = self.accelerator.device - # all_prompts = copy.deepcopy(all_prompts) # to avoid modifying the input list # Generate completions using either vLLM or regular generation if self.use_vllm: @@ -1345,7 +1344,7 @@ def _generate_single_turn(self, prompts: list): use_tqdm=False, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, - chat_template=self.processing_class._training_chat_template, + chat_template=self.processing_class._training_chat_template if self.tools else None, ) else: all_outputs = self.llm.generate( @@ -1500,6 +1499,8 @@ def _generate(self, prompts: list): idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] tool_mask = [[0] * len(ids) for ids in completion_ids] + tool_call_count = 0 + tool_failure_count = 0 else: idxs_with_tool = [] tool_mask = None @@ -1508,6 +1509,10 @@ def _generate(self, prompts: list): while idxs_with_tool: prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls + # Tokenize the current prompt. We will use this to filter out overlong samples later. + kwargs = dict(tools=self.tools, add_generation_prompt=True, tokenize=True, **self.chat_template_kwargs) + p_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] + # Call the tools, and build the new prompt for generation for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] @@ -1515,6 +1520,7 @@ def _generate(self, prompts: list): prompt_completion_tool = prompt_completion_tools[idx] prompt_completion_tool.append(completions[idx_with_tool][-1]) for tool_call in tool_call_list: + tool_call_count += 1 if tool_call["type"] == "function": function = tool_call["function"] try: @@ -1522,6 +1528,8 @@ def _generate(self, prompts: list): except Exception as e: # Store the full traceback as a string in the result result = {"error": str(e), "traceback": traceback.format_exc()} + # keep track of how many times each tool failed + tool_failure_count += 1 else: result = {"error": f"Unsupported tool call type: {tool_call['type']}"} tool_call["result"] = result @@ -1529,6 +1537,21 @@ def _generate(self, prompts: list): prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) + # Tokenize and filter samples whose length exceeds max allowed length. This is important, because if vLLM + # is called with an input longer than its max model length, it will error out. + pct_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] + overlong = [len(pct) - len(p) >= self.max_completion_length for p, pct in zip(p_ids, pct_ids, strict=True)] + if logprobs is not None: + for idx in range(len(idxs_with_tool)): + if overlong[idx]: + num_tokens = len(pct_ids[idx]) - len(p_ids[idx]) + logprobs[idxs_with_tool[idx]] += [0.0] * num_tokens + tool_mask[idxs_with_tool[idx]] += [0] * num_tokens + idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] + prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] + if not idxs_with_tool: + break # all overlong, exit tool loop + # Generate new completions after tool execution prompt_completion_tool_ids, post_tool_ids, post_tool_logprobs, _ = self._generate_single_turn( prompt_completion_tools @@ -1540,11 +1563,6 @@ def _generate(self, prompts: list): for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool - # prompt_ids[idx_with_tool] = [ - # tok - # for tok, _ in takewhile(lambda x: x[0] == x[1], zip(prompt_ids[idx_with_tool], pct, strict=False)) - # ] - # sanity check assert prompt_ids[idx_with_tool] == pct[: len(prompt_ids[idx_with_tool])] # Truncate so that pct[len(prompt_ids[idx]) :] + post_tool does not exceed max_completion_length @@ -1561,7 +1579,7 @@ def _generate(self, prompts: list): excess_length = len(completion_tool_ids) + len(post_tool_ids[idx]) - self.max_completion_length if excess_length > 0: # If still exceeding max length, truncate completion_tool_ids as well - prompt_completion_tool_ids[idx] = completion_tool_ids[:-excess_length] + prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] # Update tool_mask: the tool result should be 1 and the post-tool 0 for idx in range(len(idxs_with_tool)): @@ -1588,13 +1606,23 @@ def _generate(self, prompts: list): # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] - completions[idx_with_tool].append(post_tool_completions[idx]) + if ( + post_tool_completions[idx]["content"] or "tool_calls" in post_tool_completions[idx] + ): # when the post-tool if completly truncated, content is empty + completions[idx_with_tool].append(post_tool_completions[idx]) # Check for further tool calls tool_calls = [completion.get("tool_calls") for completion in post_tool_completions] idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] + if [len(ids) for ids in completion_ids] != [len(p) for p in logprobs]: + raise ValueError( + "Length mismatch between completion_ids and logprobs after tool execution. " + f"completion_ids lengths: {[len(ids) for ids in completion_ids]}, " + f"logprobs lengths: {[len(p) for p in logprobs]}" + ) + # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) @@ -1625,6 +1653,16 @@ def _generate(self, prompts: list): self._metrics[mode]["completions/min_terminated_length"].append(term_completion_lengths.float().min().item()) self._metrics[mode]["completions/max_terminated_length"].append(term_completion_lengths.float().max().item()) + if self.tools: + agg_tool_call_count = self.accelerator.gather(torch.tensor(tool_call_count, device=device)).sum() + tool_call_frequency = (agg_tool_call_count / len(agg_prompt_lengths)).item() + self._metrics[mode]["tools/call_frequency"].append(tool_call_frequency) + agg_tool_failure_count = self.accelerator.gather(torch.tensor(tool_failure_count, device=device)).sum() + failure_frequency = ( + (agg_tool_failure_count / agg_tool_call_count).item() if agg_tool_call_count > 0 else 0.0 + ) + self._metrics[mode]["tools/failure_frequency"].append(failure_frequency) + return ( prompt_ids, completion_ids, @@ -1857,7 +1895,7 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - mask = completion_mask.bool() if self.tools is None else (completion_mask * (1 - tool_mask)).bool() + mask = completion_mask.bool() if not self.tools else (completion_mask * (1 - tool_mask)).bool() delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) @@ -1914,7 +1952,7 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images - if self.tools is not None: + if self.tools: output["tool_mask"] = tool_mask return output @@ -2060,7 +2098,7 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - mask = completion_mask if self.tools is None else completion_mask * (1 - inputs["tool_mask"]) + mask = completion_mask if not self.tools else completion_mask * (1 - inputs["tool_mask"]) if self.loss_type == "grpo": loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() loss = loss / self.current_gradient_accumulation_steps From 9dfc51177fa7f0c66133c42c4b783a9219ba0383 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 19:20:42 +0000 Subject: [PATCH 106/153] fix chat template --- trl/chat_template_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index d31b91994da..0baccce9d18 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -368,12 +368,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- if message.reasoning_content is string %} {%- set reasoning_content = message.reasoning_content %} {%- else %} - {%- if '' in content %} + {%- if '' in content and '' in content %} # Modify this to always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} # Always include thinking block during training {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content) or (not loop.first) %} From 2542320fddd58bc46158fb1d6fe55dc67141adc7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 19:21:08 +0000 Subject: [PATCH 107/153] just report error (not the traceback --- trl/trainer/grpo_trainer.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index ffad13805fc..73947dd0d5c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1521,7 +1521,7 @@ def _generate(self, prompts: list): result = self._tool_dict[function["name"]](**function["arguments"]) except Exception as e: # Store the full traceback as a string in the result - result = {"error": str(e), "traceback": traceback.format_exc()} + result = {"error": str(e)} # keep track of how many times each tool failed tool_failure_count += 1 else: @@ -1551,9 +1551,7 @@ def _generate(self, prompts: list): prompt_completion_tools ) - # Qwen3 inserts \n\n tokens only for the latest user message which can cause discrepancies - # between the prompt alone and the combined prompt+completion. To ensure consistency, we extract the - # common prefix between the two. In most cases, this is a no-op. + # Sanity check: from experience, this is useful to catch bugs in the chat template for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] pct = prompt_completion_tool_ids[idx] # = prompt-completion-tool From 1db53c1373e78b2ce48282a75c786174058d78d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 19:21:29 +0000 Subject: [PATCH 108/153] style --- trl/trainer/grpo_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 73947dd0d5c..0b8b0173f9a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -18,7 +18,6 @@ import os import textwrap import time -import traceback import warnings from collections import defaultdict, deque from collections.abc import Callable From f31996a1f4d363d5fab946ff69c610d1743125ce Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 21:50:46 +0000 Subject: [PATCH 109/153] deprecate max_length + chat utils doc --- docs/source/_toctree.yml | 16 ++++++++++------ docs/source/chat_template_utils.md | 17 +++++++++++++++++ trl/chat_template_utils.py | 12 +++++++++--- trl/trainer/grpo_config.py | 24 +++++++++++++++++++++--- trl/trainer/grpo_trainer.py | 26 ++++++++++++-------------- 5 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 docs/source/chat_template_utils.md diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index a4ca28675bc..1c3b47e51f7 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -75,18 +75,22 @@ - local: sft_trainer title: SFT title: Trainers + - sections: + - local: chat_template_utils + title: Chat Template Utilities + - local: data_utils + title: Data Utilities + - local: model_utils + title: Model Utilities + - local: script_utils + title: Script Utilities + title: Utilities - local: models title: Model Classes - - local: model_utils - title: Model Utilities - local: callbacks title: Callbacks - - local: data_utils - title: Data Utilities - local: rewards title: Reward Functions - - local: script_utils - title: Script Utilities - local: others title: Others title: API diff --git a/docs/source/chat_template_utils.md b/docs/source/chat_template_utils.md new file mode 100644 index 00000000000..2900adf3781 --- /dev/null +++ b/docs/source/chat_template_utils.md @@ -0,0 +1,17 @@ +# Chat template utilities + +## add_response_schema + +[[autodoc]] chat_template_utils.add_response_schema + +## is_chat_template_prefix_preserving + +[[autodoc]] chat_template_utils.is_chat_template_prefix_preserving + +## patch_chat_template_for_training + +[[autodoc]] chat_template_utils.patch_chat_template_for_training + +## parse_response + +[[autodoc]] chat_template_utils.parse_response diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 0baccce9d18..877002721a2 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -329,7 +329,13 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) and text3.startswith(text2) - +# Modifications: +# - {%- if '' in content %} +# + {%- if '' in content and '' in content %} +# Always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly +# - {%- if loop.index0 > ns.last_query_index %} ... {%- endif %} +# + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} +# Always include thinking block during training. It's important to have a prefix-preserving template. # docstyle-ignore qwen3_training_chat_template = r"""{%- if tools %} {{- '<|im_start|>system\n' }} @@ -368,12 +374,12 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- if message.reasoning_content is string %} {%- set reasoning_content = message.reasoning_content %} {%- else %} - {%- if '' in content and '' in content %} # Modify this to always check for both tags to avoid edge cases where the model generates only one tag, which would otherwise be parsed incorrectly + {%- if '' in content and '' in content %} {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} {%- set content = content.split('')[-1].lstrip('\n') %} {%- endif %} {%- endif %} - {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} # Always include thinking block during training + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} {%- if message.tool_calls %} {%- for tool_call in message.tool_calls %} {%- if (loop.first and content) or (not loop.first) %} diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index ab668d9f7a7..38a67fa5660 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -262,6 +262,15 @@ class GRPOConfig(TrainingArguments): > Deprecated arguments + max_prompt_length (`bool`, *optional*): + + + + Parameter `max_prompt_length` is deprecated and will be removed in version 0.28.0. You should instead + filter your dataset before training to ensure that prompts do not exceed your desired length. + + + wandb_log_unique_prompts (`bool`, *optional*): @@ -330,8 +339,9 @@ class GRPOConfig(TrainingArguments): default=False, metadata={ "help": "Whether to cast the language modeling head of the policy and reference, models to float32." - "As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only supported when the model" - " has untied word embedding and language modeling head layers i.e. `tie_word_embeddings` in the model config is False." + "As recommended by the [ScaleRL](https://huggingface.co/papers/2510.13786) recipe. This flag is only " + "supported when the model has untied word embedding and language modeling head layers i.e. " + "`tie_word_embeddings` in the model config is False." }, ) @@ -687,11 +697,19 @@ class GRPOConfig(TrainingArguments): log_unique_prompts: bool = field( default=False, metadata={ - "help": "Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all prompts are logged." + "help": "Whether to log unique prompts. If `True`, only unique prompts are logged. If `False`, all " + "prompts are logged." }, ) # Deprecated arguments + max_prompt_length: int | None = field( + default=512, + metadata={ + "help": "Deprecated, filter your dataset before training to ensure that prompts do not exceed your " + "desired length." + }, + ) wandb_log_unique_prompts: bool | None = field( default=None, metadata={"help": "Deprecated, use `log_unique_prompts` instead."}, diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0b8b0173f9a..2dbf33cdc9a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -642,7 +642,7 @@ def cast_outputs_to_original_dtype(module, args, output): # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", quantization=vllm_quantization, - enforce_eager=True, + enforce_eager=True ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) @@ -1195,15 +1195,15 @@ def _generate_single_turn(self, prompts: list): if is_conversational({"prompt": prompts[0]}): prompts = [prepare_multimodal_messages_vllm(prompt) for prompt in prompts] - # In vLLM, tool call arguments must be JSON strings. - # See https://github.com/vllm-project/vllm/pull/28820 + # In vLLM, tool call arguments must be JSON strings. See https://github.com/vllm-project/vllm/pull/28820 for prompt in prompts: # iterate over each conversation - for message in prompt: # iterate over each message - if "tool_calls" in message: # check if message has tool calls - for call in message["tool_calls"]: - args = call["function"]["arguments"] - if isinstance(args, dict): # only convert dict → JSON string - call["function"]["arguments"] = json.dumps(args) + if is_conversational({"prompt": prompt}): + for message in prompt: # iterate over each message + if "tool_calls" in message: # check if message has tool calls + for call in message["tool_calls"]: + args = call["function"]["arguments"] + if isinstance(args, dict): # only convert dict → JSON string + call["function"]["arguments"] = json.dumps(args) # Generate completions using vLLM: gather all prompts and use them in a single call in the main process if self.vllm_mode == "server": @@ -1380,6 +1380,7 @@ def _generate_single_turn(self, prompts: list): tokenize=True, return_dict=True, **self.chat_template_kwargs, + tools=self.tools, ) else: processor_outputs = self.processing_class(text=prompts, **processor_kwargs) @@ -1519,9 +1520,7 @@ def _generate(self, prompts: list): try: result = self._tool_dict[function["name"]](**function["arguments"]) except Exception as e: - # Store the full traceback as a string in the result result = {"error": str(e)} - # keep track of how many times each tool failed tool_failure_count += 1 else: result = {"error": f"Unsupported tool call type: {tool_call['type']}"} @@ -1597,9 +1596,8 @@ def _generate(self, prompts: list): # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] - if ( - post_tool_completions[idx]["content"] or "tool_calls" in post_tool_completions[idx] - ): # when the post-tool if completly truncated, content is empty + # When the post-tool if completly truncated, content is empty. + if post_tool_completions[idx]["content"] or "tool_calls" in post_tool_completions[idx]: completions[idx_with_tool].append(post_tool_completions[idx]) # Check for further tool calls From 6f2524d54b95023b11819e7bf06c177df06ae93c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 18 Nov 2025 23:26:22 +0000 Subject: [PATCH 110/153] test chat template utils --- tests/test_chat_template_utils.py | 114 ++++++++++++++++++++++++++++++ trl/chat_template_utils.py | 4 +- 2 files changed, 116 insertions(+), 2 deletions(-) create mode 100644 tests/test_chat_template_utils.py diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py new file mode 100644 index 00000000000..17fc8b02406 --- /dev/null +++ b/tests/test_chat_template_utils.py @@ -0,0 +1,114 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import textwrap + +from transformers import AutoTokenizer + +from trl.chat_template_utils import ( + add_response_schema, + is_chat_template_prefix_preserving, + patch_chat_template_for_training, +) + + +class TestAddResponseSchema: + def test_add_response_schema(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + assistant_text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + parsed = tokenizer.parse_response(assistant_text) + expected = { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + } + assert parsed == expected + + +class TestIsChatTemplatePrefixPreserving: + def test_prefix_preserving_template(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer.chat_template = textwrap.dedent(r""" + {%- for message in messages %} + + {%- if message.role == 'user' %} + {{- '<|im_start|>user\n' + message.content + '<|im_end|>\n' }} + {%- elif message.role == 'assistant' %} + {{- '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }} + {%- endif %} + + {%- endfor %} + + {%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- endif %}""") + assert is_chat_template_prefix_preserving(tokenizer) is True + + def test_non_prefix_preserving_template(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is + # only present for last assistant message, which makes it non-prefix-preserving. + tokenizer.chat_template = textwrap.dedent(r""" + {%- if messages[0].role == 'system' %} + {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} + {%- endif %} + {%- set ns = namespace(last_query_index=messages|length - 1) %} + {%- for message in messages[::-1] %} + {%- set index = (messages|length - 1) - loop.index0 %} + {%- if message.role == "user" and message.content is string %} + {%- set ns.last_query_index = index %} + {%- break %} + {%- endif %} + {%- endfor %} + {%- for message in messages %} + {%- set content = message.content if message.content is string else '' %} + {%- if message.role == "user" or (message.role == "system" and not loop.first) %} + {{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>\n' }} + {%- elif message.role == "assistant" %} + {%- set reasoning_content = '' %} + {%- if message.reasoning_content is string %} + {%- set reasoning_content = message.reasoning_content %} + {%- else %} + {%- if '' in content %} + {%- set reasoning_content = content.split('')[0].rstrip('\n').split('')[-1].lstrip('\n') %} + {%- set content = content.split('')[-1].lstrip('\n') %} + {%- endif %} + {%- endif %} + {%- if loop.index0 > ns.last_query_index %} + {%- if loop.last or (not loop.last and reasoning_content) %} + {{- '<|im_start|>' + message.role + '\n\n' + reasoning_content.strip('\n') + '\n\n\n' + content.lstrip('\n') }} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {%- else %} + {{- '<|im_start|>' + message.role + '\n' + content }} + {%- endif %} + {{- '<|im_end|>\n' }} + {%- endif %} + {%- endfor %} + {%- if add_generation_prompt %} + {{- '<|im_start|>assistant\n' }} + {%- if enable_thinking is defined and enable_thinking is false %} + {{- '\n\n\n\n' }} + {%- endif %} + {%- endif %}""") + assert is_chat_template_prefix_preserving(tokenizer) is False + + +class TestPatchChatTemplateForTraining: + def test_patch_qwen3(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = patch_chat_template_for_training(tokenizer) + assert is_chat_template_prefix_preserving(tokenizer) is True diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 877002721a2..960727c5fcb 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -145,7 +145,6 @@ }, } - qwen3_schema = { "x-regex": r"^(?:\n?(?P.+?)\n?\s*)?(?P.*?)(?=(?:|<\|im_end\|>|$))(?:(?P.+?))?\s*(?:<\|im_end\|>|$)", "type": "object", @@ -288,7 +287,7 @@ def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor Tokenizer or processor with the added response schema. """ if processor.chat_template == qwen3_chat_template: - # The qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See + # The Qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See # https://github.com/huggingface/transformers/issues/42220 processor.response_schema = qwen3_schema return processor @@ -329,6 +328,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: return text2.startswith(text1) and text3.startswith(text2) + # Modifications: # - {%- if '' in content %} # + {%- if '' in content and '' in content %} From eb9eca9f0ddd693a2a0c50b40f1aae9609d31be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 01:20:38 +0000 Subject: [PATCH 111/153] test --- tests/test_chat_template_utils.py | 1 + tests/test_grpo_trainer.py | 75 +++++++++++++++++++++++++++++++ trl/trainer/grpo_trainer.py | 3 +- 3 files changed, 77 insertions(+), 2 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 17fc8b02406..5c252d17aaf 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -60,6 +60,7 @@ def test_non_prefix_preserving_template(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") # The following template is quite typical of models like Qwen3 and GPT-OSS, where the thinking part is # only present for last assistant message, which makes it non-prefix-preserving. + # docstyle-ignore tokenizer.chat_template = textwrap.dedent(r""" {%- if messages[0].role == 'system' %} {{- '<|im_start|>system\n' + messages[0].content + '<|im_end|>\n' }} diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 52abd51f96b..c14ee13551f 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1690,6 +1690,81 @@ def test_training_with_chat_template_kwargs(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_training_with_tools(self): + def multiply(a: int, b: int) -> int: + """ + Multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + + dataset = load_dataset("trl-internal-testing/zen", "conversational_prompt_only", split="train") + + training_args = GRPOConfig( + output_dir=self.tmp_dir, + learning_rate=0.1, + per_device_train_batch_size=3, + num_generations=3, + max_completion_length=64, + report_to="none", + ) + trainer = GRPOTrainer( + model="trl-internal-testing/tiny-Qwen3MoeForCausalLM", + reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", + args=training_args, + train_dataset=dataset, + tools=[multiply], + ) + + previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} + + def fake_generate(input_ids, **kwargs): + if input_ids.shape[0] == 3: # first call + # fmt: off + completion_ids = torch.tensor( + [ + # '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 65, 788, 220, 19, 11248, 151658, 151645], + # '\n{"name": "multiply", "arguments": {"a": 3, "c": 4}}\n<|im_end|>' + [151657, 198, 4913, 606, 788, 330, 64648, 497, 330, 16370, 788, 5212, 64, 788, 220, 18, 11, 330, 66, 788, 220, 19, 11248, 151658, 151645], + # "I don't know any tool<|im_end|>" + [40, 1513, 944, 1414, 894, 5392, 151645, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643, 151643], + ], + device=input_ids.device, + ) + # fmt: on + else: # second call will only have two inputs in the batch, because two examples haave a tool call. + completion_ids = torch.tensor( + [ + # 'Done!<|im_end|>' + [17453, 0, 151645], + # 'Done!<|im_end|>' + [17453, 0, 151645], + ], + device=input_ids.device, + ) + return torch.cat([input_ids, completion_ids], dim=-1) + + with patch.object(trainer.model, "generate", side_effect=fake_generate): + trainer.train() + + assert trainer.state.log_history[-1]["train_loss"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] is not None + assert trainer.state.log_history[-1]["tools/call_frequency"] == pytest.approx(2 / 3) + assert trainer.state.log_history[-1]["tools/failure_frequency"] is not None + assert trainer.state.log_history[-1]["tools/failure_frequency"] == pytest.approx(1 / 2) + + # Check that the params have changed + for n, param in previous_trainable_params.items(): + new_param = trainer.model.get_parameter(n) + assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + def test_mismatched_reward_processing_classes_length(self): """Test that mismatched length between reward_funcs and reward_processing_classes raises error.""" dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 2dbf33cdc9a..0b7242db15c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -642,7 +642,6 @@ def cast_outputs_to_original_dtype(module, args, output): # Important so temperature scaling/logit tweaking affects the TIS log probs logprobs_mode="processed_logprobs", quantization=vllm_quantization, - enforce_eager=True ) if self.args.vllm_enable_sleep_mode: self.llm.sleep(level=2) @@ -1605,7 +1604,7 @@ def _generate(self, prompts: list): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] - if [len(ids) for ids in completion_ids] != [len(p) for p in logprobs]: + if logprobs and [len(ids) for ids in completion_ids] != [len(p) for p in logprobs]: raise ValueError( "Length mismatch between completion_ids and logprobs after tool execution. " f"completion_ids lengths: {[len(ids) for ids in completion_ids]}, " From 19fa924e224dc2e5d8b31e85cf403ccdd2682212 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 16:19:03 +0000 Subject: [PATCH 112/153] remove max_prompt_length --- docs/source/rloo_trainer.md | 1 - examples/notebooks/grpo_qwen3_vl.ipynb | 1 - examples/scripts/rloo.py | 1 - examples/scripts/rloo_vlm.py | 2 -- tests/test_grpo_trainer.py | 1 - tests/test_rloo_trainer.py | 3 --- trl/experimental/openenv/utils.py | 1 - 7 files changed, 10 deletions(-) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index c0d62e7cb17..c0298e97350 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -528,7 +528,6 @@ accelerate launch \ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ diff --git a/examples/notebooks/grpo_qwen3_vl.ipynb b/examples/notebooks/grpo_qwen3_vl.ipynb index f46f61b5771..4940976ea6b 100644 --- a/examples/notebooks/grpo_qwen3_vl.ipynb +++ b/examples/notebooks/grpo_qwen3_vl.ipynb @@ -406,7 +406,6 @@ " per_device_train_batch_size=2,\n", " max_completion_length=1024, # default: 256 # Max completion length produced during training\n", " num_generations=2, # 2, # default: 8 # Number of generations produced during trainig for comparison\n", - " max_prompt_length=2048, # default: 512 # Max prompt lenght of the input prompt used for generation during training\n", "\n", " fp16=True,\n", "\n", diff --git a/examples/scripts/rloo.py b/examples/scripts/rloo.py index abeabb45b60..faa90df30be 100644 --- a/examples/scripts/rloo.py +++ b/examples/scripts/rloo.py @@ -73,7 +73,6 @@ def make_conversation(example): gradient_checkpointing_kwargs=dict(use_reentrant=False), log_completions=True, num_completions_to_print=2, - max_prompt_length=2048, max_completion_length=1024, gradient_accumulation_steps=2, steps_per_generation=8, diff --git a/examples/scripts/rloo_vlm.py b/examples/scripts/rloo_vlm.py index a98674db15c..48fc63250b6 100644 --- a/examples/scripts/rloo_vlm.py +++ b/examples/scripts/rloo_vlm.py @@ -37,7 +37,6 @@ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ @@ -55,7 +54,6 @@ --output_dir rloo-SmolVLM2-2.2B-Instruct \ --learning_rate 1e-5 \ --dtype bfloat16 \ - --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index c14ee13551f..1745ef7abb1 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -2211,7 +2211,6 @@ def test_vlm_processor_vllm_colocate_mode(self): gradient_accumulation_steps=2, # Make effective batch size 2, divisible by num_generations num_generations=2, max_completion_length=4, # Very short completions to reduce memory - max_prompt_length=32, # Very short prompts to reduce memory use_vllm=True, # Enable vLLM vllm_mode="colocate", # Use colocate mode to avoid server dependency vllm_gpu_memory_utilization=0.05, # Use minimal GPU memory (5%) diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index 82810b1cca5..dff1958e3ad 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1100,7 +1100,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = RLOOTrainer( @@ -1247,7 +1246,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, num_generations=3, max_completion_length=8, - max_prompt_length=18, report_to="none", use_vllm=True, vllm_mode="server", @@ -1289,7 +1287,6 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage - max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = RLOOTrainer( diff --git a/trl/experimental/openenv/utils.py b/trl/experimental/openenv/utils.py index ca26fdc796c..87854cb538d 100644 --- a/trl/experimental/openenv/utils.py +++ b/trl/experimental/openenv/utils.py @@ -43,7 +43,6 @@ def _build_colocate_sampling_params( "top_k": -1 if trainer.top_k is None else trainer.top_k, "min_p": 0.0 if trainer.min_p is None else trainer.min_p, "max_tokens": trainer.max_completion_length, - "truncate_prompt_tokens": trainer.max_prompt_length, "guided_decoding": guided_decoding, } if trainer.repetition_penalty is not None: From 278703ecf0ebab6d81e2595bf08121414537f00d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 16:41:01 +0000 Subject: [PATCH 113/153] better doc --- trl/chat_template_utils.py | 66 +++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 23 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 960727c5fcb..7e709da18d0 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -285,6 +285,18 @@ def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor Returns: `TokenizerOrProcessor`: Tokenizer or processor with the added response schema. + + Examples: + + ```python + >>> from trl.chat_template_utils import add_response_schema + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + >>> tokenizer = add_response_schema(tokenizer) + >>> assistant_text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + >>> tokenizer.parse_response(assistant_text) + {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} + ``` """ if processor.chat_template == qwen3_chat_template: # The Qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See @@ -422,30 +434,11 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: """ - Wrap `tokenizer.apply_chat_template()` to use a training-compatible chat template if needed. - - During training, we need a *prefix-preserving* template where each message strictly appends to previous ones. For - example: - - ```python - turn0 = {"role": "user", "content": "Hello!"} - turn1 = {"role": "assistant", "content": "Hi!"} - text0 = tokenizer.apply_chat_template([turn0], add_generation_prompt=True) - text1 = tokenizer.apply_chat_template([turn0, turn1]) - assert text1.startswith(text0) - ``` + Ensure a tokenizer uses a prefix-preserving chat template during training. - Tokenizers typically use inference-ready templates that may differ from the template used in training. The - inference template may not satisfy the prefix-preservation requirement. For example, Qwen3 and OpenAI GPT OSS drop - thinking blocks from non-final turns. - - This function first checks if the template is prefix-preserving. If it is, no patching is needed. If not, it - patches the `apply_chat_template()` method to temporarily swap in a training-compatible template during calls, then - restore the original afterward. This ensures the chat template complies with training needs while preserving the - original template for later inference. It also stores the training template in a `_training_chat_template` - attribute, which is useful when you need to access it—for example, when using vLLM inference. - - Currently supported: Qwen3 models only. + If the tokenizer's template isn't prefix-preserving, temporarily swap in a training-compatible template (currently + only Qwen3 supported) for the duration of `apply_chat_template()` calls. The training template is saved as + `tokenizer._training_chat_template`. Args: tokenizer (`PreTrainedTokenizer`): @@ -454,6 +447,33 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain Returns: `PreTrainedTokenizer`: The same tokenizer with `apply_chat_template()` patched (if needed and supported). + + Example: + + ```python + >>> from trl.chat_template_utils import patch_chat_template_for_training + >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + >>> messages1 = [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is blue."}, + ... ] + >>> messages2 = [ + ... {"role": "user", "content": "What color is the sky?"}, + ... {"role": "assistant", "content": "It is blue."}, + ... {"role": "user", "content": "And at night?"}, + ... ] + >>> tokenizer.apply_chat_template(messages1, tokenize=False) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + >>> # ^ think tags missing + >>> tokenizer = patch_chat_template_for_training(tokenizer) + >>> tokenizer.apply_chat_template(messages1, tokenize=False) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False) + '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' + ``` """ # First check if patching is needed if is_chat_template_prefix_preserving(tokenizer): From 6828ba2b65a02f39abb5b3af6d2eea294e382907 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 20:06:13 +0000 Subject: [PATCH 114/153] doc example and skip version below dev --- tests/test_grpo_trainer.py | 7 ++++++- trl/chat_template_utils.py | 9 +++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index 1745ef7abb1..ec1f7cae730 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1690,6 +1690,11 @@ def test_training_with_chat_template_kwargs(self): new_param = trainer.model.get_parameter(n) assert not torch.equal(param, new_param), f"Parameter {n} has not changed." + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0.dev0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0.dev0", + strict=True, + ) def test_training_with_tools(self): def multiply(a: int, b: int) -> int: """ @@ -1711,7 +1716,7 @@ def multiply(a: int, b: int) -> int: learning_rate=0.1, per_device_train_batch_size=3, num_generations=3, - max_completion_length=64, + max_completion_length=128, report_to="none", ) trainer = GRPOTrainer( diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 7e709da18d0..b077af9554b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -285,12 +285,13 @@ def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor Returns: `TokenizerOrProcessor`: Tokenizer or processor with the added response schema. - + Examples: ```python >>> from trl.chat_template_utils import add_response_schema >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> tokenizer = add_response_schema(tokenizer) >>> assistant_text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' @@ -453,6 +454,7 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain ```python >>> from trl.chat_template_utils import patch_chat_template_for_training >>> from transformers import AutoTokenizer + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> messages1 = [ ... {"role": "user", "content": "What color is the sky?"}, @@ -465,12 +467,15 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain ... ] >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' - >>> # ^ think tags missing + + >>> # ^ think tags missing >>> tokenizer = patch_chat_template_for_training(tokenizer) >>> tokenizer.apply_chat_template(messages1, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' + >>> tokenizer.apply_chat_template(messages2, tokenize=False) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' ``` From ae653d811ab02352fae4deadb1047d26e6774556 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:15:33 +0000 Subject: [PATCH 115/153] fix overlong case --- trl/trainer/grpo_trainer.py | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0b7242db15c..6df0646440c 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1532,12 +1532,15 @@ def _generate(self, prompts: list): # is called with an input longer than its max model length, it will error out. pct_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] overlong = [len(pct) - len(p) >= self.max_completion_length for p, pct in zip(p_ids, pct_ids, strict=True)] - if logprobs is not None: - for idx in range(len(idxs_with_tool)): - if overlong[idx]: - num_tokens = len(pct_ids[idx]) - len(p_ids[idx]) - logprobs[idxs_with_tool[idx]] += [0.0] * num_tokens - tool_mask[idxs_with_tool[idx]] += [0] * num_tokens + for idx in range(len(idxs_with_tool)): + idx_with_tool = idxs_with_tool[idx] + if overlong[idx]: + prompt_length = len(prompt_ids[idx_with_tool]) + ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] + completion_ids[idx_with_tool] = ct + tool_mask[idx_with_tool] += [0] * (len(ct) - len(tool_mask[idx_with_tool])) + if logprobs is not None: + logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] if not idxs_with_tool: @@ -1604,13 +1607,6 @@ def _generate(self, prompts: list): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] - if logprobs and [len(ids) for ids in completion_ids] != [len(p) for p in logprobs]: - raise ValueError( - "Length mismatch between completion_ids and logprobs after tool execution. " - f"completion_ids lengths: {[len(ids) for ids in completion_ids]}, " - f"logprobs lengths: {[len(p) for p in logprobs]}" - ) - # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) From 96387b32a3ea251339f5731b795a1b8ba92c6e00 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:36:13 +0000 Subject: [PATCH 116/153] test parse --- tests/test_chat_template_utils.py | 64 +++++++++++++++++++++++++++++++ 1 file changed, 64 insertions(+) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 5c252d17aaf..05825e7f1f6 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -14,16 +14,25 @@ import textwrap +import pytest +import transformers +from packaging.version import Version from transformers import AutoTokenizer from trl.chat_template_utils import ( add_response_schema, is_chat_template_prefix_preserving, + parse_response, patch_chat_template_for_training, ) class TestAddResponseSchema: + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0.dev0"), + reason="Response parsing is not supported in transformers versions below 5.0.0.dev0", + strict=True, + ) def test_add_response_schema(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) @@ -113,3 +122,58 @@ def test_patch_qwen3(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = patch_chat_template_for_training(tokenizer) assert is_chat_template_prefix_preserving(tokenizer) is True + + +class TestParseResponse: + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0.dev0"), + reason="Response parsing is not supported in transformers versions below 5.0.0.dev0", + strict=True, + ) + def test_parse_response(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + assistant_text = tokenizer([text])["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = [ + { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + } + ] + assert parsed == expected + + @pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0.dev0"), + reason="Response parsing is not supported in transformers versions below 5.0.0.dev0", + strict=True, + ) + def test_parse_response_no_tool_call(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + text = "Here is the answer to your question.<|im_end|>" + assistant_text = tokenizer([text])["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = [ + { + "role": "assistant", + "content": "Here is the answer to your question.", + } + ] + assert parsed == expected + + def test_parse_response_malformed_tool_call(self): + tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") + tokenizer = add_response_schema(tokenizer) + text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>' + assistant_text = tokenizer([text])["input_ids"] + parsed = parse_response(tokenizer, assistant_text) + expected = [ + { + "role": "assistant", + "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', + } + ] + assert parsed == expected From 714b9ea48fbb511009ef45fe2d6f077fb8f06031 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:36:23 +0000 Subject: [PATCH 117/153] example in the doc --- trl/chat_template_utils.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index b077af9554b..6afca626ec9 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -526,6 +526,19 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list Returns: `list[dict]`: List of response dictionaries. + + Example: + ```python + >>> from trl.chat_template_utils import parse_response, add_response_schema + >>> from transformers import AutoTokenizer + + >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") + >>> tokenizer = add_response_schema(tokenizer) # temporary until built-in support + >>> text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' + >>> sequences = tokenizer([text])["input_ids"] + >>> parse_response(tokenizer, sequences) + [{'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}] + ``` """ outputs = [] From 3a1c7fb9a58ba33efb0c1289f2a5a5885e085e6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:40:17 +0000 Subject: [PATCH 118/153] comment in test --- tests/test_grpo_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index ec1f7cae730..eaf417d6457 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1696,6 +1696,9 @@ def test_training_with_chat_template_kwargs(self): strict=True, ) def test_training_with_tools(self): + # In this test, we define a simple tool that multiplies two integers. Regardless of the input prompt, + # the model will generate 3 completions, 2 of which will be valid tool calls. Among the 2 tool calls, one will + # succeed and the other will fail (because of a wrong argument name). def multiply(a: int, b: int) -> int: """ Multiplies two integers. From a1ebcba9c2c4e85f60b2403c739b76c3d56b5690 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:44:12 +0000 Subject: [PATCH 119/153] version.parse -> Version --- trl/trainer/grpo_trainer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 6df0646440c..d5def05df17 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -34,7 +34,7 @@ from accelerate import logging from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed from datasets import Dataset, IterableDataset -from packaging import version +from packaging.version import Version from torch import nn from torch.distributed.fsdp import FullyShardedDataParallel as FSDP from torch.utils.data import DataLoader, Sampler @@ -390,7 +390,7 @@ def __init__( self.rollout_func = rollout_func # Tools - if tools and not version.parse(transformers.__version__) >= version.parse("5.0.0.dev0"): + if tools and not Version(transformers.__version__) >= Version("5.0.0.dev0"): raise ImportError( "Using tools with GRPOTrainer requires transformers version 5.0.0.dev0 or higher. Please upgrade " "transformers to use this feature." @@ -1474,7 +1474,7 @@ def _generate(self, prompts: list): # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. if is_conversational({"prompt": prompts[0]}): if ( - version.parse(transformers.__version__) >= version.parse("5.0.0.dev0") # parse_response added in v5 + Version(transformers.__version__) >= Version("5.0.0.dev0") # parse_response added in v5 and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors and self.processing_class.response_schema is not None # only works if the tokenizer has a schema ): From c340f525a6e535da4ad40d69a69c18e467426599 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 22:54:50 +0000 Subject: [PATCH 120/153] comment chat template for vllm --- trl/extras/vllm_client.py | 6 ++++++ trl/trainer/grpo_trainer.py | 11 ++++++++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/trl/extras/vllm_client.py b/trl/extras/vllm_client.py index b38bbf16fee..e21df6d837e 100644 --- a/trl/extras/vllm_client.py +++ b/trl/extras/vllm_client.py @@ -282,6 +282,7 @@ def chat( generation_kwargs: dict | None = None, chat_template_kwargs: dict | None = None, tools: list | None = None, + chat_template: str | None = None, ) -> dict[str, list[list[int]]]: """ Generates model completions for the provided chat messages. @@ -318,6 +319,9 @@ def chat( Additional keyword arguments to customize the chat template used by the model. tools (`list`, *optional*): List of tool functions available for tool calling during chat generation. + chat_template (`str`, *optional*): + Template to use for structuring the chat. If not provided, the model's default chat template will be + used. Returns: `dict` with keys: @@ -330,6 +334,8 @@ def chat( """ if tools is not None: raise NotImplementedError("Tool calling is not yet implemented in VLLMClient.chat().") + if chat_template is not None: + raise NotImplementedError("Custom chat templates are not yet implemented in VLLMClient.chat().") url = f"{self.base_url}/chat/" diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d5def05df17..d8f463e257a 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1243,6 +1243,11 @@ def _generate_single_turn(self, prompts: list): **sampling_params, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, + # In multi-turn training, the chat template must be prefix-preserving. If the + # tokenizer's original template isn't, we replace it at initialization with a + # training-safe, prefix-preserving template (see patch_chat_template_for_training). + # In such cases, we must ensure vLLM uses the training template here. + chat_template=getattr(self.processing_class, "_training_chat_template"), ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) @@ -1336,7 +1341,11 @@ def _generate_single_turn(self, prompts: list): use_tqdm=False, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, - chat_template=self.processing_class._training_chat_template if self.tools else None, + # In multi-turn training, the chat template must be prefix-preserving. If the + # tokenizer's original template isn't, we replace it at initialization with a + # training-safe, prefix-preserving template (see patch_chat_template_for_training). + # In such cases, we must ensure vLLM uses the training template here. + chat_template=getattr(self.processing_class, "_training_chat_template"), ) else: all_outputs = self.llm.generate( From d338c84f340eefe45cc91cf92cdffff248780747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 23:15:42 +0000 Subject: [PATCH 121/153] qol --- trl/trainer/grpo_trainer.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index d8f463e257a..43bcb9dad20 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1247,7 +1247,7 @@ def _generate_single_turn(self, prompts: list): # tokenizer's original template isn't, we replace it at initialization with a # training-safe, prefix-preserving template (see patch_chat_template_for_training). # In such cases, we must ensure vLLM uses the training template here. - chat_template=getattr(self.processing_class, "_training_chat_template"), + chat_template=self.processing_class._training_chat_template, ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) @@ -1345,7 +1345,7 @@ def _generate_single_turn(self, prompts: list): # tokenizer's original template isn't, we replace it at initialization with a # training-safe, prefix-preserving template (see patch_chat_template_for_training). # In such cases, we must ensure vLLM uses the training template here. - chat_template=getattr(self.processing_class, "_training_chat_template"), + chat_template=self.chat_template, ) else: all_outputs = self.llm.generate( @@ -1379,19 +1379,18 @@ def _generate_single_turn(self, prompts: list): self.llm.sleep(level=2) elif self.use_transformers_paged: - processor_kwargs = {"truncation": True, "add_special_tokens": False} if is_conversational({"prompt": prompts[0]}): processor_outputs = self.processing_class.apply_chat_template( conversation=prompts, - **processor_kwargs, + tools=self.tools, + chat_template=self.chat_template, add_generation_prompt=True, tokenize=True, return_dict=True, **self.chat_template_kwargs, - tools=self.tools, ) else: - processor_outputs = self.processing_class(text=prompts, **processor_kwargs) + processor_outputs = self.processing_class(text=prompts) with ( profiling_context(self, "transformers.generate_batch"), @@ -1421,25 +1420,23 @@ def _generate_single_turn(self, prompts: list): else: # Regular generation path - processor_kwargs = { - "return_tensors": "pt", - "padding": True, - "padding_side": "left", - "truncation": True, - "add_special_tokens": False, - } if is_conversational({"prompt": prompts[0]}): generate_inputs = self.processing_class.apply_chat_template( conversation=prompts, - **processor_kwargs, + tools=self.tools, + chat_template=self._chat_template, add_generation_prompt=True, tokenize=True, - tools=self.tools, + padding=True, + padding_side="left", + return_tensors="pt", return_dict=True, **self.chat_template_kwargs, ) else: - generate_inputs = self.processing_class(text=prompts, **processor_kwargs) + generate_inputs = self.processing_class( + text=prompts, padding=True, padding_side="left", return_tensors="pt" + ) generate_inputs = super()._prepare_inputs(generate_inputs) with ( @@ -1512,7 +1509,12 @@ def _generate(self, prompts: list): prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls # Tokenize the current prompt. We will use this to filter out overlong samples later. - kwargs = dict(tools=self.tools, add_generation_prompt=True, tokenize=True, **self.chat_template_kwargs) + kwargs = { + "tools": self.tools, + "add_generation_prompt": True, + "tokenize": True, + **self.chat_template_kwargs, + } p_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] # Call the tools, and build the new prompt for generation From f8444dfae4a1268e1b8f0a7a1dfe031e31e81587 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 23:44:52 +0000 Subject: [PATCH 122/153] use chat template arg instead of ugly patch --- tests/test_chat_template_utils.py | 9 ++++--- trl/chat_template_utils.py | 41 +++++++++---------------------- trl/trainer/grpo_trainer.py | 22 ++++++++--------- 3 files changed, 27 insertions(+), 45 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 05825e7f1f6..988ee526c98 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -21,9 +21,9 @@ from trl.chat_template_utils import ( add_response_schema, + get_training_chat_template, is_chat_template_prefix_preserving, parse_response, - patch_chat_template_for_training, ) @@ -117,10 +117,11 @@ def test_non_prefix_preserving_template(self): assert is_chat_template_prefix_preserving(tokenizer) is False -class TestPatchChatTemplateForTraining: - def test_patch_qwen3(self): +class TestGetTrainingChatTemplate: + def test_qwen3(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") - tokenizer = patch_chat_template_for_training(tokenizer) + assert is_chat_template_prefix_preserving(tokenizer) is False + tokenizer.chat_template = get_training_chat_template(tokenizer) assert is_chat_template_prefix_preserving(tokenizer) is True diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 6afca626ec9..bf9c780f13b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import functools from typing import TypeVar from transformers import PreTrainedTokenizer, ProcessorMixin @@ -433,26 +432,25 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: {%- endif %}""" -def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: +def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: """ - Ensure a tokenizer uses a prefix-preserving chat template during training. + Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, temporarily swap in a training-compatible template (currently - only Qwen3 supported) for the duration of `apply_chat_template()` calls. The training template is saved as - `tokenizer._training_chat_template`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template + (currently only Qwen3 supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): Tokenizer instance to patch. Returns: - `PreTrainedTokenizer`: - The same tokenizer with `apply_chat_template()` patched (if needed and supported). + `str` or `None`: + Training-compatible chat template, or `None` if no patching is needed. Example: ```python - >>> from trl.chat_template_utils import patch_chat_template_for_training + >>> from trl.chat_template_utils import get_training_chat_template >>> from transformers import AutoTokenizer >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") @@ -472,41 +470,26 @@ def patch_chat_template_for_training(tokenizer: PreTrainedTokenizer) -> PreTrain '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' >>> # ^ think tags missing - >>> tokenizer = patch_chat_template_for_training(tokenizer) - >>> tokenizer.apply_chat_template(messages1, tokenize=False) + >>> chat_template = get_training_chat_template(tokenizer) + >>> tokenizer.apply_chat_template(messages1, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n' - >>> tokenizer.apply_chat_template(messages2, tokenize=False) + >>> tokenizer.apply_chat_template(messages2, tokenize=False, chat_template=chat_template) '<|im_start|>user\nWhat color is the sky?<|im_end|>\n<|im_start|>assistant\n\n\n\n\nIt is blue.<|im_end|>\n<|im_start|>user\nAnd at night?<|im_end|>\n' ``` """ # First check if patching is needed if is_chat_template_prefix_preserving(tokenizer): - return tokenizer # No patching needed - - original_method = tokenizer.apply_chat_template - original_chat_template = tokenizer.chat_template + return None # No patching needed if tokenizer.chat_template == qwen3_chat_template: - tokenizer._training_chat_template = qwen3_training_chat_template + return qwen3_training_chat_template else: raise ValueError( "The tokenizer's chat template is not prefix-preserving and patching is not supported for this template. " "Please manually modify the tokenizer's chat template for training." ) - @functools.wraps(original_method) - def wrapper(*args, **kwargs): - tokenizer.chat_template = tokenizer._training_chat_template - try: - result = original_method(*args, **kwargs) - finally: - tokenizer.chat_template = original_chat_template - return result - - tokenizer.apply_chat_template = wrapper - return tokenizer - def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list[dict]: """ diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 43bcb9dad20..bd89472c182 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -55,7 +55,7 @@ from transformers.trainer_utils import seed_worker from transformers.utils import is_datasets_available, is_peft_available, is_rich_available -from ..chat_template_utils import add_response_schema, parse_response, patch_chat_template_for_training +from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response from ..data_utils import ( apply_chat_template, is_conversational, @@ -402,7 +402,12 @@ def __init__( # known chat templates. if tools and not processing_class.response_schema: processing_class = add_response_schema(processing_class) - processing_class = patch_chat_template_for_training(processing_class) + # In multi-turn training, the chat template *must* be prefix-preserving. If the tokenizer's original template + # isn't, we replace it at initialization with a training-safe, prefix-preserving template. + if tools: + self.chat_template = get_training_chat_template(processing_class) + else: + self.chat_template = None # Training arguments self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper @@ -1243,11 +1248,7 @@ def _generate_single_turn(self, prompts: list): **sampling_params, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, - # In multi-turn training, the chat template must be prefix-preserving. If the - # tokenizer's original template isn't, we replace it at initialization with a - # training-safe, prefix-preserving template (see patch_chat_template_for_training). - # In such cases, we must ensure vLLM uses the training template here. - chat_template=self.processing_class._training_chat_template, + chat_template=self.chat_template, ) else: output = self.vllm_client.generate(prompts=ordered_set_of_prompts, **sampling_params) @@ -1341,10 +1342,6 @@ def _generate_single_turn(self, prompts: list): use_tqdm=False, chat_template_kwargs=self.chat_template_kwargs, tools=self.tools, - # In multi-turn training, the chat template must be prefix-preserving. If the - # tokenizer's original template isn't, we replace it at initialization with a - # training-safe, prefix-preserving template (see patch_chat_template_for_training). - # In such cases, we must ensure vLLM uses the training template here. chat_template=self.chat_template, ) else: @@ -1424,7 +1421,7 @@ def _generate_single_turn(self, prompts: list): generate_inputs = self.processing_class.apply_chat_template( conversation=prompts, tools=self.tools, - chat_template=self._chat_template, + chat_template=self.chat_template, add_generation_prompt=True, tokenize=True, padding=True, @@ -1513,6 +1510,7 @@ def _generate(self, prompts: list): "tools": self.tools, "add_generation_prompt": True, "tokenize": True, + "chat_template": self.chat_template, **self.chat_template_kwargs, } p_ids = self.processing_class.apply_chat_template(prompt_completion_tools, **kwargs)["input_ids"] From 6ac02e04928f9c56425509f5cdc21d980ee15496 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Wed, 19 Nov 2025 23:59:14 +0000 Subject: [PATCH 123/153] refactor: simplify response parsing in tokenizer and trainer --- docs/source/chat_template_utils.md | 4 +-- tests/test_chat_template_utils.py | 40 +++++++++++-------------- trl/chat_template_utils.py | 48 ++++++++++++++---------------- trl/trainer/grpo_trainer.py | 5 ++-- 4 files changed, 44 insertions(+), 53 deletions(-) diff --git a/docs/source/chat_template_utils.md b/docs/source/chat_template_utils.md index 2900adf3781..38c15c25ae7 100644 --- a/docs/source/chat_template_utils.md +++ b/docs/source/chat_template_utils.md @@ -8,9 +8,9 @@ [[autodoc]] chat_template_utils.is_chat_template_prefix_preserving -## patch_chat_template_for_training +## get_training_chat_template -[[autodoc]] chat_template_utils.patch_chat_template_for_training +[[autodoc]] chat_template_utils.get_training_chat_template ## parse_response diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 988ee526c98..9edf0d1af0a 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -135,15 +135,13 @@ def test_parse_response(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' - assistant_text = tokenizer([text])["input_ids"] + assistant_text = tokenizer(text)["input_ids"] parsed = parse_response(tokenizer, assistant_text) - expected = [ - { - "role": "assistant", - "content": "", - "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], - } - ] + expected = { + "role": "assistant", + "content": "", + "tool_calls": [{"type": "function", "function": {"name": "multiply", "arguments": {"a": 3, "b": 4}}}], + } assert parsed == expected @pytest.mark.xfail( @@ -155,26 +153,24 @@ def test_parse_response_no_tool_call(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) text = "Here is the answer to your question.<|im_end|>" - assistant_text = tokenizer([text])["input_ids"] + assistant_text = tokenizer(text)["input_ids"] parsed = parse_response(tokenizer, assistant_text) - expected = [ - { - "role": "assistant", - "content": "Here is the answer to your question.", - } - ] + expected = { + "role": "assistant", + "content": "Here is the answer to your question.", + } + assert parsed == expected def test_parse_response_malformed_tool_call(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n<|im_end|>' - assistant_text = tokenizer([text])["input_ids"] + assistant_text = tokenizer(text)["input_ids"] parsed = parse_response(tokenizer, assistant_text) - expected = [ - { - "role": "assistant", - "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', - } - ] + expected = { + "role": "assistant", + "content": '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}\n', + } + assert parsed == expected diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index bf9c780f13b..c19227cba9b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -436,12 +436,12 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: """ Get a prefix-preserving chat template for training, if needed. - If the tokenizer's template isn't prefix-preserving, returns a training-compatible template - (currently only Qwen3 supported). Otherwise, returns `None`. + If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently only Qwen3 + supported). Otherwise, returns `None`. Args: tokenizer (`PreTrainedTokenizer`): - Tokenizer instance to patch. + Tokenizer instance to check. Returns: `str` or `None`: @@ -491,11 +491,11 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: ) -def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list[dict]: +def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: """ - Parse token sequences into structured response dictionaries with fallback handling. + Parse a token sequence into structured response dictionaries with fallback handling. - Attempts to parse each sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool + Attempts to parse the sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool calls like `{"type":"function"`), falls back to decoding as plain text. Also removes incorrectly appended EOS tokens from tool call content when present. @@ -503,12 +503,12 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list Args: tokenizer (`PreTrainedTokenizer`): Tokenizer with a `parse_response()` method. - ids (`list[list[int]]`): + ids (`list[int]`): List of token sequences. Returns: - `list[dict]`: - List of response dictionaries. + `dict`: + Response dictionary. Example: ```python @@ -518,22 +518,18 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[list[int]]) -> list >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B") >>> tokenizer = add_response_schema(tokenizer) # temporary until built-in support >>> text = '\n{"name": "multiply", "arguments": {"a": 3, "b": 4}}\n<|im_end|>' - >>> sequences = tokenizer([text])["input_ids"] - >>> parse_response(tokenizer, sequences) - [{'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]}] + >>> ids = tokenizer(text)["input_ids"] + >>> parse_response(tokenizer, ids) + {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - - outputs = [] - for seq in ids: - try: - parsed = tokenizer.parse_response(seq) - # Hotfix: remove incorrectly appended EOS token from tool calls - # See https://github.com/huggingface/transformers/issues/42249 - parsed["content"] = parsed["content"].removesuffix(tokenizer.eos_token) - except Exception: - # Fallback: decode as plain text if parsing fails. This happens if the model outputs malformed tool calls. - content = tokenizer.decode(seq, skip_special_tokens=True) - parsed = {"role": "assistant", "content": content} - outputs.append(parsed) - return outputs + try: + parsed = tokenizer.parse_response(ids) + # Hotfix: remove incorrectly appended EOS token from tool calls + # See https://github.com/huggingface/transformers/issues/42249 + parsed["content"] = parsed["content"].removesuffix(tokenizer.eos_token) + except Exception: + # Fallback: decode as plain text if parsing fails. This happens if the model outputs malformed tool calls. + content = tokenizer.decode(ids, skip_special_tokens=True) + parsed = {"role": "assistant", "content": content} + return parsed diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index bd89472c182..3ae6588822f 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1481,8 +1481,7 @@ def _generate(self, prompts: list): and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors and self.processing_class.response_schema is not None # only works if the tokenizer has a schema ): - completions = parse_response(self.processing_class, completion_ids) - completions = [[completion] for completion in completions] # format as list of messages + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] else: contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) completions = [[{"role": "assistant", "content": content}] for content in contents] @@ -1602,7 +1601,7 @@ def _generate(self, prompts: list): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions - post_tool_completions = parse_response(self.processing_class, post_tool_ids) + post_tool_completions = [parse_response(self.processing_class, ids) for ids in post_tool_ids] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): From b8125bfc7f090147fb0bc12570aa0f7b83c03d6a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:27:28 +0000 Subject: [PATCH 124/153] why it doesn't render well? --- trl/chat_template_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c19227cba9b..990a0b8b5fb 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -433,7 +433,7 @@ def is_chat_template_prefix_preserving(tokenizer: PreTrainedTokenizer) -> bool: def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: - """ + r""" Get a prefix-preserving chat template for training, if needed. If the tokenizer's template isn't prefix-preserving, returns a training-compatible template (currently only Qwen3 From 37d77baef7ce6d431273efbfe9d94ae78ea2c63b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:36:54 +0000 Subject: [PATCH 125/153] raw --- trl/chat_template_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 990a0b8b5fb..ab2582b6f73 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -270,7 +270,7 @@ def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: - """ + r""" Adds the appropriate response schema to the given tokenizer or processor based on its chat template. At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While @@ -492,7 +492,7 @@ def get_training_chat_template(tokenizer: PreTrainedTokenizer) -> str | None: def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: - """ + r""" Parse a token sequence into structured response dictionaries with fallback handling. Attempts to parse the sequence using `tokenizer.parse_response()`. If parsing fails (e.g., due to malformed tool From a1365920669a608771deca902d9c5fdcdb6cb322 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:38:59 +0000 Subject: [PATCH 126/153] style --- examples/scripts/openenv/browsergym.py | 4 ++-- examples/scripts/openenv/wordle.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/examples/scripts/openenv/browsergym.py b/examples/scripts/openenv/browsergym.py index 2dabce6a371..8962518ec69 100644 --- a/examples/scripts/openenv/browsergym.py +++ b/examples/scripts/openenv/browsergym.py @@ -452,10 +452,10 @@ def main() -> None: print(f"🌍 Using existing BrowserGym Environment (Docker) at: {env_url}") elif args.env_mode == "docker-image": client = BrowserGymEnv.from_docker_image(args.env_image) - print(f"🌍 Using BrowserGym Environment (Docker) from local Image") + print("🌍 Using BrowserGym Environment (Docker) from local Image") elif args.env_mode == "docker-hub": client = BrowserGymEnv.from_hub(args.env_image) - print(f"🌍 Using existing BrowserGym Environment (Docker) from Hub Image") + print("🌍 Using existing BrowserGym Environment (Docker) from Hub Image") elif args.env_mode == "space": env_url = args.env_host print(f"🌍 Using Hugging Face Space environment at: {env_url}") diff --git a/examples/scripts/openenv/wordle.py b/examples/scripts/openenv/wordle.py index dfef53c8072..7cad1b6571a 100644 --- a/examples/scripts/openenv/wordle.py +++ b/examples/scripts/openenv/wordle.py @@ -93,7 +93,9 @@ def parse_args() -> argparse.Namespace: default="docker-image", help="Where to run the environment: 'docker-local' if already running locally, 'docker-image' to run from a Docker image, 'docker-hub' to run from Docker Hub, or 'space' to use a remote Space URL.", ) - parser.add_argument("--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment.") + parser.add_argument( + "--env-image", type=str, default="textarena-env:latest", help="Docker image for the TextArena environment." + ) parser.add_argument( "--system-prompt-path", default="wordle_prompt.txt", @@ -427,10 +429,10 @@ def main() -> None: print(f"🌍 Using existing TextArena Environment (Docker) at: {env_url}") elif args.env_mode == "docker-image": client = TextArenaEnv.from_docker_image(args.env_image) - print(f"🌍 Using TextArena Environment (Docker) from local Image") + print("🌍 Using TextArena Environment (Docker) from local Image") elif args.env_mode == "docker-hub": client = TextArenaEnv.from_hub(args.env_image) - print(f"🌍 Using existing TextArena Environment (Docker) from Hub Image") + print("🌍 Using existing TextArena Environment (Docker) from Hub Image") elif args.env_mode == "space": env_url = args.env_host print(f"🌍 Using Hugging Face Space environment at: {env_url}") From e63a46c47e26249ab0f58b9c6a97f784f7211f4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:46:40 +0000 Subject: [PATCH 127/153] fix: update xfail reason for tool parsing in TestParseResponse --- tests/test_chat_template_utils.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 9edf0d1af0a..7468538b499 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -128,7 +128,7 @@ def test_qwen3(self): class TestParseResponse: @pytest.mark.xfail( condition=Version(transformers.__version__) < Version("5.0.0.dev0"), - reason="Response parsing is not supported in transformers versions below 5.0.0.dev0", + reason="Tool parsing is not supported in transformers versions below 5.0.0.dev0", strict=True, ) def test_parse_response(self): @@ -144,11 +144,6 @@ def test_parse_response(self): } assert parsed == expected - @pytest.mark.xfail( - condition=Version(transformers.__version__) < Version("5.0.0.dev0"), - reason="Response parsing is not supported in transformers versions below 5.0.0.dev0", - strict=True, - ) def test_parse_response_no_tool_call(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) From d0823099cd76e9cefb95b34e78af8cc2b0b29070 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:48:49 +0000 Subject: [PATCH 128/153] revert rloo for now --- docs/source/rloo_trainer.md | 10 ++++++++++ examples/scripts/rloo.py | 1 + examples/scripts/rloo_vlm.py | 2 ++ tests/test_rloo_trainer.py | 3 +++ 4 files changed, 16 insertions(+) diff --git a/docs/source/rloo_trainer.md b/docs/source/rloo_trainer.md index c0298e97350..68173d218da 100644 --- a/docs/source/rloo_trainer.md +++ b/docs/source/rloo_trainer.md @@ -528,6 +528,7 @@ accelerate launch \ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ + --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ @@ -538,6 +539,15 @@ accelerate launch \ ### Configuration Tips +> [!TIP] +> For VLMs, truncating may remove image tokens, leading to errors during training. To avoid this, set `max_prompt_length=None` in the [`RLOOConfig`]. This allows the model to process the full sequence length without truncating image tokens. +> +> ```python +> RLOOConfig(max_prompt_length=None, ...) +> ``` +> +> Only use `max_prompt_length` when you've verified that truncation won't remove image tokens for the entire dataset. + - Use LoRA on vision-language projection layers - Enable 4-bit quantization to reduce memory usage - VLMs are memory-intensive — start with smaller batch sizes diff --git a/examples/scripts/rloo.py b/examples/scripts/rloo.py index faa90df30be..abeabb45b60 100644 --- a/examples/scripts/rloo.py +++ b/examples/scripts/rloo.py @@ -73,6 +73,7 @@ def make_conversation(example): gradient_checkpointing_kwargs=dict(use_reentrant=False), log_completions=True, num_completions_to_print=2, + max_prompt_length=2048, max_completion_length=1024, gradient_accumulation_steps=2, steps_per_generation=8, diff --git a/examples/scripts/rloo_vlm.py b/examples/scripts/rloo_vlm.py index 48fc63250b6..a98674db15c 100644 --- a/examples/scripts/rloo_vlm.py +++ b/examples/scripts/rloo_vlm.py @@ -37,6 +37,7 @@ --learning_rate 1e-5 \ --gradient_checkpointing \ --dtype bfloat16 \ + --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_vllm \ --vllm_mode colocate \ @@ -54,6 +55,7 @@ --output_dir rloo-SmolVLM2-2.2B-Instruct \ --learning_rate 1e-5 \ --dtype bfloat16 \ + --max_prompt_length 2048 \ --max_completion_length 1024 \ --use_peft \ --lora_target_modules "q_proj", "v_proj" \ diff --git a/tests/test_rloo_trainer.py b/tests/test_rloo_trainer.py index dff1958e3ad..82810b1cca5 100644 --- a/tests/test_rloo_trainer.py +++ b/tests/test_rloo_trainer.py @@ -1100,6 +1100,7 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = RLOOTrainer( @@ -1246,6 +1247,7 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, num_generations=3, max_completion_length=8, + max_prompt_length=18, report_to="none", use_vllm=True, vllm_mode="server", @@ -1287,6 +1289,7 @@ def reward_func(completions, **kwargs): per_device_train_batch_size=3, # reduce the batch size to reduce memory usage num_generations=3, # reduce the number of generations to reduce memory usage max_completion_length=8, # reduce the completion length to reduce memory usage + max_prompt_length=None, # disable prompt truncation, because usually, models don't support it report_to="none", ) trainer = RLOOTrainer( From 0707baa4e3fd4d5270abee1b41a5f382eb6b9a62 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 00:55:24 +0000 Subject: [PATCH 129/153] grpo with replay buffer --- .../grpo_with_replay_buffer_trainer.py | 36 ++++++++++--------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index e5c44710123..0eb89965fab 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -88,9 +88,15 @@ def _generate_and_score_completions( for prompt, image_list in zip(prompts, images, strict=True) ] - prompt_ids_list, completion_ids_list, num_items_in_batch, sampling_per_token_logps_list, extra_fields = ( - self._generate(prompts) - ) + ( + prompt_ids_list, + completion_ids_list, + tool_mask_list, + completions, + num_items_in_batch, + sampling_per_token_logps_list, + extra_fields, + ) = self._generate(prompts) # Convert lists of token IDs to padded tensors prompt_ids = [torch.tensor(ids, device=device) for ids in prompt_ids_list] @@ -106,6 +112,9 @@ def _generate_and_score_completions( sampling_per_token_logps = pad(sampling_per_token_logps, padding_value=0.0, padding_side="right") else: sampling_per_token_logps = None + if self.tools: + tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] + tool_mask = pad(tool_mask, padding_value=0, padding_side="right") # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -125,7 +134,9 @@ def _generate_and_score_completions( # Get forward_kwargs for models with multimodal inputs if images is not None: prompts_text = [ - apply_chat_template({"prompt": prompt}, self.processing_class, **self.chat_template_kwargs)["prompt"] + apply_chat_template( + {"prompt": prompt}, self.processing_class, tools=self.tools, **self.chat_template_kwargs + )["prompt"] for prompt in prompts ] prompt_inputs = self.processing_class(images=images, text=prompts_text, padding=True, return_tensors="pt") @@ -201,16 +212,6 @@ def _generate_and_score_completions( # Decode prompts_text = self.processing_class.batch_decode(prompt_ids, skip_special_tokens=True) completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - if is_conversational(inputs[0]): - completions = [] - for prompt, completion in zip(prompts, completions_text, strict=True): - bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" - if isinstance(bootstrap, list): # for VLM, the format might be [{"type": "text", "text": "..."}] - assert len(bootstrap) == 1 and bootstrap[0]["type"] == "text" - bootstrap = bootstrap[0]["text"] - completions.append([{"role": "assistant", "content": bootstrap + completion}]) - else: - completions = completions_text # Merge extra_fields from rollout_func into inputs for reward functions if extra_fields: @@ -285,7 +286,8 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - delta = delta[completion_mask.bool()] + mask = completion_mask.bool() if not self.tools else (completion_mask * (1 - tool_mask)).bool() + delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) self._metrics[mode]["sampling/sampling_logp_difference/mean"].append( @@ -295,7 +297,7 @@ def _generate_and_score_completions( self.accelerator.gather(max_delta).max().item() ) - flat_is_ratio = importance_sampling_ratio[completion_mask.bool()] + flat_is_ratio = importance_sampling_ratio[mask] min_importance_sampling_ratio = ( torch.min(flat_is_ratio) if flat_is_ratio.numel() > 0 else torch.tensor(0.0, device=device) ) @@ -356,6 +358,8 @@ def _generate_and_score_completions( output["token_type_ids"] = forward_kwargs["token_type_ids"] if images is not None: output["num_images"] = num_images + if self.tools: + output["tool_mask"] = tool_mask return output def slice_group_data( From 753d70d8ba0c9299f4a9f66de776a31605d12d09 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 01:45:51 +0000 Subject: [PATCH 130/153] jmespath dep --- pyproject.toml | 4 +++- trl/chat_template_utils.py | 2 +- .../grpo_with_replay_buffer_trainer.py | 2 +- trl/trainer/grpo_trainer.py | 18 ++++++++++++------ 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b84ff50d7f..82d9bba7c67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -122,7 +122,9 @@ dev = [ # vlm "Pillow", "torchvision", - "num2words==0.5.14" + "num2words==0.5.14", + # for response parsing (required for training with tools) + "jmespath", ] [tool.setuptools] diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index ab2582b6f73..ba886ee29f8 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -528,7 +528,7 @@ def parse_response(tokenizer: PreTrainedTokenizer, ids: list[int]) -> dict: # Hotfix: remove incorrectly appended EOS token from tool calls # See https://github.com/huggingface/transformers/issues/42249 parsed["content"] = parsed["content"].removesuffix(tokenizer.eos_token) - except Exception: + except ValueError: # Fallback: decode as plain text if parsing fails. This happens if the model outputs malformed tool calls. content = tokenizer.decode(ids, skip_special_tokens=True) parsed = {"role": "assistant", "content": content} diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index 0eb89965fab..ae57836d156 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -18,7 +18,7 @@ import torch from accelerate.utils import gather_object -from ...data_utils import apply_chat_template, is_conversational, prepare_multimodal_messages +from ...data_utils import apply_chat_template, prepare_multimodal_messages from ...trainer.grpo_trainer import GRPOTrainer from ...trainer.utils import nanmax, nanmin, nanstd, pad from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 3ae6588822f..0dc6c257d72 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -53,7 +53,7 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_jmespath_available, is_peft_available, is_rich_available from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response from ..data_utils import ( @@ -390,11 +390,17 @@ def __init__( self.rollout_func = rollout_func # Tools - if tools and not Version(transformers.__version__) >= Version("5.0.0.dev0"): - raise ImportError( - "Using tools with GRPOTrainer requires transformers version 5.0.0.dev0 or higher. Please upgrade " - "transformers to use this feature." - ) + if tools: + if not Version(transformers.__version__) >= Version("5.0.0.dev0"): + raise ImportError( + "Using tools with GRPOTrainer requires transformers version 5.0.0.dev0 or higher. Please upgrade " + "transformers to use this feature." + ) + if not is_jmespath_available(): + raise ImportError( + "Using tools with GRPOTrainer requires the jmespath library for response parsing. Please install " + "it with `pip install jmespath` to use this feature." + ) self.tools = tools or [] self._tool_dict = {tool.__name__: tool for tool in self.tools} # At the time of initial implementation, most tokenizers do not have built-in support for response schemas. From 06414f2c9c4eb73e8eedcde2d7b72345f9b9d929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 01:50:44 +0000 Subject: [PATCH 131/153] is_jmespath_available --- trl/import_utils.py | 5 +++++ trl/trainer/grpo_trainer.py | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/trl/import_utils.py b/trl/import_utils.py index 4d8a9c84ce0..0cde1cfc673 100644 --- a/trl/import_utils.py +++ b/trl/import_utils.py @@ -28,6 +28,7 @@ # Use same as transformers.utils.import_utils _deepspeed_available = _is_package_available("deepspeed") _fastapi_available = _is_package_available("fastapi") +_is_jmespath_available = _is_package_available("jmespath") _joblib_available = _is_package_available("joblib") _liger_kernel_available, _liger_kernel_version = _is_package_available("liger_kernel", return_version=True) _llm_blender_available = _is_package_available("llm_blender") @@ -50,6 +51,10 @@ def is_fastapi_available() -> bool: return _fastapi_available +def is_jmespath_available() -> bool: + return _is_jmespath_available + + def is_joblib_available() -> bool: return _joblib_available diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 0dc6c257d72..47c105106b7 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -53,7 +53,7 @@ is_wandb_available, ) from transformers.trainer_utils import seed_worker -from transformers.utils import is_datasets_available, is_jmespath_available, is_peft_available, is_rich_available +from transformers.utils import is_datasets_available, is_peft_available, is_rich_available from ..chat_template_utils import add_response_schema, get_training_chat_template, parse_response from ..data_utils import ( @@ -64,7 +64,7 @@ ) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient -from ..import_utils import is_liger_kernel_available, is_vllm_available +from ..import_utils import is_liger_kernel_available, is_vllm_available, is_jmespath_available from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation from ..models.utils import _ForwardRedirection from .base_trainer import BaseTrainer From 21792daad569d4eaf4ccf32a77eb0d8b0d04e310 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 01:52:32 +0000 Subject: [PATCH 132/153] style --- trl/trainer/grpo_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 47c105106b7..03803f3f8f6 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -64,7 +64,7 @@ ) from ..extras.profiling import profiling_context, profiling_decorator from ..extras.vllm_client import VLLMClient -from ..import_utils import is_liger_kernel_available, is_vllm_available, is_jmespath_available +from ..import_utils import is_jmespath_available, is_liger_kernel_available, is_vllm_available from ..models import prepare_deepspeed, prepare_fsdp, prepare_peft_model, unwrap_model_for_generation from ..models.utils import _ForwardRedirection from .base_trainer import BaseTrainer From 850a9ebcedb6b22a29c0459441689f4f96d35caf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 02:07:07 +0000 Subject: [PATCH 133/153] new section --- docs/source/grpo_trainer.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 2a0978f7b9c..748270e4da4 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -514,6 +514,10 @@ and the reward will be computed as the sum of the rewards from each function, or Note that [`GRPOTrainer`] supports multiple reward functions of different types. See the parameters documentation for more details. +## Agent Training + +To write... + ## Vision-Language Model (VLM) Training GRPO supports training Vision-Language Models (VLMs) on multimodal datasets containing both text and images. From 438b58644cb931fe81107e10090d27db84bbb190 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 03:23:29 +0000 Subject: [PATCH 134/153] ignore TestParseResponse for transformers<5 --- tests/test_chat_template_utils.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_chat_template_utils.py b/tests/test_chat_template_utils.py index 7468538b499..3db180c64d6 100644 --- a/tests/test_chat_template_utils.py +++ b/tests/test_chat_template_utils.py @@ -125,12 +125,12 @@ def test_qwen3(self): assert is_chat_template_prefix_preserving(tokenizer) is True +@pytest.mark.xfail( + condition=Version(transformers.__version__) < Version("5.0.0.dev0"), + reason="Tool parsing is not supported in transformers versions below 5.0.0.dev0", + strict=True, +) class TestParseResponse: - @pytest.mark.xfail( - condition=Version(transformers.__version__) < Version("5.0.0.dev0"), - reason="Tool parsing is not supported in transformers versions below 5.0.0.dev0", - strict=True, - ) def test_parse_response(self): tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen3MoeForSequenceClassification") tokenizer = add_response_schema(tokenizer) From 1c026cebdbd640d3d24560c0fe3a4bd8be9bb6c9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 04:03:03 +0000 Subject: [PATCH 135/153] fix qwen schema --- trl/chat_template_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index ba886ee29f8..c1a7a67adcf 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -145,12 +145,12 @@ } qwen3_schema = { - "x-regex": r"^(?:\n?(?P.+?)\n?\s*)?(?P.*?)(?=(?:|<\|im_end\|>|$))(?:(?P.+?))?\s*(?:<\|im_end\|>|$)", + "x-regex": r"^(?:\n?(?P.+?)\n?\s*)?(?P.*?)(?=(?:|<\|im_end\|>|$))(?:(?P.+?))?\s*(?:<\|im_end\|>|$)", "type": "object", "properties": { "role": {"const": "assistant"}, "content": {"type": "string"}, - "thinking": {"type": "string"}, + "reasoning_content": {"type": "string"}, "tool_calls": { "x-parser": "json", "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, From c54bf4fd48c60c3060fd902686464854f6660c2d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 04:30:46 +0000 Subject: [PATCH 136/153] another fix --- trl/trainer/grpo_trainer.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 03803f3f8f6..e91758c024b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1607,13 +1607,14 @@ def _generate(self, prompts: list): completion_ids[idx_with_tool] = pct[prompt_length:] + post_tool_ids[idx] # Decode post-tool completions - post_tool_completions = [parse_response(self.processing_class, ids) for ids in post_tool_ids] + post_tool_completions = [ + parse_response(self.processing_class, ids) if ids else {} for ids in post_tool_ids + ] # Add post-tool completions to the existing completions for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] - # When the post-tool if completly truncated, content is empty. - if post_tool_completions[idx]["content"] or "tool_calls" in post_tool_completions[idx]: + if post_tool_completions[idx]: # {} if post-tool completions completely truncated completions[idx_with_tool].append(post_tool_completions[idx]) # Check for further tool calls From 9f0aa3db1ddf244c9cefad45a984f18a5e71d7da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 06:52:45 +0000 Subject: [PATCH 137/153] remove unsused schemas --- trl/chat_template_utils.py | 141 ++----------------------------------- 1 file changed, 6 insertions(+), 135 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index c1a7a67adcf..3c4e6d7d852 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -12,138 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import TypeVar +from transformers import PreTrainedTokenizer -from transformers import PreTrainedTokenizer, ProcessorMixin - - -# These schemas are copy-pasted from https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py -cohere_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": r"<\|START_RESPONSE\|>(.*?)(?:<\|END_RESPONSE\|>|$)"}, - "thinking": {"type": "string", "x-regex": r"<\|START_THINKING\|>(.*?)(?:<\|END_THINKING\|>|$)"}, - "tool_calls": { - "x-regex": r"<\|START_ACTION\|>(.*?)(?:<\|END_ACTION\|>|$)", - "x-parser": "json", - "x-parser-args": { - "transform": "[*].{type: 'function', function: {name: tool_name, arguments: parameters}}" - }, - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -ernie_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": "\n(.*?)\n?"}, - "thinking": {"type": "string", "x-regex": r"(?:^|\s*)(.*?)\s*<\/think>"}, - "tool_calls": { - "x-regex-iterator": "(.*?)", - "type": "array", - "items": { - "type": "object", - "x-parser": "json", - "x-parser-args": {"transform": "{type: 'function', function: @}"}, - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -gpt_oss_schema = { - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string", "x-regex": r"<\|channel\|>final<\|message\|>(.*?)(?:<\|end\|>|$)"}, - "thinking": {"type": "string", "x-regex": r"<\|channel\|>analysis<\|message\|>(.*?)<\|end\|>"}, - "tool_calls": { - "x-regex-iterator": r"<\|channel\|>commentary (to=functions\..*?<\|message\|>.*?)(?:<\|call\|>|$)", - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string", "x-regex": r"^to=functions\.(\w+)"}, - "arguments": { - "type": "object", - "x-regex": r"<\|message\|>(.*)", - "x-parser": "json", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} - -smollm_schema = { - "x-regex": r"(?:\n?(?P.+?)\n?)?\s*(?:(?P.+?))?\s*(?P.+?)?\s*(?:<\|im_end\|>|$)", - "type": "object", - "properties": { - "role": {"const": "assistant"}, - "content": {"type": "string"}, - "thinking": {"type": "string"}, - "tool_calls": { - "x-parser": "json", - "x-parser-args": {"transform": "[{type: 'function', function: @}]"}, - "type": "array", - "items": { - "type": "object", - "properties": { - "type": {"const": "function"}, - "function": { - "type": "object", - "properties": { - "name": {"type": "string"}, - "arguments": { - "type": "object", - "additionalProperties": {}, - }, - }, - }, - }, - }, - }, - }, -} +# Adapted and corrected versions of the schemas from: +# https://github.com/huggingface/transformers/blob/main/tests/utils/test_chat_parsing_utils.py qwen3_schema = { "x-regex": r"^(?:\n?(?P.+?)\n?\s*)?(?P.*?)(?=(?:|<\|im_end\|>|$))(?:(?P.+?))?\s*(?:<\|im_end\|>|$)", "type": "object", @@ -266,10 +139,8 @@ {%- endif %} {%- endif %}""" -TokenizerOrProcessor = TypeVar("TokenizerOrProcessor", PreTrainedTokenizer, ProcessorMixin) - -def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor: +def add_response_schema(processor: PreTrainedTokenizer) -> PreTrainedTokenizer: r""" Adds the appropriate response schema to the given tokenizer or processor based on its chat template. @@ -278,11 +149,11 @@ def add_response_schema(processor: TokenizerOrProcessor) -> TokenizerOrProcessor templates. Args: - processor (`TokenizerOrProcessor`): + processor (`PreTrainedTokenizer`): Tokenizer or processor to which the response schema will be added. Returns: - `TokenizerOrProcessor`: + `PreTrainedTokenizer`: Tokenizer or processor with the added response schema. Examples: From fbb625f1904c79549c2cbff07706d341250dc143 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 07:08:35 +0000 Subject: [PATCH 138/153] rename processor to tokenizer in add_response_schema function --- trl/chat_template_utils.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/trl/chat_template_utils.py b/trl/chat_template_utils.py index 3c4e6d7d852..5bc601c2b3b 100644 --- a/trl/chat_template_utils.py +++ b/trl/chat_template_utils.py @@ -140,21 +140,21 @@ {%- endif %}""" -def add_response_schema(processor: PreTrainedTokenizer) -> PreTrainedTokenizer: +def add_response_schema(tokenizer: PreTrainedTokenizer) -> PreTrainedTokenizer: r""" - Adds the appropriate response schema to the given tokenizer or processor based on its chat template. + Adds the appropriate response schema to the given tokenizer based on its chat template. At the time of initial implementation, most tokenizers do not have built-in support for response schemas. While waiting for broader adoption, we provide this utility function to manually set the response schema for known chat templates. Args: - processor (`PreTrainedTokenizer`): - Tokenizer or processor to which the response schema will be added. + tokenizer (`PreTrainedTokenizer`): + Tokenizer to which the response schema will be added. Returns: `PreTrainedTokenizer`: - Tokenizer or processor with the added response schema. + Tokenizer with the added response schema. Examples: @@ -169,11 +169,9 @@ def add_response_schema(processor: PreTrainedTokenizer) -> PreTrainedTokenizer: {'role': 'assistant', 'content': '', 'tool_calls': [{'type': 'function', 'function': {'name': 'multiply', 'arguments': {'a': 3, 'b': 4}}}]} ``` """ - if processor.chat_template == qwen3_chat_template: - # The Qwen3 response schema seems to be smollm_schema, and not the qwen3_schema. See - # https://github.com/huggingface/transformers/issues/42220 - processor.response_schema = qwen3_schema - return processor + if tokenizer.chat_template == qwen3_chat_template: + tokenizer.response_schema = qwen3_schema + return tokenizer raise ValueError( "Unrecognized chat template, failed to add response schema. Please manually set the response schema on the " "tokenizer or processor." From ce6341b83053d9a01f19b57c2aa99e7bc4c4e723 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 07:14:50 +0000 Subject: [PATCH 139/153] deprecate max_prompt_length argument and add warning for future removal --- trl/trainer/grpo_config.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trl/trainer/grpo_config.py b/trl/trainer/grpo_config.py index 38a67fa5660..b2c48342d04 100644 --- a/trl/trainer/grpo_config.py +++ b/trl/trainer/grpo_config.py @@ -704,7 +704,7 @@ class GRPOConfig(TrainingArguments): # Deprecated arguments max_prompt_length: int | None = field( - default=512, + default=None, metadata={ "help": "Deprecated, filter your dataset before training to ensure that prompts do not exceed your " "desired length." @@ -778,6 +778,15 @@ def __post_init__(self): if self.delta is not None and self.use_liger_kernel: raise ValueError("Liger kernel does not support two-sided GRPO loss yet.") + if self.max_prompt_length is not None: + warnings.warn( + "The `max_prompt_length` argument is deprecated and will be removed in version 0.28.0. You should " + "instead filter your dataset before training to ensure that prompts do not exceed your desired " + "length.", + FutureWarning, + stacklevel=2, + ) + if self.wandb_log_unique_prompts is not None: warnings.warn( "The `wandb_log_unique_prompts` argument is deprecated and will be removed in version 0.27.0. Please " From 493881f7b59e06dc6dc282ac77aa7d3ce7562b54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= <45557362+qgallouedec@users.noreply.github.com> Date: Thu, 20 Nov 2025 00:18:29 -0700 Subject: [PATCH 140/153] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tests/test_grpo_trainer.py | 2 +- trl/trainer/grpo_trainer.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_grpo_trainer.py b/tests/test_grpo_trainer.py index eaf417d6457..ebefada34c0 100644 --- a/tests/test_grpo_trainer.py +++ b/tests/test_grpo_trainer.py @@ -1747,7 +1747,7 @@ def fake_generate(input_ids, **kwargs): device=input_ids.device, ) # fmt: on - else: # second call will only have two inputs in the batch, because two examples haave a tool call. + else: # second call will only have two inputs in the batch, because two examples have a tool call. completion_ids = torch.tensor( [ # 'Done!<|im_end|>' diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index e91758c024b..8ab37452506 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1535,10 +1535,11 @@ def _generate(self, prompts: list): except Exception as e: result = {"error": str(e)} tool_failure_count += 1 + tool_message = {"role": "tool", "name": function["name"], "content": str(result)} else: result = {"error": f"Unsupported tool call type: {tool_call['type']}"} + tool_message = {"role": "tool", "name": tool_call.get("name", "unknown"), "content": str(result)} tool_call["result"] = result - tool_message = {"role": "tool", "name": function["name"], "content": str(result)} prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) From 4d6a064779198ec4466f7751162267ec8c23d0c6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 07:31:30 +0000 Subject: [PATCH 141/153] nit simplification --- trl/trainer/grpo_trainer.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 8ab37452506..94f5ebb2efe 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1530,16 +1530,17 @@ def _generate(self, prompts: list): tool_call_count += 1 if tool_call["type"] == "function": function = tool_call["function"] + name = function["name"] try: - result = self._tool_dict[function["name"]](**function["arguments"]) + result = self._tool_dict[name](**function["arguments"]) except Exception as e: - result = {"error": str(e)} tool_failure_count += 1 - tool_message = {"role": "tool", "name": function["name"], "content": str(result)} + result = {"error": str(e)} else: + tool_failure_count += 1 + name = tool_call.get("name", "unknown") result = {"error": f"Unsupported tool call type: {tool_call['type']}"} - tool_message = {"role": "tool", "name": tool_call.get("name", "unknown"), "content": str(result)} - tool_call["result"] = result + tool_message = {"role": "tool", "name": name, "content": str(result)} prompt_completion_tool.append(tool_message) completions[idx_with_tool].append(tool_message) From 5a9bb2038466aa6f48740bef1e128d0385c114d4 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 20 Nov 2025 17:31:22 +0100 Subject: [PATCH 142/153] Docs updated --- docs/source/grpo_trainer.md | 60 ++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 748270e4da4..77e43829aaa 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -516,7 +516,65 @@ Note that [`GRPOTrainer`] supports multiple reward functions of different types. ## Agent Training -To write... +GRPO supports **agent training** through the `tools` argument in [`GRPOTrainer`]. +This parameter expects a list of Python functions that define the tools available to the agent: + +```python +from trl import GRPOTrainer + +trainer = GRPOTrainer( + tools=[tool1, tool2], + ..., +) +``` + +Each tool must be a standard Python function with **type-hinted arguments and return types**, along with a **Google-style docstring** describing its purpose, arguments, and return value. +For more details, see the [Passing tools guide](https://huggingface.co/docs/transformers/en/chat_extras#passing-tools). + +Example: + +```python +from trl import GRPOTrainer + +def multiply(a: int, b: int) -> int: + """ + Multiplies two integers. + + Args: + a: The first integer. + b: The second integer. + + Returns: + The product of the two integers. + """ + return a * b + +trainer = GRPOTrainer( + tools=[multiply], + ..., +) +``` + +### Supported Models + +Tested with: + +- **Qwen3** — e.g., `Qwen/Qwen3-0.6B` + +> [!TIP] +> Compatibility with all LLMs is not guaranteed. If you believe a model should be supported, feel free to open an issue on GitHub — or better yet, submit a pull request with the required changes. + +### Quick Start + +Use [grpo\_agent.py](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_agent.py) to fine-tune a LLM for agentic workflows. + +```bash +accelerate launch \ + --config_file=examples/accelerate_configs/deepspeed_zero3.yaml \ + examples/scripts/grpo_agent.py \ + --model_name_or_path Qwen/Qwen3-0.6B + ... +``` ## Vision-Language Model (VLM) Training From 90a1ed103066927ff481e3dd9853faa351ad2975 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 18:08:45 +0000 Subject: [PATCH 143/153] Add monkey-patch for vLLM compatibility with TRL --- trl/__init__.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/trl/__init__.py b/trl/__init__.py index 44228d2092e..02d2bd5804c 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -196,3 +196,23 @@ module_spec=__spec__, extra_objects={"__version__": __version__}, ) + + +# Monkey-patch for vLLM. +# Bug introduced in https://github.com/vllm-project/vllm/pull/52 +# Fixed inhttps://github.com/vllm-project/vllm/pull/28471 (released in v0.11.1) +# Since TRL currently only supports vLLM v0.10.2, we patch it here. This can be removed when TRL requires vLLM >=0.11.1 +from .import_utils import is_vllm_available # noqa: E402 + + +if is_vllm_available(): + import vllm.model_executor.model_loader.weight_utils + from tqdm import tqdm + + class DisabledTqdm(tqdm): + def __init__(self, *args, **kwargs): + kwargs["disable"] = True + super().__init__(*args, **kwargs) + + # overwrite the class in the dependency + vllm.model_executor.model_loader.weight_utils.DisabledTqdm = DisabledTqdm From a584e4224392df1387b33f32bed130032c0a5c91 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Thu, 20 Nov 2025 18:44:20 +0000 Subject: [PATCH 144/153] VLLM_LOGGING_LEVEL", "ERROR --- trl/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/trl/__init__.py b/trl/__init__.py index 02d2bd5804c..f64e392a725 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -206,6 +206,9 @@ if is_vllm_available(): + import os + + os.environ["VLLM_LOGGING_LEVEL"] = os.getenv("VLLM_LOGGING_LEVEL", "ERROR") import vllm.model_executor.model_loader.weight_utils from tqdm import tqdm From caf1ad259f11a6d0571ce53f742b49ac002f0e4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 25 Nov 2025 05:03:52 +0000 Subject: [PATCH 145/153] flip tool mask --- .../grpo_with_replay_buffer_trainer.py | 4 ++-- trl/trainer/grpo_trainer.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py index ae57836d156..90af8177e83 100644 --- a/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py +++ b/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_trainer.py @@ -114,7 +114,7 @@ def _generate_and_score_completions( sampling_per_token_logps = None if self.tools: tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] - tool_mask = pad(tool_mask, padding_value=0, padding_side="right") + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -286,7 +286,7 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - mask = completion_mask.bool() if not self.tools else (completion_mask * (1 - tool_mask)).bool() + mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index a54291ac975..c29b714a6e5 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1502,7 +1502,7 @@ def _generate(self, prompts: list): tool_calls = [completion[0].get("tool_calls") for completion in completions] idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] tool_calls = [tool_calls[idx] for idx in idxs_with_tool] - tool_mask = [[0] * len(ids) for ids in completion_ids] + tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere tool_call_count = 0 tool_failure_count = 0 else: @@ -1557,7 +1557,7 @@ def _generate(self, prompts: list): prompt_length = len(prompt_ids[idx_with_tool]) ct = pct_ids[idx][prompt_length : prompt_length + self.max_completion_length] completion_ids[idx_with_tool] = ct - tool_mask[idx_with_tool] += [0] * (len(ct) - len(tool_mask[idx_with_tool])) + tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] @@ -1592,7 +1592,7 @@ def _generate(self, prompts: list): # If still exceeding max length, truncate completion_tool_ids as well prompt_completion_tool_ids[idx] = prompt_completion_tool_ids[idx][:-excess_length] - # Update tool_mask: the tool result should be 1 and the post-tool 0 + # Update tool_mask: the tool result should be 0 and the post-tool 1 for idx in range(len(idxs_with_tool)): idx_with_tool = idxs_with_tool[idx] prompt_completion_tool_length = len(prompt_completion_tool_ids[idx]) @@ -1600,7 +1600,7 @@ def _generate(self, prompts: list): completion_length = len(completion_ids[idx_with_tool]) post_tool_length = len(post_tool_ids[idx]) tool_length = prompt_completion_tool_length - prompt_length - completion_length - tool_mask[idx_with_tool] += [1] * tool_length + [0] * post_tool_length + tool_mask[idx_with_tool] += [0] * tool_length + [1] * post_tool_length if logprobs is not None: logprobs[idx_with_tool] += [0.0] * tool_length + post_tool_logprobs[idx] @@ -1730,7 +1730,7 @@ def _generate_and_score_completions( sampling_per_token_logps = None if self.tools: tool_mask = [torch.tensor(mask, device=device) for mask in tool_mask_list] - tool_mask = pad(tool_mask, padding_value=0, padding_side="right") + tool_mask = pad(tool_mask, padding_value=1, padding_side="right") # 0 for tool result tokens, 1 elsewhere # If mask_truncated_completions is enabled, zero out truncated completions in completion_mask if self.mask_truncated_completions: @@ -1900,7 +1900,7 @@ def _generate_and_score_completions( if self.use_vllm and self.vllm_importance_sampling_correction: delta = torch.abs(old_per_token_logps - sampling_per_token_logps) - mask = completion_mask.bool() if not self.tools else (completion_mask * (1 - tool_mask)).bool() + mask = completion_mask.bool() if not self.tools else (completion_mask * tool_mask).bool() delta = delta[mask] mean_delta = torch.mean(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) max_delta = torch.max(delta) if delta.numel() > 0 else torch.tensor(0.0, device=device) @@ -2103,7 +2103,7 @@ def _compute_loss(self, model, inputs): if self.beta != 0.0: per_token_loss = per_token_loss + self.beta * per_token_kl - mask = completion_mask if not self.tools else completion_mask * (1 - inputs["tool_mask"]) + mask = completion_mask if not self.tools else completion_mask * inputs["tool_mask"] if self.loss_type == "grpo": loss = ((per_token_loss * mask).sum(-1) / mask.sum(-1).clamp(min=1.0)).mean() loss = loss / self.current_gradient_accumulation_steps From 94c2ff21e183068bf81e7036402809d5c2c3c8ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 25 Nov 2025 05:30:56 +0000 Subject: [PATCH 146/153] isolate tool call loop --- trl/trainer/grpo_trainer.py | 83 +++++++++++++++++++++---------------- 1 file changed, 48 insertions(+), 35 deletions(-) diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index c29b714a6e5..41f4422065b 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1474,42 +1474,15 @@ def _generate_single_turn(self, prompts: list): return prompt_ids, completion_ids, logprobs, extra_fields - def _generate(self, prompts: list): - device = self.accelerator.device - mode = "train" if self.model.training else "eval" - - # Copy the prompts to avoid modifying the original list - prompts = copy.deepcopy(prompts) - - prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) - - # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. - if is_conversational({"prompt": prompts[0]}): - if ( - Version(transformers.__version__) >= Version("5.0.0.dev0") # parse_response added in v5 - and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors - and self.processing_class.response_schema is not None # only works if the tokenizer has a schema - ): - completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] - else: - contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - completions = [[{"role": "assistant", "content": content}] for content in contents] - else: - completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) - - # Extract tool calls from the completions - if self.tools: - tool_calls = [completion[0].get("tool_calls") for completion in completions] - idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] - tool_calls = [tool_calls[idx] for idx in idxs_with_tool] - tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere - tool_call_count = 0 - tool_failure_count = 0 - else: - idxs_with_tool = [] - tool_mask = None - + def _tool_call_loop(self, prompts, prompt_ids, completion_ids, completions, logprobs): # Tool execution loop: execute tools, then regenerate completions with tool results appended to the prompt + tool_calls = [completion[0].get("tool_calls") for completion in completions] + idxs_with_tool = [idx for idx, tool_call in enumerate(tool_calls) if tool_call] + tool_calls = [tool_calls[idx] for idx in idxs_with_tool] + tool_mask = [[1] * len(ids) for ids in completion_ids] # 0 for tool result tokens, 1 elsewhere + tool_call_count = 0 + tool_failure_count = 0 + while idxs_with_tool: prompt_completion_tools = [prompts[i] for i in idxs_with_tool] # select only prompts that need tool calls @@ -1528,6 +1501,7 @@ def _generate(self, prompts: list): idx_with_tool = idxs_with_tool[idx] tool_call_list = tool_calls[idx] prompt_completion_tool = prompt_completion_tools[idx] + # Append the last assistant message (which triggered tool_calls) to the prompt prompt_completion_tool.append(completions[idx_with_tool][-1]) for tool_call in tool_call_list: tool_call_count += 1 @@ -1560,6 +1534,7 @@ def _generate(self, prompts: list): tool_mask[idx_with_tool] += [1] * (len(ct) - len(tool_mask[idx_with_tool])) if logprobs is not None: logprobs[idx_with_tool] += [0.0] * (len(ct) - len(logprobs[idx_with_tool])) + # Keep only non-overlong items for further processing idxs_with_tool = [idx for idx, o in zip(idxs_with_tool, overlong, strict=True) if not o] prompt_completion_tools = [pct for pct, o in zip(prompt_completion_tools, overlong, strict=True) if not o] if not idxs_with_tool: @@ -1627,6 +1602,44 @@ def _generate(self, prompts: list): idxs_with_tool = [idx for idx, tool_call in zip(idxs_with_tool, tool_calls, strict=True) if tool_call] tool_calls = [tool_call for tool_call in tool_calls if tool_call] + return tool_mask, completions, completion_ids, logprobs, tool_call_count, tool_failure_count + + def _generate(self, prompts: list): + device = self.accelerator.device + mode = "train" if self.model.training else "eval" + + # Copy the prompts to avoid modifying the original list + prompts = copy.deepcopy(prompts) + + prompt_ids, completion_ids, logprobs, extra_fields = self._generate_single_turn(prompts) + + # Decode completions. It's important to use `parse_response` when possible, because it handles tool calls. + if is_conversational({"prompt": prompts[0]}): + if ( + Version(transformers.__version__) >= Version("5.0.0.dev0") # parse_response added in v5 + and isinstance(self.processing_class, PreTrainedTokenizerBase) # doesn't work with processors + and self.processing_class.response_schema is not None # only works if the tokenizer has a schema + ): + completions = [[parse_response(self.processing_class, ids)] for ids in completion_ids] + else: + contents = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + completions = [[{"role": "assistant", "content": content}] for content in contents] + else: + completions = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) + + # Extract tool calls from the completions and (possibly) execute them + if self.tools: + ( + tool_mask, + completions, + completion_ids, + logprobs, + tool_call_count, + tool_failure_count, + ) = self._tool_call_loop(prompts, prompt_ids, completion_ids, completions, logprobs) + else: + tool_mask = None + # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) From 3cbb28ee796721522e8b257d32cf00aeba896993 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 25 Nov 2025 16:35:14 +0100 Subject: [PATCH 147/153] Add example script --- examples/scripts/grpo_agent.py | 238 +++++++++++++++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100644 examples/scripts/grpo_agent.py diff --git a/examples/scripts/grpo_agent.py b/examples/scripts/grpo_agent.py new file mode 100644 index 00000000000..ea24b3dbfee --- /dev/null +++ b/examples/scripts/grpo_agent.py @@ -0,0 +1,238 @@ +# Copyright 2020-2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# /// script +# dependencies = [ +# "trl", +# "peft", +# "trackio", +# "kernels", +# ] +# /// + +""" +# Full training +``` +python grpo_agent.py \ + --model_name_or_path Qwen/Qwen3-1.7B \ + --output_dir grpo_biogrid_qwen_3g-1.7b \ + --push_to_hub True \ + --use_vllm True \ + --vllm_mode colocate \ + --vllm_enable_sleep_mode False \ + --max_completion_length 1024 \ + --report_to trackio \ + --log_completions True \ + --max_steps 200 +``` +""" + +import os +import sqlite3 +import signal +from contextlib import contextmanager +import textwrap +from datasets import load_dataset +from trl import ( + GRPOConfig, + GRPOTrainer, + ModelConfig, + ScriptArguments, + TrlParser, +) + +# Enable logging in a Hugging Face Space +os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") + +# ------------------------ +# Reward functions +# ------------------------ + + +def correctness_reward(completions, answer, **kwargs): # measures Yes/No answer correctness + rewards = [] + for completion, ans in zip(completions, answer): + guess = completion[-1]["content"].strip() + reward = 0.0 + + if "*Yes*" not in guess and "*No*" not in guess: + reward -= 0.2 + elif ("*Yes*" in guess and ans == "Yes") or ("*No*" in guess and ans == "No"): + reward += 0.5 + elif ("*Yes*" in guess and ans == "No") or ("*No*" in guess and ans == "Yes"): + reward -= 0.2 + rewards.append(reward) + + return rewards + + +def tool_usage_reward(completions, **kwargs): # rewards correct tool usage + rewards = [] + for completion in completions: + tool_used = False + reward = 0.0 + + for turn in completion: + if turn["role"] == "tool": + tool_used = True + if "error" in turn["content"]: + reward -= 0.3 + + if not tool_used: + reward -= 0.3 + elif reward == 0.0: + reward += 0.25 + + rewards.append(reward) + return rewards + + +def structure_reward(completions, **kwargs): # rewards proper assistant structure + rewards = [] + + for completion in completions: + has_call = False + has_response = False + has_other = False + + for turn in completion: + if turn.get("role") == "assistant" and turn.get("tool_calls"): + has_call = True + elif turn.get("role") == "tool": + has_response = True + else: + content = turn.get("content") + if content and content.strip() not in ["", ""]: + has_other = True + + reward = 0.0 + if has_call and has_response and has_other: + reward = 0.25 + elif has_call and has_response and not has_other: + reward = -0.15 + elif has_call and not has_response: + reward = -0.15 + + rewards.append(reward) + + return rewards + + +# ------------------------ +# Database tool function +# ------------------------ +class TimeoutError(Exception): + """Raised when a function call times out.""" + pass + +@contextmanager +def timeout(seconds): + """Context manager that raises TimeoutError if execution exceeds time limit.""" + def timeout_handler(signum, frame): + raise TimeoutError(f"Operation timed out after {seconds} seconds") + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(seconds) + try: + yield + finally: + signal.alarm(0) + +def query_biogrid(sql_command: str) -> list[tuple]: + """ + Execute a read-only SQL command on the BioGRID database. + + BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics. + + Args: + sql_command: The SQL command to execute. + + Returns: + A list of tuples containing the query results. + """ + with timeout(5): + conn = sqlite3.connect("file:biogrid.db?mode=ro", uri=True) + cursor = conn.cursor() + try: + cursor.execute(sql_command) + results = cursor.fetchall() + finally: + conn.close() + return results + + +# ------------------------ +# Dataset formatting +# ------------------------ +def format_example(example): + question = example["question"] + preamble = textwrap.dedent("""\ + You may use the BioGRID database to answer the question. Feel free to run exploratory SQL queries to familiarize yourself with the database structure if needed (e.g., `SELECT * FROM interactions LIMIT 1;` or `PRAGMA table_info(interactions);`). + Provide your final answer enclosed in stars, such as `*Yes*` or `*No*`. + Facts: + - The NCBI Taxonomy identifier for humans is taxid:9606 + """) + content = f"{preamble}\nQuestion: {question}" + prompt = [{"role": "user", "content": content}] + return {"prompt": prompt} + + +# ------------------------ +# Main +# ------------------------ +if __name__ == "__main__": + parser = TrlParser((ScriptArguments, GRPOConfig, ModelConfig)) + script_args, training_args, model_args = parser.parse_args_and_config() + + # ------------------------ + # Create DB + # ------------------------ + print("Creating biogrid.db...") + biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train") + biogrid_dataset.to_sql("interactions", "sqlite:///biogrid.db", if_exists="replace") + print("biogrid.db created.") + + # ------------------------ + # Load and format dataset + # ------------------------ + dataset = load_dataset("qgallouedec/biogrid_qa", split="train") + dataset = dataset.map(format_example, remove_columns=["question"]) + + train_dataset = dataset + eval_dataset = None # No eval by default, can be added if needed + + training_args.chat_template_kwargs={"enable_thinking": False} + + # ------------------------ + # Initialize trainer + # ------------------------ + trainer = GRPOTrainer( + model=model_args.model_name_or_path, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + tools=[query_biogrid], + reward_funcs=[correctness_reward, tool_usage_reward, structure_reward], + args=training_args + ) + + # ------------------------ + # Train + # ------------------------ + trainer.train() + + # ------------------------ + # Save and push + # ------------------------ + trainer.save_model(training_args.output_dir) + if training_args.push_to_hub: + trainer.push_to_hub(dataset_name=script_args.dataset_name) From 6074ade65ccc0c32321ba8b37b6ab0adc523e143 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 25 Nov 2025 16:37:27 +0100 Subject: [PATCH 148/153] code quality --- examples/scripts/grpo_agent.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/examples/scripts/grpo_agent.py b/examples/scripts/grpo_agent.py index ea24b3dbfee..aaecdba3ff1 100644 --- a/examples/scripts/grpo_agent.py +++ b/examples/scripts/grpo_agent.py @@ -39,11 +39,13 @@ """ import os -import sqlite3 import signal -from contextlib import contextmanager +import sqlite3 import textwrap +from contextlib import contextmanager + from datasets import load_dataset + from trl import ( GRPOConfig, GRPOTrainer, @@ -52,6 +54,7 @@ TrlParser, ) + # Enable logging in a Hugging Face Space os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") @@ -60,12 +63,12 @@ # ------------------------ -def correctness_reward(completions, answer, **kwargs): # measures Yes/No answer correctness +def correctness_reward(completions, answer, **kwargs): # measures Yes/No answer correctness rewards = [] - for completion, ans in zip(completions, answer): + for completion, ans in zip(completions, answer, strict=False): guess = completion[-1]["content"].strip() reward = 0.0 - + if "*Yes*" not in guess and "*No*" not in guess: reward -= 0.2 elif ("*Yes*" in guess and ans == "Yes") or ("*No*" in guess and ans == "No"): @@ -77,28 +80,28 @@ def correctness_reward(completions, answer, **kwargs): # measures Yes/No answer return rewards -def tool_usage_reward(completions, **kwargs): # rewards correct tool usage +def tool_usage_reward(completions, **kwargs): # rewards correct tool usage rewards = [] for completion in completions: tool_used = False reward = 0.0 - + for turn in completion: if turn["role"] == "tool": tool_used = True if "error" in turn["content"]: reward -= 0.3 - + if not tool_used: reward -= 0.3 elif reward == 0.0: reward += 0.25 - + rewards.append(reward) return rewards -def structure_reward(completions, **kwargs): # rewards proper assistant structure +def structure_reward(completions, **kwargs): # rewards proper assistant structure rewards = [] for completion in completions: @@ -134,13 +137,17 @@ def structure_reward(completions, **kwargs): # rewards proper assistant structur # ------------------------ class TimeoutError(Exception): """Raised when a function call times out.""" + pass + @contextmanager def timeout(seconds): """Context manager that raises TimeoutError if execution exceeds time limit.""" + def timeout_handler(signum, frame): raise TimeoutError(f"Operation timed out after {seconds} seconds") + signal.signal(signal.SIGALRM, timeout_handler) signal.alarm(seconds) try: @@ -148,6 +155,7 @@ def timeout_handler(signum, frame): finally: signal.alarm(0) + def query_biogrid(sql_command: str) -> list[tuple]: """ Execute a read-only SQL command on the BioGRID database. @@ -211,7 +219,7 @@ def format_example(example): train_dataset = dataset eval_dataset = None # No eval by default, can be added if needed - training_args.chat_template_kwargs={"enable_thinking": False} + training_args.chat_template_kwargs = {"enable_thinking": False} # ------------------------ # Initialize trainer @@ -222,7 +230,7 @@ def format_example(example): eval_dataset=eval_dataset, tools=[query_biogrid], reward_funcs=[correctness_reward, tool_usage_reward, structure_reward], - args=training_args + args=training_args, ) # ------------------------ From fc3d7594e0547032ef65bf26806649b6caa5e41d Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 25 Nov 2025 17:07:30 +0100 Subject: [PATCH 149/153] Update to more strict reward funcs --- examples/scripts/grpo_agent.py | 78 ++++++++++++++++++++++------------ 1 file changed, 50 insertions(+), 28 deletions(-) diff --git a/examples/scripts/grpo_agent.py b/examples/scripts/grpo_agent.py index aaecdba3ff1..0ad68606b9b 100644 --- a/examples/scripts/grpo_agent.py +++ b/examples/scripts/grpo_agent.py @@ -63,45 +63,63 @@ # ------------------------ -def correctness_reward(completions, answer, **kwargs): # measures Yes/No answer correctness +def correctness_reward(completions, answer, **kwargs): + """ + Reward Yes/No correctness. + Robust to extra spaces, punctuation, or surrounding stars. + """ rewards = [] for completion, ans in zip(completions, answer, strict=False): - guess = completion[-1]["content"].strip() + guess = completion[-1]["content"].strip().lower() + guess_clean = guess.replace("*", "").replace("`", "").strip() reward = 0.0 - if "*Yes*" not in guess and "*No*" not in guess: - reward -= 0.2 - elif ("*Yes*" in guess and ans == "Yes") or ("*No*" in guess and ans == "No"): - reward += 0.5 - elif ("*Yes*" in guess and ans == "No") or ("*No*" in guess and ans == "Yes"): - reward -= 0.2 - rewards.append(reward) + if guess_clean not in ["yes", "no"]: + reward -= 0.2 # didn't produce a valid Yes/No + elif guess_clean == ans.lower(): + reward += 0.5 # correct answer + else: + reward -= 0.2 # incorrect answer + rewards.append(reward) return rewards -def tool_usage_reward(completions, **kwargs): # rewards correct tool usage +def tool_usage_reward(completions, **kwargs): + """ + Reward proper tool usage. + Looks for assistant tool_calls and corresponding tool responses. + """ rewards = [] for completion in completions: - tool_used = False + tool_called = False + tool_response_ok = False reward = 0.0 for turn in completion: - if turn["role"] == "tool": - tool_used = True - if "error" in turn["content"]: - reward -= 0.3 - - if not tool_used: - reward -= 0.3 - elif reward == 0.0: - reward += 0.25 + if turn.get("role") == "assistant" and turn.get("tool_calls"): + tool_called = True + if turn.get("role") == "tool" and turn.get("content"): + tool_response_ok = True + if "error" in turn["content"].lower(): + reward -= 0.3 # penalize errors + + if tool_called and tool_response_ok: + reward += 0.25 # reward correct tool usage + elif not tool_called: + reward -= 0.3 # penalize missing tool call + elif tool_called and not tool_response_ok: + reward -= 0.2 # called tool but no response rewards.append(reward) return rewards -def structure_reward(completions, **kwargs): # rewards proper assistant structure +def structure_reward(completions, **kwargs): + """ + Reward proper assistant structure. + Encourages a logical sequence: tool call + response + optional extra content. + """ rewards = [] for completion in completions: @@ -110,22 +128,26 @@ def structure_reward(completions, **kwargs): # rewards proper assistant structu has_other = False for turn in completion: - if turn.get("role") == "assistant" and turn.get("tool_calls"): + role = turn.get("role") + if role == "assistant" and turn.get("tool_calls"): has_call = True - elif turn.get("role") == "tool": + elif role == "tool": has_response = True else: content = turn.get("content") if content and content.strip() not in ["", ""]: has_other = True - reward = 0.0 - if has_call and has_response and has_other: - reward = 0.25 - elif has_call and has_response and not has_other: - reward = -0.15 + # Reward sequences + if has_call and has_response: + if has_other: + reward = 0.25 + else: + reward = 0.1 # still positive even without extra text elif has_call and not has_response: reward = -0.15 + else: + reward = 0.0 # neutral if no call rewards.append(reward) From e37508dedf6ba5d772995d423c6fa839f1355cd7 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Tue, 25 Nov 2025 17:14:24 +0100 Subject: [PATCH 150/153] Update steps --- examples/scripts/grpo_agent.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/scripts/grpo_agent.py b/examples/scripts/grpo_agent.py index 0ad68606b9b..92a55dd3210 100644 --- a/examples/scripts/grpo_agent.py +++ b/examples/scripts/grpo_agent.py @@ -34,7 +34,7 @@ --max_completion_length 1024 \ --report_to trackio \ --log_completions True \ - --max_steps 200 + --max_steps 400 ``` """ From af749c1303f149df61e93912d614847b091b5186 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Quentin=20Gallou=C3=A9dec?= Date: Tue, 25 Nov 2025 19:05:07 +0000 Subject: [PATCH 151/153] Clarify token counting in reward metrics and adjust completion length calculation to exclude tool tokens --- docs/source/grpo_trainer.md | 14 +++++++------- trl/trainer/grpo_trainer.py | 5 ++++- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/docs/source/grpo_trainer.md b/docs/source/grpo_trainer.md index 77e43829aaa..e47d166d3e5 100644 --- a/docs/source/grpo_trainer.md +++ b/docs/source/grpo_trainer.md @@ -141,14 +141,14 @@ This constant is recommended to be the maximum completion length. To use this fo While training and evaluating, we record the following reward metrics: -- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. +- `num_tokens`: The total number of tokens processed so far, including both prompts and completions. When using tools, only non-tool tokens are counted. - `step_time`: The average time (in seconds) taken per training step (including generation). -- `completions/mean_length`: The average length of generated completions. -- `completions/min_length`: The minimum length of generated completions. -- `completions/max_length`: The maximum length of generated completions. -- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. -- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS. -- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. +- `completions/mean_length`: The average length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/min_length`: The minimum length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/max_length`: The maximum length of generated completions. When using tools, only non-tool tokens are counted. +- `completions/mean_terminated_length`: The average length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. +- `completions/min_terminated_length`: The minimum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. +- `completions/max_terminated_length`: The maximum length of generated completions that terminate with EOS. When using tools, only non-tool tokens are counted. - `completions/clipped_ratio`: The ratio of truncated (clipped) completions. - `reward/{reward_func_name}/mean`: The average reward from a specific reward function. - `reward/{reward_func_name}/std`: The standard deviation of the reward from a specific reward function. diff --git a/trl/trainer/grpo_trainer.py b/trl/trainer/grpo_trainer.py index 41f4422065b..ffb5e525e60 100644 --- a/trl/trainer/grpo_trainer.py +++ b/trl/trainer/grpo_trainer.py @@ -1642,7 +1642,10 @@ def _generate(self, prompts: list): # Get completion length per sequence, used for logging prompt_lengths = torch.tensor([len(ids) for ids in prompt_ids], device=device) - completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) + if tool_mask is not None: # count only non-tool tokens (tool_mask=1) + completion_lengths = torch.tensor([sum(mask) for mask in tool_mask], device=device) + else: + completion_lengths = torch.tensor([len(ids) for ids in completion_ids], device=device) agg_prompt_lengths = self.accelerator.gather(prompt_lengths) agg_completion_lengths = self.accelerator.gather(completion_lengths) total_prompt_tokens = agg_prompt_lengths.sum() From 988efc10405cc75dc872ed745ae0b99f02f0fd6d Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Thu, 27 Nov 2025 17:48:23 +0100 Subject: [PATCH 152/153] Updated example script with elaborated reward funcs --- examples/scripts/grpo_agent.py | 147 ++++++++++++++++++++++++--------- 1 file changed, 108 insertions(+), 39 deletions(-) diff --git a/examples/scripts/grpo_agent.py b/examples/scripts/grpo_agent.py index 92a55dd3210..df0b600e104 100644 --- a/examples/scripts/grpo_agent.py +++ b/examples/scripts/grpo_agent.py @@ -24,7 +24,7 @@ """ # Full training ``` -python grpo_agent.py \ +python examples/scripts/grpo_agent.py \ --model_name_or_path Qwen/Qwen3-1.7B \ --output_dir grpo_biogrid_qwen_3g-1.7b \ --push_to_hub True \ @@ -39,6 +39,7 @@ """ import os +import re import signal import sqlite3 import textwrap @@ -63,55 +64,98 @@ # ------------------------ -def correctness_reward(completions, answer, **kwargs): +def query_reward(completions, answer, **kwargs): """ - Reward Yes/No correctness. - Robust to extra spaces, punctuation, or surrounding stars. + Reward query strategy: + - Penalize more than 2 queries + - Penalize generic queries (LIMIT 1 / PRAGMA) + - Reward usage of WHERE + - Reward evidence supporting the final answer """ rewards = [] + for completion, ans in zip(completions, answer, strict=False): - guess = completion[-1]["content"].strip().lower() - guess_clean = guess.replace("*", "").replace("`", "").strip() reward = 0.0 + sql_queries = [] + tool_results = [] - if guess_clean not in ["yes", "no"]: - reward -= 0.2 # didn't produce a valid Yes/No - elif guess_clean == ans.lower(): - reward += 0.5 # correct answer + # collect all SQL queries and tool results + for turn in completion: + if turn.get("tool_calls"): + for call in turn["tool_calls"]: + sql = call["function"]["arguments"].get("sql_command", "").lower() + sql_queries.append(sql) + if turn.get("role") == "tool" and turn.get("content"): + tool_results.append(turn["content"]) + + # --- penalize too many queries --- + if len(sql_queries) > 3: + reward -= 1.5 + + # --- check query quality --- + where_count = 0 + for q in sql_queries: + if "limit 1" in q: + reward -= 1.0 + if " where " not in q: + reward -= 0.5 + else: + where_count += 1 + reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage + + # --- evidence check: do queries support the answer? --- + combined_results = [] + error_detected = False + + for res in tool_results: + if isinstance(res, dict) and "error" in res: + error_detected = True + elif isinstance(res, list): + combined_results.extend(res) + + # if error detected, penalize heavily + if error_detected: + reward -= 2.0 + elif len(sql_queries) == 0: + reward -= 1.5 else: - reward -= 0.2 # incorrect answer + has_hits = len(combined_results) > 0 + correct_answer = ans.lower() + if (has_hits and correct_answer == "yes") or (not has_hits and correct_answer == "no"): + reward += 2.0 + else: + reward -= 1.5 rewards.append(reward) + return rewards -def tool_usage_reward(completions, **kwargs): +def correctness_reward(completions, answer, **kwargs): """ - Reward proper tool usage. - Looks for assistant tool_calls and corresponding tool responses. + Reward Yes/No correctness. + Model must provide final answer enclosed in stars — *yes* or *no*. + Does not reward informal yes/no buried in text. """ rewards = [] - for completion in completions: - tool_called = False - tool_response_ok = False - reward = 0.0 + for completion, ans in zip(completions, answer, strict=False): + raw = completion[-1]["content"].lower() - for turn in completion: - if turn.get("role") == "assistant" and turn.get("tool_calls"): - tool_called = True - if turn.get("role") == "tool" and turn.get("content"): - tool_response_ok = True - if "error" in turn["content"].lower(): - reward -= 0.3 # penalize errors + # detect form *yes* or *no* + match = re.search(r"\*(yes|no)\*", raw) + guess = match.group(1) if match else None - if tool_called and tool_response_ok: - reward += 0.25 # reward correct tool usage - elif not tool_called: - reward -= 0.3 # penalize missing tool call - elif tool_called and not tool_response_ok: - reward -= 0.2 # called tool but no response + reward = 0.0 + + if guess is None: + reward -= 0.5 # invalid format + elif guess == ans.lower(): + reward += 0.6 # correct under required format + else: + reward -= 1.0 # wrong answer rewards.append(reward) + return rewards @@ -141,9 +185,9 @@ def structure_reward(completions, **kwargs): # Reward sequences if has_call and has_response: if has_other: - reward = 0.25 + reward = 0.1 else: - reward = 0.1 # still positive even without extra text + reward = 0.05 # still positive even without extra text elif has_call and not has_response: reward = -0.15 else: @@ -207,10 +251,23 @@ def query_biogrid(sql_command: str) -> list[tuple]: def format_example(example): question = example["question"] preamble = textwrap.dedent("""\ - You may use the BioGRID database to answer the question. Feel free to run exploratory SQL queries to familiarize yourself with the database structure if needed (e.g., `SELECT * FROM interactions LIMIT 1;` or `PRAGMA table_info(interactions);`). - Provide your final answer enclosed in stars, such as `*Yes*` or `*No*`. + You have access to the BioGRID SQLite database. + Use SQL queries to retrieve only the information needed to answer the question. + + Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`, + and each entry can contain multiple gene names or synonyms separated by '|', for example: + 'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...' + So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings. + + If the database schema is unclear or you are unsure about column names: + - First inspect the schema with `PRAGMA table_info(interactions);` + - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;` + + Otherwise, directly query the required data. + + Final answer must be enclosed in stars, e.g. *Yes* or *No*. Facts: - - The NCBI Taxonomy identifier for humans is taxid:9606 + - The NCBI Taxonomy identifier for humans is taxid:9606. """) content = f"{preamble}\nQuestion: {question}" prompt = [{"role": "user", "content": content}] @@ -228,14 +285,26 @@ def format_example(example): # Create DB # ------------------------ print("Creating biogrid.db...") + # Load dataset biogrid_dataset = load_dataset("qgallouedec/biogrid", split="train") - biogrid_dataset.to_sql("interactions", "sqlite:///biogrid.db", if_exists="replace") - print("biogrid.db created.") + df = biogrid_dataset.to_pandas() + + # Normalize column names: remove spaces, replace with underscores + df.columns = [c.replace(" ", "_") for c in df.columns] + conn = sqlite3.connect("biogrid.db") + try: + df.to_sql("interactions", conn, if_exists="replace", index=False) + print(f"biogrid.db created. Rows stored: {len(df)}") + finally: + conn.close() # ------------------------ # Load and format dataset # ------------------------ dataset = load_dataset("qgallouedec/biogrid_qa", split="train") + dataset = dataset.filter( + lambda example: example["question"].startswith("Does the gene ") + ) # keep only simple questions for example dataset = dataset.map(format_example, remove_columns=["question"]) train_dataset = dataset @@ -251,7 +320,7 @@ def format_example(example): train_dataset=train_dataset, eval_dataset=eval_dataset, tools=[query_biogrid], - reward_funcs=[correctness_reward, tool_usage_reward, structure_reward], + reward_funcs=[correctness_reward, structure_reward, query_reward], args=training_args, ) From ce7d60781f9a92bf61e96071c9e8f1512c265772 Mon Sep 17 00:00:00 2001 From: sergiopaniego Date: Mon, 1 Dec 2025 18:27:37 +0100 Subject: [PATCH 153/153] Add example notebook and update docs --- docs/source/example_overview.md | 1 + examples/notebooks/README.md | 1 + examples/notebooks/grpo_agent.ipynb | 610 ++++++++++++++++++++++++++++ 3 files changed, 612 insertions(+) create mode 100644 examples/notebooks/grpo_agent.ipynb diff --git a/docs/source/example_overview.md b/docs/source/example_overview.md index 887e1b6914a..b48277b9ad2 100644 --- a/docs/source/example_overview.md +++ b/docs/source/example_overview.md @@ -47,6 +47,7 @@ Scripts are maintained in the [`trl/scripts`](https://github.com/huggingface/trl | [`examples/scripts/evals/judge_tldr.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/evals/judge_tldr.py) | This script shows how to use [`experimental.judges.HfPairwiseJudge`] or [`experimental.judges.OpenAIPairwiseJudge`] to judge model generations. | | [`examples/scripts/gkd.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gkd.py) | This script shows how to use the [`experimental.gkd.GKDTrainer`] to fine-tune a model. | | [`trl/scripts/grpo.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model. | +| [`trl/scripts/grpo_agent.py`](https://github.com/huggingface/trl/blob/main/trl/scripts/grpo_agent.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a model to enable agentic usage. | | [`examples/scripts/grpo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/grpo_vlm.py) | This script shows how to use the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | | [`examples/scripts/gspo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune model for reasoning using the [AI-MO/NuminaMath-TIR](https://huggingface.co/datasets/AI-MO/NuminaMath-TIR) dataset. | | [`examples/scripts/gspo_vlm.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/gspo_vlm.py) | This script shows how to use GSPO via the [`GRPOTrainer`] to fine-tune a multimodal model for reasoning using the [lmms-lab/multimodal-open-r1-8k-verified](https://huggingface.co/datasets/lmms-lab/multimodal-open-r1-8k-verified) dataset. | diff --git a/examples/notebooks/README.md b/examples/notebooks/README.md index 5ce4aa36d95..0bf45d675a3 100644 --- a/examples/notebooks/README.md +++ b/examples/notebooks/README.md @@ -4,6 +4,7 @@ This directory contains a collection of Jupyter notebooks that demonstrate how t | Notebook | Description | Open in Colab | | --- | --- | --- | +| [`grpo_agent.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_agent.ipynb) | GRPO for agent training | Not available due to OOM with Colab GPUs | | [`openenv_wordle_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_wordle_grpo.ipynb) | GRPO to play Worldle on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb) | | [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | | [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) | diff --git a/examples/notebooks/grpo_agent.ipynb b/examples/notebooks/grpo_agent.ipynb new file mode 100644 index 00000000000..be6d5ff7629 --- /dev/null +++ b/examples/notebooks/grpo_agent.ipynb @@ -0,0 +1,610 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b", + "metadata": { + "id": "63ceecbc-87ad-4ad3-a317-f49267ffc93b" + }, + "source": [ + "# Agent Training with GRPO using TRL\n", + "\n", + "![trl banner](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/trl_banner_dark.png)\n", + "\n", + "\n", + "With [**Transformers Reinforcement Learning (TRL)**](https://github.com/huggingface/trl), you can train a language model to act as an **agent**. One that learns to reason, interact with external tools, and improve through reinforcement.\n", + "\n", + "- [TRL GitHub Repository](https://github.com/huggingface/trl) — star us to support the project! \n", + "- [Official TRL Examples](https://huggingface.co/docs/trl/example_overview) \n", + "- [Community Tutorials](https://huggingface.co/docs/trl/community_tutorials)\n", + "- [OpenEnv](https://github.com/meta-pytorch/OpenEnv)\n", + "\n", + "\n", + "TRL supports training agents that can use external tools as part of their decision process. \n", + "In this notebook, the agent has access to the **BioGRID database**, which it can query using **read-only SQL commands** to retrieve biological interaction data. The model learns when and how to use tools based on rewards.\n", + "\n", + "We'll fine-tune a model using GRPO (Group Relative Policy Optimization) via TRL. The agent will:\n", + "\n", + "1. Generate tool call to query the database if needed.\n", + "2. Receive the tool response and add it it to the context.\n", + "3. Learn to improve its tool usage and general capabilities over time through reward signals.\n", + "\n", + "## Install dependencies\n", + "\n", + "We'll start by installing **TRL**, which automatically includes the main dependencies like **Transformers**. \n", + "We'll also install **trackio** (for logging and monitoring training runs), **vLLM** (for efficient generation), and **jmespath** (needed for the tools capabilities)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b4812fbf-3f61-481e-9a64-95277eada9c9", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install -Uq git+https://github.com/huggingface/trl.git git+https://github.com/huggingface/transformers.git trackio vllm==0.11.2 jmespath" + ] + }, + { + "cell_type": "markdown", + "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148", + "metadata": { + "id": "ede8e566-a1b5-460f-9fe8-a6010bc56148" + }, + "source": [ + "### Log in to Hugging Face\n", + "\n", + "Log in to your **Hugging Face** account to save your fine-tuned model, track your experiment results directly on the Hub or access gated models. You can find your **access token** on your [account settings page](https://huggingface.co/settings/tokens)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "21756ac0-78b2-495d-8137-28dfa9faae6a", + "metadata": {}, + "outputs": [], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "markdown", + "id": "KVGklspLYlmz", + "metadata": { + "id": "KVGklspLYlmz" + }, + "source": [ + "## Create the database for the tool\n", + "\n", + "For this example, we will use the [BioGRID database](https://thebiogrid.org/), a curated resource containing **protein, genetic, and chemical interaction data**. We've already compiled and uploaded it to the Hub at [qgallouedec/biogrid](https://huggingface.co/datasets/qgallouedec/biogrid). The dataset is loaded and converted into an sqlite database.\n", + "\n", + "> 💡 We remove spaces in the column names to easen the model work. In real-world deployments, you may keep your original column names and rely on the agent to reason about them. Here, we simplify the schema to make training smoother." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rRzPMhfXBLkF", + "metadata": {}, + "outputs": [], + "source": [ + "import sqlite3\n", + "from datasets import load_dataset\n", + "\n", + "# Load dataset\n", + "biogrid_dataset = load_dataset(\"qgallouedec/biogrid\", split=\"train\")\n", + "df = biogrid_dataset.to_pandas()\n", + "\n", + "# Normalize column names: remove spaces, replace with underscores\n", + "df.columns = [c.replace(\" \", \"_\") for c in df.columns]\n", + "\n", + "# Save to SQLite\n", + "conn = sqlite3.connect(\"biogrid.db\")\n", + "try:\n", + " df.to_sql(\"interactions\", conn, if_exists=\"replace\", index=False)\n", + " print(f\"biogrid.db created. Rows stored: {len(df)}\")\n", + "finally:\n", + " conn.close()" + ] + }, + { + "cell_type": "markdown", + "id": "pSSGvLbmZyC2", + "metadata": { + "id": "pSSGvLbmZyC2" + }, + "source": [ + "## Load the QA dataset\n", + "\n", + "The training objective is to fine-tune a model to answer gene-related questions. The model should learn to use the database query tool to retrieve factual information when needed.\n", + "\n", + "We'll define a formatting function for each sample, adding instructions about the database and how to call it. The model must answer with **yes** or **no**. Let's implement the `format_example` function.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "asrv7LbaD71C", + "metadata": {}, + "outputs": [], + "source": [ + "import textwrap\n", + "\n", + "def format_example(example):\n", + " question = example[\"question\"]\n", + " preamble = textwrap.dedent(\"\"\"\\\n", + " You have access to the BioGRID SQLite database.\n", + " Use SQL queries to retrieve only the information needed to answer the question.\n", + "\n", + " Genes may appear in the database in columns `Alt_IDs_Interactor_A` `Alt_IDs_Interactor_B`, `Aliases_Interactor_A` and `Aliases_Interactor_B`,\n", + " and each entry can contain multiple gene names or synonyms separated by '|', for example:\n", + " 'entrez gene/locuslink:JNKK(gene name synonym)|entrez gene/locuslink:MAPKK4(gene name synonym)|...'\n", + " So a gene like 'JNKK' or 'MAPKK4' may appear inside one of these strings.\n", + "\n", + " If the database schema is unclear or you are unsure about column names:\n", + " - First inspect the schema with `PRAGMA table_info(interactions);`\n", + " - Or preview a few rows with `SELECT * FROM interactions LIMIT 1;`\n", + "\n", + " Otherwise, directly query the required data.\n", + "\n", + " Final answer must be enclosed in stars, e.g. *Yes* or *No*.\n", + " Facts:\n", + " - The NCBI Taxonomy identifier for humans is taxid:9606.\n", + " \"\"\")\n", + " content = f\"{preamble}\\nQuestion: {question}\"\n", + " prompt = [{\"role\": \"user\", \"content\": content}]\n", + " return {\"prompt\": prompt}" + ] + }, + { + "cell_type": "markdown", + "id": "UMnHXYZla_EO", + "metadata": { + "id": "UMnHXYZla_EO" + }, + "source": [ + "Now, let's load the database and call the previous function. \n", + "For simplicity, we will only use questions that start with **“Does the gene…”**. \n", + "In a real use case, the full dataset can be used.\n", + "\n", + "The QA dataset is available on the [Hub](https://huggingface.co/datasets/qgallouedec/biogrid_qa)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "jEs12KqwDnVl", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(\"qgallouedec/biogrid_qa\", split=\"train\")\n", + "dataset = dataset.filter(\n", + " lambda example: example[\"question\"].startswith(\"Does the gene \")\n", + ") # keep only simple questions for example\n", + "dataset = dataset.map(format_example, remove_columns=[\"question\"])\n", + "\n", + "train_dataset = dataset\n", + "eval_dataset = None # No eval by default, can be added if needed" + ] + }, + { + "cell_type": "markdown", + "id": "m4GRjbHycM5L", + "metadata": { + "id": "m4GRjbHycM5L" + }, + "source": [ + "## Create tool for the agent\n", + "\n", + "The `query_biogrid` function is the tool the model will use to query the database and retrieve factual information. \n", + "Each tool must be a standard Python function with **type-hinted arguments and return types**, and a **Google-style docstring** describing its purpose, parameters, and return value." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "nLMH7hahGTyO", + "metadata": {}, + "outputs": [], + "source": [ + "from contextlib import contextmanager\n", + "import signal\n", + "\n", + "@contextmanager\n", + "def timeout(seconds):\n", + " \"\"\"Context manager that raises TimeoutError if execution exceeds time limit.\"\"\"\n", + "\n", + " def timeout_handler(signum, frame):\n", + " raise TimeoutError(f\"Operation timed out after {seconds} seconds\")\n", + "\n", + " signal.signal(signal.SIGALRM, timeout_handler)\n", + " signal.alarm(seconds)\n", + " try:\n", + " yield\n", + " finally:\n", + " signal.alarm(0)\n", + "\n", + "def query_biogrid(sql_command: str) -> list[tuple]:\n", + " \"\"\"\n", + " Execute a read-only SQL command on the BioGRID database.\n", + "\n", + " BioGRID is a curated biological database that compiles protein, genetic, and chemical interactions from multiple organisms. It provides researchers with experimentally verified interaction data to support studies in systems biology and functional genomics.\n", + "\n", + " Args:\n", + " sql_command: The SQL command to execute.\n", + "\n", + " Returns:\n", + " A list of tuples containing the query results.\n", + " \"\"\"\n", + " with timeout(5):\n", + " conn = sqlite3.connect(\"file:biogrid.db?mode=ro\", uri=True)\n", + " cursor = conn.cursor()\n", + " try:\n", + " cursor.execute(sql_command)\n", + " results = cursor.fetchall()\n", + " finally:\n", + " conn.close()\n", + " return results" + ] + }, + { + "cell_type": "markdown", + "id": "GiHtooTwci3B", + "metadata": { + "id": "GiHtooTwci3B" + }, + "source": [ + "## Define reward functions\n", + "\n", + "To guide the agent during training, we define a few simple reward functions:\n", + "\n", + "- **`query_reward`**: evaluates the model’s query strategy — penalizes more than two queries, penalizes generic database scans, and rewards use of `WHERE` and evidence supporting the final answer.\n", + "- **`correctness_reward`**: rewards Yes/No predictions that match the expected answer.\n", + "- **`structure_reward`**: rewards a proper assistant structure (tool call → response → optional explanation).\n", + "\n", + "Each function returns a list of floats used by the **GRPOTrainer** during optimization. \n", + "Combined, they encourage effective tool use and factual answers." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "sXyqC6cJGe3L", + "metadata": {}, + "outputs": [], + "source": [ + "import re\n", + "\n", + "def query_reward(completions, answer, **kwargs):\n", + " \"\"\"\n", + " Reward query strategy:\n", + " - Penalize more than 2 queries\n", + " - Penalize generic queries (LIMIT 1 / PRAGMA)\n", + " - Reward usage of WHERE\n", + " - Reward evidence supporting the final answer\n", + " \"\"\"\n", + " rewards = []\n", + "\n", + " for completion, ans in zip(completions, answer, strict=False):\n", + " reward = 0.0\n", + " sql_queries = []\n", + " tool_results = []\n", + "\n", + " # collect all SQL queries and tool results\n", + " for turn in completion:\n", + " if turn.get(\"tool_calls\"):\n", + " for call in turn[\"tool_calls\"]:\n", + " sql = call[\"function\"][\"arguments\"].get(\"sql_command\", \"\").lower()\n", + " sql_queries.append(sql)\n", + " if turn.get(\"role\") == \"tool\" and turn.get(\"content\"):\n", + " tool_results.append(turn[\"content\"])\n", + "\n", + " # --- penalize too many queries ---\n", + " if len(sql_queries) > 3:\n", + " reward -= 1.5\n", + "\n", + " # --- check query quality ---\n", + " where_count = 0\n", + " for q in sql_queries:\n", + " if \"limit 1\" in q:\n", + " reward -= 1.0\n", + " if \" where \" not in q:\n", + " reward -= 0.5\n", + " else:\n", + " where_count += 1\n", + " reward += min(where_count, 3) * 0.4 # small bonus for WHERE usage\n", + "\n", + " # --- evidence check: do queries support the answer? ---\n", + " combined_results = []\n", + " error_detected = False\n", + "\n", + " for res in tool_results:\n", + " if isinstance(res, dict) and \"error\" in res:\n", + " error_detected = True\n", + " elif isinstance(res, list):\n", + " combined_results.extend(res)\n", + "\n", + " # if error detected, penalize heavily\n", + " if error_detected:\n", + " reward -= 2.0\n", + " elif len(sql_queries) == 0:\n", + " reward -= 1.5\n", + " else:\n", + " has_hits = len(combined_results) > 0\n", + " correct_answer = ans.lower()\n", + " if (has_hits and correct_answer == \"yes\") or (not has_hits and correct_answer == \"no\"):\n", + " reward += 2.0\n", + " else:\n", + " reward -= 1.5\n", + "\n", + " rewards.append(reward)\n", + "\n", + " return rewards\n", + "\n", + "\n", + "def correctness_reward(completions, answer, **kwargs):\n", + " \"\"\"\n", + " Reward Yes/No correctness.\n", + " Model must provide final answer enclosed in stars — *yes* or *no*.\n", + " Does not reward informal yes/no buried in text.\n", + " \"\"\"\n", + " rewards = []\n", + " for completion, ans in zip(completions, answer, strict=False):\n", + " raw = completion[-1][\"content\"].lower()\n", + "\n", + " # detect form *yes* or *no*\n", + " match = re.search(r\"\\*(yes|no)\\*\", raw)\n", + " guess = match.group(1) if match else None\n", + "\n", + " reward = 0.0\n", + "\n", + " if guess is None:\n", + " reward -= 0.5 # invalid format\n", + " elif guess == ans.lower():\n", + " reward += 0.6 # correct under required format\n", + " else:\n", + " reward -= 1.0 # wrong answer\n", + "\n", + " rewards.append(reward)\n", + "\n", + " return rewards\n", + "\n", + "\n", + "def structure_reward(completions, **kwargs):\n", + " \"\"\"\n", + " Reward proper assistant structure.\n", + " Encourages a logical sequence: tool call + response + optional extra content.\n", + " \"\"\"\n", + " rewards = []\n", + "\n", + " for completion in completions:\n", + " has_call = False\n", + " has_response = False\n", + " has_other = False\n", + "\n", + " for turn in completion:\n", + " role = turn.get(\"role\")\n", + " if role == \"assistant\" and turn.get(\"tool_calls\"):\n", + " has_call = True\n", + " elif role == \"tool\":\n", + " has_response = True\n", + " else:\n", + " content = turn.get(\"content\")\n", + " if content and content.strip() not in [\"\", \"\"]:\n", + " has_other = True\n", + "\n", + " # Reward sequences\n", + " if has_call and has_response:\n", + " if has_other:\n", + " reward = 0.1\n", + " else:\n", + " reward = 0.05 # still positive even without extra text\n", + " elif has_call and not has_response:\n", + " reward = -0.15\n", + " else:\n", + " reward = 0.0 # neutral if no call\n", + "\n", + " rewards.append(reward)\n", + "\n", + " return rewards\n" + ] + }, + { + "cell_type": "markdown", + "id": "zcgkrKtTb4T9", + "metadata": { + "id": "zcgkrKtTb4T9" + }, + "source": [ + "## Set GRPO Config\n", + "\n", + "Next, we define the **GRPOConfig**, which controls the main training parameters. \n", + "This configuration specifies how the model interacts with **vLLM**, manages memory, and logs results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "t4ifJsNLElIN", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import GRPOConfig\n", + "\n", + "output_dir = \"grpo_biogrid_qwen_3g-1.7b\"\n", + "\n", + "grpo_config = GRPOConfig(\n", + " # Training schedule / optimization\n", + " max_steps=400, # Max number of training steps\n", + " chat_template_kwargs = {\"enable_thinking\": False}, # Disable thinking to reduce token generation\n", + "\n", + " # GRPO configuration\n", + " max_completion_length = 1024, # Maximum tokens generated per model response\n", + "\n", + " # vLLM configuration\n", + " use_vllm = True, # Enable vLLM for faster inference during rollouts\n", + " vllm_mode = \"colocate\", # Run vLLM in colocate mode (same process as training)\n", + " vllm_enable_sleep_mode=False,\n", + "\n", + " # Logging / reporting\n", + " output_dir = output_dir, # Directory for checkpoints and logs\n", + " report_to=\"trackio\", # Experiment tracking tool (integrates with HF Spaces)\n", + " trackio_space_id = output_dir, # HF Space where experiment tracking will be saved\n", + " save_steps = 10, # Interval for saving checkpoints\n", + " log_completions = True,\n", + "\n", + " # Memory optimization\n", + " gradient_checkpointing = True, # Enable activation recomputation to save memory\n", + " gradient_checkpointing_kwargs = {\"use_reentrant\": False}, # Use non-reentrant checkpointing\n", + "\n", + " # Hub integration\n", + " push_to_hub = True, # Set True to automatically push model to Hugging Face Hub\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "34I-Q2MJuf42", + "metadata": { + "id": "34I-Q2MJuf42" + }, + "source": [ + "## Create `GRPOTrainer` and Start Training\n", + "\n", + "Next, we initialize the **`GRPOTrainer`**, which handles the full reinforcement learning loop.\n", + "\n", + "It receives the model name, reward functions, tool(s), and dataset defined earlier. \n", + "\n", + "Finally, we call `trainer.train()` to begin fine-tuning, allowing the model to learn how to query the database effectively through iterative feedback." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "IysntAUOFvRn", + "metadata": {}, + "outputs": [], + "source": [ + "from trl import GRPOTrainer\n", + "\n", + "model_name=\"Qwen/Qwen3-1.7B\"\n", + "\n", + "trainer = GRPOTrainer(\n", + " model=model_name,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " tools=[query_biogrid],\n", + " reward_funcs=[correctness_reward, structure_reward, query_reward],\n", + " args=grpo_config,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "r_qJ5UwLuzCG", + "metadata": { + "id": "r_qJ5UwLuzCG" + }, + "source": [ + "Show memory stats before training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "DusT8JUaGmA6", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "gpu_stats = torch.cuda.get_device_properties(0)\n", + "start_gpu_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "max_memory = round(gpu_stats.total_memory / 1024 / 1024 / 1024, 3)\n", + "\n", + "print(f\"GPU = {gpu_stats.name}. Max memory = {max_memory} GB.\")\n", + "print(f\"{start_gpu_memory} GB of memory reserved.\")" + ] + }, + { + "cell_type": "markdown", + "id": "OTPkiz3fu0lp", + "metadata": { + "id": "OTPkiz3fu0lp" + }, + "source": [ + "And train!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "NwI3buPOFMFk", + "metadata": {}, + "outputs": [], + "source": [ + "trainer_stats = trainer.train()" + ] + }, + { + "cell_type": "markdown", + "id": "ITnLBLcTu2-p", + "metadata": { + "id": "ITnLBLcTu2-p" + }, + "source": [ + "Show memory stats after training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ftek6m4-GncK", + "metadata": {}, + "outputs": [], + "source": [ + "used_memory = round(torch.cuda.max_memory_reserved() / 1024 / 1024 / 1024, 3)\n", + "used_memory_for_lora = round(used_memory - start_gpu_memory, 3)\n", + "used_percentage = round(used_memory / max_memory * 100, 3)\n", + "lora_percentage = round(used_memory_for_lora / max_memory * 100, 3)\n", + "\n", + "print(f\"{trainer_stats.metrics['train_runtime']} seconds used for training.\")\n", + "print(f\"{round(trainer_stats.metrics['train_runtime']/60, 2)} minutes used for training.\")\n", + "print(f\"Peak reserved memory = {used_memory} GB.\")\n", + "print(f\"Peak reserved memory for training = {used_memory_for_lora} GB.\")\n", + "print(f\"Peak reserved memory % of max memory = {used_percentage} %.\")\n", + "print(f\"Peak reserved memory for training % of max memory = {lora_percentage} %.\")" + ] + }, + { + "cell_type": "markdown", + "id": "O6LAwznKu7mc", + "metadata": { + "id": "O6LAwznKu7mc" + }, + "source": [ + "Let's save the trained model." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "idVgnNS1MWPr", + "metadata": {}, + "outputs": [], + "source": [ + "trainer.save_model(output_dir)\n", + "trainer.push_to_hub()" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}