diff --git a/.gitignore b/.gitignore index d062eb8..309c70d 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ - +out/ +training/ cudnn_windows/cudnn64_8.dll cudnn_windows/cudnn_adv_infer64_8.dll cudnn_windows/cudnn_adv_train64_8.dll diff --git a/scripts/configuration_gui.py b/scripts/configuration_gui.py index 94c5150..3e32fe4 100644 --- a/scripts/configuration_gui.py +++ b/scripts/configuration_gui.py @@ -75,8 +75,10 @@ def __init__(self, parent, concept=None,width=150,height=150, *args, **kwargs): self.concept_do_not_balance = False self.process_sub_dirs = False self.image_preview = self.default_image_preview + self.repeat_concept = '1' + self.separate_bucket = False #create concept - self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None) + self.concept = Concept(self.concept_name, self.concept_data_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, None, self.repeat_concept, self.separate_bucket ) else: self.concept = concept self.concept.image_preview = self.make_image_preview() @@ -203,7 +205,7 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): self.parent = parent self.conceptWidget = conceptWidget self.concept = concept - self.geometry("576x297") + self.geometry("576x377") self.resizable(False, False) #self.protocol("WM_DELETE_WINDOW", self.on_close) self.wait_visibility() @@ -213,9 +215,9 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): #self.default_image_preview = ImageTk.PhotoImage(self.default_image_preview) #make a frame for the concept window - self.concept_frame = ctk.CTkFrame(self, width=600, height=300) + self.concept_frame = ctk.CTkFrame(self, width=600, height=380) self.concept_frame.grid(row=0, column=0, sticky="nsew",padx=10,pady=10) - self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=300) + self.concept_frame_subframe=ctk.CTkFrame(self.concept_frame, width=600, height=380) #4 column grid #self.concept_frame.grid_columnconfigure(0, weight=1) #self.concept_frame.grid_columnconfigure(1, weight=5) @@ -274,27 +276,45 @@ def __init__(self, parent,conceptWidget,concept,*args, **kwargs): if self.concept.flip_p != '': self.flip_probability_entry.insert(0, self.concept.flip_p) #self.flip_probability_entry.bind("", self.create_right_click_menu) + + #entry and label for concept repeats + self.repeat_concept_label = ctk.CTkLabel(self.concept_frame_subframe, text="Repeat Concept:") + self.repeat_concept_label.grid(row=5, column=0, sticky="nsew",padx=5,pady=5) + self.repeat_concept_entry = ctk.CTkEntry(self.concept_frame_subframe,width=200,placeholder_text="1") + self.repeat_concept_entry.grid(row=5, column=1, sticky="e",padx=5,pady=5) + if self.concept.repeat_concept != '': + self.repeat_concept_entry.insert(1, self.concept.repeat_concept) #make a label for dataset balancingprocess_sub_dirs self.balance_dataset_label = ctk.CTkLabel(self.concept_frame_subframe, text="Don't Balance Dataset") - self.balance_dataset_label.grid(row=5, column=0, sticky="nsew",padx=5,pady=5) + self.balance_dataset_label.grid(row=6, column=0, sticky="nsew",padx=5,pady=5) #make a switch to enable or disable dataset balancing self.balance_dataset_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) - self.balance_dataset_switch.grid(row=5, column=1, sticky="e",padx=5,pady=5) + self.balance_dataset_switch.grid(row=6, column=1, sticky="e",padx=5,pady=5) if self.concept.concept_do_not_balance == True: self.balance_dataset_switch.toggle() self.process_sub_dirs = ctk.CTkLabel(self.concept_frame_subframe, text="Search Sub-Directories") - self.process_sub_dirs.grid(row=6, column=0, sticky="nsew",padx=5,pady=5) + self.process_sub_dirs.grid(row=7, column=0, sticky="nsew",padx=5,pady=5) #make a switch to enable or disable dataset balancing self.process_sub_dirs_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) - self.process_sub_dirs_switch.grid(row=6, column=1, sticky="e",padx=5,pady=5) + self.process_sub_dirs_switch.grid(row=7, column=1, sticky="e",padx=5,pady=5) if self.concept.process_sub_dirs == True: self.process_sub_dirs_switch.toggle() #self.balance_dataset_switch.set(self.concept.concept_do_not_balance) + + #make a label for separate concept buckets + self.separate_bucket_label = ctk.CTkLabel(self.concept_frame_subframe, text="Separate Buckets") + self.separate_bucket_label.grid(row=8, column=0, sticky="nsew",padx=5,pady=5) + #make a switch to enable or disable creation of separate buckets for each concept + self.separate_bucket_switch = ctk.CTkSwitch(self.concept_frame_subframe, text="", variable=tk.BooleanVar()) + self.separate_bucket_switch.grid(row=8, column=1, sticky="e",padx=5,pady=5) + if self.concept.separate_bucket == True: + self.separate_bucket_switch.toggle() + #add image preview self.image_preview_label = ctk.CTkLabel(self.concept_frame_subframe,text='', width=150, height=150,image=ctk.CTkImage(self.default_image_preview,size=(150,150))) - self.image_preview_label.grid(row=0, column=4,rowspan=5, sticky="nsew",padx=5,pady=5) + self.image_preview_label.grid(row=0, column=4,rowspan=7, sticky="nsew",padx=5,pady=5) if self.concept.image_preview != None or self.concept.image_preview != "": #print(self.concept.image_preview) self.update_preview_image(entry=None,path=None,pil_image=self.concept.image_preview) @@ -452,6 +472,10 @@ def save(self): flip_p = self.flip_probability_entry.get() #get the dataset balancing balance_dataset = self.balance_dataset_switch.get() + #get concept repeats + repeat_concept = self.repeat_concept_entry.get() + #get the separate bucket switch + separate_bucket = self.separate_bucket_switch.get() #create the concept process_sub_dirs = self.process_sub_dirs_switch.get() #image preview @@ -459,14 +483,14 @@ def save(self): #get the main window image_preview_label = self.image_preview_label #update the concept - self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label) + self.concept.update(concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs,image_preview,image_preview_label,repeat_concept,separate_bucket) self.conceptWidget.update_button() #close the window self.destroy() #class of the concept class Concept: - def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None): + def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, balance_dataset=None,process_sub_dirs=None,image_preview=None, image_container=None,repeat_concept=1, separate_bucket=None): if concept_name == None: concept_name = "" if concept_path == None: @@ -477,10 +501,14 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba class_path = "" if flip_p == None: flip_p = "" + if repeat_concept == None: + repeat_concept = "1" if balance_dataset == None: balance_dataset = False if process_sub_dirs == None: process_sub_dirs = False + if separate_bucket == None: + separate_bucket = False if image_preview == None: image_preview = "" if image_container == None: @@ -496,8 +524,10 @@ def __init__(self, concept_name, concept_path, class_name, class_path,flip_p, ba self.image_preview = image_preview self.image_container = image_container self.process_sub_dirs = process_sub_dirs + self.repeat_concept = repeat_concept + self.separate_bucket = separate_bucket #update the concept - def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container): + def update(self, concept_name, concept_path, class_name, class_path,flip_p,balance_dataset,process_sub_dirs, image_preview, image_container,repeat_concept,separate_bucket): self.concept_name = concept_name self.concept_path = concept_path self.concept_class_name = class_name @@ -509,9 +539,11 @@ def update(self, concept_name, concept_path, class_name, class_path,flip_p,balan self.image_preview = image_preview self.image_container = image_container self.process_sub_dirs = process_sub_dirs + self.repeat_concept = repeat_concept + self.separate_bucket = separate_bucket #get the cocept details def get_details(self): - return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container + return self.concept_name, self.concept_path, self.concept_class_name, self.concept_class_path,self.flip_p, self.concept_do_not_balance,self.process_sub_dirs, self.image_preview, self.image_container, self.repeat_concept, self.separate_bucket #class to make popup right click menu with select all, copy, paste, cut, and delete when right clicked on an entry box class DynamicGrid(ctk.CTkFrame): def __init__(self, parent, *args, **kwargs): @@ -2199,6 +2231,8 @@ def packageForCloud(self): new_concept['class_data_dir'] = 'datasets' + '/' + concept_class_name if concept_class_name != '' else '' new_concept['do_not_balance'] = concept['do_not_balance'] new_concept['use_sub_dirs'] = concept['use_sub_dirs'] + new_concept['repeat_concept'] = concept['repeat_concept'] + new_concept['separate_bucket'] = concept['separate_bucket'] new_concepts.append(new_concept) #make scripts folder self.save_concept_to_json(filename=self.full_export_path + os.sep + 'stabletune_concept_list.json', preMadeConcepts=new_concepts) @@ -2854,7 +2888,7 @@ def save_concept_to_json(self,filename=None,preMadeConcepts=None): concepts = [] for widget in self.concept_widgets: concept = widget.concept - concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs} + concept_dict = {'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept, 'separate_bucket' : concept.separate_bucket} concepts.append(concept_dict) if file != None: #write the json to the file @@ -2880,7 +2914,9 @@ def load_concept_from_json(self): #print(concept) if 'flip_p' not in concept: concept['flip_p'] = '' - concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"]) + if 'repeat_concept' not in concept: + concept['repeat_concept'] = '1' + concept = Concept(concept_name=concept["instance_prompt"], class_name=concept["class_prompt"], concept_path=concept["instance_data_dir"], class_path=concept["class_data_dir"],flip_p=concept['flip_p'],balance_dataset=concept["do_not_balance"], process_sub_dirs=concept["use_sub_dirs"], repeat_concept=concept["repeat_concept"], separate_bucket=concept["separate_bucket"]) self.add_new_concept(concept) #self.canvas.configure(scrollregion=self.canvas.bbox("all")) self.update() return concept_json @@ -3023,7 +3059,7 @@ def update_concepts(self): self.concepts = [] for i in range(len(self.concept_widgets)): concept = self.concept_widgets[i].concept - self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs}) + self.concepts.append({'instance_prompt' : concept.concept_name, 'class_prompt' : concept.concept_class_name, 'instance_data_dir' : concept.concept_path, 'class_data_dir' : concept.concept_class_path,'flip_p' : concept.flip_p, 'do_not_balance' : concept.concept_do_not_balance, 'use_sub_dirs' : concept.process_sub_dirs, 'repeat_concept' : concept.repeat_concept, 'separate_bucket' : concept.separate_bucket}) def save_config(self, config_file=None): #save the configure file import json @@ -3140,7 +3176,16 @@ def load_config(self,file_name=None): flip_p = configure["concepts"][i]["flip_p"] balance_dataset = configure["concepts"][i]["do_not_balance"] process_sub_dirs = configure["concepts"][i]["use_sub_dirs"] - concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs) + if 'repeat_concept' not in configure["concepts"][i]: + print(configure["concepts"][i].keys()) + configure["concepts"][i]['repeat_concept'] = '1' + repeat_concept = configure["concepts"][i]["repeat_concept"] + + if 'separate_bucket' not in configure["concepts"][i]: + print(configure["concepts"][i].keys()) + configure["concepts"][i]['separate_bucket'] = False + separate_bucket = configure["concepts"][i]["separate_bucket"] + concept = Concept(concept_name=inst_prompt, class_name=class_prompt, concept_path=inst_data_dir, class_path=class_data_dir,flip_p=flip_p,balance_dataset=balance_dataset,process_sub_dirs=process_sub_dirs,repeat_concept=repeat_concept,separate_bucket=separate_bucket) self.add_new_concept(concept) except Exception as e: print(e) diff --git a/scripts/dataloaders_util.py b/scripts/dataloaders_util.py index 0b64367..55603ee 100644 --- a/scripts/dataloaders_util.py +++ b/scripts/dataloaders_util.py @@ -333,6 +333,8 @@ def __init__(self, extra_module=None, mask_prompts=None, load_mask=False, + repeat_concept=1, + separate_bucket=False, ): self.debug_level = debug_level @@ -392,6 +394,8 @@ def __init__(self, extra_module=self.extra_module, mask_prompts=mask_prompts, load_mask=load_mask, + repeat_concept=repeat_concept, + separate_bucket=separate_bucket, ) #print(self.image_train_items) @@ -491,7 +495,7 @@ class ImageTrainItem(): pathname: path to image file flip_p: probability of flipping image (0.0 to 1.0) """ - def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False): + def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target_wh: list, pathname: str, flip_p=0.0, model_variant='base', load_mask=False, separate_bucket_count=0): self.caption = caption self.target_wh = target_wh self.pathname = pathname @@ -504,7 +508,8 @@ def __init__(self, image: Image, mask: Image, extra: Image, caption: str, target self.load_mask=load_mask self.is_dupe = [] self.variant_warning = False - + self.separate_bucket_count = separate_bucket_count + self.image = image self.mask = mask self.extra = extra @@ -736,6 +741,7 @@ class DataLoaderMultiAspect(): data_root: root folder of training data batch_size: number of images per batch flip_p: probability of flipping image horizontally (i.e. 0-0.5) + epeat_concept: How many times to repeat each concept in the dataset """ def __init__( self, @@ -756,6 +762,8 @@ def __init__( extra_module=None, mask_prompts=None, load_mask=False, + repeat_concept=1, + separate_bucket=False, ): self.resolution = resolution self.debug_level = debug_level @@ -771,29 +779,52 @@ def __init__( self.model_variant = model_variant self.extra_module = extra_module self.load_mask = load_mask + self.repeat_concept = repeat_concept + self.separate_bucket=separate_bucket, + separate_bucket_count = 0 prepared_train_data = [] + self.aspects = get_aspect_buckets(resolution) #print(f"* DLMA resolution {resolution}, buckets: {self.aspects}") #process sub directories flag print(f" {bcolors.WARNING} Preloading images...{bcolors.ENDC}") + + #Get concept repeat count + for concept in concept_list: + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if repeat_concept > 1: + print(f" {bcolors.WARNING} Repeating concept {concept['instance_data_dir']} {repeat_concept} times...{bcolors.ENDC}") + if balance_datasets: print(f" {bcolors.WARNING} Balancing datasets...{bcolors.ENDC}") #get the concept with the least number of images in instance_data_dir for concept in concept_list: count = 0 + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == 1: tot = 0 for root, dirs, files in os.walk(concept['instance_data_dir']): tot += len(files) - count = tot + count = tot*repeat_concept else: - count = len(os.listdir(concept['instance_data_dir'])) + count = len(os.listdir(concept['instance_data_dir']))*repeat_concept else: - count = len(os.listdir(concept['instance_data_dir'])) + count = len(os.listdir(concept['instance_data_dir']))*repeat_concept print(f"{concept['instance_data_dir']} has count of {count}") concept['count'] = count @@ -802,16 +833,19 @@ def __init__( min_concept_num_images = min_concept['count'] print(" Min concept: ",min_concept['instance_data_dir']," with ",min_concept_num_images," images") - balance_cocnept_list = [] + balance_concept_list = [] for concept in concept_list: #if concept has a key do not balance it if 'do_not_balance' in concept: if concept['do_not_balance'] == True: - balance_cocnept_list.append(-1) + balance_concept_list.append(-1) else: - balance_cocnept_list.append(min_concept_num_images) + balance_concept_list.append(min_concept_num_images) else: - balance_cocnept_list.append(min_concept_num_images) + balance_concept_list.append(min_concept_num_images) + + total_separate_bucket_count = 0 + for concept in concept_list: if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: @@ -824,7 +858,7 @@ def __init__( #self.class_image_paths = [] min_concept_num_images = None if balance_datasets: - min_concept_num_images = balance_cocnept_list[concept_list.index(concept)] + min_concept_num_images = balance_concept_list[concept_list.index(concept)] data_root = concept['instance_data_dir'] data_root_class = concept['class_data_dir'] concept_prompt = concept['instance_prompt'] @@ -835,19 +869,40 @@ def __init__( flip_p = 0.0 else: flip_p = float(flip_p) + + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + + + + if concept['separate_bucket'] == True: + total_separate_bucket_count += 1 + separate_bucket_count = total_separate_bucket_count + else: + separate_bucket_count = 0 + + self.__recurse_data_root(self=self, recurse_root=data_root,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) if self.model_variant == 'depth2img': print(f" {bcolors.WARNING} ** Loading Depth2Img Pipeline To Process Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)[0:min_concept_num_images]) # ImageTrainItem[] if add_class_images_to_dataset: self.image_paths = [] self.__recurse_data_root(self=self, recurse_root=data_root_class,use_sub_dirs=use_sub_dirs) random.Random(self.seed).shuffle(self.image_paths) use_image_names_as_captions = False - prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] - + prepared_train_data.extend(self.__prescan_images(debug_level, self.image_paths, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)) # ImageTrainItem[] + + + if total_separate_bucket_count > 0: + print(f" {bcolors.WARNING} There are {total_separate_bucket_count} concepts using separate buckets...{bcolors.ENDC}") + self.image_caption_pairs = self.__bucketize_images(prepared_train_data, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if self.with_prior_loss and add_class_images_to_dataset == False: self.class_image_caption_pairs = [] @@ -861,7 +916,7 @@ def __init__( print(f" {bcolors.WARNING} ** Depth2Img To Process Class Dataset{bcolors.ENDC}") self.vae_scale_factor = self.extra_module.depth_images(self.image_paths) use_image_names_as_captions = False - self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,use_text_files_as_captions=self.use_text_files_as_captions)) + self.class_image_caption_pairs.extend(self.__prescan_images(debug_level, self.class_images_path, flip_p,use_image_names_as_captions,concept_class_prompt,repeat_concept,separate_bucket_count,use_text_files_as_captions=self.use_text_files_as_captions)) self.class_image_caption_pairs = self.__bucketize_images(self.class_image_caption_pairs, batch_size=batch_size, debug_level=debug_level,aspect_mode=self.aspect_mode,action_preference=self.action_preference) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks...{bcolors.ENDC}") @@ -878,43 +933,43 @@ def get_all_images(self): return self.image_caption_pairs else: return self.image_caption_pairs, self.class_image_caption_pairs - def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,use_text_files_as_captions=False): + def __prescan_images(self,debug_level: int, image_paths: list, flip_p=0.0,use_image_names_as_captions=True,concept=None,repeat_concept=1,separate_bucket_count=0,use_text_files_as_captions=False): """ Create ImageTrainItem objects with metadata for hydration later """ decorated_image_train_items = [] - - for pathname in image_paths: - identifier = concept - if use_image_names_as_captions: - caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] - identifier = caption_from_filename - if use_text_files_as_captions: - txt_file_path = os.path.splitext(pathname)[0] + ".txt" - - if os.path.exists(txt_file_path): - try: - with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: - identifier = f.readline().rstrip() - f.close() - if len(identifier) < 1: - raise ValueError(f" *** Could not find valid text in: {txt_file_path}") - - except Exception as e: - print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") - print(e) - identifier = caption_from_filename - pass - #print("identifier: ",identifier) - image = Image.open(pathname) - width, height = image.size - image_aspect = width / height + for i in range(repeat_concept): + for pathname in image_paths: + identifier = concept + if use_image_names_as_captions: + caption_from_filename = os.path.splitext(os.path.basename(pathname))[0].split("_")[0] + identifier = caption_from_filename + if use_text_files_as_captions: + txt_file_path = os.path.splitext(pathname)[0] + ".txt" + + if os.path.exists(txt_file_path): + try: + with open(txt_file_path, 'r',encoding='utf-8',errors='ignore') as f: + identifier = f.readline().rstrip() + f.close() + if len(identifier) < 1: + raise ValueError(f" *** Could not find valid text in: {txt_file_path}") + + except Exception as e: + print(f" {bcolors.FAIL} *** Error reading {txt_file_path} to get caption, falling back to filename{bcolors.ENDC}") + print(e) + identifier = caption_from_filename + pass + #print("identifier: ",identifier) + image = Image.open(pathname) + width, height = image.size + image_aspect = width / height - target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) + target_wh = min(self.aspects, key=lambda aspects:abs(aspects[0]/aspects[1] - image_aspect)) - image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask) + image_train_item = ImageTrainItem(image=None, mask=None, extra=None, caption=identifier, target_wh=target_wh, pathname=pathname, flip_p=flip_p,model_variant=self.model_variant, load_mask=self.load_mask,separate_bucket_count=separate_bucket_count) - decorated_image_train_items.append(image_train_item) + decorated_image_train_items.append(image_train_item) return decorated_image_train_items @staticmethod @@ -927,10 +982,12 @@ def __bucketize_images(prepared_train_data: list, batch_size=1, debug_level=0,as buckets = {} for image_caption_pair in prepared_train_data: target_wh = image_caption_pair.target_wh + separate_bucket_count = image_caption_pair.separate_bucket_count - if (target_wh[0],target_wh[1]) not in buckets: - buckets[(target_wh[0],target_wh[1])] = [] - buckets[(target_wh[0],target_wh[1])].append(image_caption_pair) + #concept_bucket = image_caption_pair.concept_bucket + if (target_wh[0],target_wh[1],separate_bucket_count) not in buckets: + buckets[(target_wh[0],target_wh[1],separate_bucket_count)] = [] + buckets[(target_wh[0],target_wh[1],separate_bucket_count)].append(image_caption_pair) print(f" ** Number of buckets: {len(buckets)}") for bucket in buckets: bucket_len = len(buckets[bucket]) @@ -1058,6 +1115,7 @@ def __init__( extra_module=None, mask_prompts=None, load_mask=None, + repeat_concept=1, ): self.use_image_names_as_captions = use_image_names_as_captions self.size = size @@ -1072,7 +1130,19 @@ def __init__( self.variant_warning = False self.vae_scale_factor = None self.load_mask = load_mask + self.repeat_concept = repeat_concept for concept in concepts_list: + + #Get concept repeat count + if 'repeat_concept' in concept.keys(): + repeat_concept = concept['repeat_concept'] + if repeat_concept == '': + repeat_concept = 1 + else: + repeat_concept = int(repeat_concept) + if repeat_concept > 1: + print(f" {bcolors.WARNING} Repeating concept {concept['instance_data_dir']} {repeat_concept} times...{bcolors.ENDC}") + if 'use_sub_dirs' in concept: if concept['use_sub_dirs'] == True: use_sub_dirs = True @@ -1081,12 +1151,14 @@ def __init__( else: use_sub_dirs = False - for i in range(repeats): - self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) + for i in range(repeat_concept): + for i in range(repeats): + self.__recurse_data_root(self, concept,use_sub_dirs=use_sub_dirs) if with_prior_preservation: - for i in range(repeats): - self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) + for i in range(repeat_concept): + for i in range(repeats): + self.__recurse_data_root(self, concept,use_sub_dirs=False,class_images=True) if mask_prompts is not None: print(f" {bcolors.WARNING} Checking and generating missing masks{bcolors.ENDC}") clip_seg = ClipSeg()