From 5fcffaac14036805f5e11a4b639a033d47dbc288 Mon Sep 17 00:00:00 2001 From: Rafael Pardinas Date: Wed, 23 Oct 2024 18:22:45 +0100 Subject: [PATCH 1/4] Initial scaffolding --- fast_llm/models/grpo/__init__.py | 0 fast_llm/models/grpo/config.py | 63 +++++++++++++++++++++++++++++ fast_llm/models/grpo/data.py | 38 +++++++++++++++++ fast_llm/models/grpo/head.py | 6 +++ fast_llm/models/grpo/huggingface.py | 18 +++++++++ fast_llm/models/grpo/model.py | 52 ++++++++++++++++++++++++ fast_llm/models/grpo/readme.md | 39 ++++++++++++++++++ fast_llm/models/grpo/trainer.py | 21 ++++++++++ 8 files changed, 237 insertions(+) create mode 100644 fast_llm/models/grpo/__init__.py create mode 100644 fast_llm/models/grpo/config.py create mode 100644 fast_llm/models/grpo/data.py create mode 100644 fast_llm/models/grpo/head.py create mode 100644 fast_llm/models/grpo/huggingface.py create mode 100644 fast_llm/models/grpo/model.py create mode 100644 fast_llm/models/grpo/readme.md create mode 100644 fast_llm/models/grpo/trainer.py diff --git a/fast_llm/models/grpo/__init__.py b/fast_llm/models/grpo/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fast_llm/models/grpo/config.py b/fast_llm/models/grpo/config.py new file mode 100644 index 00000000..53fa895d --- /dev/null +++ b/fast_llm/models/grpo/config.py @@ -0,0 +1,63 @@ +from fast_llm.config import FieldUpdate, config_class +from fast_llm.data.config import DataConfig +from fast_llm.models.gpt.config import ( + GPTArchitectureConfig, + GPTBaseModelConfig, + GPTModelConfig, + GPTTrainerConfig, + PretrainedGPTModelConfig, +) + + +@config_class() +class GRPODataConfig(DataConfig): + # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. + pass + + +@config_class() +class GRPOArchitectureConfig(GPTArchitectureConfig): + # TODO: Add custom base model architecture config parameters, if any. + pass + + +@config_class() +class GRPOBaseModelConfig(GPTBaseModelConfig, GRPOArchitectureConfig): + # TODO: Add custom other base model config parameters, if any. + architecture_cls = GRPOArchitectureConfig + + +@config_class() +class GRPOModelConfig(GPTModelConfig): + # TODO: Add custom model config parameters, if any (typically none). + base_model: GRPOBaseModelConfig = FieldUpdate(default_factory=GRPOBaseModelConfig) + + @classmethod + def get_model_class(cls): + from fast_llm.models.grpo.model import GRPOModel + + return GRPOModel + + @classmethod + def get_huggingface_model_class(cls): + from fast_llm.models.grpo.huggingface import HuggingfaceGRPOModelForCausalLM + + return HuggingfaceGRPOModelForCausalLM + + +@config_class() +class PretrainedGRPOModelConfig(PretrainedGPTModelConfig): + model: GRPOModelConfig = FieldUpdate(default_factory=GRPOModelConfig) + + +@config_class() +class GRPOTrainerConfig(PretrainedGRPOModelConfig, GPTTrainerConfig): + # TODO: Add custom trainer config parameters, if any (typically none). + + data: GRPODataConfig = FieldUpdate(default_factory=GRPODataConfig) + + @classmethod + def get_trainer_class(cls): + from fast_llm.models.grpo.trainer import GRPOTrainer + + return GRPOTrainer diff --git a/fast_llm/models/grpo/data.py b/fast_llm/models/grpo/data.py new file mode 100644 index 00000000..2c133af7 --- /dev/null +++ b/fast_llm/models/grpo/data.py @@ -0,0 +1,38 @@ +from fast_llm.data.data import Data +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.models.grpo.config import GRPODataConfig + + +class GRPOData(Data): + # TODO: If needed, inherit from AbstractData instead and re-implement everything. + def __init__( + self, + config: GRPODataConfig, + distributed_config: DistributedConfig, + vocab_size: int, + max_sequence_length: int, + ): + # TODO: Adjust or reimplement. + super().__init__(config, distributed_config, vocab_size, max_sequence_length) + + def setup(self, distributed, samples_per_phase): + # TODO: Adjust or reimplement. + return super().setup(distributed, samples_per_phase) + + def get_iterator( + self, + batch_config, + phase, + *, + consumed_samples, + num_workers, + prefetch_factor=None, + ): + # TODO: Adjust or reimplement. + return super().get_iterator( + batch_config, + phase, + consumed_samples=consumed_samples, + num_workers=num_workers, + prefetch_factor=prefetch_factor, + ) diff --git a/fast_llm/models/grpo/head.py b/fast_llm/models/grpo/head.py new file mode 100644 index 00000000..786e3692 --- /dev/null +++ b/fast_llm/models/grpo/head.py @@ -0,0 +1,6 @@ +from fast_llm.layers.language_model.head import LanguageModelHead + + +class CustomHead(LanguageModelHead): + # TODO: Implement custom parts + pass diff --git a/fast_llm/models/grpo/huggingface.py b/fast_llm/models/grpo/huggingface.py new file mode 100644 index 00000000..99a7cb20 --- /dev/null +++ b/fast_llm/models/grpo/huggingface.py @@ -0,0 +1,18 @@ +from fast_llm.models.grpo.config import GRPOModelConfig +from fast_llm.models.grpo.model import GRPOModel +from fast_llm.models.gpt.huggingface import HuggingfaceGPTModelConfig, HuggingfaceGPTModelForCausalLM + + +class HuggingfaceCustomModelConfig(HuggingfaceGPTModelConfig): + model_type = "fast_llm_gpt_custom" + model_config_class = GRPOModelConfig + fast_llm_config: GRPOModelConfig + + +class HuggingfaceCustomModelForCausalLM(HuggingfaceGPTModelForCausalLM): + # TODO: Implement changes in huggingface interface, if any. + # Ex.: Return predictions instead of logits. + config_class = HuggingfaceCustomModelConfig + config: HuggingfaceCustomModelConfig + model_class = GRPOModel + _fast_llm_model: GRPOModel diff --git a/fast_llm/models/grpo/model.py b/fast_llm/models/grpo/model.py new file mode 100644 index 00000000..83318ade --- /dev/null +++ b/fast_llm/models/grpo/model.py @@ -0,0 +1,52 @@ +from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.grpo.config import GRPOBaseModelConfig, GRPOModelConfig +from fast_llm.models.grpo.head import GRPOHead +from fast_llm.models.gpt.model import GPTBaseModel, GPTModel + + +class GRPOBaseModel(GPTBaseModel): + _config: GRPOBaseModelConfig + config_cls = GRPOBaseModelConfig + + def __init__( + self, + config: GRPOBaseModelConfig, + distributed_config: DistributedConfig, + ): + # TODO: Implement / update. + super().__init__(config, distributed_config) + + def get_layers(self): + # TODO: Adjust as needed. + return [ + LanguageModelEmbedding(self._config, self._tensor_space), + *[ + TransformerLayer( + self._config.transformer, + self._tensor_space, + layer_index=i + 1, + ) + for i in range(self._config.transformer.num_layers) + ], + GRPOHead(self._config, self._tensor_space), + ] + + def preprocess_meta(self, input_, phase): + # TODO: Adjust or reimplement. + return super().preprocess_meta(input_, phase) + + def preprocess(self, batch, preprocessed_meta=None, *, phase, iteration, metrics=None): + # TODO: Adjust or reimplement. + return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) + + @property + def loss_defs(self): + # TODO: Adjust or reimplement. + return super().loss_defs + + +class GRPOModel(GPTModel): + config_class = GRPOModelConfig + base_model_class = GRPOBaseModel diff --git a/fast_llm/models/grpo/readme.md b/fast_llm/models/grpo/readme.md new file mode 100644 index 00000000..bb3330a3 --- /dev/null +++ b/fast_llm/models/grpo/readme.md @@ -0,0 +1,39 @@ +# Custom model template + +The "custom" model is a template for customized training of a GPT-style model, +for example to fine-tune it for a particular class. +This is typically done as follows: + +1. Create a copy of the `custom` model, and rename it appropriately, ex. `my_model`, `MyModelTrainer`, etc. +2. If necessary, adjust the base classes to inherit from more abstract classes or another model. +ex. `MyModelData(AbstractData)` to re-implement data processing from scratch. +3. Add custom configuration fields in `config.py`. +4. Adapt or re-implement the data loading scheme in `MyModelData`. +5. Adapt or re-implement the preprocessing scheme in `MyModelBaseModel`. +6. Adapt or re-implement the model head, ex. change the task and/or add a custom loss. +7. If needed, adapt the huggingface interface to return outputs for the desired task. +8. Apply other changes as needed. +9. Add the new model to the registry (`models.auto.py`) so it can be used through the cli. +10. Run training with the new model, ex. `fast-llm train my_model [...]`. + + +## Preprocessing variables and kwargs + +To pass additional parameters to the model during preprocessing, ex. a target for the loss or a runtime parameter, +simply add them to the returned `kwargs`. +Those kwargs will be passed directly to the `forward` method of each layer and can be used as needed. + +In some cases, it may be desirable to modify the `kwargs` inside a layer, +for example to pass additional data to other layers or to the backward pass. +This possible with certain caveats: +* There is no direct support for autograd. Detaching tensors is recommended to prevent memory losses. +* Such modifications may be incompatible with pipeline parallelism, +as the data will not be transferred to pipeline-parallel devices. + + +## Disclaimer + +Model customization is a work in progress. +Some abstractions may be missing or poorly implemented, +and some methods and variables may be hard-coded or very difficult to override. +We intend to address these issues in the future, but it will most likely incur some breaking changes in the interface. diff --git a/fast_llm/models/grpo/trainer.py b/fast_llm/models/grpo/trainer.py new file mode 100644 index 00000000..9151012f --- /dev/null +++ b/fast_llm/models/grpo/trainer.py @@ -0,0 +1,21 @@ +from fast_llm.models.grpo.config import GRPOTrainerConfig +from fast_llm.models.grpo.data import GRPOData +from fast_llm.models.grpo.model import GRPOModel +from fast_llm.models.gpt.trainer import GPTTrainer + + +class GRPOTrainer(GPTTrainer): + # TODO: Implement changes in the training loop (or tflops computation), if any (typically none). + _abstract = False + _config: GRPOTrainerConfig + config_class = GRPOTrainerConfig + model_class = GRPOModel + + def _get_data(self): + # TODO: Adjust signature if needed. + return GRPOData( + config=self._config.data, + distributed_config=self._config.distributed, + vocab_size=self._config.base_model.vocab_size, + max_sequence_length=self._config.batch.sequence_length, + ) From 0562ee2f0b3982b64dc4859625a757478ed62820 Mon Sep 17 00:00:00 2001 From: Rafael Pardinas Date: Wed, 23 Oct 2024 18:23:43 +0100 Subject: [PATCH 2/4] Add model names --- fast_llm/models/auto.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fast_llm/models/auto.py b/fast_llm/models/auto.py index 2cafa0fa..25c58e75 100644 --- a/fast_llm/models/auto.py +++ b/fast_llm/models/auto.py @@ -1,4 +1,4 @@ -from fast_llm.models.custom.config import CustomModelConfig, CustomTrainerConfig +from fast_llm.models.grpo.config import GRPOModelConfig, GRPOTrainerConfig from fast_llm.models.gpt.config import GPTModelConfig, GPTTrainerConfig from fast_llm.utils import Registry @@ -6,7 +6,7 @@ "Model", { "gpt": GPTModelConfig, - "gpt_custom": CustomModelConfig, + "grpo": GRPOModelConfig, }, ) @@ -14,6 +14,6 @@ "Model", { "gpt": GPTTrainerConfig, - "gpt_custom": CustomTrainerConfig, + "grpo": GRPOTrainerConfig, }, ) From 74e7932bf816550b530f67be1f1a2576625add7f Mon Sep 17 00:00:00 2001 From: Rafael Pardinas Date: Thu, 24 Oct 2024 13:48:12 +0100 Subject: [PATCH 3/4] Check correct layers --- fast_llm/models/grpo/model.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/fast_llm/models/grpo/model.py b/fast_llm/models/grpo/model.py index 83318ade..d532c1ec 100644 --- a/fast_llm/models/grpo/model.py +++ b/fast_llm/models/grpo/model.py @@ -1,8 +1,8 @@ from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.layers.language_model.embedding import LanguageModelEmbedding +from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.transformer.transformer import TransformerLayer from fast_llm.models.grpo.config import GRPOBaseModelConfig, GRPOModelConfig -from fast_llm.models.grpo.head import GRPOHead from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -12,11 +12,12 @@ class GRPOBaseModel(GPTBaseModel): def __init__( self, - config: GRPOBaseModelConfig, + config: GRPOModelConfig, distributed_config: DistributedConfig, ): - # TODO: Implement / update. super().__init__(config, distributed_config) + assert self._config.transformer.use_rotary_position_embeddings + assert not self._config.use_absolute_position_embeddings def get_layers(self): # TODO: Adjust as needed. @@ -30,7 +31,7 @@ def get_layers(self): ) for i in range(self._config.transformer.num_layers) ], - GRPOHead(self._config, self._tensor_space), + LanguageModelHead(self._config, self._tensor_space), ] def preprocess_meta(self, input_, phase): From d8eb073683fcd5cb358168a16b9215c571ee84f2 Mon Sep 17 00:00:00 2001 From: Rafael Pardinas Date: Thu, 14 Nov 2024 17:37:53 +0000 Subject: [PATCH 4/4] Fix state --- fast_llm/data/data.py | 2 +- fast_llm/models/grpo/config.py | 11 ++++- fast_llm/models/grpo/data.py | 66 +++++++++++++++++++++++++---- fast_llm/models/grpo/head.py | 77 ++++++++++++++++++++++++++++++++-- fast_llm/models/grpo/model.py | 46 +++++++++++++++----- 5 files changed, 179 insertions(+), 23 deletions(-) diff --git a/fast_llm/data/data.py b/fast_llm/data/data.py index e58b62c4..8098aba7 100644 --- a/fast_llm/data/data.py +++ b/fast_llm/data/data.py @@ -163,7 +163,7 @@ def setup(self, distributed: Distributed, samples_per_phase: dict[PhaseType, int data_sample_warn_time_ms=self._config.data_sample_warn_time_ms, ) ) - for phase, datasets in self._sampled_datasets.items() + for phase, datasets in self._sampled_datasets.items() # check data/dataset.py } def get_iterator( diff --git a/fast_llm/models/grpo/config.py b/fast_llm/models/grpo/config.py index 53fa895d..186d95d6 100644 --- a/fast_llm/models/grpo/config.py +++ b/fast_llm/models/grpo/config.py @@ -1,4 +1,4 @@ -from fast_llm.config import FieldUpdate, config_class +from fast_llm.config import Field, FieldUpdate, config_class from fast_llm.data.config import DataConfig from fast_llm.models.gpt.config import ( GPTArchitectureConfig, @@ -9,6 +9,14 @@ ) +@config_class() +class GRPOConfig: + epsilon: float = Field(default=0.2, desc="PPO clipping parameter") + kl_coef: float = Field(default=0.1, desc="KL divergence coefficient") + ratio_threshold: float = Field(default=1.5, desc="Early stopping ratio threshold") + use_advantages: bool = Field(default=True, desc="Use advantages instead of raw rewards") + + @config_class() class GRPODataConfig(DataConfig): # TODO: If needed, inherit from AbstractDataConfig instead and re-implement everything. @@ -25,6 +33,7 @@ class GRPOArchitectureConfig(GPTArchitectureConfig): class GRPOBaseModelConfig(GPTBaseModelConfig, GRPOArchitectureConfig): # TODO: Add custom other base model config parameters, if any. architecture_cls = GRPOArchitectureConfig + grpo: GRPOConfig = Field(default_factory=GRPOConfig, desc="GRPO specific configuration") @config_class() diff --git a/fast_llm/models/grpo/data.py b/fast_llm/models/grpo/data.py index 2c133af7..4154dc8b 100644 --- a/fast_llm/models/grpo/data.py +++ b/fast_llm/models/grpo/data.py @@ -1,10 +1,48 @@ -from fast_llm.data.data import Data +import json +import torch +from fast_llm.data.data import Data, DatasetSource from fast_llm.engine.distributed.config import DistributedConfig from fast_llm.models.grpo.config import GRPODataConfig +from fast_llm.data.dataset import BlendedDataset, SampledDataset +from fast_llm.utils import Assert + + +class GRPODataset(SampledDataset): + """Dataset wrapper that adds GRPO-specific fields (rewards, advantages, etc)""" + def __init__(self, base_dataset: SampledDataset, data_path: str): + self.base_dataset = base_dataset + self.data_path = data_path + + # Load the JSONL data + self.data = [] + with open(data_path, 'r') as f: + for line in f: + self.data.append(json.loads(line)) + + def __len__(self): + return len(self.base_dataset) + + def __getitem__(self, idx): + item = self.base_dataset[idx] + data_item = self.data[idx] + + # Extract fields from the JSONL data + batch = { + "input_ids": item, # Original input tokens + "rewards": torch.tensor(data_item["reward"]), + "old_logprobs": torch.tensor(data_item["logprobs"]), # These are the logprobs from previous iteration + "ref_logprobs": torch.tensor(data_item["ref_logprobs"]), + } + + # Compute advantages if not provided in data + # Here we're using rewards as advantages, but you might want to implement + # proper advantage estimation + batch["advantages"] = batch["rewards"].clone() + + return batch class GRPOData(Data): - # TODO: If needed, inherit from AbstractData instead and re-implement everything. def __init__( self, config: GRPODataConfig, @@ -12,12 +50,27 @@ def __init__( vocab_size: int, max_sequence_length: int, ): - # TODO: Adjust or reimplement. super().__init__(config, distributed_config, vocab_size, max_sequence_length) - + def setup(self, distributed, samples_per_phase): - # TODO: Adjust or reimplement. - return super().setup(distributed, samples_per_phase) + # setup the base data infrastructure + super().setup(distributed, samples_per_phase) + + # wrap each dataset with GRPO-specific functionality + for phase in self._blended_datasets: + if isinstance(self._blended_datasets[phase], BlendedDataset): + # if it's a blended dataset, wrap each underlying dataset + for i, dataset in enumerate(self._blended_datasets[phase].datasets): + dataset = GRPODataset( + dataset, + data_path=self._dataset_prefixes[f"dataset_{i}"] + ) + else: + # single dataset case + self._blended_datasets[phase] = GRPODataset( + self._blended_datasets[phase], + data_path=next(iter(self._dataset_prefixes.values())) + ) def get_iterator( self, @@ -28,7 +81,6 @@ def get_iterator( num_workers, prefetch_factor=None, ): - # TODO: Adjust or reimplement. return super().get_iterator( batch_config, phase, diff --git a/fast_llm/models/grpo/head.py b/fast_llm/models/grpo/head.py index 786e3692..6408753b 100644 --- a/fast_llm/models/grpo/head.py +++ b/fast_llm/models/grpo/head.py @@ -1,6 +1,77 @@ +import torch +import torch.nn.functional as F from fast_llm.layers.language_model.head import LanguageModelHead +from fast_llm.layers.language_model.config import LanguageModelLossNames +class GRPOHead(LanguageModelHead): + def masked_mean(self, values: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: + """Calculate mean of values with masks applied""" + return (values * masks).sum() / (masks.sum() + 1e-8) -class CustomHead(LanguageModelHead): - # TODO: Implement custom parts - pass + def compute_grpo_loss( + self, + logits: torch.Tensor, + labels: torch.Tensor, + rewards: torch.Tensor, + advantages: torch.Tensor, + ref_logprobs: torch.Tensor, + old_logprobs: torch.Tensor, + config: GRPOConfig, + ) -> torch.Tensor: + masks = labels != -100 + masks = masks[:, 1:] + + new_log_probs = torch.gather( + F.log_softmax(logits[:, :-1, :], dim=-1), + dim=2, + index=labels[:, 1:].unsqueeze(2), + ).squeeze(2) + + # surrogate loss calculation + log_ratio_new_old = new_log_probs - old_logprobs + ratio_new_old = torch.exp(log_ratio_new_old) + weights = advantages if config.use_advantages else rewards + + surr1 = ratio_new_old * weights + clamped_ratio = torch.clamp( + ratio_new_old, + 1 - config.epsilon, + 1 + config.epsilon + ) + surr2 = clamped_ratio * weights + surrogate_loss = torch.min(surr1, surr2) + + # KL divergence approximation + log_ratio_ref_new = ref_logprobs - new_log_probs + approx_kl = torch.exp(log_ratio_ref_new) - log_ratio_ref_new - 1 + + # Final loss computation + loss = -self.masked_mean( + surrogate_loss - config.kl_coef * approx_kl, + masks + ) + + # Early stopping based on ratio threshold + if self.masked_mean(ratio_new_old, masks) > config.ratio_threshold: + loss = loss * 0 + + return loss + + def forward(self, input_: torch.Tensor, kwargs: dict): + # Regular language model forward pass + output = super().forward(input_, kwargs) + + # If we have GRPO inputs, compute GRPO loss + if all(k in kwargs for k in ["rewards", "advantages", "ref_logprobs", "old_logprobs"]): + grpo_loss = self.compute_grpo_loss( + logits=kwargs["logits"], + labels=kwargs["labels"], + rewards=kwargs["rewards"], + advantages=kwargs["advantages"], + ref_logprobs=kwargs["ref_logprobs"], + old_logprobs=kwargs["old_logprobs"], + config=kwargs["grpo_config"], + ) + kwargs[LanguageModelLossNames.grpo_loss] = grpo_loss + + return output diff --git a/fast_llm/models/grpo/model.py b/fast_llm/models/grpo/model.py index d532c1ec..935ab746 100644 --- a/fast_llm/models/grpo/model.py +++ b/fast_llm/models/grpo/model.py @@ -1,7 +1,7 @@ -from fast_llm.engine.distributed.config import DistributedConfig +from fast_llm.engine.distributed.config import DistributedConfig, PhaseType from fast_llm.layers.language_model.embedding import LanguageModelEmbedding -from fast_llm.layers.language_model.head import LanguageModelHead from fast_llm.layers.transformer.transformer import TransformerLayer +from fast_llm.models.grpo.head import GRPOHead from fast_llm.models.grpo.config import GRPOBaseModelConfig, GRPOModelConfig from fast_llm.models.gpt.model import GPTBaseModel, GPTModel @@ -20,7 +20,6 @@ def __init__( assert not self._config.use_absolute_position_embeddings def get_layers(self): - # TODO: Adjust as needed. return [ LanguageModelEmbedding(self._config, self._tensor_space), *[ @@ -31,16 +30,41 @@ def get_layers(self): ) for i in range(self._config.transformer.num_layers) ], - LanguageModelHead(self._config, self._tensor_space), + GRPOHead(self._config, self._tensor_space), # Use our custom head ] - def preprocess_meta(self, input_, phase): - # TODO: Adjust or reimplement. - return super().preprocess_meta(input_, phase) - - def preprocess(self, batch, preprocessed_meta=None, *, phase, iteration, metrics=None): - # TODO: Adjust or reimplement. - return super().preprocess(batch, preprocessed_meta, phase=phase, iteration=iteration, metrics=metrics) + def preprocess( + self, + batch: dict, + preprocessed_meta=None, + *, + phase: PhaseType, + iteration: int, + metrics=None + ): + # Extract GRPO specific inputs + grpo_inputs = { + "rewards": batch.pop("rewards")[:, 1:], + "advantages": batch.pop("advantages")[:, 1:], + "ref_logprobs": batch.pop("ref_logprobs")[:, 1:], + "old_logprobs": batch.pop("old_logprobs")[:, 1:], + "grpo_config": self._config.grpo, + } + + # Process the remaining inputs using parent class + preprocessed = super().preprocess( + batch["input_ids"], + preprocessed_meta, + phase=phase, + iteration=iteration, + metrics=metrics + ) + + # Add GRPO inputs to kwargs + for tokens, kwargs in preprocessed: + kwargs.update(grpo_inputs) + + return preprocessed @property def loss_defs(self):