diff --git a/infer.py b/infer.py index 449d67f0..5e85eca3 100644 --- a/infer.py +++ b/infer.py @@ -1,5 +1,5 @@ import torch -import numpy +import numpy as np from omnigan.utils import load_opts from pathlib import Path from argparse import ArgumentParser @@ -44,7 +44,9 @@ def parsed_args(): required=True, ) parser.add_argument( - "--new_size", type=int, help="Size of generated masks", + "--new_size", + type=int, + help="Size of generated masks", ) parser.add_argument( "--output_dir", @@ -52,14 +54,41 @@ def parsed_args(): type=str, help="Directory to write images to", ) + parser.add_argument( + "--path_to_masks", + type=str, + help="Path of masks to be used for painting", + required=False, + ) + parser.add_argument( + "--apply_mask", + action="store_true", + help="Apply mask to image to save", + ) return parser.parse_args() -def eval_folder(path_to_images, output_dir, paint=False): - images = [path_to_images / Path(i) for i in os.listdir(path_to_images)] - for img_path in images: +def eval_folder( + output_dir, + path_to_images, + path_to_masks=None, + paint=False, + masker=False, + apply_mask=False, +): + + image_list = os.listdir(path_to_images) + image_list.sort() + images = [path_to_images / Path(i) for i in image_list] + if not masker: + mask_list = os.listdir(path_to_masks) + mask_list.sort() + masks = [path_to_masks / Path(i) for i in mask_list] + + for i, img_path in enumerate(images): img = tensor_loader(img_path, task="x", domain="val") + # Resize img: img = F.interpolate(img, (new_size, new_size), mode="nearest") img = img.squeeze(0) @@ -67,15 +96,52 @@ def eval_folder(path_to_images, output_dir, paint=False): img = tf(img) img = img.unsqueeze(0).to(device) - z = model.encode(img) - mask = model.decoders["m"](z) - vutils.save_image(mask, output_dir / ("mask_" + img_path.name), normalize=True) + if not masker: + mask = tensor_loader(masks[i], task="m", domain="val", binarize=False) + # mask = F.interpolate(mask, (new_size, new_size), mode="nearest") + mask = mask.squeeze() + mask = mask.unsqueeze(0).to(device) + + if masker: + if "m2" in opts.tasks: + z = model.encode(img) + num_masks = 10 + label_vals = np.linspace(start=0, stop=1, num=num_masks) + for label_val in label_vals: + z_aug = torch.cat( + (z, label_val * trainer.label_2[0, :, :, :].unsqueeze(0)), + dim=1, + ) + mask = model.decoders["m"](z_aug) + + vutils.save_image( + mask, output_dir / (f"mask_{label_val}_" + img_path.name), normalize=True + ) + if apply_mask: + vutils.save_image( + img * (1.0 - mask) + mask, + output_dir / (img_path.stem + f"img_masked_{label_val}" + ".jpg"), + normalize=True, + ) + + else: + z = model.encode(img) + mask = model.decoders["m"](z) + vutils.save_image( + mask, output_dir / ("mask_" + img_path.name), normalize=True + ) if paint: z_painter = trainer.sample_z(1) fake_flooded = model.painter(z_painter, img * (1.0 - mask)) vutils.save_image(fake_flooded, output_dir / img_path.name, normalize=True) + if apply_mask: + vutils.save_image( + img * (1.0 - mask) + mask, + output_dir / (img_path.stem + "_masked" + ".jpg"), + normalize=True, + ) def isimg(path_file): @@ -113,10 +179,12 @@ def isimg(path_file): else: new_size = args.new_size - if "m" in opts.tasks and "p" in opts.tasks: + paint = False + masker = False + if "p" in opts.tasks: paint = True - else: - paint = False + if "m" in opts.tasks: + masker = True # ------------------------ # ----- Define model ----- # ------------------------ @@ -142,6 +210,7 @@ def isimg(path_file): # eval_folder(args.path_to_images, output_dir) rootdir = args.path_to_images + maskdir = args.path_to_masks writedir = args.output_dir for root, subdirs, files in tqdm(os.walk(rootdir)): @@ -151,10 +220,6 @@ def isimg(path_file): has_imgs = False for f in files: if isimg(f): - # read_path = root / f - # rel_path = read_path.relative_to(rootdir) - # write_path = writedir / rel_path - # write_path.mkdir(parents=True, exist_ok=True) has_imgs = True break @@ -163,4 +228,12 @@ def isimg(path_file): rel_path = root.relative_to(rootdir) write_path = writedir / rel_path write_path.mkdir(parents=True, exist_ok=True) - eval_folder(root, write_path, paint) + print("root: ", root) + eval_folder( + write_path, + root, + path_to_masks=maskdir, + paint=paint, + masker=masker, + apply_mask=args.apply_mask, + ) diff --git a/omnigan/data.py b/omnigan/data.py index 2a4c998a..39fa585b 100644 --- a/omnigan/data.py +++ b/omnigan/data.py @@ -264,7 +264,7 @@ def tensor_loader(path, task, domain): arr = np.moveaxis(arr, 2, 0) elif task == "s": arr = np.moveaxis(arr, 2, 0) - elif task == "m": + elif task == "m" or task == "m2": arr[arr != 0] = 1 # Make sure mask is single-channel if len(arr.shape) >= 3: diff --git a/omnigan/discriminator.py b/omnigan/discriminator.py index 753b2215..6529b850 100644 --- a/omnigan/discriminator.py +++ b/omnigan/discriminator.py @@ -233,6 +233,116 @@ def forward(self, input): return result +class AuxiliaryClassifier(nn.Module): + def __init__( + self, + input_size=640, + input_nc=1, + ndf=64, + n_layers=3, + norm_layer=nn.BatchNorm2d, + use_sigmoid=False, + ): + super(AuxiliaryClassifier, self).__init__() + self.input_nc = input_nc + self.ndf = ndf + self.n_layers = n_layers + self.norm_layer = norm_layer + self.use_sigmoid = use_sigmoid + + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + + kw = 3 + padw = 1 + sequence = [ + # Use spectral normalization + SpectralNorm( + nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw) + ), + nn.LeakyReLU(0.2, True), + ] + + nf_mult = 1 + nf_mult_prev = 1 + for n in range(1, n_layers): + nf_mult_prev = nf_mult + nf_mult = min(2 ** n, 8) + sequence += [ + # Use spectral normalization + SpectralNorm( # TODO replace with Conv2dBlock + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=2, + padding=padw, + bias=use_bias, + ) + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + nf_mult_prev = nf_mult + nf_mult = min(2 ** n_layers, 8) + sequence += [ + # Use spectral normalization + SpectralNorm( + nn.Conv2d( + ndf * nf_mult_prev, + ndf * nf_mult, + kernel_size=kw, + stride=1, + padding=padw, + bias=use_bias, + ) + ), + norm_layer(ndf * nf_mult), + nn.LeakyReLU(0.2, True), + ] + + self.shared_layers = nn.Sequential(*sequence) + + proj_dim = ndf + + self.projection = SpectralNorm( + nn.Conv2d( + ndf * nf_mult + 2048, proj_dim, kernel_size=1, stride=1, padding=0 + ) + ) + + latent_size = int(input_size / (2 ** n_layers)) + self.linear_size = int(proj_dim * latent_size * latent_size) + self.gan_layer = nn.Linear(self.linear_size, 1) + self.ac_layer = nn.Linear(self.linear_size, 2) + + def forward(self, mask, z): + x = self.shared_layers(mask) + x = torch.cat((x, z), dim=1) + x = self.projection(x) + x = x.view(-1, self.linear_size) + + return [self.gan_layer(x), self.ac_layer(x)] + + +def get_AC( + input_size, input_nc, ndf, n_layers=3, norm="batch", use_sigmoid=False, +): + norm_layer = get_norm_layer(norm_type=norm) + net = AuxiliaryClassifier( + input_size=input_size, + input_nc=input_nc, + ndf=ndf, + n_layers=n_layers, + norm_layer=norm_layer, + use_sigmoid=use_sigmoid, + ) + return net + + class OmniDiscriminator(nn.ModuleDict): def __init__(self, opts): super().__init__() @@ -281,6 +391,21 @@ def __init__(self, opts): ) else: raise Exception("This Discriminator is currently not supported!") + if "m2" in opts.tasks: + # Create a flood-level discriminator / classifier + self["m2"] = nn.ModuleDict( + { + "FloodLevel": get_AC( + input_size=640, + input_nc=1, + ndf=opts.dis.m2.ndf, + n_layers=opts.dis.m2.n_layers, + norm=opts.dis.m2.norm, + use_sigmoid=opts.dis.m2.use_sigmoid, + ) + } + ) + if "s" in opts.tasks: if opts.gen.s.use_advent: self["s"] = nn.ModuleDict( diff --git a/omnigan/generator.py b/omnigan/generator.py index e4691a56..0a4e5a47 100644 --- a/omnigan/generator.py +++ b/omnigan/generator.py @@ -84,7 +84,10 @@ def __init__(self, opts, latent_shape=None, verbose=None): self.decoders["s"] = SegmentationDecoder(opts) if "m" in opts.tasks and not opts.gen.m.ignore: - self.decoders["m"] = MaskDecoder(opts) + if "m2" in opts.tasks: + self.decoders["m"] = ConditionalMasker(opts) + else: + self.decoders["m"] = MaskDecoder(opts) self.decoders = nn.ModuleDict(self.decoders) @@ -116,6 +119,21 @@ def __init__(self, opts): ) +class ConditionalMasker(BaseDecoder): + def __init__(self, opts): + super().__init__( + n_upsample=opts.gen.m.n_upsample, + n_res=opts.gen.m.n_res, + input_dim=opts.gen.encoder.res_dim + 1, + proj_dim=opts.gen.m.proj_dim, + output_dim=opts.gen.m.output_dim, + res_norm=opts.gen.m.res_norm, + activ=opts.gen.m.activ, + pad_type=opts.gen.m.pad_type, + output_activ="sigmoid", + ) + + class DepthDecoder(BaseDecoder): def __init__(self, opts): super().__init__( diff --git a/omnigan/losses.py b/omnigan/losses.py index 04c83d79..7a8c9c64 100644 --- a/omnigan/losses.py +++ b/omnigan/losses.py @@ -358,7 +358,7 @@ def get_losses(opts, verbose, device=None): losses = { "G": {"a": {}, "p": {}, "tasks": {}}, - "D": {"default": {}, "advent": {}}, + "D": {"default": {}, "advent": {}, "multilevel": {}}, "C": {}, } @@ -417,6 +417,7 @@ def get_losses(opts, verbose, device=None): soft_shift=opts.dis.soft_shift, flip_prob=opts.dis.flip_prob, verbose=verbose ) losses["D"]["advent"] = ADVENTAdversarialLoss(opts) + losses["D"]["multilevel"] = CrossEntropy() return losses @@ -441,9 +442,7 @@ def __init__(self): def __call__(self, prediction, target): return self.loss( prediction, - torch.FloatTensor(prediction.size()) - .fill_(target) - .to(prediction.get_device()), + torch.FloatTensor(prediction.size()).fill_(target).to(prediction.device), ) diff --git a/omnigan/trainer.py b/omnigan/trainer.py index b86a4b35..9b1ed36c 100644 --- a/omnigan/trainer.py +++ b/omnigan/trainer.py @@ -37,8 +37,7 @@ class Trainer: - """Main trainer class - """ + """Main trainer class""" def __init__(self, opts, comet_exp=None, verbose=0): """Trainer class to gather various model training procedures @@ -208,9 +207,9 @@ def print_num_parameters(self): def setup(self): """Prepare the trainer before it can be used to train the models: - * initialize G and D - * compute latent space dims and create classifier accordingly - * creates 3 optimizers + * initialize G and D + * compute latent space dims and create classifier accordingly + * creates 3 optimizers """ self.logger.global_step = 0 start_time = time() @@ -221,6 +220,9 @@ def setup(self): self.G: OmniGenerator = get_gen(self.opts, verbose=self.verbose).to(self.device) print("Generator OK. Computing latent & input shapes...", end="", flush=True) + if self.G.encoder is not None: + self.latent_shape = self.compute_latent_shape() + self.input_shape = self.compute_input_shape() if "s" in self.opts.tasks: self.G.decoders["s"].set_target_size(self.input_shape[-2:]) @@ -228,6 +230,24 @@ def setup(self): print("OK.") self.painter_z_h = self.input_shape[-2] // (2 ** self.opts.gen.p.spade_n_up) self.painter_z_w = self.input_shape[-1] // (2 ** self.opts.gen.p.spade_n_up) + if "m2" in self.opts.tasks: + self.label_1 = torch.zeros( + ( + self.opts.data.loaders.batch_size, + 1, + self.latent_shape[1], + self.latent_shape[2], + ) + ).to(self.device) + self.label_2 = torch.ones( + ( + self.opts.data.loaders.batch_size, + 1, + self.latent_shape[1], + self.latent_shape[2], + ) + ).to(self.device) + self.D: OmniDiscriminator = get_dis(self.opts, verbose=self.verbose).to( self.device ) @@ -442,9 +462,31 @@ def log_comet_images(self, mode, domain): if update_task not in save_images: save_images[update_task] = [] - prediction = self.G.decoders[update_task](self.z) + if update_task == "m" or update_task == "m2": + if update_task == "m": + if "m2" in self.opts.tasks: + prediction = self.G.decoders[update_task]( + torch.cat( + (self.z, self.label_1[0, :, :, :].unsqueeze(0)), + dim=1, + ) + ) + else: + prediction = self.G.decoders[update_task](self.z) + + if update_task == "m2": + prediction = self.G.decoders["m"]( + torch.cat( + (self.z, self.label_2[0, :, :, :].unsqueeze(0)), + dim=1, + ) + ) + + prediction = prediction.repeat(1, 3, 1, 1) + task_saves.append(x * (1.0 - prediction)) + task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) - if update_task == "s": + elif update_task == "s": if domain == "s": target = ( decode_segmap_unity_labels(target, domain, True) @@ -458,23 +500,17 @@ def log_comet_images(self, mode, domain): ) task_saves.append(target) - elif update_task == "m": - prediction = prediction.repeat(1, 3, 1, 1) - task_saves.append(x * (1.0 - prediction)) - task_saves.append(x * (1.0 - target.repeat(1, 3, 1, 1))) - elif update_task == "d": # prediction is a log depth tensor target = (norm_tensor(target)) * 255 prediction = (norm_tensor(prediction)) * 255 prediction = prediction.repeat(1, 3, 1, 1) task_saves.append(target.repeat(1, 3, 1, 1)) - task_saves.append(prediction) - save_images[update_task].append(x.cpu().detach()) - + # ! This assumes the output is some kind of image + save_images[update_task].append(x) for im in task_saves: - save_images[update_task].append(im.cpu().detach()) + save_images[update_task].append(im) for task in save_images.keys(): # Write images: @@ -584,7 +620,7 @@ def write_images( image_grid = vutils.make_grid( ims, nrow=im_per_row, normalize=True, scale_each=True ) - image_grid = image_grid.permute(1, 2, 0).numpy() + image_grid = image_grid.permute(1, 2, 0).cpu().numpy() if comet_exp is not None: comet_exp.log_image( @@ -689,6 +725,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings x = batch["data"]["x"] self.z = self.G.encode(x) + # --------------------------------- # ----- classifier loss (1) ----- # --------------------------------- @@ -710,7 +747,7 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings # ----- task-specific regression losses (2) ----- # ------------------------------------------------- for update_task, update_target in batch["data"].items(): - if update_task not in {"m", "p", "x", "s"}: + if update_task not in {"m", "p", "x", "s", "m2"}: prediction = self.G.decoders[update_task](self.z) update_loss = self.losses["G"]["tasks"][update_task]( prediction, update_target @@ -772,54 +809,63 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings self.logger.losses.generator.task_loss[update_task][ "advent" ][batch_domain] = update_loss.item() - elif update_task == "m": - # ? output features classifier - prediction = self.G.decoders[update_task](self.z) - if batch_domain == "s": + elif update_task in {"m", "m2"}: + # If multi-level flooding + if "m2" in self.opts.tasks: + # Get label + if update_task == "m": + prediction, m_step_loss = self.multi_level_loss( + label_idx=1, batch_domain=batch_domain + ) + elif update_task == "m2": + prediction, m_step_loss = self.multi_level_loss( + label_idx=2, batch_domain=batch_domain + ) + + step_loss += m_step_loss + + # No multi-level flooding + else: + prediction = self.G.decoders["m"](self.z) + + # Simulated domain --> compare w/ ground truth masks + if batch_domain == "s": # Main loss first: update_loss = ( - self.losses["G"]["tasks"][update_task]["main"]( + self.losses["G"]["tasks"]["m"]["main"]( prediction, update_target ) - * lambdas.G[update_task]["main"] + * lambdas.G["m"]["main"] ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task]["main"][ + self.logger.losses.generator.task_loss["m"]["main"][ batch_domain ] = update_loss.item() - # Then TV loss - update_loss = self.losses["G"]["tasks"][update_task]["tv"]( - prediction - ) - step_loss += update_loss - - self.logger.losses.generator.task_loss[update_task]["tv"][ - batch_domain - ] = update_loss.item() - + # Real domain --> ADVENT domain adaptation if batch_domain == "r": + # -----------ADVENT losses-------------- pred_complementary = 1 - prediction prob = torch.cat([prediction, pred_complementary], dim=1) if self.opts.gen.m.use_minent: # Then Minent loss update_loss = ( - self.losses["G"]["tasks"][update_task]["minent"]( + self.losses["G"]["tasks"]["m"]["minent"]( prob.to(self.device) ) * self.opts.train.lambdas.advent.ent_main ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task][ - "minent" - ][batch_domain] = update_loss.item() + self.logger.losses.generator.task_loss["m"]["minent"][ + batch_domain + ] = update_loss.item() if self.opts.gen.m.use_advent: # Then Advent loss update_loss = ( - self.losses["G"]["tasks"][update_task]["advent"]( + self.losses["G"]["tasks"]["m"]["advent"]( prob.to(self.device), self.domain_labels["s"], self.D["m"]["Advent"], @@ -827,11 +873,57 @@ def get_masker_loss(self, multi_domain_batch): # TODO update docstrings * self.opts.train.lambdas.advent.adv_main ) step_loss += update_loss - self.logger.losses.generator.task_loss[update_task][ - "advent" - ][batch_domain] = update_loss.item() + self.logger.losses.generator.task_loss["m"]["advent"][ + batch_domain + ] = update_loss.item() + # ------------------------------------- + + # Then TV loss + update_loss = self.losses["G"]["tasks"]["m"]["tv"](prediction) + step_loss += update_loss + + self.logger.losses.generator.task_loss["m"]["tv"][ + batch_domain + ] = update_loss.item() + return step_loss + def multi_level_loss(self, batch_domain, label_idx): + step_loss = 0 + labels = {1: self.label_1, 2: self.label_2} + + prediction = self.G.decoders["m"](torch.cat((self.z, labels[label_idx]), dim=1)) + + # AC-GAN loss + validity, auxiliary = self.D["m2"]["FloodLevel"](prediction, self.z) + + # GAN loss + update_loss = ( + self.losses["D"]["default"](validity, True) + * self.opts.train.lambdas.G.m2.gan + ) + self.logger.losses.generator.task_loss["m"]["gan"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + # Auxiliary loss + update_loss = ( + self.losses["D"]["multilevel"]( + auxiliary, labels[label_idx][:, 0, 0, 0].squeeze() + ) + * self.opts.train.lambdas.G.m2.aux + ) + + self.logger.losses.generator.task_loss["m"]["aux"][ + batch_domain + ] = update_loss.item() + + step_loss += update_loss + + return prediction, step_loss + def sample_z(self, batch_size): return ( torch.empty( @@ -1034,11 +1126,14 @@ def get_d_loss(self, multi_domain_batch, verbose=0): "m": {"Advent": 0}, "s": {"Advent": 0}, "p": {"global": 0, "local": 0}, + "m2": {"gan": 0, "aux": 0}, } for batch_domain, batch in multi_domain_batch.items(): x = batch["data"]["x"] m = batch["data"]["m"] + if "m2" in self.opts.tasks: + m2 = batch["data"]["m2"] if batch_domain == "rf": # sample vector @@ -1073,20 +1168,148 @@ def get_d_loss(self, multi_domain_batch, verbose=0): if self.opts.gen.m.use_advent: if verbose > 0: print("Now training the ADVENT discriminator!") - fake_mask = self.G.decoders["m"](z) - fake_complementary_mask = 1 - fake_mask - prob = torch.cat([fake_mask, fake_complementary_mask], dim=1) - prob = prob.detach() - loss_main = self.losses["D"]["advent"]( - prob.to(self.device), - self.domain_labels[batch_domain], - self.D["m"]["Advent"], - ) + if "m2" in self.opts.tasks: + # --------ADVENT LOSS--------------- + # Compute loss for flood level 1 + fake_mask = self.G.decoders["m"]( + torch.cat((z, self.label_1), dim=1) + ) + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() - disc_loss["m"]["Advent"] += ( - self.opts.train.lambdas.advent.adv_main * loss_main - ) + loss_main = self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) + + # ---------------------------------- + # ------------AC GAN LOSS------------- + fake_validity, fake_auxiliary = self.D["m2"]["FloodLevel"]( + fake_mask.detach(), self.z.detach() + ) + + gan_loss = self.losses["D"]["default"](fake_validity, False) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + fake_auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + # ---------------------------------- + + # Compute again for flood level 2 + fake_mask = self.G.decoders["m"]( + torch.cat((z, self.label_2), dim=1) + ) + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() + + loss_main += self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) + + # Fake AC GAN loss + fake_validity, fake_auxiliary = self.D["m2"]["FloodLevel"]( + fake_mask.detach(), self.z.detach() + ) + + gan_loss = self.losses["D"]["default"](fake_validity, False) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + fake_auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + + # AC-GAN loss for groundtruth (if in simulated domain) + if batch_domain == "s": + # For flood level 1: + validity, auxiliary = self.D["m2"]["FloodLevel"]( + m, self.z.detach() + ) + gan_loss = self.losses["D"]["default"](validity, True) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + auxiliary, self.label_1[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + + # For flood level 2 + validity, auxiliary = self.D["m2"]["FloodLevel"]( + m2, self.z.detach() + ) + gan_loss = self.losses["D"]["default"](validity, True) + + disc_loss["m2"]["gan"] += ( + self.opts.train.lambdas.G.m2.gan * gan_loss + ) + + # AC auxiliary loss + aux_loss = self.losses["D"]["multilevel"]( + auxiliary, self.label_2[:, 0, 0, 0].squeeze() + ) + + disc_loss["m2"]["aux"] += ( + self.opts.train.lambdas.G.m2.aux * aux_loss + ) + else: + fake_mask = self.G.decoders["m"](z) + + fake_complementary_mask = 1 - fake_mask + prob = torch.cat( + [fake_mask, fake_complementary_mask], dim=1 + ) + prob = prob.detach() + + loss_main = self.losses["D"]["advent"]( + prob.to(self.device), + self.domain_labels[batch_domain], + self.D["m"]["Advent"], + ) + + disc_loss["m"]["Advent"] += ( + self.opts.train.lambdas.advent.adv_main * loss_main + ) if "s" in self.opts.tasks: if self.opts.gen.s.use_advent: preds = self.G.decoders["s"](z) @@ -1163,39 +1386,40 @@ def get_classifier_loss(self, multi_domain_batch): def run_evaluation(self, verbose=0): print("******************* Running Evaluation ***********************") self.eval_mode() - val_logger = None - nb_of_batches = None - for i, multi_batch_tuple in enumerate(self.val_loaders): - # create a dictionnary (domain => batch) from tuple - # (batch_domain_0, ..., batch_domain_i) - # and send it to self.device - nb_of_batches = i + 1 - multi_domain_batch = { - batch["domain"][0]: self.batch_to_device(batch) - for batch in multi_batch_tuple - } - self.get_g_loss(multi_domain_batch, verbose) + with torch.no_grad(): + val_logger = None + nb_of_batches = None + for i, multi_batch_tuple in enumerate(self.val_loaders): + # create a dictionnary (domain => batch) from tuple + # (batch_domain_0, ..., batch_domain_i) + # and send it to self.device + nb_of_batches = i + 1 + multi_domain_batch = { + batch["domain"][0]: self.batch_to_device(batch) + for batch in multi_batch_tuple + } + self.get_g_loss(multi_domain_batch, verbose) - if val_logger is None: - val_logger = deepcopy(self.logger.losses.generator) - else: - val_logger = sum_dict(val_logger, self.logger.losses.generator) + if val_logger is None: + val_logger = deepcopy(self.logger.losses.generator) + else: + val_logger = sum_dict(val_logger, self.logger.losses.generator) - val_logger = div_dict(val_logger, nb_of_batches) - self.logger.losses.generator = val_logger - self.log_losses(model_to_update="G", mode="val") + val_logger = div_dict(val_logger, nb_of_batches) + self.logger.losses.generator = val_logger + self.log_losses(model_to_update="G", mode="val") - for d in self.opts.domains: - self.log_comet_images("train", d) - self.log_comet_images("val", d) + for d in self.opts.domains: + self.log_comet_images("train", d) + self.log_comet_images("val", d) - if "m" in self.opts.tasks and "p" in self.opts.tasks: - self.log_comet_combined_images("train", "r") - self.log_comet_combined_images("val", "r") + if "m" in self.opts.tasks and "p" in self.opts.tasks: + self.log_comet_combined_images("train", "r") + self.log_comet_combined_images("val", "r") - if "m" in self.opts.tasks: - self.eval_images("val", "r") - self.eval_images("val", "s") + if "m" in self.opts.tasks: + self.eval_images("val", "r") + self.eval_images("val", "s") self.train_mode() print("****************** Done *********************") @@ -1293,7 +1517,18 @@ def eval_images(self, mode, domain): x = im_set["data"]["x"].unsqueeze(0).to(self.device) m = im_set["data"]["m"].unsqueeze(0).detach().cpu().numpy() z = self.G.encode(x) - pred_mask = self.G.decoders["m"](z).detach().cpu().numpy() + if "m2" in self.opts.tasks: + pred_mask = ( + self.G.decoders["m"]( + torch.cat((z, self.label_1[0, :, :, :].unsqueeze(0)), dim=1) + ) + .detach() + .cpu() + .numpy() + ) + + else: + pred_mask = self.G.decoders["m"](z).detach().cpu().numpy() # Binarize mask pred_mask[pred_mask > 0.5] = 1.0 diff --git a/omnigan/transforms.py b/omnigan/transforms.py index b6de53fd..cb11bcc5 100644 --- a/omnigan/transforms.py +++ b/omnigan/transforms.py @@ -91,7 +91,7 @@ def __call__(self, data): for task, im in data.items(): if task in {"x", "a"}: new_data[task] = self.ImagetoTensor(im) - elif task in {"m"}: + elif task in {"m", "m2"}: new_data[task] = self.MaptoTensor(im) elif task == "s": new_data[task] = torch.squeeze(torch.from_numpy(np.array(im))).to( @@ -116,6 +116,7 @@ def __init__(self): "s": self.normSeg, "d": self.normDepth, "m": self.normMask, + "m2": self.normMask, } def __call__(self, data): diff --git a/shared/trainer/defaults.yaml b/shared/trainer/defaults.yaml index c0d16802..c52b623e 100644 --- a/shared/trainer/defaults.yaml +++ b/shared/trainer/defaults.yaml @@ -138,6 +138,10 @@ dis: <<: *default-dis multi_level: false architecture: base # can be [base | OmniDiscriminator] + m2: + <<: *default-dis + input_nc: 1 + # ------------------------------- # ----- Domain Classifier ----- @@ -176,6 +180,9 @@ train: m: main: 1 # Main prediction loss, i.e. GAN or BCE tv: 1 # Total variational loss (for smoothing) + m2: + gan: 0.01 + aux: 0.01 p: gan: 1 # gan loss sm: 1 # semantic matching diff --git a/tests/test_gen.py b/tests/test_gen.py index 5d0a94ea..17e7e09d 100644 --- a/tests/test_gen.py +++ b/tests/test_gen.py @@ -15,7 +15,7 @@ parser = argparse.ArgumentParser() -parser.add_argument("-c", "--config", default="config/trainer/maskgen_v0.yaml") +parser.add_argument("-c", "--config", default="config/trainer/local_tests.yaml") args = parser.parse_args() root = Path(__file__).parent.parent opts = load_test_opts(args.config) @@ -61,7 +61,6 @@ print(sum(p.numel() for p in G.decoders.parameters())) G = get_gen(opts).to(device) - # ------------------------------- # ----- Test Architecture ----- # -------------------------------