@@ -72,6 +72,9 @@ def create_finetune_request(
7272 train_on_inputs : bool | Literal ["auto" ] | None = None ,
7373 training_method : str = "sft" ,
7474 dpo_beta : float | None = None ,
75+ dpo_normalize_logratios_by_length : bool = False ,
76+ rpo_alpha : float | None = None ,
77+ simpo_gamma : float | None = None ,
7578 from_checkpoint : str | None = None ,
7679) -> FinetuneRequest :
7780 if model is not None and from_checkpoint is not None :
@@ -182,6 +185,21 @@ def create_finetune_request(
182185
183186 if dpo_beta is not None and training_method != "dpo" :
184187 raise ValueError ("dpo_beta is only supported for DPO training" )
188+ if dpo_normalize_logratios_by_length and training_method != "dpo" :
189+ raise ValueError (
190+ "dpo_normalize_logratios_by_length=True is only supported for DPO training"
191+ )
192+ if rpo_alpha is not None :
193+ if training_method != "dpo" :
194+ raise ValueError ("rpo_alpha is only supported for DPO training" )
195+ if not rpo_alpha >= 0.0 :
196+ raise ValueError (f"rpo_alpha should be non-negative (got { rpo_alpha } )" )
197+
198+ if simpo_gamma is not None :
199+ if training_method != "dpo" :
200+ raise ValueError ("simpo_gamma is only supported for DPO training" )
201+ if not simpo_gamma >= 0.0 :
202+ raise ValueError (f"simpo_gamma should be non-negative (got { simpo_gamma } )" )
185203
186204 lr_scheduler : FinetuneLRScheduler
187205 if lr_scheduler_type == "cosine" :
@@ -204,7 +222,24 @@ def create_finetune_request(
204222 if training_method == "sft" :
205223 training_method_cls = TrainingMethodSFT (train_on_inputs = train_on_inputs )
206224 elif training_method == "dpo" :
207- training_method_cls = TrainingMethodDPO (dpo_beta = dpo_beta )
225+ if simpo_gamma is not None and simpo_gamma > 0 :
226+ dpo_reference_free = True
227+ dpo_normalize_logratios_by_length = True
228+ rprint (
229+ f"Parameter simpo_gamma was set to { simpo_gamma } . "
230+ "SimPO training detected. Reference logits will not be used "
231+ "and length normalization of log-probabilities will be enabled."
232+ )
233+ else :
234+ dpo_reference_free = False
235+
236+ training_method_cls = TrainingMethodDPO (
237+ dpo_beta = dpo_beta ,
238+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
239+ dpo_reference_free = dpo_reference_free ,
240+ rpo_alpha = rpo_alpha ,
241+ simpo_gamma = simpo_gamma ,
242+ )
208243
209244 finetune_request = FinetuneRequest (
210245 model = model ,
@@ -302,6 +337,9 @@ def create(
302337 train_on_inputs : bool | Literal ["auto" ] | None = None ,
303338 training_method : str = "sft" ,
304339 dpo_beta : float | None = None ,
340+ dpo_normalize_logratios_by_length : bool = False ,
341+ rpo_alpha : float | None = None ,
342+ simpo_gamma : float | None = None ,
305343 from_checkpoint : str | None = None ,
306344 ) -> FinetuneResponse :
307345 """
@@ -353,6 +391,9 @@ def create(
353391 training_method (str, optional): Training method. Defaults to "sft".
354392 Supported methods: "sft", "dpo".
355393 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
394+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
395+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
396+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
356397 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
357398 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
358399 The step value is optional, without it the final checkpoint will be used.
@@ -405,6 +446,9 @@ def create(
405446 train_on_inputs = train_on_inputs ,
406447 training_method = training_method ,
407448 dpo_beta = dpo_beta ,
449+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
450+ rpo_alpha = rpo_alpha ,
451+ simpo_gamma = simpo_gamma ,
408452 from_checkpoint = from_checkpoint ,
409453 )
410454
@@ -714,6 +758,9 @@ async def create(
714758 train_on_inputs : bool | Literal ["auto" ] | None = None ,
715759 training_method : str = "sft" ,
716760 dpo_beta : float | None = None ,
761+ dpo_normalize_logratios_by_length : bool = False ,
762+ rpo_alpha : float | None = None ,
763+ simpo_gamma : float | None = None ,
717764 from_checkpoint : str | None = None ,
718765 ) -> FinetuneResponse :
719766 """
@@ -765,6 +812,9 @@ async def create(
765812 training_method (str, optional): Training method. Defaults to "sft".
766813 Supported methods: "sft", "dpo".
767814 dpo_beta (float, optional): DPO beta parameter. Defaults to None.
815+ dpo_normalize_logratios_by_length (bool): Whether or not normalize logratios by sample length. Defaults to False,
816+ rpo_alpha (float, optional): RPO alpha parameter of DPO training to include NLL in the loss. Defaults to None.
817+ simpo_gamma: (float, optional): SimPO gamma parameter. Defaults to None.
768818 from_checkpoint (str, optional): The checkpoint identifier to continue training from a previous fine-tuning job.
769819 The format: {$JOB_ID/$OUTPUT_MODEL_NAME}:{$STEP}.
770820 The step value is optional, without it the final checkpoint will be used.
@@ -817,6 +867,9 @@ async def create(
817867 train_on_inputs = train_on_inputs ,
818868 training_method = training_method ,
819869 dpo_beta = dpo_beta ,
870+ dpo_normalize_logratios_by_length = dpo_normalize_logratios_by_length ,
871+ rpo_alpha = rpo_alpha ,
872+ simpo_gamma = simpo_gamma ,
820873 from_checkpoint = from_checkpoint ,
821874 )
822875
0 commit comments