diff --git a/.gitignore b/.gitignore index a58a0c5..aa3cebb 100644 --- a/.gitignore +++ b/.gitignore @@ -320,5 +320,6 @@ pip-selfcheck.json ### Allor ### config.json resources/timestamp.json +resources/logs/* # End of https://www.toptal.com/developers/gitignore/api/python,venv,visualstudiocode,pycharm diff --git a/Loader.py b/Loader.py deleted file mode 100644 index d2737bd..0000000 --- a/Loader.py +++ /dev/null @@ -1,371 +0,0 @@ -import json -import os -import platform -import time - -from pathlib import Path - -import folder_paths -import nodes - - -class Loader: - def __init__(self): - pass - - __ROOT_PATH = os.path.dirname(os.path.abspath(__file__)) - __TEMPLATE_PATH = os.path.join(__ROOT_PATH, "resources/template.json") - __TIMESTAMP_PATH = os.path.join(__ROOT_PATH, "resources/timestamp.json") - __CONFIG_PATH = os.path.join(__ROOT_PATH, "config.json") - __GIT_PATH = Path(os.path.join(__ROOT_PATH, ".git")) - - __DAY_SECONDS = 24 * 60 * 60 - __WEEK_SECONDS = 7 * __DAY_SECONDS - __MONTH_SECONDS = 30 * __DAY_SECONDS - - def __log(self, text): - print("\033[92m[Allor]\033[0m: " + text) - - def __error(self, text): - print("\033[91m[Allor]\033[0m: " + text) - - def __notification(self, text): - print("\033[94m[Allor]\033[0m: " + text) - - def __new_line(self): - print() - - def __warning_unstable_branch(self): - self.__new_line() - self.__error("Attention! You are currently using an unstable \"main\" update branch intended for the development of Allor 2.") - self.__error("Please be aware that changes made in Allor 2 may disrupt your current workflow.") - self.__error("Nodes may be renamed, parameters within them may be altered or even removed.") - self.__new_line() - self.__error("If backward compatibility of your workflow is important to you, " - "you can change the \"branch_name\" parameter to \"allor-1\" in your config.json.") - self.__error("Switch the \"confirm_unstable_agreement\" parameter in your config.json to \"true\", " - "if you are prepared for potential changes and are willing to modify your current workflow from time to time.") - self.__error("This will result in this warning no longer appearing.") - self.__new_line() - self.__notification("We appreciate your support and understanding during this transition period.") - self.__notification("Thank you for using Allor 2.\n") - - def __create_config(self): - with open(self.__CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(self.__template(), f, ensure_ascii=False, indent=4) - - def __create_timestamp(self): - with open(self.__TIMESTAMP_PATH, "w", encoding="utf-8") as f: - json.dump({"timestamp": 0}, f, ensure_ascii=False, indent=4) - - def __get_template(self): - with open(self.__TEMPLATE_PATH, "r") as f: - template = json.load(f) - - if "__comment" in template: - del template["__comment"] - - return template - - def __get_config(self): - with open(self.__CONFIG_PATH, "r") as f: - return json.load(f) - - def __get_timestamp(self): - with open(self.__TIMESTAMP_PATH, "r") as f: - return json.load(f) - - def __update_config(self, template, source): - def update_source(__template, __source): - for k, v in __template.items(): - if k not in __source: - if isinstance(v, dict): - __source[k] = {} - else: - __source[k] = v - - if isinstance(v, dict): - __source[k] = update_source(v, __source[k]) - - return __source - - def delete_keys(__template, __source): - keys_to_delete = [k for k in __source if k not in __template] - - for k in keys_to_delete: - del __source[k] - - return __source - - def sync_order(__template, __source): - new_source = {} - - for key in __template: - if key in __source: - if isinstance(__template[key], dict): - new_source[key] = sync_order(__template[key], __source[key]) - else: - new_source[key] = __source[key] - - return new_source - - source = update_source(template, source) - source = delete_keys(template, source) - source = sync_order(template, source) - - with open(self.__CONFIG_PATH, "w", encoding="utf-8") as f: - json.dump(source, f, ensure_ascii=False, indent=4) - - def __update_timestamp(self): - with open(self.__TIMESTAMP_PATH, "w", encoding="utf-8") as f: - json.dump({"timestamp": time.time()}, f, ensure_ascii=False, indent=4) - - __template = __get_template - __config = __get_config - __timestamp = __get_timestamp - - def __get_fonts_folder_path(self): - system = platform.system() - user_home = os.path.expanduser('~') - - config_font_path = os.path.join(folder_paths.base_path, *self.__config()["fonts"]["folder_path"].replace("\\", "/").split("/")) - - if not os.path.exists(config_font_path): - os.makedirs(config_font_path, exist_ok=True) - - paths = [config_font_path] - - if self.__config()["fonts"]["system_fonts"]: - if system == "Windows": - paths.append(os.path.join(os.environ["WINDIR"], "Fonts")) - elif system == "Darwin": - paths.append(os.path.join("/Library", "Fonts")) - elif system == "Linux": - paths.append(os.path.join("/usr", "share", "fonts")) - paths.append(os.path.join("/usr", "local", "share", "fonts")) - - if self.__config()["fonts"]["user_fonts"]: - if system == "Darwin": - paths.append(os.path.join(user_home, "Library", "Fonts")) - elif system == "Linux": - paths.append(os.path.join(user_home, ".fonts")) - - return [path for path in paths if os.path.exists(path)] - - def __get_keys(self, json_obj, prefix=''): - keys = [] - - for k, v in json_obj.items(): - if isinstance(v, dict): - keys.extend(self.__get_keys(v, prefix + k + '.')) - else: - keys.append(prefix + k) - - return set(keys) - - def __check_json_keys(self, json1, json2): - keys1 = self.__get_keys(json1) - keys2 = self.__get_keys(json2) - - return keys1 == keys2 - - def setup_config(self): - if not os.path.exists(self.__CONFIG_PATH): - self.__log("Creating config.json") - self.__create_config() - else: - if not self.__check_json_keys(self.__template(), self.__config()): - self.__log("Updating config.json") - self.__update_config(self.__template(), self.__config()) - - def setup_timestamp(self): - if not os.path.exists(self.__TIMESTAMP_PATH): - self.__log("Creating timestamp.json") - self.__create_timestamp() - - def check_updates(self): - # confirm_unstable_agreement = self.__config()["updates"]["confirm_unstable_agreement"] - confirm_unstable_agreement = True - branch_name = self.__config()["updates"]["branch_name"] - update_frequency = self.__config()["updates"]["update_frequency"].lower() - valid_frequencies = ["always", "day", "week", "month", "never"] - time_difference = time.time() - self.__timestamp()["timestamp"] - - if update_frequency == valid_frequencies[0]: - it_is_time_for_update = True - elif update_frequency == valid_frequencies[1]: - it_is_time_for_update = time_difference >= self.__DAY_SECONDS - elif update_frequency == valid_frequencies[2]: - it_is_time_for_update = time_difference >= self.__WEEK_SECONDS - elif update_frequency == valid_frequencies[3]: - it_is_time_for_update = time_difference >= self.__MONTH_SECONDS - elif update_frequency == valid_frequencies[4]: - it_is_time_for_update = False - else: - self.__error(f"Unknown update frequency - {update_frequency}, available: {valid_frequencies}") - - return - - if not confirm_unstable_agreement and branch_name == "main" and update_frequency != "never": - self.__warning_unstable_branch() - - if it_is_time_for_update: - if not (self.__GIT_PATH.exists() or self.__GIT_PATH.is_dir()): - self.__error("Root directory of Allor is not a git repository. Update canceled.") - - return - - try: - import git - - from git import Repo - from git import GitCommandError - - # noinspection PyTypeChecker, PyUnboundLocalVariable - repo = Repo(self.__ROOT_PATH, odbt=git.db.GitDB) - current_commit = repo.head.commit.hexsha - - repo.remotes.origin.fetch() - - latest_commit = getattr(repo.remotes.origin.refs, branch_name).commit.hexsha - - if current_commit == latest_commit: - if self.__config()["updates"]["notify_if_no_new_updates"]: - self.__notification("No new updates.") - else: - if self.__config()["updates"]["notify_if_has_new_updates"]: - self.__notification("New updates are available.") - - if self.__config()["updates"]["auto_update"]: - update_mode = self.__config()["updates"]["update_mode"].lower() - valid_modes = ["soft", "hard"] - - if repo.active_branch.name != branch_name: - try: - repo.git.checkout(branch_name) - except GitCommandError: - self.__error(f"An error occurred while switching to the branch {branch_name}.") - - return - - if update_mode == "soft": - try: - repo.git.pull() - except GitCommandError: - self.__error("An error occurred during the update. " - "It is recommended to use \"hard\" update mode. " - "But be careful, it erases all personal changes from Allor repository.") - - elif update_mode == "hard": - repo.git.reset('--hard', 'origin/' + branch_name) - else: - self.__error(f"Unknown update mode - {update_mode}, available: {valid_modes}") - - return - - self.__notification("Update complete.") - - self.__update_timestamp() - - except ImportError: - self.__error("GitPython is not installed.") - - def setup_rembg(self): - os.environ["U2NET_HOME"] = folder_paths.models_dir + "/onnx" - - def setup_paths(self): - fonts_folder_path = self.__get_fonts_folder_path() - - folder_paths.folder_names_and_paths["onnx"] = ([os.path.join(folder_paths.models_dir, "onnx")], {".onnx"}) - folder_paths.folder_names_and_paths["fonts"] = (fonts_folder_path, {".otf", ".ttf"}) - - def setup_override(self): - override_nodes_len = 0 - - def override(function): - start_len = nodes.NODE_CLASS_MAPPINGS.__len__() - - nodes.NODE_CLASS_MAPPINGS = dict( - filter(function, nodes.NODE_CLASS_MAPPINGS.items()) - ) - - return start_len - nodes.NODE_CLASS_MAPPINGS.__len__() - - if self.__config()["override"]["postprocessing"]: - override_nodes_len += override(lambda item: not item[1].CATEGORY.startswith("image/postprocessing")) - - if self.__config()["override"]["transform"]: - override_nodes_len += override(lambda item: not item[0] == "ImageScale" and not item[0] == "ImageScaleBy" and not item[0] == "ImageInvert") - - if self.__config()["override"]["debug"]: - nodes.VAEDecodeTiled.CATEGORY = "latent" - nodes.VAEEncodeTiled.CATEGORY = "latent" - - override_nodes_len += override(lambda item: not item[1].CATEGORY.startswith("_for_testing")) - - self.__log(str(override_nodes_len) + " standard nodes was overridden.") - - def get_modules(self): - modules = dict() - - if self.__config()["modules"]["AlphaChanel"]: - from .modules import AlphaChanel - modules.update(AlphaChanel.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["Clamp"]: - from .modules import Clamp - modules.update(Clamp.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageBatch"]: - from .modules import ImageBatch - modules.update(ImageBatch.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageComposite"]: - from .modules import ImageComposite - modules.update(ImageComposite.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageContainer"]: - from .modules import ImageContainer - modules.update(ImageContainer.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageDraw"]: - from .modules import ImageDraw - modules.update(ImageDraw.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageEffects"]: - from .modules import ImageEffects - modules.update(ImageEffects.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageFilter"]: - from .modules import ImageFilter - modules.update(ImageFilter.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageNoise"]: - from .modules import ImageNoise - modules.update(ImageNoise.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageSegmentation"]: - from .modules import ImageSegmentation - modules.update(ImageSegmentation.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageText"]: - from .modules import ImageText - modules.update(ImageText.NODE_CLASS_MAPPINGS) - - if self.__config()["modules"]["ImageTransform"]: - from .modules import ImageTransform - modules.update(ImageTransform.NODE_CLASS_MAPPINGS) - - modules_len = dict( - filter( - lambda item: item[1], - self.__config()["modules"].items() - ) - ).__len__() - - nodes_len = modules.__len__() - - self.__log(str(modules_len) + " modules enabled.") - self.__log(str(nodes_len) + " nodes was loaded.") - - return modules diff --git a/__init__.py b/__init__.py index 647f394..cf5d012 100644 --- a/__init__.py +++ b/__init__.py @@ -1,12 +1,17 @@ -from .Loader import Loader +from .boot.Logger import Logger +from .boot.Backends import Backends +from .boot.Config import Config +from .boot.Update import Update +from .boot.Paths import Paths +from .boot.Override import Override +from .boot.Modules import Modules -loader = Loader() +logger = Logger() +config = Config(logger).initiate() +backends = Backends(logger).initiate() -loader.setup_config() -loader.setup_timestamp() -loader.check_updates() -loader.setup_rembg() -loader.setup_paths() -loader.setup_override() +Update(logger, config, backends).initiate() +Paths(logger, config, backends).initiate() +Override(logger, config, backends).initiate() -NODE_CLASS_MAPPINGS = loader.get_modules() +NODE_CLASS_MAPPINGS = Modules(logger, config, backends).initiate() diff --git a/boot/Backends.py b/boot/Backends.py new file mode 100644 index 0000000..0a9480b --- /dev/null +++ b/boot/Backends.py @@ -0,0 +1,28 @@ +from importlib import import_module + + +class Backends: + COMFY_UI = "main" + TORCH = "torch" + NUMPY = "numpy" + CV2 = "cv2" + PIL = "PIL" + REMBG = "rembg" + GIT = "git" + + def __init__(self, logger): + self.__logger = logger + self.__backends = [Backends.COMFY_UI, Backends.TORCH, Backends.NUMPY, Backends.CV2, Backends.PIL, Backends.REMBG, Backends.GIT] + + def initiate(self): + dependencies = {} + + for backend in self.__backends: + try: + import_module(backend) + dependencies[backend] = True + except ImportError: + self.__logger.error(f"Loading {backend} library ended with an error.") + dependencies[backend] = False + + return dependencies diff --git a/boot/Config.py b/boot/Config.py new file mode 100644 index 0000000..f9f9ce7 --- /dev/null +++ b/boot/Config.py @@ -0,0 +1,97 @@ +import json + +from .Paths import Paths + + +class Config: + def __init__(self, logger): + self.__logger = logger + self.__template = self.__get_template() + + if not Paths.CONFIG_PATH.exists(): + self.__logger.info("Creating configuration file.") + self.__create_config() + + self.__config = self.__get_config() + + def initiate(self): + if not self.__verify_keys(self.__template, self.__config): + self.__logger.info("Updating configuration file.") + self.__update_config(self.__template, self.__config) + + return self.__get_config() + + def __create_config(self): + with open(Paths.CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(self.__template, f, ensure_ascii=False, indent=4) + + def __get_template(self): + with open(Paths.TEMPLATE_PATH, "r") as f: + template = json.load(f) + + if "__comment" in template: + del template["__comment"] + + return template + + def __get_config(self): + with open(Paths.CONFIG_PATH, "r") as f: + return json.load(f) + + def __verify_keys(self, json1, json2): + def get_keys(json_obj, prefix=''): + keys = [] + + for k, v in json_obj.items(): + if isinstance(v, dict): + keys.extend(get_keys(v, prefix + k + '.')) + else: + keys.append(prefix + k) + + return set(keys) + + keys1 = get_keys(json1) + keys2 = get_keys(json2) + + return keys1 == keys2 + + def __update_config(self, template, source): + def update_source(__template, __source): + for k, v in __template.items(): + if k not in __source: + if isinstance(v, dict): + __source[k] = {} + else: + __source[k] = v + + if isinstance(v, dict): + __source[k] = update_source(v, __source[k]) + + return __source + + def delete_keys(__template, __source): + keys_to_delete = [k for k in __source if k not in __template] + + for k in keys_to_delete: + del __source[k] + + return __source + + def sync_order(__template, __source): + new_source = {} + + for key in __template: + if key in __source: + if isinstance(__template[key], dict): + new_source[key] = sync_order(__template[key], __source[key]) + else: + new_source[key] = __source[key] + + return new_source + + source = update_source(template, source) + source = delete_keys(template, source) + source = sync_order(template, source) + + with open(Paths.CONFIG_PATH, "w", encoding="utf-8") as f: + json.dump(source, f, ensure_ascii=False, indent=4) diff --git a/boot/Logger.py b/boot/Logger.py new file mode 100644 index 0000000..2b91a2d --- /dev/null +++ b/boot/Logger.py @@ -0,0 +1,249 @@ +import glob +import json +import os +import re +import sys +from datetime import datetime + +from .Paths import Paths + + +class SingletonLogger(type): + _instances = {} + + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(SingletonLogger, cls).__call__(*args, **kwargs) + return cls._instances[cls] + + +class Logger(metaclass=SingletonLogger): + file = Paths.LOG_PATH + + def __init__(self): + self.formatting = _Formatting() + self.levels = { + "FATAL": "r", + "ERROR": "r", + "WARN": "y", + "INFO": "b", + "EMIT": "g", + "DEBUG": "m", + "TRACE": "c", + "LINE": "n" + } + + self.__length = max(len(key) for key in self.levels) + + info = self.__get_info() + + if info["branch"] == "main": + info["branch"] = "v.2" + + self.debug(f"{info['branch']} : {info['hex']}") + + def __log(self, text, level, display): + text = self.formatting.format(text) + + if display: + if level == "LINE": + print() + else: + prefix = self.formatting.format(f"!![f{self.levels.get(level)}b;" + re.escape("[Allor]") + ":" + "] ") + postfix = self.formatting.format(text, 9) + + print(f"{prefix}{postfix}") + + if Logger.file: + directory = Logger.file.parent + + if not directory.exists(): + directory.mkdir(parents=True) + + if not Logger.file.exists(): + log_files = list(glob.glob(str(directory / "log_*.log"))) + + if len(log_files) > 5: + log_files.sort(key=os.path.getmtime) + + os.remove(log_files[0]) + + Logger.file.touch() + + with open(Logger.file, "a") as f: + if level == "LINE": + f.write("\n") + else: + prefix = f"[{datetime.utcnow().strftime('%Y-%m-%d %H:%M:%S')}][{level:<{self.__length}}]: " + postfix = self.formatting.format(text, len(prefix)) + postfix = self.formatting.reset(postfix) + + f.write(f"{prefix}{postfix}\n") + + def fatal(self, text, display=True): + self.__log(text, "FATAL", display) + + def error(self, text, display=True): + self.__log(text, "ERROR", display) + + def warn(self, text, display=True): + self.__log(text, "WARN", display) + + def info(self, text, display=True): + self.__log(text, "INFO", display) + + def emit(self, text, display=True): + self.__log(text, "EMIT", display) + + def debug(self, text, display=True): + self.__log(text, "DEBUG", display) + + def trace(self, text, display=True): + self.__log(text, "TRACE", display) + + def line(self, count=1, display=True): + self.__log("\n" * count, "LINE", display) + + def warning_unstable_branch(self, branch_name="main"): + warn_messages = ( + f"Attention! You are currently using an unstable !![fyb;{branch_name}] update branch. \n" + f"This branch is intended for the development of !![fmb;Allor v.2]. \n\n" + f"Please be aware that changes made in !![fmb;Allor v.2] may disrupt your current workflow. \n" + "Nodes may be renamed, parameters within them may be altered or even removed. \n\n" + "If backward compatibility of your workflow is important to you, consider this. \n" + f"You can change the !![fcb;branch_name_param] parameter to !![fbb;v1] in your !![fgb;config_json]. \n\n" + "If you are prepared for potential changes, you can modify your configuration file. \n" + f"To accept, switch the !![fcb;confirm_unstable_param] parameter in your !![fgb;config_json] to !![fbb;true]. \n" + "This will result in this warning no longer appearing." + ) + + emit_messages = ( + "We appreciate your support and understanding during this transition period. \n" + f"Thank you and welcome to !![fmb;Allor v.2]." + ) + + self.warn(warn_messages) + self.line() + self.emit(emit_messages) + self.line() + + def __get_info(self): + with open(Paths.INFO_PATH, "r") as f: + return json.load(f) + + +class _Formatting: + COLOR = { + "d": 0, # DARK + "r": 1, # RED + "g": 2, # GREEN + "y": 3, # YELLOW + "b": 4, # BLUE + "m": 5, # MAGENTA + "c": 6, # CYAN + "w": 7, # WHITE + "n": 8 # NORMAL + } + + INTENSITY = { + "n": 30, # NORMAL + "b": 90 # BRIGHT + } + + ATTRIBUTE = { + "n": 0, # NORMAL + "b": 1, # BOLD + "f": 2, # FAINT + "i": 3, # ITALIC + "u": 4, # UNDERLINE + "l": 5, # BLINKING + "a": 6, # FAST_BLINKING + "r": 7, # REVERSE + "h": 8, # HIDE + "s": 9 # STRIKETHROUGH + } + + def __init__(self): + self.__ansi = self.__formatting_support() + + # noinspection PyPep8Naming + def __formatting_support(self): + platform = sys.platform + + if platform == "win32": + import ctypes + + STD_OUTPUT_HANDLE = -11 + ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004 + + kernel32 = ctypes.WinDLL("kernel32") + hStdOut = kernel32.GetStdHandle(STD_OUTPUT_HANDLE) + mode = ctypes.c_ulong() + + if not kernel32.GetConsoleMode(hStdOut, ctypes.byref(mode)): + return False + + if not mode.value & ENABLE_VIRTUAL_TERMINAL_PROCESSING: + return False + + if platform in ("linux", "darwin"): + term = os.getenv("TERM") + + if term not in ("xterm", "xterm-256color", "vt100", "ansi", "linux"): + return False + + return True + + def format(self, text, intend=0): + if self.__ansi: + text = self.__interpolation(text) + else: + text = self.__extraction(text) + + if intend > 0: + text = "\n".join(line if i == 0 or not line else " " * intend + line for i, line in enumerate(text.split("\n"))) + + return text + + def reset(self, text): + return re.compile(r"\033\[[0-?]*[ -/]*[@-~]").sub("", text) + + def __interpolation(self, input_string): + pattern = re.compile(r"!!\[(.*?)(?= self.DAY_SECONDS, + "week": time_difference >= self.WEEK_SECONDS, + "month": time_difference >= self.MONTH_SECONDS, + "never": False + } + + unstable_branches = ["main"] + + if branch_name in unstable_branches and search_frequency != "never": + if not confirm_unstable: + self.__logger.warning_unstable_branch(branch_name) + else: + self.__logger.debug("Unstable branch agreement confirmed.\nPlease be careful when using this version.", False) + + if self.__backends[Backends.GIT]: + if Paths.GIT_PATH.exists() and Paths.GIT_PATH.is_dir(): + repo = self.__repo() + + self.__checkout(repo, branch_name) + + try: + update_scheduled = valid_frequencies[search_frequency] + + if update_scheduled: + self.__pull(repo, branch_name) + except KeyError: + self.__logger.error(f"Unknown update frequency - {search_frequency}, available: {list(valid_frequencies.keys())}") + else: + self.__logger.error("Update canceled because Allor is not a git repository.") + else: + self.__logger.error("Update canceled because GitPython is not installed.") + + def __get_timestamp(self): + with open(Paths.TIMESTAMP_PATH, "r") as f: + return json.load(f) + + def __create_timestamp(self): + with open(Paths.TIMESTAMP_PATH, "w", encoding="utf-8") as f: + json.dump({"timestamp": 0}, f, ensure_ascii=False, indent=4) + + def __update_timestamp(self): + with open(Paths.TIMESTAMP_PATH, "w", encoding="utf-8") as f: + json.dump({"timestamp": time.time()}, f, ensure_ascii=False, indent=4) + + def __repo(self): + git = import_module("git") + repo = git.Repo + + return repo(Paths.ROOT_PATH, odbt=git.db.GitDB) + + def __checkout(self, repo, branch_name): + from git import GitCommandError + + if repo.active_branch.name != branch_name: + if any([branch.name == branch_name for branch in repo.branches]): + try: + update_mode = self.__config["updates"]["update_mode"].lower() + valid_modes = ["soft", "hard"] + + if update_mode == "soft": + repo.git.checkout(branch_name) + elif update_mode == "hard": + repo.git.checkout(branch_name, force=True) + else: + self.__logger.error(f"Unknown update mode - {update_mode}, available: {valid_modes}") + except GitCommandError: + self.__logger.error(f"An error occurred while switching to the branch {branch_name}.") + else: + self.__logger.error(f"Branch with name {branch_name} not exist.") + + def __pull(self, repo, branch_name): + from git import GitCommandError + + try: + repo.remotes.origin.fetch() + + local_commits = list(repo.iter_commits(branch_name))[::-1] + remote_commits = list(repo.iter_commits(f"origin/{branch_name}"))[::-1] + + incorrect_hex = any(lc != rc for lc, rc in zip(local_commits, remote_commits)) + + if incorrect_hex: + new_remote_commits = [] + else: + last_hex = next((i for i, (lc, rc) in enumerate(zip(local_commits, remote_commits)) if lc != rc), len(local_commits)) + new_remote_commits = remote_commits[last_hex:] if len(remote_commits) > len(local_commits) else [] + + if incorrect_hex or new_remote_commits: + if self.__config["updates"]["install_update"]: + update_mode = self.__config["updates"]["update_mode"].lower() + valid_modes = ["soft", "hard"] + + if update_mode == "soft": + if incorrect_hex: + self.__logger.warn("Incorrect hex of commits found in commits repository history.\n" + "Updating using \"soft\" update mode is unlikely to complete successfully.") + + repo.git.pull() + elif update_mode == "hard": + repo.git.reset('--hard', 'origin/' + branch_name) + else: + self.__logger.error(f"Unknown update mode - {update_mode}, available: {valid_modes}") + + return + + self.__update_timestamp() + self.__logger.info("Updates installed successfully.", self.__config["logger"]["install_complete"]) + except GitCommandError: + self.__logger.error("An error occurred during the updating.\n" + "It is recommended to use \"hard\" update mode.\n" + "But be careful, it erases all personal changes from Allor repository.") diff --git a/boot/__init__.py b/boot/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/resources/info.bat b/resources/info.bat new file mode 100644 index 0000000..db3c183 --- /dev/null +++ b/resources/info.bat @@ -0,0 +1,7 @@ +@echo off + +for /f %%i in ('git rev-parse --abbrev-ref HEAD') do set branch=%%i +for /f "delims=" %%i in ('git log -1 --pretty=%%B') do set commit=%%i +for /f %%i in ('git rev-parse --short HEAD') do set hex=%%i + +echo {\"branch\": \"%branch%\", \"commit\": \"%commit%\", \"hex\": \"%hex%\"} > info.json diff --git a/resources/info.json b/resources/info.json new file mode 100644 index 0000000..f89bb5b --- /dev/null +++ b/resources/info.json @@ -0,0 +1 @@ +{"branch": "main", "commit": "Split Loader.py", "hex": "14bff6a"} diff --git a/resources/info.sh b/resources/info.sh new file mode 100644 index 0000000..894ce40 --- /dev/null +++ b/resources/info.sh @@ -0,0 +1,7 @@ +#!/bin/bash + +branch=$(git rev-parse --abbrev-ref HEAD) +commit=$(git log -1 --pretty=%B) +hex=$(git rev-parse --short HEAD) + +echo "{\"branch\": \"$branch\", \"commit\": \"$commit\", \"hex\": \"$hex\"}" > info.json diff --git a/resources/template.json b/resources/template.json index 8c7a1d5..36ee74f 100644 --- a/resources/template.json +++ b/resources/template.json @@ -20,17 +20,22 @@ "debug": false }, "updates": { - "update_frequency": "day", - "notify_if_has_new_updates": true, - "notify_if_no_new_updates": true, - "auto_update": true, + "search_frequency": "day", + "install_update": true, "branch_name": "main", - "update_mode": "soft", - "confirm_unstable_agreement": false + "update_mode": "soft" }, "fonts": { "folder_path": "comfy_extras/fonts", "system_fonts": false, "user_fonts": false + }, + "logger": { + "confirm_unstable": false, + "updates_search": true, + "install_complete": true, + "modules_enabled": true, + "nodes_loaded": true, + "nodes_overridden": true } }