11from __future__ import annotations
22
33from pathlib import Path
4+ from typing import Literal
45
56from rich import print as rprint
67
1314 FinetuneListEvents ,
1415 FinetuneRequest ,
1516 FinetuneResponse ,
17+ FinetuneTrainingLimits ,
1618 FullTrainingType ,
1719 LoRATrainingType ,
1820 TogetherClient ,
1921 TogetherRequest ,
2022 TrainingType ,
2123)
2224from together .types .finetune import DownloadCheckpointType
23- from together .utils import log_warn , normalize_key
25+ from together .utils import log_warn_once , normalize_key
2426
2527
2628class FineTuning :
@@ -36,16 +38,17 @@ def create(
3638 validation_file : str | None = "" ,
3739 n_evals : int | None = 0 ,
3840 n_checkpoints : int | None = 1 ,
39- batch_size : int | None = 16 ,
41+ batch_size : int | Literal [ "max" ] = "max" ,
4042 learning_rate : float | None = 0.00001 ,
4143 lora : bool = False ,
42- lora_r : int | None = 8 ,
44+ lora_r : int | None = None ,
4345 lora_dropout : float | None = 0 ,
44- lora_alpha : float | None = 8 ,
46+ lora_alpha : float | None = None ,
4547 lora_trainable_modules : str | None = "all-linear" ,
4648 suffix : str | None = None ,
4749 wandb_api_key : str | None = None ,
4850 verbose : bool = False ,
51+ model_limits : FinetuneTrainingLimits | None = None ,
4952 ) -> FinetuneResponse :
5053 """
5154 Method to initiate a fine-tuning job
@@ -58,7 +61,7 @@ def create(
5861 n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
5962 n_checkpoints (int, optional): Number of checkpoints to save during fine-tuning.
6063 Defaults to 1.
61- batch_size (int, optional): Batch size for fine-tuning. Defaults to 32 .
64+ batch_size (int, optional): Batch size for fine-tuning. Defaults to max .
6265 learning_rate (float, optional): Learning rate multiplier to use for training
6366 Defaults to 0.00001.
6467 lora (bool, optional): Whether to use LoRA adapters. Defaults to True.
@@ -72,24 +75,59 @@ def create(
7275 Defaults to None.
7376 verbose (bool, optional): whether to print the job parameters before submitting a request.
7477 Defaults to False.
78+ model_limits (FinetuneTrainingLimits, optional): Limits for the hyperparameters the model in Fine-tuning.
79+ Defaults to None.
7580
7681 Returns:
7782 FinetuneResponse: Object containing information about fine-tuning job.
7883 """
7984
85+ if batch_size == "max" :
86+ log_warn_once (
87+ "Starting from together>=1.3.0, "
88+ "the default batch size is set to the maximum allowed value for each model."
89+ )
90+
8091 requestor = api_requestor .APIRequestor (
8192 client = self ._client ,
8293 )
8394
95+ if model_limits is None :
96+ model_limits = self .get_model_limits (model = model )
97+
8498 training_type : TrainingType = FullTrainingType ()
8599 if lora :
100+ if model_limits .lora_training is None :
101+ raise ValueError (
102+ "LoRA adapters are not supported for the selected model."
103+ )
104+ lora_r = (
105+ lora_r if lora_r is not None else model_limits .lora_training .max_rank
106+ )
107+ lora_alpha = lora_alpha if lora_alpha is not None else lora_r * 2
86108 training_type = LoRATrainingType (
87109 lora_r = lora_r ,
88110 lora_alpha = lora_alpha ,
89111 lora_dropout = lora_dropout ,
90112 lora_trainable_modules = lora_trainable_modules ,
91113 )
92114
115+ batch_size = (
116+ batch_size
117+ if batch_size != "max"
118+ else model_limits .lora_training .max_batch_size
119+ )
120+ else :
121+ if model_limits .full_training is None :
122+ raise ValueError (
123+ "Full training is not supported for the selected model."
124+ )
125+ batch_size = (
126+ batch_size
127+ if batch_size != "max"
128+ else model_limits .full_training .max_batch_size
129+ )
130+
93131 finetune_request = FinetuneRequest (
94132 model = model ,
95133 training_file = training_file ,
@@ -121,12 +159,6 @@ def create(
121159
122160 assert isinstance (response , TogetherResponse )
123161
124- # TODO: Remove after next LoRA default change
125- log_warn (
126- "Some of the jobs run _directly_ from the together-python library might be trained using LoRA adapters. "
127- "The version range when this change occurred is from 1.2.3 to 1.2.6."
128- )
129-
130162 return FinetuneResponse (** response .data )
131163
132164 def list (self ) -> FinetuneList :
@@ -305,6 +337,34 @@ def download(
305337 size = file_size ,
306338 )
307339
340+ def get_model_limits (self , * , model : str ) -> FinetuneTrainingLimits :
341+ """
342+ Requests training limits for a specific model
343+
344+ Args:
345+ model_name (str): Name of the model to get limits for
346+
347+ Returns:
348+ FinetuneTrainingLimits: Object containing training limits for the model
349+ """
350+
351+ requestor = api_requestor .APIRequestor (
352+ client = self ._client ,
353+ )
354+
355+ model_limits_response , _ , _ = requestor .request (
356+ options = TogetherRequest (
357+ method = "GET" ,
358+ url = "fine-tunes/models/limits" ,
359+ params = {"model_name" : model },
360+ ),
361+ stream = False ,
362+ )
363+
364+ model_limits = FinetuneTrainingLimits (** model_limits_response .data )
365+
366+ return model_limits
367+
308368
309369class AsyncFineTuning :
310370 def __init__ (self , client : TogetherClient ) -> None :
@@ -493,3 +553,31 @@ async def download(
493553 "AsyncFineTuning.download not implemented. "
494554 "Please use FineTuning.download function instead."
495555 )
556+
557+ async def get_model_limits (self , * , model : str ) -> FinetuneTrainingLimits :
558+ """
559+ Requests training limits for a specific model
560+
561+ Args:
562+ model_name (str): Name of the model to get limits for
563+
564+ Returns:
565+ FinetuneTrainingLimits: Object containing training limits for the model
566+ """
567+
568+ requestor = api_requestor .APIRequestor (
569+ client = self ._client ,
570+ )
571+
572+ model_limits_response , _ , _ = await requestor .arequest (
573+ options = TogetherRequest (
574+ method = "GET" ,
575+ url = "fine-tunes/models/limits" ,
576+ params = {"model" : model },
577+ ),
578+ stream = False ,
579+ )
580+
581+ model_limits = FinetuneTrainingLimits (** model_limits_response .data )
582+
583+ return model_limits
0 commit comments