diff --git a/.gitignore b/.gitignore index c390d5982..4850e8961 100644 --- a/.gitignore +++ b/.gitignore @@ -47,3 +47,5 @@ slurm*.out #lightning_logs directory lightning_logs/ + +applications/dynacell/test \ No newline at end of file diff --git a/applications/dynacell/fid.py b/applications/dynacell/fid.py new file mode 100644 index 000000000..d53392dc4 --- /dev/null +++ b/applications/dynacell/fid.py @@ -0,0 +1,207 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +from vae_3d.vae_3d_config import VAE3DConfig +from vae_3d.vae_3d_model import VAE3DModel + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +@torch.no_grad() +def encode_fovs( + fovs, + vae, + channel_name: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb: (N, latent_dim) tensors + """ + emb = [] + + for pos in tqdm(fovs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + + v = normalise(v) # still (T, D, H, W) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat = vae.encode(slice).mean # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) + + return torch.cat(emb, 0) + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + model_cfg = VAE3DConfig() + model_cfg.loadcheck_path = args.loadcheck_path + vae = VAE3DModel(config=model_cfg).to(device).eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + if args.max_fov: + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) + + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID: {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/dynacell/fid_ts.py b/applications/dynacell/fid_ts.py new file mode 100644 index 000000000..8c8f272cf --- /dev/null +++ b/applications/dynacell/fid_ts.py @@ -0,0 +1,203 @@ +# -*- coding: utf-8 -*- +import argparse +from pathlib import Path + +import torch +from tqdm import tqdm +from iohub.ngff import open_ome_zarr +from torch import Tensor + +# ----------------------------------------------------------------------------- # +# Helper functions # +# ----------------------------------------------------------------------------- # + +def read_zarr(zarr_path: str): + plate = open_ome_zarr(zarr_path, mode="r") + return [pos for _, pos in plate.positions()] + +def normalise(volume: torch.Tensor) -> torch.Tensor: + """Per-sample min max → [-1,1]. Shape: (D, H, W) or (B, D, H, W).""" + v_min = volume.amin(dim=(-3, -2, -1), keepdim=True) + v_max = volume.amax(dim=(-3, -2, -1), keepdim=True) + volume = (volume - v_min) / (v_max - v_min + 1e-6) # → [0,1] + return volume * 2.0 - 1.0 # → [-1,1] + +@torch.jit.script_if_tracing +def sqrtm(sigma: Tensor) -> Tensor: + r"""Returns the square root of a positive semi-definite matrix. + + .. math:: \sqrt{\Sigma} = Q \sqrt{\Lambda} Q^T + + where :math:`Q \Lambda Q^T` is the eigendecomposition of :math:`\Sigma`. + + Args: + sigma: A positive semi-definite matrix, :math:`(*, D, D)`. + + Example: + >>> V = torch.randn(4, 4, dtype=torch.double) + >>> A = V @ V.T + >>> B = sqrtm(A @ A) + >>> torch.allclose(A, B) + True + """ + + L, Q = torch.linalg.eigh(sigma) + L = L.relu().sqrt() + + return Q @ (L[..., None] * Q.mT) + +@torch.jit.script_if_tracing +def frechet_distance( + mu_x: Tensor, + sigma_x: Tensor, + mu_y: Tensor, + sigma_y: Tensor, +) -> Tensor: + r"""Returns the Fréchet distance between two multivariate Gaussian distributions. + + .. math:: d^2 = \left\| \mu_x - \mu_y \right\|_2^2 + + \operatorname{tr} \left( \Sigma_x + \Sigma_y - 2 \sqrt{\Sigma_y^{\frac{1}{2}} \Sigma_x \Sigma_y^{\frac{1}{2}}} \right) + + Wikipedia: + https://wikipedia.org/wiki/Frechet_distance + + Args: + mu_x: The mean :math:`\mu_x` of the first distribution, :math:`(*, D)`. + sigma_x: The covariance :math:`\Sigma_x` of the first distribution, :math:`(*, D, D)`. + mu_y: The mean :math:`\mu_y` of the second distribution, :math:`(*, D)`. + sigma_y: The covariance :math:`\Sigma_y` of the second distribution, :math:`(*, D, D)`. + + Example: + >>> mu_x = torch.arange(3).float() + >>> sigma_x = torch.eye(3) + >>> mu_y = 2 * mu_x + 1 + >>> sigma_y = 2 * sigma_x + 1 + >>> frechet_distance(mu_x, sigma_x, mu_y, sigma_y) + tensor(15.8710) + """ + + sigma_y_12 = sqrtm(sigma_y) + + a = (mu_x - mu_y).square().sum(dim=-1) + b = sigma_x.trace() + sigma_y.trace() + c = sqrtm(sigma_y_12 @ sigma_x @ sigma_y_12).trace() + + return a + b - 2 * c + +@torch.no_grad() +def fid_from_features(f1, f2, eps=1e-6): + mu1, sigma1 = f1.mean(0), torch.cov(f1.T) + mu2, sigma2 = f2.mean(0), torch.cov(f2.T) + + eye = torch.eye(sigma1.size(0), device=sigma1.device, dtype=sigma1.dtype) + sigma1 = sigma1 + eps * eye + sigma2 = sigma2 + eps * eye + + return frechet_distance(mu1, sigma1, mu2, sigma2).clamp_min_(0).item() + +@torch.no_grad() +def encode_fovs( + fovs, + vae, + channel_name: str, + device: str = "cuda", + batch_size: int = 4, + input_spatial_size: tuple = (32, 512, 512), +): + """ + For each FOV pair: + • take all T time-frames (shape: T, D, H, W) + • normalise to [-1, 1] + • feed through VAE in chunks of ≤ batch_size frames + • average the resulting T latent vectors → one embedding / FOV + Returns + emb: (N, latent_dim) tensors + """ + emb = [] + + for pos in tqdm(fovs, desc="Encoding FOVs"): + # ---------------- load & normalise ---------------- # + v = torch.as_tensor( + pos.data[:, pos.get_channel_index(channel_name)], + dtype=torch.float32, device=device, + ) # (T, D, H, W) + + v = normalise(v) # still (T, D, H, W) + + # ---------------- chunked VAE inference ----------- # + for t0 in range(0, v.shape[0], batch_size): + slice = v[t0 : t0 + batch_size].unsqueeze(1) # (b, 1, D, H, W) + + # resize to input spatial size + slice = torch.nn.functional.interpolate( + slice, size=input_spatial_size, mode="trilinear", align_corners=False, + ) # (b, 1, D, H, W) + + feat = vae.encode(slice)[0] # mean, + feat = feat.flatten(start_dim=1) # (b, latent_dim) + emb.append(feat) + + return torch.cat(emb, 0) + +# ----------------------------------------------------------------------------- # +# Main # +# ----------------------------------------------------------------------------- # + +def build_argparser() -> argparse.ArgumentParser: + p = argparse.ArgumentParser(add_help=False) + p.add_argument("--data_path1", type=Path, required=True) + p.add_argument("--data_path2", type=Path, required=True) + p.add_argument("--channel_name", type=str, default=None) + p.add_argument("--channel_name1", type=str, default=None) + p.add_argument("--channel_name2", type=str, default=None) + p.add_argument("--input_spatial_size", type=str, default="32,512,512", + help="Input spatial size for the VAE, e.g. '32,512,512'.") + p.add_argument("--loadcheck_path", type=Path, default=None, + help="Path to the VAE model checkpoint for loading.") + p.add_argument("--batch_size", type=int, default=4) + p.add_argument("--device", type=str, default="cuda") + p.add_argument("--max_fov", type=int, default=None, + help="Limit number of FOV pairs (for quick tests).") + return p + +def main(args) -> None: + device = args.device + + # ----------------- VAE ----------------- # + vae = torch.jit.load(args.loadcheck_path).to(device) + vae.eval() + + # ----------------- FOV list ------------ # + fovs1, fovs2 = read_zarr(args.data_path1), read_zarr(args.data_path2) + if args.max_fov: + fovs1 = fovs1[:args.max_fov] + fovs2 = fovs2[:args.max_fov] + + # ----------------- Embeddings ----------- # + input_spatial_size = [int(dim) for dim in args.input_spatial_size.split(",")] + + if args.channel_name is not None: + args.channel_name1 = args.channel_name2 = args.channel_name + + emb1 = encode_fovs( + fovs1, vae, + args.channel_name1, + device, args.batch_size, + input_spatial_size, + ) + + emb2 = encode_fovs( + fovs2, vae, + args.channel_name2, + device, args.batch_size, + input_spatial_size, + ) + + # ----------------- FID ------------------ # + fid_val = fid_from_features(emb1, emb2) + print(f"\nFID: {fid_val:.6f}") + +if __name__ == "__main__": + parser = build_argparser() + args = parser.parse_args() + main(args) \ No newline at end of file diff --git a/applications/dynacell/test_fid.sh b/applications/dynacell/test_fid.sh new file mode 100644 index 000000000..84c330e5d --- /dev/null +++ b/applications/dynacell/test_fid.sh @@ -0,0 +1,17 @@ +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Organelle \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae.pth \ + --batch_size 4 \ + --device cuda + +python fid.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae.pth \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/dynacell/test_fid_ts.sh b/applications/dynacell/test_fid_ts.sh new file mode 100644 index 000000000..012cb23b1 --- /dev/null +++ b/applications/dynacell/test_fid_ts.sh @@ -0,0 +1,17 @@ +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Nuclei-prediction \ + --channel_name2 Organelle \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/nucleus_vae_ts.pt \ + --batch_size 4 \ + --device cuda + +python fid_ts.py \ + --data_path1 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --data_path2 /hpc/projects/virtual_staining/datasets/huang-lab/crops/mantis_figure_4.zarr \ + --channel_name1 Membrane-prediction \ + --channel_name2 Membrane \ + --loadcheck_path /hpc/projects/virtual_staining/models/huang-lab/fid/membrane_vae_ts.pt \ + --batch_size 4 \ + --device cuda \ No newline at end of file diff --git a/applications/dynacell/vae_3d/__init__.py b/applications/dynacell/vae_3d/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/dynacell/vae_3d/modules/__init__.py b/applications/dynacell/vae_3d/modules/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/applications/dynacell/vae_3d/modules/autoencoders.py b/applications/dynacell/vae_3d/modules/autoencoders.py new file mode 100644 index 000000000..9c3fb927a --- /dev/null +++ b/applications/dynacell/vae_3d/modules/autoencoders.py @@ -0,0 +1,160 @@ +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + +from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders.single_file_model import FromOriginalModelMixin +from diffusers.utils.accelerate_utils import apply_forward_hook +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.modeling_utils import ModelMixin + + +class Autoencoder3DKL(ModelMixin, ConfigMixin, FromOriginalModelMixin): + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[DecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.DecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is + returned. + + """ + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/applications/dynacell/vae_3d/modules/autoencoders_ts.py b/applications/dynacell/vae_3d/modules/autoencoders_ts.py new file mode 100644 index 000000000..9438ddd6d --- /dev/null +++ b/applications/dynacell/vae_3d/modules/autoencoders_ts.py @@ -0,0 +1,82 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .decoder import Decoder +from .encoder import Encoder + + +class Autoencoder3DKL(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + num_up_blocks: int = 2, + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + use_quant_conv: bool = True, + use_post_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + num_down_blocks=num_down_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + num_up_blocks=num_up_blocks, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None + self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, 1) if use_post_quant_conv else None + + def _encode(self, x: torch.Tensor) -> torch.Tensor: + + enc = self.encoder(x) + if self.quant_conv is not None: + enc = self.quant_conv(enc) + + return enc + + def encode(self, x: torch.Tensor): + h = self._encode(x) + mean, logvar = torch.chunk(h, 2, dim=1) + + return mean, logvar + + def _decode(self, z: torch.Tensor): + + if self.post_quant_conv is not None: + z = self.post_quant_conv(z) + dec = self.decoder(z) + + return dec + + def decode(self, z: torch.FloatTensor): + decoded = self._decode(z) + + return decoded + + def forward(self, x): + # placeholder forward + return x \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/blocks.py b/applications/dynacell/vae_3d/modules/blocks.py new file mode 100644 index 000000000..569d66e9a --- /dev/null +++ b/applications/dynacell/vae_3d/modules/blocks.py @@ -0,0 +1,353 @@ +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from diffusers.models.normalization import RMSNorm +from diffusers.models.activations import get_activation + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + groups_out: Optional[int] = None, + eps: float = 1e-6, + non_linearity: str = "swish", + output_scale_factor: float = 1.0, + use_in_shortcut: Optional[bool] = None, + conv_shortcut_bias: bool = True, + ): + super().__init__() + + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + + self.dropout = torch.nn.Dropout(dropout) + conv_3d_out_channels = out_channels + self.conv2 = nn.Conv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1, padding=1) + + self.nonlinearity = get_activation(non_linearity) + self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d( + in_channels, + conv_3d_out_channels, + kernel_size=1, + stride=1, + padding=0, + bias=conv_shortcut_bias, + ) + + def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + hidden_states = self.norm2(hidden_states) + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + +class Downsample3D(nn.Module): + """A 3D downsampling layer with an optional convolution. + """ + + def __init__( + self, + channels: int, + use_conv: bool = False, + out_channels: Optional[int] = None, + padding: int = 1, + kernel_size: int = 3, + norm_type: Optional[str] = None, + eps: Optional[float] = 1e-5, + elementwise_affine: Optional[bool] = True, + bias: bool = True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + self.kernel_size = kernel_size + stride = 2 # Downsampling stride is fixed to 2 + + # Initialize normalization + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # Choose between convolutional or pooling downsampling + if use_conv: + self.conv = nn.Conv3d( + self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias + ) + else: + assert self.channels == self.out_channels, "out_channels must match channels when using pooling" + self.conv = nn.AvgPool3d(kernel_size=stride, stride=stride) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the downsampling block. + + Args: + hidden_states (torch.Tensor): Input feature map of shape (B, C, D, H, W). + + Returns: + torch.Tensor: Downsampled feature map. + """ + assert hidden_states.shape[1] == self.channels, \ + f"Expected input channels {self.channels}, but got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + # Apply padding if using conv downsampling and no padding was specified + if self.use_conv and self.padding == 0: + pad = (0, 1, 0, 1, 0, 1) # Padding for 3D tensor: (D, H, W) + hidden_states = F.pad(hidden_states, pad, mode="constant", value=0.0) + + # Apply downsampling + hidden_states = self.conv(hidden_states) + + return hidden_states + +class Upsample3D(nn.Module): + """A 3D upsampling layer with a convolution. + """ + + def __init__( + self, + channels: int, + out_channels: Optional[int] = None, + kernel_size: Optional[int] = None, + padding=1, + norm_type=None, + eps=None, + elementwise_affine=None, + bias=True, + ): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + if norm_type == "ln_norm": + self.norm = nn.LayerNorm(self.channels, eps=eps, elementwise_affine=elementwise_affine) + elif norm_type == "rms_norm": + self.norm = RMSNorm(channels, eps, elementwise_affine) + elif norm_type is None: + self.norm = None + else: + raise ValueError(f"unknown norm_type: {norm_type}") + + conv = None + if kernel_size is None: + kernel_size = 3 + conv = nn.Conv3d(self.channels, self.out_channels, kernel_size=kernel_size, padding=padding, bias=bias) + + # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed + self.conv = conv + + def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None) -> torch.Tensor: + assert hidden_states.shape[1] == self.channels, f"Expected {self.channels} channels, got {hidden_states.shape[1]}" + + # Apply normalization if specified + if self.norm is not None: + # LayerNorm expects (B, C, D, H, W), but normalizes over C. Permute to (B, D, H, W, C) + hidden_states = self.norm(hidden_states.permute(0, 2, 3, 4, 1)) + hidden_states = hidden_states.permute(0, 4, 1, 2, 3) # Back to (B, C, D, H, W) + + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + hidden_states = self.conv(hidden_states) + + return hidden_states + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + """ + + def __init__( + self, + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + if attn_groups is None: + attn_groups = resnet_groups + + self.resnets = nn.ModuleList([ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + for _ in range(num_layers + 1) + ]) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states) + + return hidden_states \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/decoder.py b/applications/dynacell/vae_3d/modules/decoder.py new file mode 100644 index 000000000..19ff8725b --- /dev/null +++ b/applications/dynacell/vae_3d/modules/decoder.py @@ -0,0 +1,98 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .blocks import UNetMidBlock3D, UpDecoderBlock3D +from diffusers.models.attention_processor import SpatialNorm + +class Decoder(nn.Module): + + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_up_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + norm_type: str = "group", # group, spatial + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.up_blocks = nn.ModuleList([]) + + temb_channels = in_channels if norm_type == "spatial" else None + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(num_up_blocks): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + in_channels=prev_output_channel, + out_channels=output_channel, + num_layers=self.layers_per_block + 1, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + add_upsample=not is_final_block, + ) + + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_type == "spatial": + self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels) + else: + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + self.conv_out = nn.Conv3d( + block_out_channels[0], + out_channels, + kernel_size=3, + padding=1, + padding_mode='reflect', + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.conv_in(sample) + + # middle + sample = self.mid_block(sample) + + # up + for up_block in self.up_blocks: + sample = up_block(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/encoder.py b/applications/dynacell/vae_3d/modules/encoder.py new file mode 100644 index 000000000..bc93f857b --- /dev/null +++ b/applications/dynacell/vae_3d/modules/encoder.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from typing import Tuple +from .blocks import DownEncoderBlock3D, UNetMidBlock3D + +class Encoder(nn.Module): + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + num_down_blocks: int = 2, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + act_fn: str = "silu", + double_z: bool = True, + ): + super().__init__() + self.layers_per_block = layers_per_block + + self.conv_in = nn.Conv3d( + in_channels, + block_out_channels[0], + kernel_size=3, + stride=1, + padding=1, + padding_mode='reflect' + ) + + self.down_blocks = nn.ModuleList([]) + + # down + output_channel = block_out_channels[0] + for i in range(num_down_blocks): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + in_channels=input_channel, + out_channels=output_channel, + dropout=0.0, + num_layers=self.layers_per_block, + add_downsample=not is_final_block, + resnet_eps=1e-6, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + downsample_padding=0, + ) + + self.down_blocks.append(down_block) + + # mid + self.mid_block = UNetMidBlock3D( + in_channels=block_out_channels[-1], + resnet_eps=1e-6, + resnet_act_fn=act_fn, + output_scale_factor=1, + resnet_groups=norm_num_groups, + ) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6) + self.conv_act = nn.SiLU() + + conv_out_channels = 2 * out_channels if double_z else out_channels + self.conv_out = nn.Conv3d(block_out_channels[-1], conv_out_channels, 3, padding=1) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.Tensor) -> torch.Tensor: + sample = self.conv_in(sample) + + # down + for down_block in self.down_blocks: + sample = down_block(sample) + + # middle + sample = self.mid_block(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + return sample \ No newline at end of file diff --git a/applications/dynacell/vae_3d/modules/utils.py b/applications/dynacell/vae_3d/modules/utils.py new file mode 100644 index 000000000..0c2a8185e --- /dev/null +++ b/applications/dynacell/vae_3d/modules/utils.py @@ -0,0 +1,9 @@ +import torch +from dataclasses import dataclass +from transformers.utils import ModelOutput + +@dataclass +class VAEOutput(ModelOutput): + loss: torch.FloatTensor = None + recon_loss: torch.FloatTensor = None + kl_loss: torch.FloatTensor = None \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_config.py b/applications/dynacell/vae_3d/vae_3d_config.py new file mode 100644 index 000000000..d30883ed2 --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_config.py @@ -0,0 +1,16 @@ +# -*- coding: utf-8 -*- +from dataclasses import dataclass +from transformers import PretrainedConfig +from dataclasses import field + +@dataclass +class VAE3DConfig(PretrainedConfig): + model_type: str = 'vae' + + # Model parameters + in_channels: int = 1 + out_channels: int = 1 + num_down_blocks: int = 5 + latent_channels: int = 2 + vae_block_out_channels: list = field(default_factory=lambda: [32, 64, 128, 256, 256]) + loadcheck_path: str = "" \ No newline at end of file diff --git a/applications/dynacell/vae_3d/vae_3d_model.py b/applications/dynacell/vae_3d/vae_3d_model.py new file mode 100644 index 000000000..d01138b1e --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_model.py @@ -0,0 +1,138 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig +from transformers import PreTrainedModel +from .modules.utils import VAEOutput + + +class VAE3DModel(PreTrainedModel): + config_class = VAE3DConfig + + def __init__(self, config: VAE3DConfig): + super().__init__(config) + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x).latent_dist + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, batched_data): + x = batched_data['data'] + + """Forward pass through the VAE.""" + latent_dist = self.encode(x) + latents = latent_dist.sample() + recon_x = self.decode(latents).sample + + total_loss, recon_loss, kl_loss = self.compute_loss(x, recon_x, latent_dist) + + return VAEOutput(total_loss, recon_loss, kl_loss) + + def compute_loss(self, x, recon_x, latent_dist): + """Compute reconstruction and KL divergence loss.""" + if self.config.vae_recon_loss_type == 'mse': + recon_loss = nn.MSELoss()(recon_x, x) + elif self.config.vae_recon_loss_type == 'poisson': + x = x.clip(-1, 1) + recon_x = recon_x.clip(-1, 1) + peak = self.config.poisson_peak if hasattr(self.config, 'poisson_peak') else 1.0 + target = (x + 1) / 2.0 * peak + lam = (recon_x + 1) / 2.0 * peak + recon_loss = torch.mean(lam - target * torch.log(lam + 1e-8)) + + kl_loss = -0.5 * torch.mean(1 + latent_dist.logvar - latent_dist.mean.pow(2) - latent_dist.logvar.exp()) + total_loss = self.config.recon_loss_coeff * recon_loss + self.config.kl_loss_coeff * kl_loss + return total_loss, recon_loss, kl_loss + + def sample(self, num_samples=1, latent_size=32, device="cpu"): + """ + Generate samples from the latent space. + + Args: + num_samples (int): Number of samples to generate. + device (str): Device to perform sampling on. + + Returns: + torch.Tensor: Generated images. + """ + # Sample from a standard normal distribution in latent space + latents = torch.randn((num_samples, self.config.latent_channels, latent_size, latent_size, latent_size), device=device) # Shape matches latent dimensions + + # Decode latents to generate images + with torch.no_grad(): + generated_images = self.decode(latents).sample + + return generated_images + + def reconstruct(self, x): + latent_dist = self.encode(x) + latents = latent_dist.sample() # Reparameterization trick + recon_x = self.decode(latents).sample + + return recon_x diff --git a/applications/dynacell/vae_3d/vae_3d_model_ts.py b/applications/dynacell/vae_3d/vae_3d_model_ts.py new file mode 100644 index 000000000..92a53ac24 --- /dev/null +++ b/applications/dynacell/vae_3d/vae_3d_model_ts.py @@ -0,0 +1,90 @@ +import os +import torch +import torch.nn as nn +from .modules.autoencoders_ts import Autoencoder3DKL +from .vae_3d_config import VAE3DConfig + + +class VAE3DModel(nn.Module): + def __init__(self, config: VAE3DConfig): + super().__init__() + self.config = config + + self.num_down_blocks = config.num_down_blocks + self.num_up_blocks = self.num_down_blocks + + # Initialize Autoencoder3DKL + self.vae = Autoencoder3DKL( + in_channels=config.in_channels, + out_channels=config.out_channels, + num_down_blocks=self.num_down_blocks, + num_up_blocks=self.num_up_blocks, + block_out_channels=config.vae_block_out_channels, + latent_channels=config.latent_channels, + ) + + self.load_pretrained_weights(checkpoint_path=config.loadcheck_path) + + def load_pretrained_weights(self, checkpoint_path): + """ + Load pretrained weights from a given state_dict. + """ + + if os.path.splitext(checkpoint_path)[1] == '.safetensors': + from safetensors.torch import load_file + checkpoints_state = load_file(checkpoint_path) + else: + checkpoints_state = torch.load(checkpoint_path, map_location="cpu") + + if "model" in checkpoints_state: + checkpoints_state = checkpoints_state["model"] + elif "module" in checkpoints_state: + checkpoints_state = checkpoints_state["module"] + + IncompatibleKeys = self.load_state_dict(checkpoints_state, strict=True) + IncompatibleKeys = IncompatibleKeys._asdict() + + missing_keys = [] + for keys in IncompatibleKeys["missing_keys"]: + if keys.find("dummy") == -1: + missing_keys.append(keys) + + unexpected_keys = [] + for keys in IncompatibleKeys["unexpected_keys"]: + if keys.find("dummy") == -1: + unexpected_keys.append(keys) + + if len(missing_keys) > 0: + print( + "Missing keys in {}: {}".format( + checkpoint_path, + missing_keys, + ) + ) + + if len(unexpected_keys) > 0: + print( + "Unexpected keys {}: {}".format( + checkpoint_path, + unexpected_keys, + ) + ) + + def encode(self, x): + """Encodes input into latent space.""" + return self.vae.encode(x) + + def decode(self, latents): + """Decodes latent space into reconstructed input.""" + return self.vae.decode(latents) + + def forward(self, x): + # placeholder forward + return x + + def reconstruct(self, x): + mean, logvar = self.encode(x) + latents = mean + torch.exp(0.5 * logvar) * torch.randn_like(logvar) # Reparameterization trick + recon_x = self.decode(latents) + + return recon_x