From cfde0a3768bf274c5be9eb068adfae1670395ec5 Mon Sep 17 00:00:00 2001 From: raghupas Date: Mon, 24 Aug 2020 11:11:14 -0400 Subject: [PATCH 1/6] Naive approach to multi-level masking --- infer.py | 57 ++++++++++---- omnigan/data.py | 2 +- omnigan/generator.py | 20 ++++- omnigan/losses.py | 4 +- omnigan/trainer.py | 175 +++++++++++++++++++++++++++++++++--------- omnigan/transforms.py | 3 +- tests/test_gen.py | 3 +- 7 files changed, 207 insertions(+), 57 deletions(-) diff --git a/infer.py b/infer.py index 449d67f0..36a48d0c 100644 --- a/infer.py +++ b/infer.py @@ -52,14 +52,31 @@ def parsed_args(): type=str, help="Directory to write images to", ) + parser.add_argument( + "--path_to_masks", + type=str, + help="Path of masks to be used for painting", + required=False, + ) return parser.parse_args() -def eval_folder(path_to_images, output_dir, paint=False): - images = [path_to_images / Path(i) for i in os.listdir(path_to_images)] - for img_path in images: +def eval_folder( + output_dir, path_to_images, path_to_masks=None, paint=False, masker=False +): + + image_list = os.listdir(path_to_images) + image_list.sort() + images = [path_to_images / Path(i) for i in image_list] + if not masker: + mask_list = os.listdir(path_to_masks) + mask_list.sort() + masks = [path_to_masks / Path(i) for i in mask_list] + + for i, img_path in enumerate(images): img = tensor_loader(img_path, task="x", domain="val") + # Resize img: img = F.interpolate(img, (new_size, new_size), mode="nearest") img = img.squeeze(0) @@ -67,10 +84,19 @@ def eval_folder(path_to_images, output_dir, paint=False): img = tf(img) img = img.unsqueeze(0).to(device) - z = model.encode(img) - mask = model.decoders["m"](z) - vutils.save_image(mask, output_dir / ("mask_" + img_path.name), normalize=True) + if not masker: + mask = tensor_loader(masks[i], task="m", domain="val") + mask = F.interpolate(mask, (new_size, new_size), mode="nearest") + mask = mask.squeeze() + mask = mask.unsqueeze(0).to(device) + + if masker: + z = model.encode(img) + mask = model.decoders["m"](z) + vutils.save_image( + mask, output_dir / ("mask_" + img_path.name), normalize=True + ) if paint: z_painter = trainer.sample_z(1) @@ -113,10 +139,12 @@ def isimg(path_file): else: new_size = args.new_size - if "m" in opts.tasks and "p" in opts.tasks: + paint = False + masker = False + if "p" in opts.tasks: paint = True - else: - paint = False + if "m" in opts.tasks: + masker = True # ------------------------ # ----- Define model ----- # ------------------------ @@ -142,6 +170,7 @@ def isimg(path_file): # eval_folder(args.path_to_images, output_dir) rootdir = args.path_to_images + maskdir = args.path_to_masks writedir = args.output_dir for root, subdirs, files in tqdm(os.walk(rootdir)): @@ -151,16 +180,16 @@ def isimg(path_file): has_imgs = False for f in files: if isimg(f): - # read_path = root / f - # rel_path = read_path.relative_to(rootdir) - # write_path = writedir / rel_path - # write_path.mkdir(parents=True, exist_ok=True) has_imgs = True break if has_imgs: print(f"Eval on {root}") rel_path = root.relative_to(rootdir) + mask_path = maskdir / rel_path write_path = writedir / rel_path write_path.mkdir(parents=True, exist_ok=True) - eval_folder(root, write_path, paint) + print("root: ", root) + eval_folder( + write_path, root, path_to_masks=maskdir, paint=paint, masker=masker + ) diff --git a/omnigan/data.py b/omnigan/data.py index f3439f9e..d62edbc3 100644 --- a/omnigan/data.py +++ b/omnigan/data.py @@ -263,7 +263,7 @@ def tensor_loader(path, task, domain): arr = np.moveaxis(arr, 2, 0) elif task == "s": arr = np.moveaxis(arr, 2, 0) - elif task == "m": + elif task == "m" or task == "m2": arr[arr != 0] = 1 # Make sure mask is single-channel if len(arr.shape) >= 3: diff --git a/omnigan/generator.py b/omnigan/generator.py index 39de3c79..fa62aef3 100644 --- a/omnigan/generator.py +++ b/omnigan/generator.py @@ -83,7 +83,10 @@ def __init__(self, opts, latent_shape=None, verbose=None): self.decoders["s"] = SegmentationDecoder(opts) if "m" in opts.tasks and not opts.gen.m.ignore: - self.decoders["m"] = MaskDecoder(opts) + if "m2" in opts.tasks: + self.decoders["m"] = ConditionalMasker(opts) + else: + self.decoders["m"] = MaskDecoder(opts) self.decoders = nn.ModuleDict(self.decoders) @@ -115,6 +118,21 @@ def __init__(self, opts): ) +class ConditionalMasker(BaseDecoder): + def __init__(self, opts): + super().__init__( + n_upsample=opts.gen.m.n_upsample, + n_res=opts.gen.m.n_res, + input_dim=opts.gen.encoder.res_dim + 1, + proj_dim=opts.gen.m.proj_dim, + output_dim=opts.gen.m.output_dim, + res_norm=opts.gen.m.res_norm, + activ=opts.gen.m.activ, + pad_type=opts.gen.m.pad_type, + output_activ="sigmoid", + ) + + class DepthDecoder(BaseDecoder): def __init__(self, opts): super().__init__( diff --git a/omnigan/losses.py b/omnigan/losses.py index 03e7cfd3..14025ed7 100644 --- a/omnigan/losses.py +++ b/omnigan/losses.py @@ -430,9 +430,7 @@ def __init__(self): def __call__(self, prediction, target): return self.loss( prediction, - torch.FloatTensor(prediction.size()) - .fill_(target) - .to(prediction.get_device()), + torch.FloatTensor(prediction.size()).fill_(target).to(prediction.device), ) diff --git a/omnigan/trainer.py b/omnigan/trainer.py index ca013f7b..1ad7a858 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -199,6 +199,26 @@ def setup(self): self.input_shape = self.compute_input_shape() self.painter_z_h = self.input_shape[-2] // (2 ** self.opts.gen.p.spade_n_up) self.painter_z_w = self.input_shape[-1] // (2 ** self.opts.gen.p.spade_n_up) + if "m2" in self.opts.tasks: + self.label_1 = torch.zeros( + ( + self.opts.data.loaders.batch_size, + 1, + self.latent_shape[1], + self.latent_shape[2], + ) + ).to(self.device) + self.label_2 = torch.ones( + ( + self.opts.data.loaders.batch_size, + 1, + self.latent_shape[1], + self.latent_shape[2], + ) + ).to(self.device) + print("label sizes: ", self.label_1.shape) + print(self.label_2.shape) + self.D: OmniDiscriminator = get_dis(self.opts, verbose=self.verbose).to( self.device ) @@ -411,7 +431,25 @@ def log_comet_images(self, mode, domain): if update_task != "x": if update_task not in save_images: save_images[update_task] = [] - prediction = self.G.decoders[update_task](self.z) + if update_task == "m2": + prediction = self.G.decoders["m"]( + torch.cat( + (self.z, self.label_2[0, :, :, :].unsqueeze(0)), + dim=1, + ) + ) + + elif update_task == "m": + if "m2" in self.opts.tasks: + prediction = self.G.decoders[update_task]( + torch.cat( + (self.z, self.label_1[0, :, :, :].unsqueeze(0)), + dim=1, + ) + ) + else: + prediction = self.G.decoders[update_task](self.z) + if update_task == "s": if domain == "s": target = ( @@ -425,7 +463,7 @@ def log_comet_images(self, mode, domain): .to(self.device) ) task_saves.append(target) - if update_task == "m": + if update_task == "m" or update_task == "m2": prediction = prediction.repeat(1, 3, 1, 1) task_saves.append(x * (1.0 - prediction)) task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) @@ -538,7 +576,6 @@ def train(self): for self.logger.epoch in range( self.logger.epoch, self.logger.epoch + self.opts.train.epochs ): - # self.infer() self.run_epoch() self.infer(verbose=1) if ( @@ -623,6 +660,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings x = batch["data"]["x"] self.z = self.G.encode(x) + # --------------------------------- # ----- classifier loss (1) ----- # --------------------------------- @@ -644,7 +682,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings # ----- task-specific regression losses (2) ----- # ------------------------------------------------- for update_task, update_target in batch["data"].items(): - if update_task not in {"m", "p", "x", "s"}: + if update_task not in {"m", "p", "x", "s", "m2"}: prediction = self.G.decoders[update_task](self.z) update_loss = self.losses["G"]["tasks"][update_task]( prediction, update_target @@ -699,31 +737,42 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings self.logger.losses.generator.task_loss[update_task][ "advent" ][batch_domain] = update_loss.item() - elif update_task == "m": + elif update_task in {"m", "m2"}: # ? output features classifier - prediction = self.G.decoders[update_task](self.z) - if batch_domain == "s": + if "m2" in self.opts.tasks: + # Get label + if update_task == "m": + prediction = self.G.decoders["m"]( + torch.cat((self.z, self.label_1), dim=1) + ) + elif update_task == "m2": + prediction = self.G.decoders["m"]( + torch.cat((self.z, self.label_2), dim=1) + ) + + else: + prediction = self.G.decoders["m"](self.z) + + if batch_domain == "s": # Main loss first: update_loss = ( - self.losses["G"]["tasks"][update_task]["main"]( + self.losses["G"]["tasks"]["m"]["main"]( prediction, update_target ) - * lambdas.G[update_task]["main"] + * lambdas.G["m"]["main"] ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task]["main"][ + self.logger.losses.generator.task_loss["m"]["main"][ batch_domain ] = update_loss.item() # Then TV loss - update_loss = self.losses["G"]["tasks"][update_task]["tv"]( - prediction - ) + update_loss = self.losses["G"]["tasks"]["m"]["tv"](prediction) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task]["tv"][ + self.logger.losses.generator.task_loss["m"]["tv"][ batch_domain ] = update_loss.item() @@ -733,20 +782,20 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings if self.opts.gen.m.use_minent: # Then Minent loss update_loss = ( - self.losses["G"]["tasks"][update_task]["minent"]( + self.losses["G"]["tasks"]["m"]["minent"]( prob.to(self.device) ) * self.opts.train.lambdas.advent.ent_main ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task][ - "minent" - ][batch_domain] = update_loss.item() + self.logger.losses.generator.task_loss["m"]["minent"][ + batch_domain + ] = update_loss.item() if self.opts.gen.m.use_advent: # Then Advent loss update_loss = ( - self.losses["G"]["tasks"][update_task]["advent"]( + self.losses["G"]["tasks"]["m"]["advent"]( prob.to(self.device), self.domain_labels["s"], self.D["m"]["Advent"], @@ -754,9 +803,9 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings * self.opts.train.lambdas.advent.adv_main ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task][ - "advent" - ][batch_domain] = update_loss.item() + self.logger.losses.generator.task_loss["m"]["advent"][ + batch_domain + ] = update_loss.item() return step_loss def sample_z(self, batch_size): @@ -1000,20 +1049,65 @@ def get_d_loss(self, multi_domain_batch, verbose=0): if self.opts.gen.m.use_advent: if verbose > 0: print("Now training the ADVENT discriminator!") - fake_mask = self.G.decoders["m"](z) - fake_complementary_mask = 1 - fake_mask - prob = torch.cat([fake_mask, fake_complementary_mask], dim=1) - prob = prob.detach() - loss_main = self.losses["D"]["advent"]( - prob.to(self.device), - self.domain_labels[batch_domain], - self.D["m"]["Advent"], - ) + if "m2" in self.opts.tasks: + # Compute loss for flood level 1 + fake_mask = self.G.decoders["m"]( + torch.cat((z, self.label_1), dim=1) + ) + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() - disc_loss["m"]["Advent"] += ( - self.opts.train.lambdas.advent.adv_main * loss_main - ) + loss_main = self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) + + # Compute again for flood level 2 + fake_mask = self.G.decoders["m"]( + torch.cat((z, self.label_2), dim=1) + ) + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() + + loss_main += self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) + else: + fake_mask = self.G.decoders["m"](z) + + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() + + loss_main = self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) if "s" in self.opts.tasks: if self.opts.gen.s.use_advent: preds = self.G.decoders["s"](z) @@ -1214,7 +1308,18 @@ def eval_images(self, mode, domain): x = im_set["data"]["x"].unsqueeze(0).to(self.device) m = im_set["data"]["m"].unsqueeze(0).detach().cpu().numpy() z = self.G.encode(x) - pred_mask = self.G.decoders["m"](z).detach().cpu().numpy() + if "m2" in self.opts.tasks: + pred_mask = ( + self.G.decoders["m"]( + torch.cat((z, self.label_1[0, :, :, :].unsqueeze(0)), dim=1) + ) + .detach() + .cpu() + .numpy() + ) + + else: + pred_mask = self.G.decoders["m"](z).detach().cpu().numpy() # Binarize mask pred_mask[pred_mask > 0.5] = 1.0 diff --git a/omnigan/transforms.py b/omnigan/transforms.py index 31097827..7bedf2a1 100644 --- a/omnigan/transforms.py +++ b/omnigan/transforms.py @@ -78,7 +78,7 @@ def __call__(self, data): for task, im in data.items(): if task in {"x", "a"}: new_data[task] = self.ImagetoTensor(im) - elif task in {"m"}: + elif task in {"m", "m2"}: new_data[task] = self.MaptoTensor(im) elif task == "s": new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to( @@ -103,6 +103,7 @@ def __init__(self): "s": self.normSeg, "d": self.normDepth, "m": self.normMask, + "m2": self.normMask, } def __call__(self, data): diff --git a/tests/test_gen.py b/tests/test_gen.py index 5d0a94ea..17e7e09d 100644 --- a/tests/test_gen.py +++ b/tests/test_gen.py @@ -15,7 +15,7 @@ parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/maskgen_v0.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -61,7 +61,6 @@ print(sum(p.numel() for p in G.decoders.parameters())) G = get_gen(opts).to(device) - # ------------------------------- # ----- Test Architecture ----- # ------------------------------- From d37faddb6a1d22690110159774c7327c0d64a93c Mon Sep 17 00:00:00 2001 From: raghupas Date: Mon, 21 Sep 2020 22:04:33 -0400 Subject: [PATCH 2/6] Working multi-level masker --- infer.py | 66 ++++++++++++--- omnigan/discriminator.py | 125 ++++++++++++++++++++++++++++ omnigan/losses.py | 3 +- omnigan/trainer.py | 153 +++++++++++++++++++++++++++++++++++ shared/trainer/defaults.yaml | 20 +++-- 5 files changed, 350 insertions(+), 17 deletions(-) diff --git a/infer.py b/infer.py index 36a48d0c..bb7352b3 100644 --- a/infer.py +++ b/infer.py @@ -44,7 +44,9 @@ def parsed_args(): required=True, ) parser.add_argument( - "--new_size", type=int, help="Size of generated masks", + "--new_size", + type=int, + help="Size of generated masks", ) parser.add_argument( "--output_dir", @@ -58,12 +60,22 @@ def parsed_args(): help="Path of masks to be used for painting", required=False, ) + parser.add_argument( + "--apply_mask", + action="store_true", + help="Apply mask to image to save", + ) return parser.parse_args() def eval_folder( - output_dir, path_to_images, path_to_masks=None, paint=False, masker=False + output_dir, + path_to_images, + path_to_masks=None, + paint=False, + masker=False, + apply_mask=False, ): image_list = os.listdir(path_to_images) @@ -92,11 +104,43 @@ def eval_folder( mask = mask.unsqueeze(0).to(device) if masker: - z = model.encode(img) - mask = model.decoders["m"](z) - vutils.save_image( - mask, output_dir / ("mask_" + img_path.name), normalize=True - ) + if "m2" in opts.tasks: + z = model.encode(img) + z_aug_1 = torch.cat( + (z, trainer.label_1[0, :, :, :].unsqueeze(0)), + dim=1, + ) + z_aug_2 = torch.cat( + (z, trainer.label_2[0, :, :, :].unsqueeze(0)), + dim=1, + ) + mask_1 = model.decoders["m"](z_aug_1) + mask_2 = model.decoders["m"](z_aug_2) + vutils.save_image( + mask_1, output_dir / ("mask1_" + img_path.name), normalize=True + ) + vutils.save_image( + mask_2, output_dir / ("mask2_" + img_path.name), normalize=True + ) + + if apply_mask: + vutils.save_image( + img * (1.0 - mask_1) + mask_1, + output_dir / (img_path.stem + "img_masked_1" + ".jpg"), + normalize=True, + ) + vutils.save_image( + img * (1.0 - mask_2) + mask_2, + output_dir / (img_path.stem + "img_masked_2" + ".jpg"), + normalize=True, + ) + + else: + z = model.encode(img) + mask = model.decoders["m"](z) + vutils.save_image( + mask, output_dir / ("mask_" + img_path.name), normalize=True + ) if paint: z_painter = trainer.sample_z(1) @@ -186,10 +230,14 @@ def isimg(path_file): if has_imgs: print(f"Eval on {root}") rel_path = root.relative_to(rootdir) - mask_path = maskdir / rel_path write_path = writedir / rel_path write_path.mkdir(parents=True, exist_ok=True) print("root: ", root) eval_folder( - write_path, root, path_to_masks=maskdir, paint=paint, masker=masker + write_path, + root, + path_to_masks=maskdir, + paint=paint, + masker=masker, + apply_mask=args.apply_mask, ) diff --git a/omnigan/discriminator.py b/omnigan/discriminator.py index 753b2215..6529b850 100644 --- a/omnigan/discriminator.py +++ b/omnigan/discriminator.py @@ -233,6 +233,116 @@ def forward(self, input): return result +class AuxiliaryClassifier(nn.Module): + def __init__( + self, + input_size=640, + input_nc=1, + ndf=64, + n_layers=3, + norm_layer=nn.BatchNorm2d, + use_sigmoid=False, + ): + super(AuxiliaryClassifier, self).__init__() + self.input_nc = input_nc + self.ndf = ndf + self.n_layers = n_layers + self.norm_layer = norm_layer + self.use_sigmoid = use_sigmoid + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 3 + padw = 1 + sequence = [ + # Use spectral normalization + SpectralNorm( + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw) + ), + nn.LeakyReLU(0.2, True), + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + # Use spectral normalization + SpectralNorm( # TODO replace with Conv2dBlock + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ) + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + # Use spectral normalization + SpectralNorm( + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ) + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + self.shared_layers = nn.Sequential(*sequence) + + proj_dim = ndf + + self.projection = SpectralNorm( + nn.Conv2d( + ndf * nf_mult + 2048, proj_dim, kernel_size=1, stride=1, padding=0 + ) + ) + + latent_size = int(input_size / (2 ** n_layers)) + self.linear_size = int(proj_dim * latent_size * latent_size) + self.gan_layer = nn.Linear(self.linear_size, 1) + self.ac_layer = nn.Linear(self.linear_size, 2) + + def forward(self, mask, z): + x = self.shared_layers(mask) + x = torch.cat((x, z), dim=1) + x = self.projection(x) + x = x.view(-1, self.linear_size) + + return [self.gan_layer(x), self.ac_layer(x)] + + +def get_AC( + input_size, input_nc, ndf, n_layers=3, norm="batch", use_sigmoid=False, +): + norm_layer = get_norm_layer(norm_type=norm) + net = AuxiliaryClassifier( + input_size=input_size, + input_nc=input_nc, + ndf=ndf, + n_layers=n_layers, + norm_layer=norm_layer, + use_sigmoid=use_sigmoid, + ) + return net + + class OmniDiscriminator(nn.ModuleDict): def __init__(self, opts): super().__init__() @@ -281,6 +391,21 @@ def __init__(self, opts): ) else: raise Exception("This Discriminator is currently not supported!") + if "m2" in opts.tasks: + # Create a flood-level discriminator / classifier + self["m2"] = nn.ModuleDict( + { + "FloodLevel": get_AC( + input_size=640, + input_nc=1, + ndf=opts.dis.m2.ndf, + n_layers=opts.dis.m2.n_layers, + norm=opts.dis.m2.norm, + use_sigmoid=opts.dis.m2.use_sigmoid, + ) + } + ) + if "s" in opts.tasks: if opts.gen.s.use_advent: self["s"] = nn.ModuleDict( diff --git a/omnigan/losses.py b/omnigan/losses.py index 14025ed7..f52814ba 100644 --- a/omnigan/losses.py +++ b/omnigan/losses.py @@ -347,7 +347,7 @@ def get_losses(opts, verbose, device=None): losses = { "G": {"a": {}, "p": {}, "tasks": {}}, - "D": {"default": {}, "advent": {}}, + "D": {"default": {}, "advent": {}, "multilevel": {}}, "C": {}, } @@ -406,6 +406,7 @@ def get_losses(opts, verbose, device=None): soft_shift=opts.dis.soft_shift, flip_prob=opts.dis.flip_prob, verbose=verbose ) losses["D"]["advent"] = ADVENTAdversarialLoss(opts) + losses["D"]["multilevel"] = CrossEntropy() return losses diff --git a/omnigan/trainer.py b/omnigan/trainer.py index 1ad7a858..ad62d0e2 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -196,6 +196,7 @@ def setup(self): self.G: OmniGenerator = get_gen(self.opts, verbose=self.verbose).to(self.device) if self.G.encoder is not None: self.latent_shape = self.compute_latent_shape() + self.input_shape = self.compute_input_shape() self.painter_z_h = self.input_shape[-2] // (2 ** self.opts.gen.p.spade_n_up) self.painter_z_w = self.input_shape[-1] // (2 ** self.opts.gen.p.spade_n_up) @@ -740,17 +741,80 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings elif update_task in {"m", "m2"}: # ? output features classifier + # If multi-level flooding if "m2" in self.opts.tasks: # Get label if update_task == "m": prediction = self.G.decoders["m"]( torch.cat((self.z, self.label_1), dim=1) ) + + # AC-GAN loss + validity, auxiliary = self.D["m2"]["FloodLevel"]( + prediction, self.z + ) + + # GAN loss + update_loss = ( + self.losses["D"]["default"](validity, True) + * self.opts.train.lambdas.G.m2.gan + ) + self.logger.losses.generator.task_loss["m"]["gan"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + # Auxiliary loss + update_loss = ( + self.losses["D"]["multilevel"]( + auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + * self.opts.train.lambdas.G.m2.aux + ) + + self.logger.losses.generator.task_loss["m"]["aux"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + elif update_task == "m2": prediction = self.G.decoders["m"]( torch.cat((self.z, self.label_2), dim=1) ) + # AC GAN loss + validity, auxiliary = self.D["m2"]["FloodLevel"]( + prediction, self.z + ) + + # GAN loss + update_loss = ( + self.losses["D"]["default"](validity, True) + * self.opts.train.lambdas.G.m2.gan + ) + self.logger.losses.generator.task_loss["m2"]["gan"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + # Auxiliary loss + update_loss = ( + self.losses["D"]["multilevel"]( + auxiliary, self.label_2[:, 0, 0, 0].squeeze() + ) + * self.opts.train.lambdas.G.m2.aux + ) + + self.logger.losses.generator.task_loss["m2"]["aux"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + # No multi-level flooding else: prediction = self.G.decoders["m"](self.z) @@ -777,6 +841,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings ] = update_loss.item() if batch_domain == "r": + # -----------ADVENT losses-------------- pred_complementary = 1 - prediction prob = torch.cat([prediction, pred_complementary], dim=1) if self.opts.gen.m.use_minent: @@ -806,6 +871,8 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings self.logger.losses.generator.task_loss["m"]["advent"][ batch_domain ] = update_loss.item() + # ------------------------------------- + return step_loss def sample_z(self, batch_size): @@ -1010,11 +1077,14 @@ def get_d_loss(self, multi_domain_batch, verbose=0): "m": {"Advent": 0}, "s": {"Advent": 0}, "p": {"global": 0, "local": 0}, + "m2": {"gan": 0, "aux": 0}, } for batch_domain, batch in multi_domain_batch.items(): x = batch["data"]["x"] m = batch["data"]["m"] + if "m2" in self.opts.tasks: + m2 = batch["data"]["m2"] if batch_domain == "rf": # sample vector @@ -1051,6 +1121,7 @@ def get_d_loss(self, multi_domain_batch, verbose=0): print("Now training the ADVENT discriminator!") if "m2" in self.opts.tasks: + # --------ADVENT LOSS--------------- # Compute loss for flood level 1 fake_mask = self.G.decoders["m"]( torch.cat((z, self.label_1), dim=1) @@ -1071,6 +1142,28 @@ def get_d_loss(self, multi_domain_batch, verbose=0): self.opts.train.lambdas.advent.adv_main * loss_main ) + # ---------------------------------- + # ------------AC GAN LOSS------------- + fake_validity, fake_auxiliary = self.D["m2"]["FloodLevel"]( + fake_mask.detach(), self.z.detach() + ) + + gan_loss = self.losses["D"]["default"](fake_validity, False) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + fake_auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + # ---------------------------------- + # Compute again for flood level 2 fake_mask = self.G.decoders["m"]( torch.cat((z, self.label_2), dim=1) @@ -1090,6 +1183,66 @@ def get_d_loss(self, multi_domain_batch, verbose=0): disc_loss["m"]["Advent"] += ( self.opts.train.lambdas.advent.adv_main * loss_main ) + + # Fake AC GAN loss + fake_validity, fake_auxiliary = self.D["m2"]["FloodLevel"]( + fake_mask.detach(), self.z.detach() + ) + + gan_loss = self.losses["D"]["default"](fake_validity, False) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + fake_auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + + # AC-GAN loss for groundtruth (if in simulated domain) + if batch_domain == "s": + # For flood level 1: + validity, auxiliary = self.D["m2"]["FloodLevel"]( + m, self.z.detach() + ) + gan_loss = self.losses["D"]["default"](validity, True) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + + # For flood level 2 + validity, auxiliary = self.D["m2"]["FloodLevel"]( + m2, self.z.detach() + ) + gan_loss = self.losses["D"]["default"](validity, True) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + auxiliary, self.label_2[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) else: fake_mask = self.G.decoders["m"](z) diff --git a/shared/trainer/defaults.yaml b/shared/trainer/defaults.yaml index b507a490..9a70963f 100644 --- a/shared/trainer/defaults.yaml +++ b/shared/trainer/defaults.yaml @@ -32,14 +32,14 @@ data: p: 0.5 - name: resize ignore: false - new_size: 256 + new_size: 640 - name: crop - ignore: false - height: 224 - width: 224 + ignore: true + height: 560 + width: 560 - name: resize # ? this or change generator's output? Or resize larger then crop to 256? ignore: false - new_size: 256 + new_size: 640 # --------------------- # ----- Generator ----- @@ -130,12 +130,15 @@ dis: get_intermediate_features: false p: <<: *default-dis + num_D: 3 + get_intermediate_features: true m: <<: *default-dis multi_level: false architecture: base # can be [base | OmniDiscriminator] - num_D: 1 #Number of discriminators to use (>1 means multi-scale) - get_intermediate_features: True + m2: + <<: *default-dis + input_nc: 1 # ------------------------------- # ----- Domain Classifier ----- @@ -173,6 +176,9 @@ train: m: main: 1 # Main prediction loss, i.e. GAN or BCE tv: 1 # Total variational loss (for smoothing) + m2: + gan: 0.01 + aux: 0.01 p: gan: 1 # gan loss sm: 1 # semantic matching From 84203df512d7d38a87d2810048ac586edd408abc Mon Sep 17 00:00:00 2001 From: raghupas Date: Mon, 21 Sep 2020 22:07:45 -0400 Subject: [PATCH 3/6] Removing prints --- omnigan/trainer.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/omnigan/trainer.py b/omnigan/trainer.py index ad62d0e2..0264e591 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -36,8 +36,7 @@ class Trainer: - """Main trainer class - """ + """Main trainer class""" def __init__(self, opts, comet_exp=None, verbose=0): """Trainer class to gather various model training procedures @@ -183,9 +182,9 @@ def print_num_parameters(self): def setup(self): """Prepare the trainer before it can be used to train the models: - * initialize G and D - * compute latent space dims and create classifier accordingly - * creates 3 optimizers + * initialize G and D + * compute latent space dims and create classifier accordingly + * creates 3 optimizers """ self.logger.global_step = 0 start_time = time() @@ -217,8 +216,6 @@ def setup(self): self.latent_shape[2], ) ).to(self.device) - print("label sizes: ", self.label_1.shape) - print(self.label_2.shape) self.D: OmniDiscriminator = get_dis(self.opts, verbose=self.verbose).to( self.device From 332d155fc03b69a531fd58df5cafa6d470acc04c Mon Sep 17 00:00:00 2001 From: raghupas Date: Mon, 28 Sep 2020 12:00:15 -0400 Subject: [PATCH 4/6] Interpolate between flood levels --- infer.py | 33 ++++++++++++++++++++++++++++++--- 1 file changed, 30 insertions(+), 3 deletions(-) diff --git a/infer.py b/infer.py index bb7352b3..6b92e14f 100644 --- a/infer.py +++ b/infer.py @@ -1,5 +1,5 @@ import torch -import numpy +import numpy as np from omnigan.utils import load_opts from pathlib import Path from argparse import ArgumentParser @@ -98,14 +98,34 @@ def eval_folder( img = img.unsqueeze(0).to(device) if not masker: - mask = tensor_loader(masks[i], task="m", domain="val") - mask = F.interpolate(mask, (new_size, new_size), mode="nearest") + mask = tensor_loader(masks[i], task="m", domain="val", binarize=False) + # mask = F.interpolate(mask, (new_size, new_size), mode="nearest") mask = mask.squeeze() mask = mask.unsqueeze(0).to(device) if masker: if "m2" in opts.tasks: z = model.encode(img) + num_masks = 10 + label_vals = np.linspace(start=0, stop=1, num=num_masks) + for label_val in label_vals: + z_aug = torch.cat( + (z, label_val * trainer.label_2[0, :, :, :].unsqueeze(0)), + dim=1, + ) + mask = model.decoders["m"](z_aug) + + vutils.save_image( + mask, output_dir / (f"mask_{label_val}_" + img_path.name), normalize=True + ) + if apply_mask: + vutils.save_image( + img * (1.0 - mask) + mask, + output_dir / (img_path.stem + f"img_masked_{label_val}" + ".jpg"), + normalize=True, + ) + + """ z_aug_1 = torch.cat( (z, trainer.label_1[0, :, :, :].unsqueeze(0)), dim=1, @@ -134,6 +154,7 @@ def eval_folder( output_dir / (img_path.stem + "img_masked_2" + ".jpg"), normalize=True, ) + """ else: z = model.encode(img) @@ -146,6 +167,12 @@ def eval_folder( z_painter = trainer.sample_z(1) fake_flooded = model.painter(z_painter, img * (1.0 - mask)) vutils.save_image(fake_flooded, output_dir / img_path.name, normalize=True) + if apply_mask: + vutils.save_image( + img * (1.0 - mask) + mask, + output_dir / (img_path.stem + "_masked" + ".jpg"), + normalize=True, + ) def isimg(path_file): From 0fff5ea63f611d004acdda9723b9ef6853f77aa3 Mon Sep 17 00:00:00 2001 From: raghupas Date: Mon, 28 Sep 2020 12:03:04 -0400 Subject: [PATCH 5/6] Interpolating between flood levels --- infer.py | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/infer.py b/infer.py index 6b92e14f..5e85eca3 100644 --- a/infer.py +++ b/infer.py @@ -125,37 +125,6 @@ def eval_folder( normalize=True, ) - """ - z_aug_1 = torch.cat( - (z, trainer.label_1[0, :, :, :].unsqueeze(0)), - dim=1, - ) - z_aug_2 = torch.cat( - (z, trainer.label_2[0, :, :, :].unsqueeze(0)), - dim=1, - ) - mask_1 = model.decoders["m"](z_aug_1) - mask_2 = model.decoders["m"](z_aug_2) - vutils.save_image( - mask_1, output_dir / ("mask1_" + img_path.name), normalize=True - ) - vutils.save_image( - mask_2, output_dir / ("mask2_" + img_path.name), normalize=True - ) - - if apply_mask: - vutils.save_image( - img * (1.0 - mask_1) + mask_1, - output_dir / (img_path.stem + "img_masked_1" + ".jpg"), - normalize=True, - ) - vutils.save_image( - img * (1.0 - mask_2) + mask_2, - output_dir / (img_path.stem + "img_masked_2" + ".jpg"), - normalize=True, - ) - """ - else: z = model.encode(img) mask = model.decoders["m"](z) From b056cde5c171c559d3fd043192df17e004badd78 Mon Sep 17 00:00:00 2001 From: raghupas Date: Sat, 3 Oct 2020 18:33:17 -0400 Subject: [PATCH 6/6] Cleaning up multi-level masker calls --- omnigan/trainer.py | 226 +++++++++++++++++++++------------------------ 1 file changed, 103 insertions(+), 123 deletions(-) diff --git a/omnigan/trainer.py b/omnigan/trainer.py index c9b23a26..9b1ed36c 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -223,7 +223,6 @@ def setup(self): if self.G.encoder is not None: self.latent_shape = self.compute_latent_shape() - self.input_shape = self.compute_input_shape() if "s" in self.opts.tasks: self.G.decoders["s"].set_target_size(self.input_shape[-2:]) @@ -462,26 +461,32 @@ def log_comet_images(self, mode, domain): if update_task not in save_images: save_images[update_task] = [] - if update_task == "m2": - prediction = self.G.decoders["m"]( - torch.cat( - (self.z, self.label_2[0, :, :, :].unsqueeze(0)), - dim=1, - ) - ) - elif update_task == "m": - if "m2" in self.opts.tasks: - prediction = self.G.decoders[update_task]( + if update_task == "m" or update_task == "m2": + if update_task == "m": + if "m2" in self.opts.tasks: + prediction = self.G.decoders[update_task]( + torch.cat( + (self.z, self.label_1[0, :, :, :].unsqueeze(0)), + dim=1, + ) + ) + else: + prediction = self.G.decoders[update_task](self.z) + + if update_task == "m2": + prediction = self.G.decoders["m"]( torch.cat( - (self.z, self.label_1[0, :, :, :].unsqueeze(0)), + (self.z, self.label_2[0, :, :, :].unsqueeze(0)), dim=1, ) ) - else: - prediction = self.G.decoders[update_task](self.z) - if update_task == "s": + prediction = prediction.repeat(1, 3, 1, 1) + task_saves.append(x * (1.0 - prediction)) + task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) + + elif update_task == "s": if domain == "s": target = ( decode_segmap_unity_labels(target, domain, True) @@ -494,12 +499,8 @@ def log_comet_images(self, mode, domain): .to(self.device) ) task_saves.append(target) - if update_task == "m" or update_task == "m2": - prediction = prediction.repeat(1, 3, 1, 1) - task_saves.append(x * (1.0 - prediction)) - task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) - if update_task == "d": + elif update_task == "d": # prediction is a log depth tensor target = (norm_tensor(target)) * 255 prediction = (norm_tensor(prediction)) * 255 @@ -619,7 +620,7 @@ def write_images( image_grid = vutils.make_grid( ims, nrow=im_per_row, normalize=True, scale_each=True ) - image_grid = image_grid.permute(1, 2, 0).numpy() + image_grid = image_grid.permute(1, 2, 0).cpu().numpy() if comet_exp is not None: comet_exp.log_image( @@ -809,85 +810,26 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings "advent" ][batch_domain] = update_loss.item() elif update_task in {"m", "m2"}: - # ? output features classifier - # If multi-level flooding if "m2" in self.opts.tasks: # Get label if update_task == "m": - prediction = self.G.decoders["m"]( - torch.cat((self.z, self.label_1), dim=1) - ) - - # AC-GAN loss - validity, auxiliary = self.D["m2"]["FloodLevel"]( - prediction, self.z - ) - - # GAN loss - update_loss = ( - self.losses["D"]["default"](validity, True) - * self.opts.train.lambdas.G.m2.gan + prediction, m_step_loss = self.multi_level_loss( + label_idx=1, batch_domain=batch_domain ) - self.logger.losses.generator.task_loss["m"]["gan"][ - batch_domain - ] = update_loss.item() - - step_loss += update_loss - - # Auxiliary loss - update_loss = ( - self.losses["D"]["multilevel"]( - auxiliary, self.label_1[:, 0, 0, 0].squeeze() - ) - * self.opts.train.lambdas.G.m2.aux - ) - - self.logger.losses.generator.task_loss["m"]["aux"][ - batch_domain - ] = update_loss.item() - - step_loss += update_loss elif update_task == "m2": - prediction = self.G.decoders["m"]( - torch.cat((self.z, self.label_2), dim=1) - ) - - # AC GAN loss - validity, auxiliary = self.D["m2"]["FloodLevel"]( - prediction, self.z - ) - - # GAN loss - update_loss = ( - self.losses["D"]["default"](validity, True) - * self.opts.train.lambdas.G.m2.gan + prediction, m_step_loss = self.multi_level_loss( + label_idx=2, batch_domain=batch_domain ) - self.logger.losses.generator.task_loss["m2"]["gan"][ - batch_domain - ] = update_loss.item() - - step_loss += update_loss - - # Auxiliary loss - update_loss = ( - self.losses["D"]["multilevel"]( - auxiliary, self.label_2[:, 0, 0, 0].squeeze() - ) - * self.opts.train.lambdas.G.m2.aux - ) - - self.logger.losses.generator.task_loss["m2"]["aux"][ - batch_domain - ] = update_loss.item() - - step_loss += update_loss + + step_loss += m_step_loss # No multi-level flooding else: prediction = self.G.decoders["m"](self.z) + # Simulated domain --> compare w/ ground truth masks if batch_domain == "s": # Main loss first: update_loss = ( @@ -902,14 +844,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings batch_domain ] = update_loss.item() - # Then TV loss - update_loss = self.losses["G"]["tasks"]["m"]["tv"](prediction) - step_loss += update_loss - - self.logger.losses.generator.task_loss["m"]["tv"][ - batch_domain - ] = update_loss.item() - + # Real domain --> ADVENT domain adaptation if batch_domain == "r": # -----------ADVENT losses-------------- pred_complementary = 1 - prediction @@ -943,8 +878,52 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings ] = update_loss.item() # ------------------------------------- + # Then TV loss + update_loss = self.losses["G"]["tasks"]["m"]["tv"](prediction) + step_loss += update_loss + + self.logger.losses.generator.task_loss["m"]["tv"][ + batch_domain + ] = update_loss.item() + return step_loss + def multi_level_loss(self, batch_domain, label_idx): + step_loss = 0 + labels = {1: self.label_1, 2: self.label_2} + + prediction = self.G.decoders["m"](torch.cat((self.z, labels[label_idx]), dim=1)) + + # AC-GAN loss + validity, auxiliary = self.D["m2"]["FloodLevel"](prediction, self.z) + + # GAN loss + update_loss = ( + self.losses["D"]["default"](validity, True) + * self.opts.train.lambdas.G.m2.gan + ) + self.logger.losses.generator.task_loss["m"]["gan"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + # Auxiliary loss + update_loss = ( + self.losses["D"]["multilevel"]( + auxiliary, labels[label_idx][:, 0, 0, 0].squeeze() + ) + * self.opts.train.lambdas.G.m2.aux + ) + + self.logger.losses.generator.task_loss["m"]["aux"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + return prediction, step_loss + def sample_z(self, batch_size): return ( torch.empty( @@ -1407,39 +1386,40 @@ def get_classifier_loss(self, multi_domain_batch): def run_evaluation(self, verbose=0): print("******************* Running Evaluation ***********************") self.eval_mode() - val_logger = None - nb_of_batches = None - for i, multi_batch_tuple in enumerate(self.val_loaders): - # create a dictionnary (domain => batch) from tuple - # (batch_domain_0, ..., batch_domain_i) - # and send it to self.device - nb_of_batches = i + 1 - multi_domain_batch = { - batch["domain"][0]: self.batch_to_device(batch) - for batch in multi_batch_tuple - } - self.get_g_loss(multi_domain_batch, verbose) + with torch.no_grad(): + val_logger = None + nb_of_batches = None + for i, multi_batch_tuple in enumerate(self.val_loaders): + # create a dictionnary (domain => batch) from tuple + # (batch_domain_0, ..., batch_domain_i) + # and send it to self.device + nb_of_batches = i + 1 + multi_domain_batch = { + batch["domain"][0]: self.batch_to_device(batch) + for batch in multi_batch_tuple + } + self.get_g_loss(multi_domain_batch, verbose) - if val_logger is None: - val_logger = deepcopy(self.logger.losses.generator) - else: - val_logger = sum_dict(val_logger, self.logger.losses.generator) + if val_logger is None: + val_logger = deepcopy(self.logger.losses.generator) + else: + val_logger = sum_dict(val_logger, self.logger.losses.generator) - val_logger = div_dict(val_logger, nb_of_batches) - self.logger.losses.generator = val_logger - self.log_losses(model_to_update="G", mode="val") + val_logger = div_dict(val_logger, nb_of_batches) + self.logger.losses.generator = val_logger + self.log_losses(model_to_update="G", mode="val") - for d in self.opts.domains: - self.log_comet_images("train", d) - self.log_comet_images("val", d) + for d in self.opts.domains: + self.log_comet_images("train", d) + self.log_comet_images("val", d) - if "m" in self.opts.tasks and "p" in self.opts.tasks: - self.log_comet_combined_images("train", "r") - self.log_comet_combined_images("val", "r") + if "m" in self.opts.tasks and "p" in self.opts.tasks: + self.log_comet_combined_images("train", "r") + self.log_comet_combined_images("val", "r") - if "m" in self.opts.tasks: - self.eval_images("val", "r") - self.eval_images("val", "s") + if "m" in self.opts.tasks: + self.eval_images("val", "r") + self.eval_images("val", "s") self.train_mode() print("****************** Done *********************")