1414from together .utils import finetune_price_to_dollars , log_warn , parse_timestamp
1515
1616
17+ _CONFIRMATION_MESSAGE = (
18+ "You are about to create a fine-tuning job. "
19+ "The cost of your job will be determined by the model size, the number of tokens "
20+ "in the training file, the number of tokens in the validation file, the number of epochs, and "
21+ "the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n "
22+ "You can pass `-y` or `--confirm` to your command to skip this message.\n \n "
23+ "Do you want to proceed?"
24+ )
25+
26+
1727class DownloadCheckpointTypeChoice (click .Choice ):
1828 def __init__ (self ) -> None :
1929 super ().__init__ ([ct .value for ct in DownloadCheckpointType ])
@@ -67,6 +77,14 @@ def fine_tuning(ctx: click.Context) -> None:
6777 "--suffix" , type = str , default = None , help = "Suffix for the fine-tuned model name"
6878)
6979@click .option ("--wandb-api-key" , type = str , default = None , help = "Wandb API key" )
80+ @click .option (
81+ "--confirm" ,
82+ "-y" ,
83+ type = bool ,
84+ is_flag = True ,
85+ default = False ,
86+ help = "Whether to skip the launch confirmation message" ,
87+ )
7088def create (
7189 ctx : click .Context ,
7290 training_file : str ,
@@ -84,6 +102,7 @@ def create(
84102 lora_trainable_modules : str ,
85103 suffix : str ,
86104 wandb_api_key : str ,
105+ confirm : bool ,
87106) -> None :
88107 """Start fine-tuning"""
89108 client : Together = ctx .obj
@@ -111,32 +130,37 @@ def create(
111130 "You have specified a number of evaluation loops but no validation file."
112131 )
113132
114- response = client .fine_tuning .create (
115- training_file = training_file ,
116- model = model ,
117- n_epochs = n_epochs ,
118- validation_file = validation_file ,
119- n_evals = n_evals ,
120- n_checkpoints = n_checkpoints ,
121- batch_size = batch_size ,
122- learning_rate = learning_rate ,
123- lora = lora ,
124- lora_r = lora_r ,
125- lora_dropout = lora_dropout ,
126- lora_alpha = lora_alpha ,
127- lora_trainable_modules = lora_trainable_modules ,
128- suffix = suffix ,
129- wandb_api_key = wandb_api_key ,
130- verbose = True ,
131- )
133+ if confirm or click .confirm (_CONFIRMATION_MESSAGE , default = True , show_default = True ):
134+ response = client .fine_tuning .create (
135+ training_file = training_file ,
136+ model = model ,
137+ n_epochs = n_epochs ,
138+ validation_file = validation_file ,
139+ n_evals = n_evals ,
140+ n_checkpoints = n_checkpoints ,
141+ batch_size = batch_size ,
142+ learning_rate = learning_rate ,
143+ lora = lora ,
144+ lora_r = lora_r ,
145+ lora_dropout = lora_dropout ,
146+ lora_alpha = lora_alpha ,
147+ lora_trainable_modules = lora_trainable_modules ,
148+ suffix = suffix ,
149+ wandb_api_key = wandb_api_key ,
150+ verbose = True ,
151+ )
132152
133- report_string = f"Successfully submitted a fine-tuning job { response .id } "
134- if response .created_at is not None :
135- created_time = datetime .strptime (response .created_at , "%Y-%m-%dT%H:%M:%S.%f%z" )
136- # created_at reports UTC time, we use .astimezone() to convert to local time
137- formatted_time = created_time .astimezone ().strftime ("%m/%d/%Y, %H:%M:%S" )
138- report_string += f" at { formatted_time } "
139- rprint (report_string )
153+ report_string = f"Successfully submitted a fine-tuning job { response .id } "
154+ if response .created_at is not None :
155+ created_time = datetime .strptime (
156+ response .created_at , "%Y-%m-%dT%H:%M:%S.%f%z"
157+ )
158+ # created_at reports UTC time, we use .astimezone() to convert to local time
159+ formatted_time = created_time .astimezone ().strftime ("%m/%d/%Y, %H:%M:%S" )
160+ report_string += f" at { formatted_time } "
161+ rprint (report_string )
162+ else :
163+ click .echo ("No confirmation received, stopping job launch" )
140164
141165
142166@fine_tuning .command ()
0 commit comments