diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index a44374c..cd33d10 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -3396,70 +3396,46 @@ def process_inputs(self,export=None): batBase = 'accelerate "launch" "--mixed_precision=no" "scripts/trainer.py"' if export == 'Linux': batBase = f'accelerate launch --mixed_precision="no" scripts/trainer.py' - - if self.shuffle_dataset_per_epoch == True: + + def platform_format_flag(flag, argument=None): if export == 'Linux': - batBase += ' --shuffle_per_epoch' + if argument is None: + return f' {flag}' + else: + return f' {flag}="{argument}"' else: - batBase += ' "--shuffle_per_epoch"' + if argument is None: + return f' "{flag}"' + else: + return f' "{flag}={argument}"' + if self.shuffle_dataset_per_epoch == True: + batBase += platform_format_flag('--shuffle_per_epoch') if self.batch_prompt_sampling != 0: - if export == 'Linux': - batBase += f' --sample_from_batch={self.batch_prompt_sampling}' - else: - batBase += f' "--sample_from_batch={self.batch_prompt_sampling}"' + batBase += platform_format_flag('--sample_from_batch', self.batch_prompt_sampling) if self.attention == 'xformers': - if export == 'Linux': - batBase += ' --attention="xformers"' - else: - batBase += ' "--attention=xformers" ' + batBase += platform_format_flag('--attention','xformers') elif self.attention == 'Flash Attention': - if export == 'Linux': - batBase += ' --attention="flash_attention"' - else: - batBase += ' "--attention=flash_attention" ' + batBase += platform_format_flag('--attention', 'flash_attention') if self.model_variant == 'Regular': - if export == 'Linux': - batBase += ' --model_variant="base"' - else: - batBase += ' "--model_variant=base" ' + batBase += platform_format_flag('--model_variant', 'base') elif self.model_variant == 'Inpaint': - if export == 'Linux': - batBase += ' --model_variant="inpainting"' - else: - batBase += ' "--model_variant=inpainting" ' + batBase += platform_format_flag('--model_variant', 'inpainting') elif self.model_variant == 'Depth2Img': - if export == 'Linux': - batBase += ' --model_variant="depth2img"' - else: - batBase += ' "--model_variant=depth2img" ' - + batBase += platform_format_flag('--model_variant', 'depth2img') if self.masked_training == True: - if export == 'Linux': - batBase += ' --masked_training ' - else: - batBase += ' "--masked_training" ' - + batBase += platform_format_flag('--masked_training') if self.normalize_masked_area_loss == True: - if export == 'Linux': - batBase += ' --normalize_masked_area_loss ' - else: - batBase += ' "--normalize_masked_area_loss" ' + batBase += platform_format_flag('--normalize_masked_area_loss') try: # if unmasked_probability is a percentage calculate what epoch to stop at if '%' in self.unmasked_probability: percent = float(self.unmasked_probability.replace('%', '')) fraction = percent / 100 - if export == 'Linux': - batBase += f' --unmasked_probability={fraction}' - else: - batBase += f' "--unmasked_probability={fraction}" ' + batBase += platform_format_flag('--unmasked_probability', fraction) elif '%' not in self.unmasked_probability and self.unmasked_probability.strip() != '' and self.unmasked_probability != '0': - if export == 'Linux': - batBase += f' --unmasked_probability={self.unmasked_probability}' - else: - batBase += f' "--unmasked_probability={self.unmasked_probability}" ' + batBase += platform_format_flag('--unmasked_probability', self.unmasked_probability) except: pass @@ -3468,212 +3444,107 @@ def process_inputs(self,export=None): if '%' in self.max_denoising_strength: percent = float(self.max_denoising_strength.replace('%', '')) fraction = percent / 100 - if export == 'Linux': - batBase += f' --max_denoising_strength={fraction}' - else: - batBase += f' "--max_denoising_strength={fraction}" ' + batBase += platform_format_flag('--max_denoising_strength', fraction) elif '%' not in self.max_denoising_strength and self.max_denoising_strength.strip() != '' and self.max_denoising_strength != '0': - if export == 'Linux': - batBase += f' --max_denoising_strength={self.max_denoising_strength}' - else: - batBase += f' "--max_denoising_strength={self.max_denoising_strength}" ' + batBase += platform_format_flag('--max_denoising_strength', self.max_denoising_strength) except: pass if self.fallback_mask_prompt != '': - if export == 'Linux': - batBase += f' --add_mask_prompt="{self.fallback_mask_prompt}"' - else: - batBase += f' "--add_mask_prompt={self.fallback_mask_prompt}" ' + batBase += platform_format_flag('--add_mask_prompt', self.fallback_mask_prompt) if self.disable_cudnn_benchmark == True: - if export == 'Linux': - batBase += ' --disable_cudnn_benchmark' - else: - batBase += ' "--disable_cudnn_benchmark" ' + batBase += platform_format_flag('--disable_cudnn_benchmark') if self.use_text_files_as_captions == True: - if export == 'Linux': - batBase += ' --use_text_files_as_captions' - else: - batBase += ' "--use_text_files_as_captions" ' + batBase += platform_format_flag('--use_text_files_as_captions') if int(self.sample_step_interval) != 0 or self.sample_step_interval != '' or self.sample_step_interval != ' ': - if export == 'Linux': - batBase += f' --sample_step_interval={self.sample_step_interval}' - else: - batBase += f' "--sample_step_interval={self.sample_step_interval}" ' + batBase += platform_format_flag('--sample_step_interval', self.sample_step_interval) try: #if limit_text_encoder is a percentage calculate what epoch to stop at if '%' in self.limit_text_encoder: percent = float(self.limit_text_encoder.replace('%','')) stop_epoch = int((int(self.train_epocs) * percent) / 100) - if export == 'Linux': - batBase += f' --stop_text_encoder_training={stop_epoch}' - else: - batBase += f' "--stop_text_encoder_training={stop_epoch}" ' + batBase += platform_format_flag('--stop_text_encoder_training', stop_epoch) elif '%' not in self.limit_text_encoder and self.limit_text_encoder.strip() != '' and self.limit_text_encoder != '0': - if export == 'Linux': - batBase += f' --stop_text_encoder_training={self.limit_text_encoder}' - else: - batBase += f' "--stop_text_encoder_training={self.limit_text_encoder}" ' + batBase += platform_format_flag('--stop_text_encoder_training', self.limit_text_encoder) except: pass - if export=='Linux': - batBase += f' --pretrained_model_name_or_path="{self.model_path}" ' - batBase += f' --pretrained_vae_name_or_path="{self.vae_path}" ' - batBase += f' --output_dir="{self.output_path}" ' - batBase += f' --seed={self.seed_number} ' - batBase += f' --resolution={self.resolution} ' - batBase += f' --train_batch_size={self.batch_size} ' - batBase += f' --num_train_epochs={self.train_epocs} ' - else: - batBase += f' "--pretrained_model_name_or_path={self.model_path}" ' - batBase += f' "--pretrained_vae_name_or_path={self.vae_path}" ' - batBase += f' "--output_dir={self.output_path}" ' - batBase += f' "--seed={self.seed_number}" ' - batBase += f' "--resolution={self.resolution}" ' - batBase += f' "--train_batch_size={self.batch_size}" ' - batBase += f' "--num_train_epochs={self.train_epocs}" ' + batBase += platform_format_flag('--pretrained_model_name_or_path', self.model_path) + batBase += platform_format_flag('--pretrained_vae_name_or_path', self.vae_path) + batBase += platform_format_flag('--output_dir', self.output_path) + batBase += platform_format_flag('--seed', self.seed_number) + batBase += platform_format_flag('--resolution', self.resolution) + batBase += platform_format_flag('--train_batch_size', self.batch_size) + batBase += platform_format_flag('--num_train_epochs', self.train_epocs) + if self.mixed_precision == 'fp16' or self.mixed_precision == 'bf16' or self.mixed_precision == 'tf32': - if export == 'Linux': - batBase += f' --mixed_precision="{self.mixed_precision}"' - else: - batBase += f' "--mixed_precision={self.mixed_precision}" ' + batBase += platform_format_flag('--mixed_precision', self.mixed_precision) if self.use_aspect_ratio_bucketing: - if export == 'Linux': - batBase += ' --use_bucketing' - else: - batBase += f' "--use_bucketing" ' + batBase += platform_format_flag('--use_bucketing') if self.aspect_ratio_bucketing_mode == 'Dynamic Fill': com = 'dynamic' if self.aspect_ratio_bucketing_mode == 'Drop Fill': com = 'truncate' if self.aspect_ratio_bucketing_mode == 'Duplicate Fill': com = 'add' - if export == 'Linux': - batBase += f' --aspect_mode="{com}"' - else: - batBase += f' "--aspect_mode={com}" ' + batBase += platform_format_flag('--aspect_mode', com) if self.dynamic_bucketing_mode == 'Duplicate': com = 'add' if self.dynamic_bucketing_mode == 'Drop': com = 'truncate' - if export == 'Linux': - batBase += f' --aspect_mode_action_preference="{com}"' - else: - batBase += f' "--aspect_mode_action_preference={com}" ' + batBase += platform_format_flag('--aspect_mode_action_preference', com) if self.use_8bit_adam == True: - if export == 'Linux': - batBase += ' --use_8bit_adam' - else: - batBase += f' "--use_8bit_adam" ' + batBase += platform_format_flag('--use_8bit_adam') if self.use_gradient_checkpointing == True: - if export == 'Linux': - batBase += ' --gradient_checkpointing' - else: - batBase += f' "--gradient_checkpointing" ' - - if export == 'Linux': - batBase += f' --gradient_accumulation_steps={self.accumulation_steps}' - batBase += f' --learning_rate={self.learning_rate}' - batBase += f' --lr_warmup_steps={self.warmup_steps}' - batBase += f' --lr_scheduler="{self.learning_rate_scheduler}"' - else: - batBase += f' "--gradient_accumulation_steps={self.accumulation_steps}" ' - batBase += f' "--learning_rate={self.learning_rate}" ' - batBase += f' "--lr_warmup_steps={self.warmup_steps}" ' - batBase += f' "--lr_scheduler={self.learning_rate_scheduler}" ' + batBase += platform_format_flag('--gradient_checkpointing') + + batBase += platform_format_flag('--gradient_accumulation_steps', self.accumulation_steps) + batBase += platform_format_flag('--learning_rate', self.learning_rate) + batBase += platform_format_flag('--lr_warmup_steps', self.warmup_steps) + batBase += platform_format_flag('--lr_scheduler', self.learning_rate_scheduler) if self.regenerate_latent_cache == True: - if export == 'Linux': - batBase += ' --regenerate_latent_cache' - else: - batBase += f' "--regenerate_latent_cache" ' + batBase += platform_format_flag('--regenerate_latent_cache') if self.train_text_encoder == True: - if export == 'Linux': - batBase += ' --train_text_encoder' - else: - batBase += f' "--train_text_encoder" ' + batBase += platform_format_flag('--train_text_encoder') if self.with_prior_loss_preservation == True and self.use_aspect_ratio_bucketing == False: - if export == 'Linux': - batBase += ' --with_prior_preservation' - batBase += f' --prior_loss_weight={self.prior_loss_preservation_weight}' - else: - batBase += f' "--with_prior_preservation" ' - batBase += f' "--prior_loss_weight={self.prior_loss_preservation_weight}" ' + batBase += platform_format_flag('--with_prior_preservation') + batBase += platform_format_flag('--prior_loss_weight', self.prior_loss_preservation_weight) elif self.with_prior_loss_preservation == True and self.use_aspect_ratio_bucketing == True: print('loss preservation isnt supported with aspect ratio bucketing yet, sorry!') if self.use_image_names_as_captions == True: - if export == 'Linux': - batBase += ' --use_image_names_as_captions' - else: - batBase += f' "--use_image_names_as_captions" ' + batBase += platform_format_flag('--use_image_names_as_captions') if self.use_offset_noise == True: - if export == 'Linux': - batBase += f' --with_offset_noise' - batBase += f' --offset_noise_weight={self.offset_noise_weight}' - else: - batBase += f' "--with_offset_noise" ' - batBase += f' "--offset_noise_weight={self.offset_noise_weight}" ' + batBase += platform_format_flag('--with_offset_noise') + batBase += platform_format_flag('--offset_noise_weight', self.offset_noise_weight) if self.auto_balance_concept_datasets == True: - if export == 'Linux': - batBase += ' --auto_balance_concept_datasets' - else: - batBase += f' "--auto_balance_concept_datasets" ' + batBase += platform_format_flag('--auto_balance_concept_datasets') if self.add_class_images_to_dataset == True and self.with_prior_loss_preservation == False: - if export == 'Linux': - batBase += ' --add_class_images_to_dataset' - else: - batBase += f' "--add_class_images_to_dataset" ' - if export == 'Linux': - batBase += f' --concepts_list="{self.concept_list_json_path}"' - batBase += f' --num_class_images={self.number_of_class_images}' - batBase += f' --save_every_n_epoch={self.save_every_n_epochs}' - batBase += f' --n_save_sample={self.number_of_samples_to_generate}' - batBase += f' --sample_height={self.sample_height}' - batBase += f' --sample_width={self.sample_width}' - batBase += f' --dataset_repeats={self.dataset_repeats}' - else: - batBase += f' "--concepts_list={self.concept_list_json_path}" ' - batBase += f' "--num_class_images={self.number_of_class_images}" ' - batBase += f' "--save_every_n_epoch={self.save_every_n_epochs}" ' - batBase += f' "--n_save_sample={self.number_of_samples_to_generate}" ' - batBase += f' "--sample_height={self.sample_height}" ' - batBase += f' "--sample_width={self.sample_width}" ' - batBase += f' "--dataset_repeats={self.dataset_repeats}" ' + batBase += platform_format_flag('--add_class_images_to_dataset') + batBase += platform_format_flag('--concepts_list', self.concept_list_json_path) + batBase += platform_format_flag('--num_class_images', self.number_of_class_images) + batBase += platform_format_flag('--save_every_n_epoch', self.save_every_n_epochs) + batBase += platform_format_flag('--n_save_sample', self.number_of_samples_to_generate) + batBase += platform_format_flag('--sample_height', self.sample_height) + batBase += platform_format_flag('--sample_width', self.sample_width) + batBase += platform_format_flag('--dataset_repeats', self.dataset_repeats) if self.sample_random_aspect_ratio == True: - if export == 'Linux': - batBase += ' --sample_aspect_ratios' - else: - batBase += f' "--sample_aspect_ratios" ' + batBase += platform_format_flag('--sample_aspect_ratios') if self.send_telegram_updates == True: - if export == 'Linux': - batBase += ' --send_telegram_updates' - batBase += f' --telegram_token="{self.telegram_token}"' - batBase += f' --telegram_chat_id="{self.telegram_chat_id}"' - else: - batBase += f' "--send_telegram_updates" ' - batBase += f' "--telegram_token={self.telegram_token}" ' - batBase += f' "--telegram_chat_id={self.telegram_chat_id}" ' + batBase += platform_format_flag(export,'--send_telegram_updates') + batBase += platform_format_flag('--telegram_token', self.telegram_token) + batBase += platform_format_flag('--telegram_chat_id', self.telegram_chat_id) #remove duplicates from self.sample_prompts self.sample_prompts = list(dict.fromkeys(self.sample_prompts)) #remove duplicates from self.add_controlled_seed_to_sample self.add_controlled_seed_to_sample = list(dict.fromkeys(self.add_controlled_seed_to_sample)) for i in range(len(self.sample_prompts)): - if export == 'Linux': - batBase += f' --add_sample_prompt="{self.sample_prompts[i]}"' - else: - batBase += f' "--add_sample_prompt={self.sample_prompts[i]}" ' + batBase += platform_format_flag('--add_sample_prompt', self.sample_prompts[i]) for i in range(len(self.add_controlled_seed_to_sample)): - if export == 'Linux': - batBase += f' --save_sample_controlled_seed={self.add_controlled_seed_to_sample[i]}' - else: - batBase += f' "--save_sample_controlled_seed={self.add_controlled_seed_to_sample[i]}" ' + batBase += platform_format_flag('--save_sample_controlled_seed',self.add_controlled_seed_to_sample[i]) if self.sample_on_training_start == True: - if export == 'Linux': - batBase += ' --sample_on_training_start' - else: - batBase += f' "--sample_on_training_start" ' + batBase += platform_format_flag('--sample_on_training_start') if len(self.conditional_dropout) > 0 and self.conditional_dropout != ' ' and self.conditional_dropout != '0': #if % is in the string, remove it if '%' in self.conditional_dropout: @@ -3694,22 +3565,13 @@ def process_inputs(self,export=None): #print(self.conditional_dropout) #if self.coniditional dropout is a float if isinstance(self.conditional_dropout, float): - if export == 'Linux': - batBase += f' --conditional_dropout={self.conditional_dropout}' - else: - batBase += f' "--conditional_dropout={self.conditional_dropout}" ' + batBase += platform_format_flag('--conditional_dropout', self.conditional_dropout) #save configure if self.clip_penultimate == True: - if export == 'Linux': - batBase += ' --clip_penultimate' - else: - batBase += f' "--clip_penultimate" ' + batBase += platform_format_flag('--clip_penultimate') if self.use_ema == True: - if export == 'Linux': - batBase += ' --use_ema' - else: - batBase += f' "--use_ema" ' + batBase += platform_format_flag('--use_ema') self.save_config('stabletune_last_run.json') #check if output folder exists