diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index 94c5150..73e15ba 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -888,6 +888,7 @@ def create_default_variables(self): self.save_sample_controlled_seed = [] self.delete_checkpoints_when_full_drive = True self.use_image_names_as_captions = True + self.shuffle_captions = False self.use_offset_noise = False self.offset_noise_weight = 0.1 self.num_samples_to_generate = 1 @@ -1201,6 +1202,10 @@ def dreambooth_mode(self): self.use_image_names_as_captions_checkbox.configure(state='disabled') self.use_image_names_as_captions_var.set(0) #self.use_image_names_as_captions_checkbox.set(0) + self.shuffle_captions_label.configure(state='disabled') + self.shuffle_captions_checkbox.configure(state='disabled') + self.shuffle_captions_var.set(0) + #self.shuffle_captions_checkbox.set(0) self.add_class_images_to_dataset_checkbox.configure(state='disabled') self.add_class_images_to_dataset_label.configure(state='disabled') self.add_class_images_to_dataset_var.set(0) @@ -1215,10 +1220,13 @@ def fine_tune_mode(self): self.use_text_files_as_captions_label.configure(state='normal') self.use_image_names_as_captions_label.configure(state='normal') self.use_image_names_as_captions_checkbox.configure(state='normal') + self.shuffle_captions_label.configure(state='normal') + self.shuffle_captions_checkbox.configure(state='normal') self.add_class_images_to_dataset_checkbox.configure(state='normal') self.add_class_images_to_dataset_label.configure(state='normal') self.use_text_files_as_captions_var.set(1) self.use_image_names_as_captions_var.set(1) + self.shuffle_captions_var.set(0) self.add_class_images_to_dataset_var.set(0) except: pass @@ -1629,39 +1637,49 @@ def create_dataset_settings_widgets(self): # create checkbox self.use_image_names_as_captions_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.use_image_names_as_captions_var) self.use_image_names_as_captions_checkbox.grid(row=2, column=1, sticky="nsew") + # create shuffle captions checkbox + self.shuffle_captions_var = tk.IntVar() + self.shuffle_captions_var.set(self.shuffle_captions) + # create label + self.shuffle_captions_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Shuffle Captions") + shuffle_captions_label_ttp = CreateToolTip(self.shuffle_captions_label, "Randomize the order of tags in a caption. Tags are separated by ','. Used for training with booru-style captions.") + self.shuffle_captions_label.grid(row=3, column=0, sticky="nsew") + # create checkbox + self.shuffle_captions_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.shuffle_captions_var) + self.shuffle_captions_checkbox.grid(row=3, column=1, sticky="nsew") # create auto balance dataset checkbox self.auto_balance_dataset_var = tk.IntVar() self.auto_balance_dataset_var.set(self.auto_balance_concept_datasets) # create label self.auto_balance_dataset_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Auto Balance Dataset") auto_balance_dataset_label_ttp = CreateToolTip(self.auto_balance_dataset_label, "Will use the concept with the least amount of images to balance the dataset by removing images from the other concepts.") - self.auto_balance_dataset_label.grid(row=3, column=0, sticky="nsew") + self.auto_balance_dataset_label.grid(row=4, column=0, sticky="nsew") # create checkbox self.auto_balance_dataset_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.auto_balance_dataset_var) - self.auto_balance_dataset_checkbox.grid(row=3, column=1, sticky="nsew") + self.auto_balance_dataset_checkbox.grid(row=4, column=1, sticky="nsew") #create add class images to dataset checkbox self.add_class_images_to_dataset_var = tk.IntVar() self.add_class_images_to_dataset_var.set(self.add_class_images_to_training) #create label self.add_class_images_to_dataset_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Add Class Images to Dataset") add_class_images_to_dataset_label_ttp = CreateToolTip(self.add_class_images_to_dataset_label, "Will add class images without prior preservation to the dataset.") - self.add_class_images_to_dataset_label.grid(row=4, column=0, sticky="nsew") + self.add_class_images_to_dataset_label.grid(row=5, column=0, sticky="nsew") #create checkbox self.add_class_images_to_dataset_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.add_class_images_to_dataset_var) - self.add_class_images_to_dataset_checkbox.grid(row=4, column=1, sticky="nsew") + self.add_class_images_to_dataset_checkbox.grid(row=5, column=1, sticky="nsew") #create number of class images entry self.number_of_class_images_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Number of Class Images") number_of_class_images_label_ttp = CreateToolTip(self.number_of_class_images_label, "The number of class images to add to the dataset, if they don't exist in the class directory they will be generated.") - self.number_of_class_images_label.grid(row=5, column=0, sticky="nsew") + self.number_of_class_images_label.grid(row=6, column=0, sticky="nsew") self.number_of_class_images_entry = ctk.CTkEntry(self.dataset_frame_subframe) - self.number_of_class_images_entry.grid(row=5, column=1, sticky="nsew") + self.number_of_class_images_entry.grid(row=6, column=1, sticky="nsew") self.number_of_class_images_entry.insert(0, self.num_class_images) #create dataset repeat entry self.dataset_repeats_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Dataset Repeats") dataset_repeat_label_ttp = CreateToolTip(self.dataset_repeats_label, "The number of times to repeat the dataset, this will increase the number of images in the dataset.") - self.dataset_repeats_label.grid(row=6, column=0, sticky="nsew") + self.dataset_repeats_label.grid(row=7, column=0, sticky="nsew") self.dataset_repeats_entry = ctk.CTkEntry(self.dataset_frame_subframe) - self.dataset_repeats_entry.grid(row=6, column=1, sticky="nsew") + self.dataset_repeats_entry.grid(row=7, column=1, sticky="nsew") self.dataset_repeats_entry.insert(0, self.dataset_repeats) #add use_aspect_ratio_bucketing checkbox @@ -1670,10 +1688,10 @@ def create_dataset_settings_widgets(self): #create label self.use_aspect_ratio_bucketing_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Use Aspect Ratio Bucketing") use_aspect_ratio_bucketing_label_ttp = CreateToolTip(self.use_aspect_ratio_bucketing_label, "Will use aspect ratio bucketing, may improve aspect ratio generations.") - self.use_aspect_ratio_bucketing_label.grid(row=7, column=0, sticky="nsew") + self.use_aspect_ratio_bucketing_label.grid(row=8, column=0, sticky="nsew") #create checkbox self.use_aspect_ratio_bucketing_checkbox = ctk.CTkSwitch(self.dataset_frame_subframe, variable=self.use_aspect_ratio_bucketing_var) - self.use_aspect_ratio_bucketing_checkbox.grid(row=7, column=1, sticky="nsew") + self.use_aspect_ratio_bucketing_checkbox.grid(row=8, column=1, sticky="nsew") #do something on checkbox click self.use_aspect_ratio_bucketing_checkbox.bind("", self.aspect_ratio_mode_toggles) @@ -1682,17 +1700,17 @@ def create_dataset_settings_widgets(self): self.aspect_ratio_bucketing_mode_var.set(self.aspect_ratio_bucketing_mode) self.aspect_ratio_bucketing_mode_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Aspect Ratio Bucketing Mode") aspect_ratio_bucketing_mode_label_ttp = CreateToolTip(self.aspect_ratio_bucketing_mode_label, "Select what the Auto Bucketing will do in case the bucket doesn't match the batch size, dynamic will choose the least amount of adding/removing of images per bucket.") - self.aspect_ratio_bucketing_mode_label.grid(row=8, column=0, sticky="nsew") + self.aspect_ratio_bucketing_mode_label.grid(row=9, column=0, sticky="nsew") self.aspect_ratio_bucketing_mode_option_menu = ctk.CTkOptionMenu(self.dataset_frame_subframe, variable=self.aspect_ratio_bucketing_mode_var, values=['Dynamic Fill', 'Drop Fill', 'Duplicate Fill']) - self.aspect_ratio_bucketing_mode_option_menu.grid(row=8, column=1, sticky="nsew") + self.aspect_ratio_bucketing_mode_option_menu.grid(row=9, column=1, sticky="nsew") #option menu to select dynamic bucketing mode (if enabled) self.dynamic_bucketing_mode_var = tk.StringVar() self.dynamic_bucketing_mode_var.set(self.dynamic_bucketing_mode) self.dynamic_bucketing_mode_label = ctk.CTkLabel(self.dataset_frame_subframe, text="Dynamic Preference") dynamic_bucketing_mode_label_ttp = CreateToolTip(self.dynamic_bucketing_mode_label, "If you're using dynamic mode, choose what you prefer in the case that dropping and duplicating are the same amount of images.") - self.dynamic_bucketing_mode_label.grid(row=9, column=0, sticky="nsew") + self.dynamic_bucketing_mode_label.grid(row=10, column=0, sticky="nsew") self.dynamic_bucketing_mode_option_menu = ctk.CTkOptionMenu(self.dataset_frame_subframe, variable=self.dynamic_bucketing_mode_var, values=['Duplicate', 'Drop']) - self.dynamic_bucketing_mode_option_menu.grid(row=9, column=1, sticky="nsew") + self.dynamic_bucketing_mode_option_menu.grid(row=10, column=1, sticky="nsew") #add shuffle dataset per epoch checkbox self.shuffle_dataset_per_epoch_var = tk.IntVar() self.shuffle_dataset_per_epoch_var.set(self.shuffle_dataset_per_epoch) @@ -3067,6 +3085,7 @@ def save_config(self, config_file=None): configure["with_prior_loss_preservation"] = self.with_prior_loss_preservation_var.get() configure["prior_loss_preservation_weight"] = self.prior_loss_preservation_weight_entry.get() configure["use_image_names_as_captions"] = self.use_image_names_as_captions_var.get() + configure["shuffle_captions"] = self.shuffle_captions_var.get() configure["auto_balance_concept_datasets"] = self.auto_balance_dataset_var.get() configure["add_class_images_to_dataset"] = self.add_class_images_to_dataset_var.get() configure["number_of_class_images"] = self.number_of_class_images_entry.get() @@ -3201,6 +3220,7 @@ def load_config(self,file_name=None): self.prior_loss_preservation_weight_entry.delete(0, tk.END) self.prior_loss_preservation_weight_entry.insert(0, configure["prior_loss_preservation_weight"]) self.use_image_names_as_captions_var.set(configure["use_image_names_as_captions"]) + self.shuffle_captions_var.set(configure["shuffle_captions"]) self.auto_balance_dataset_var.set(configure["auto_balance_concept_datasets"]) self.add_class_images_to_dataset_var.set(configure["add_class_images_to_dataset"]) self.number_of_class_images_entry.delete(0, tk.END) @@ -3296,6 +3316,7 @@ def process_inputs(self,export=None): self.with_prior_loss_preservation = self.with_prior_loss_preservation_var.get() self.prior_loss_preservation_weight = self.prior_loss_preservation_weight_entry.get() self.use_image_names_as_captions = self.use_image_names_as_captions_var.get() + self.shuffle_captions = self.shuffle_captions_var.get() self.auto_balance_concept_datasets = self.auto_balance_dataset_var.get() self.add_class_images_to_dataset = self.add_class_images_to_dataset_var.get() self.number_of_class_images = self.number_of_class_images_entry.get() @@ -3376,7 +3397,7 @@ def process_inputs(self,export=None): #check if resolution is the same try: #try because I keep adding stuff to the json file and it may error out for peeps - if self.last_run["resolution"] != self.resolution or self.use_text_files_as_captions != self.last_run['use_text_files_as_captions'] or self.last_run['dataset_repeats'] != self.dataset_repeats or self.last_run["batch_size"] != self.batch_size or self.last_run["train_text_encoder"] != self.train_text_encoder or self.last_run["use_image_names_as_captions"] != self.use_image_names_as_captions or self.last_run["auto_balance_concept_datasets"] != self.auto_balance_concept_datasets or self.last_run["add_class_images_to_dataset"] != self.add_class_images_to_dataset or self.last_run["number_of_class_images"] != self.number_of_class_images or self.last_run["aspect_ratio_bucketing"] != self.use_aspect_ratio_bucketing or self.last_run["masked_training"] != self.masked_training: + if self.last_run["resolution"] != self.resolution or self.use_text_files_as_captions != self.last_run['use_text_files_as_captions'] or self.last_run['dataset_repeats'] != self.dataset_repeats or self.last_run["batch_size"] != self.batch_size or self.last_run["train_text_encoder"] != self.train_text_encoder or self.last_run["use_image_names_as_captions"] != self.use_image_names_as_captions or self.last_run["shuffle_captions"] != self.shuffle_captions or self.last_run["auto_balance_concept_datasets"] != self.auto_balance_concept_datasets or self.last_run["add_class_images_to_dataset"] != self.add_class_images_to_dataset or self.last_run["number_of_class_images"] != self.number_of_class_images or self.last_run["aspect_ratio_bucketing"] != self.use_aspect_ratio_bucketing or self.last_run["masked_training"] != self.masked_training: self.regenerate_latent_cache = True #show message @@ -3624,6 +3645,11 @@ def process_inputs(self,export=None): batBase += ' --use_image_names_as_captions' else: batBase += f' "--use_image_names_as_captions" ' + if self.shuffle_captions == True: + if export == 'Linux': + batBase += ' --shuffle_captions' + else: + batBase += f' "--shuffle_captions" ' if self.use_offset_noise == True: if export == 'Linux': batBase += f' --with_offset_noise' diff --git a/scripts/dataloaders_util.py b/scripts/dataloaders_util.py index 0b64367..ef501b6 100644 --- a/scripts/dataloaders_util.py +++ b/scripts/dataloaders_util.py @@ -321,6 +321,7 @@ def __init__(self, resolution=512, center_crop=False, use_image_names_as_captions=True, + shuffle_captions=False, add_class_images_to_dataset=None, balance_datasets=False, crop_jitter=20, @@ -342,6 +343,7 @@ def __init__(self, self.batch_size = batch_size self.concepts_list = concepts_list self.use_image_names_as_captions = use_image_names_as_captions + self.shuffle_captions = shuffle_captions self.num_train_images = 0 self.num_reg_images = 0 self.image_train_items = [] @@ -447,6 +449,12 @@ def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False) image_train_tmp = image_train_item.hydrate(crop=False, save=0, crop_jitter=self.crop_jitter) image_train_tmp_image = Image.fromarray(self.normalize8(image_train_tmp.image)).convert("RGB") + instance_prompt = image_train_tmp.caption + if self.shuffle_captions: + caption_parts = instance_prompt.split(",") + random.shuffle(caption_parts) + instance_prompt = ",".join(caption_parts) + example["instance_images"] = self.image_transforms(image_train_tmp_image) if image_train_tmp.mask is not None: image_train_tmp_mask = Image.fromarray(self.normalize8(image_train_tmp.mask)).convert("L") @@ -454,9 +462,9 @@ def __get_image_for_trainer(self,image_train_item,debug_level=0,class_img=False) if self.model_variant == 'depth2img': image_train_tmp_depth = Image.fromarray(self.normalize8(image_train_tmp.extra)).convert("L") example["instance_depth_images"] = self.depth_image_transforms(image_train_tmp_depth) - #print(image_train_tmp.caption) + #print(instance_prompt) example["instance_prompt_ids"] = self.tokenizer( - image_train_tmp.caption, + instance_prompt, padding="do_not_pad", truncation=True, max_length=self.tokenizer.model_max_length, @@ -1051,6 +1059,7 @@ def __init__( center_crop=False, num_class_images=None, use_image_names_as_captions=False, + shuffle_captions=False, repeats=1, use_text_files_as_captions=False, seed=555, @@ -1060,6 +1069,7 @@ def __init__( load_mask=None, ): self.use_image_names_as_captions = use_image_names_as_captions + self.shuffle_captions = shuffle_captions self.size = size self.center_crop = center_crop self.tokenizer = tokenizer @@ -1229,6 +1239,10 @@ def __getitem__(self, index): instance_prompt = f.readline().rstrip() f.close() + if self.shuffle_captions: + caption_parts = instance_prompt.split(",") + random.shuffle(caption_parts) + instance_prompt = ",".join(caption_parts) #print('identifier: ' + instance_prompt) instance_image = instance_image.convert("RGB") diff --git a/scripts/trainer.py b/scripts/trainer.py index 92f1b91..2f1e9f3 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -386,6 +386,7 @@ def parse_args(): parser.add_argument('--append_sample_controlled_seed_action', action='append') parser.add_argument('--add_sample_prompt', type=str, action='append') parser.add_argument('--use_image_names_as_captions', default=False, action="store_true") + parser.add_argument('--shuffle_captions', default=False, action="store_true") parser.add_argument("--masked_training", default=False, required=False, action='store_true', help="Whether to mask parts of the image during training") parser.add_argument("--normalize_masked_area_loss", default=False, required=False, action='store_true', help="Normalize the loss, to make it independent of the size of the masked area") parser.add_argument("--unmasked_probability", type=float, default=1, required=False, help="Probability of training a step without a mask") @@ -612,6 +613,7 @@ def main(): train_dataset = AutoBucketing( concepts_list=args.concepts_list, use_image_names_as_captions=args.use_image_names_as_captions, + shuffle_captions=args.shuffle_captions, batch_size=args.train_batch_size, tokenizer=tokenizer, add_class_images_to_dataset=args.add_class_images_to_dataset, @@ -637,6 +639,7 @@ def main(): center_crop=args.center_crop, num_class_images=args.num_class_images, use_image_names_as_captions=args.use_image_names_as_captions, + shuffle_captions=args.shuffle_captions, repeats=args.dataset_repeats, use_text_files_as_captions=args.use_text_files_as_captions, seed = args.seed,