diff --git a/scripts/trainer.py b/scripts/trainer.py index 92f1b91..ba79724 100644 --- a/scripts/trainer.py +++ b/scripts/trainer.py @@ -34,6 +34,7 @@ from accelerate.utils import set_seed from diffusers import AutoencoderKL, DDIMScheduler, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel,DiffusionPipeline, DPMSolverMultistepScheduler,EulerDiscreteScheduler from diffusers.optimization import get_scheduler +from diffusers.training_utils import EMAModel from torchvision.transforms import functional from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -1077,6 +1078,7 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): scheduler = DPMSolverMultistepScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") unwrapped_unet = accelerator.unwrap_model(unet,True) if args.use_ema: + ema_unet.store(unwrapped_unet.parameters()) ema_unet.copy_to(unwrapped_unet.parameters()) pipeline = DiffusionPipeline.from_pretrained( @@ -1229,6 +1231,9 @@ def save_and_sample_weights(step,context='checkpoint',save_model=True): elif save_model == False and len(imgs) > 0: del imgs print(f"{bcolors.OKGREEN}Samples saved to {sample_dir}{bcolors.ENDC}") + if args.use_ema: + ema_unet.restore(unwrapped_unet.parameters()) + except Exception as e: print(e) print(f"{bcolors.FAIL} Error occured during sampling, skipping.{bcolors.ENDC}") diff --git a/scripts/trainer_util.py b/scripts/trainer_util.py index 7463065..3b356fd 100644 --- a/scripts/trainer_util.py +++ b/scripts/trainer_util.py @@ -425,61 +425,3 @@ def get_depth_image_path(self,image_path): image_path = Path(image_path) return image_path.parent / f"{image_path.stem}-depth.png" -# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 and taken from harubaru's implementation https://github.com/harubaru/waifu-diffusion -class EMAModel: - """ - Exponential Moving Average of models weights - """ - def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999): - parameters = list(parameters) - self.shadow_params = [p.clone().detach() for p in parameters] - - self.decay = decay - self.optimization_step = 0 - - def get_decay(self, optimization_step): - """ - Compute the decay factor for the exponential moving average. - """ - value = (1 + optimization_step) / (10 + optimization_step) - return 1 - min(self.decay, value) - - @torch.no_grad() - def step(self, parameters): - parameters = list(parameters) - - self.optimization_step += 1 - self.decay = self.get_decay(self.optimization_step) - - for s_param, param in zip(self.shadow_params, parameters): - if param.requires_grad: - tmp = self.decay * (s_param - param) - s_param.sub_(tmp) - else: - s_param.copy_(param) - - torch.cuda.empty_cache() - - def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None: - """ - Copy current averaged parameters into given collection of parameters. - Args: - parameters: Iterable of `torch.nn.Parameter`; the parameters to be - updated with the stored moving averages. If `None`, the - parameters with which this `ExponentialMovingAverage` was - initialized will be used. - """ - parameters = list(parameters) - for s_param, param in zip(self.shadow_params, parameters): - param.data.copy_(s_param.data) - - def to(self, device=None, dtype=None) -> None: - r"""Move internal buffers of the ExponentialMovingAverage to `device`. - Args: - device: like `device` argument to `torch.Tensor.to` - """ - # .to() on the tensors handles None correctly - self.shadow_params = [ - p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device) - for p in self.shadow_params - ] \ No newline at end of file diff --git a/scripts/windows_install.py b/scripts/windows_install.py index dd3c7e9..da7d25a 100644 --- a/scripts/windows_install.py +++ b/scripts/windows_install.py @@ -198,7 +198,7 @@ def check_versions(): status = shutil.copy2(src_file, cudnn_dest) if status: print("Copied CUDNN 8.6 files to destination") - d_commit = '8178c84' + d_commit = 'f727863' diffusers_cmd = f"git+https://github.com/huggingface/diffusers.git@{d_commit}#egg=diffusers --force-reinstall" run(f'"{python}" -m pip install {diffusers_cmd}', f"Installing Diffusers {d_commit} commit", "Couldn't install diffusers") #install requirements file