From 6ffb8e36c02d1e9dfafb80b2f99c9606762412f4 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 10 Mar 2025 10:57:17 -0400 Subject: [PATCH 01/20] Create Dockerfile for MSK SMIT Lung GTV --- models/msk_smit_lung_gtv/dockerfiles/Dockerfile | 1 + 1 file changed, 1 insertion(+) create mode 100644 models/msk_smit_lung_gtv/dockerfiles/Dockerfile diff --git a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile @@ -0,0 +1 @@ + From 90e36a1aab164644ec960f3c219b16a374d57f6a Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 10 Mar 2025 11:55:39 -0400 Subject: [PATCH 02/20] 1st draft Dockerfile --- models/msk_smit_lung_gtv/dockerfiles/Dockerfile | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile index 8b137891..ff18a288 100644 --- a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile +++ b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile @@ -1 +1,16 @@ +FROM mhubai/base:latest +# Update authors label +LABEL authors="aptea@mskcc.org,deasyj@mskcc.org,iyera@mskcc.org,locastre@mskcc.org" + +RUN apt update +RUN mkdir -p /app/models/msk_smit_lung_gtv +RUN cd /app/models/msk_smit_lung_gtv && git clone https://github.com/cerr/model_installer.git && cd model_installer && source installer.sh -m 4 -d /app/models/msk_smit_lung_gtv -p C +RUN chmod -R 755 /app/models/msk_smit_lung_gtv/CT_Lung_SMIT + +ENV PYTHONPATH="/app:/app/models/msk_smit_lung_gtv/CT_Lung_SMIT/conda-pack" + +RUN source /app/models/msk_smit_lung_gtv/CT_Lung_SMIT/conda-pack/bin/activate + +ENTRYPOINT ["mhub.run"] +CMD ["--config", "/app/models/msk_smit_lung_gtv/config/default.yml"] From 4de5a52afa4503db6d153a597438ab4ba5854913 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:27:30 -0400 Subject: [PATCH 03/20] Create README.md --- models/msk_smit_lung_gtv/src/README.md | 4 ++++ 1 file changed, 4 insertions(+) create mode 100644 models/msk_smit_lung_gtv/src/README.md diff --git a/models/msk_smit_lung_gtv/src/README.md b/models/msk_smit_lung_gtv/src/README.md new file mode 100644 index 00000000..2f8ac1fd --- /dev/null +++ b/models/msk_smit_lung_gtv/src/README.md @@ -0,0 +1,4 @@ +## References +[1] Jiang, Jue, and Harini Veeraraghavan. "Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation." In Medical Imaging with Deep Learning. 2024. + +[2] Jiang, Jue, Neelam Tyagi, Kathryn Tringale, Christopher Crane, and Harini Veeraraghavan. "Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT)." In International Conference on Medical Image Computing and Computer-Assisted Intervention, pp. 556-566. Cham: Springer Nature Switzerland, 2022. From 54f36b762fd654058fc23b87327d09fd83ca1cef Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:37:27 -0400 Subject: [PATCH 04/20] draft SMITrunner.py --- models/msk_smit_lung_gtv/utils/SMITrunner.py | 52 ++++++++++++++++++++ 1 file changed, 52 insertions(+) create mode 100644 models/msk_smit_lung_gtv/utils/SMITrunner.py diff --git a/models/msk_smit_lung_gtv/utils/SMITrunner.py b/models/msk_smit_lung_gtv/utils/SMITrunner.py new file mode 100644 index 00000000..8acb6eaf --- /dev/null +++ b/models/msk_smit_lung_gtv/utils/SMITrunner.py @@ -0,0 +1,52 @@ +""" +------------------------------------------------- +MHub - Run Module for SMIT +------------------------------------------------- + +------------------------------------------------- +Author: Jue Jiang +Email: jiangj1@mskcc.org +------------------------------------------------- +""" + +import os, subprocess, shutil +from mhubio.core import Instance, InstanceData, IO +from mhubio.modules.runner.ModelRunner import ModelRunner + +# Optional config parameter/s examples noted below +# @IO.Config('a_min', int, -500, the='Min frequency of image') +# @IO.Config('a_max', int, 500, the='Max frequency of image') + +class SMITRunner(ModelRunner): + + # a_min: int + # a_max: int + + @IO.Instance() + @IO.Input('scan', 'nifti:mod=ct', the='input ct scan') + @IO.Output('gtv_mask', 'gtv_mask.nii.gz', 'nifti:mod=seg:model=SMIT:roi=GTV',data='scan', the='predicted lung gtv') + def task(self, instance: Instance, scan: InstanceData, gtv_mask: InstanceData) -> None: + + workDir = os.environ['WORK_DIR'] # Needs to be defined in docker file as ENV WORK_DIR=path_to_dir e.g. /app/models/SMIT/workDir + #wrapperInstallDir = os.path.join(workDir,'CT_Lung_SMIT') + #condaEnvDir = os.path.join(wrapperInstallDir,'conda-pack') + #condaEnvActivateScript = os.path.join(condaEnvDir, 'bin', 'activate') + wrapperPath = os.path.join(workDir,'bash_run_SMIT_Segmentation.sh') + load_weight_name = os.path.join(workDir,'trained_weights','model.pt') + + sessionPath = os.path.join(workDir, 'session') + os.makedirs(sessionPath, exist_ok = True) + + subj = os.path.basename(scan.abspath) # Was originally dcmdir so might want to change + sessiondir = os.path.join(sessionPath,subj) + os.makedirs(sessiondir,exist_ok=True) + + # bash command for SMIT + bash_command = f"source " + condaEnvActivateScript + " && source " + wrapperPath + " " + sessiondir + " " + sessiondir + " " + load_weight_name + " " + scan.abspath + + # Display command on terminal + self.log("Running SMIT") + self.log(">> ".join(bash_command)) + + # run SMIT + self.subprocess(bash_command, text=True) From 84ffd42318ce45f592b2291d36fd4c2f9326e9c5 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:38:29 -0400 Subject: [PATCH 05/20] add src files --- .../src/bash_run_SMIT_Segmentation.sh | 49 ++++ .../src/edit_inference_utils.py | 186 +++++++++++++ .../msk_smit_lung_gtv/src/run_segmentation.py | 252 ++++++++++++++++++ 3 files changed, 487 insertions(+) create mode 100644 models/msk_smit_lung_gtv/src/bash_run_SMIT_Segmentation.sh create mode 100644 models/msk_smit_lung_gtv/src/edit_inference_utils.py create mode 100644 models/msk_smit_lung_gtv/src/run_segmentation.py diff --git a/models/msk_smit_lung_gtv/src/bash_run_SMIT_Segmentation.sh b/models/msk_smit_lung_gtv/src/bash_run_SMIT_Segmentation.sh new file mode 100644 index 00000000..b1c79fad --- /dev/null +++ b/models/msk_smit_lung_gtv/src/bash_run_SMIT_Segmentation.sh @@ -0,0 +1,49 @@ +#!/bin/bash +# +# +# Input arguments: +# $1 data_dir +# $2 save_folder +# $3 load_weight_name +# $4 input_nifti + +source ./conda-pack/bin/activate + +#Use SMIT +use_smit=1 #Use SMIT not SMIT+ + +#Data folder and there need a 'data.json' file in the folder +data_dir="$1" + +#Segmentation output folder +save_folder="$2" + +#Some configrations for the model, no need to change +#Trained weight +load_weight_name="$3" + +input_nifti="$4" + +a_min=-500 +a_max=500 +space_x=1.5 +space_y=1.5 +space_z=2.0 +out_channels=2 + +python utils/gen_data_json.py $input_nifti + +python run_segmentation.py \ + --roi_x 128 \ + --roi_y 128 \ + --roi_z 128 \ + --space_x $space_x \ + --space_y $space_y \ + --space_z $space_z \ + --data_dir $data_dir \ + --out_channels $out_channels \ + --load_weight_name $load_weight_name \ + --save_folder $save_folder \ + --a_min=$a_min \ + --a_max=$a_max \ + --use_smit $use_smit diff --git a/models/msk_smit_lung_gtv/src/edit_inference_utils.py b/models/msk_smit_lung_gtv/src/edit_inference_utils.py new file mode 100644 index 00000000..1455228b --- /dev/null +++ b/models/msk_smit_lung_gtv/src/edit_inference_utils.py @@ -0,0 +1,186 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Callable, List, Sequence, Tuple, Union + +import torch +import torch.nn.functional as F + +from monai.data.utils import compute_importance_map, dense_patch_slices, get_valid_patch_size +from monai.utils import BlendMode, PytorchPadMode, fall_back_tuple, look_up_option, optional_import + +import time + +tqdm, _ = optional_import("tqdm", name="tqdm") + +__all__ = ["sliding_window_inference"] + + +def sliding_window_inference( + inputs: torch.Tensor, + roi_size: Union[Sequence[int], int], + sw_batch_size: int, + predictor: Callable[..., torch.Tensor], + overlap: float = 0.25, + mode: Union[BlendMode, str] = BlendMode.CONSTANT, + sigma_scale: Union[Sequence[float], float] = 0.125, + padding_mode: Union[PytorchPadMode, str] = PytorchPadMode.CONSTANT, + cval: float = 0.0, + sw_device: Union[torch.device, str, None] = None, + device: Union[torch.device, str, None] = None, + *args: Any, + **kwargs: Any, +) -> torch.Tensor: + """ + Sliding window inference on `inputs` with `predictor`. + + When roi_size is larger than the inputs' spatial size, the input image are padded during inference. + To maintain the same spatial sizes, the output image will be cropped to the original input size. + + Args: + inputs: input image to be processed (assuming NCHW[D]) + roi_size: the spatial window size for inferences. + When its components have None or non-positives, the corresponding inputs dimension will be used. + if the components of the `roi_size` are non-positive values, the transform will use the + corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted + to `(32, 64)` if the second spatial dimension size of img is `64`. + sw_batch_size: the batch size to run window slices. + predictor: given input tensor `patch_data` in shape NCHW[D], `predictor(patch_data)` + should return a prediction with the same spatial shape and batch_size, i.e. NMHW[D]; + where HW[D] represents the patch spatial size, M is the number of output channels, N is `sw_batch_size`. + overlap: Amount of overlap between scans. + mode: {``"constant"``, ``"gaussian"``} + How to blend output of overlapping windows. Defaults to ``"constant"``. + + - ``"constant``": gives equal weight to all predictions. + - ``"gaussian``": gives less weight to predictions on edges of windows. + + sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``. + Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``. + When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding + spatial dimensions. + padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``} + Padding mode for ``inputs``, when ``roi_size`` is larger than inputs. Defaults to ``"constant"`` + See also: https://pytorch.org/docs/stable/nn.functional.html#pad + cval: fill value for 'constant' padding mode. Default: 0 + sw_device: device for the window data. + By default the device (and accordingly the memory) of the `inputs` is used. + Normally `sw_device` should be consistent with the device where `predictor` is defined. + device: device for the stitched output prediction. + By default the device (and accordingly the memory) of the `inputs` is used. If for example + set to device=torch.device('cpu') the gpu memory consumption is less and independent of the + `inputs` and `roi_size`. Output is on the `device`. + args: optional args to be passed to ``predictor``. + kwargs: optional keyword args to be passed to ``predictor``. + + Note: + - input must be channel-first and have a batch dim, supports N-D sliding window. + + """ + num_spatial_dims = len(inputs.shape) - 2 + if overlap < 0 or overlap >= 1: + raise AssertionError("overlap must be >= 0 and < 1.") + + # determine image spatial size and batch size + # Note: all input images must have the same image size and batch size + image_size_ = list(inputs.shape[2:]) + batch_size = inputs.shape[0] + + if device is None: + device = inputs.device + if sw_device is None: + sw_device = inputs.device + + roi_size = fall_back_tuple(roi_size, image_size_) + # in case that image size is smaller than roi size + image_size = tuple(max(image_size_[i], roi_size[i]) for i in range(num_spatial_dims)) + pad_size = [] + for k in range(len(inputs.shape) - 1, 1, -1): + diff = max(roi_size[k - 2] - inputs.shape[k], 0) + half = diff // 2 + pad_size.extend([half, diff - half]) + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode).value, value=cval) + + scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) + + # Store all slices in list + slices = dense_patch_slices(image_size, roi_size, scan_interval) + num_win = len(slices) # number of windows per image + total_slices = num_win * batch_size # total number of windows + + # Create window-level importance map + importance_map = compute_importance_map( + get_valid_patch_size(image_size, roi_size), mode=mode, sigma_scale=sigma_scale, device=device + ) + importance_map=importance_map.cpu() + # Perform predictions + output_image, count_map = torch.tensor(0.0, device=device), torch.tensor(0.0, device=device) + _initialized = False + for slice_g in range(0, total_slices, sw_batch_size): + slice_range = range(slice_g, min(slice_g + sw_batch_size, total_slices)) + unravel_slice = [ + [slice(int(idx / num_win), int(idx / num_win) + 1), slice(None)] + list(slices[idx % num_win]) + for idx in slice_range + ] + window_data = torch.cat([inputs[win_slice] for win_slice in unravel_slice]).to(sw_device) + seg_prob = predictor(window_data, *args, **kwargs).to(device) # batched patch segmentation + + if not _initialized: # init. buffer at the first iteration + output_classes = seg_prob.shape[1] + output_shape = [batch_size, output_classes] + list(image_size) + # allocate memory to store the full output and the count for overlapping parts + #output_image = torch.zeros(output_shape, dtype=torch.float32, device=device) + #count_map = torch.zeros(output_shape, dtype=torch.float32, device=device) + + output_image = torch.zeros(output_shape, dtype=torch.float32, device='cpu') + count_map = torch.zeros(output_shape, dtype=torch.float32, device='cpu') + + _initialized = True + + # store the result in the proper location of the full output. Apply weights from importance map. + for idx, original_idx in zip(slice_range, unravel_slice): + output_image[original_idx] += importance_map * seg_prob[idx - slice_g].cpu() + count_map[original_idx] += importance_map + + # account for any overlapping sections + output_image = output_image / count_map + + final_slicing: List[slice] = [] + for sp in range(num_spatial_dims): + slice_dim = slice(pad_size[sp * 2], image_size_[num_spatial_dims - sp - 1] + pad_size[sp * 2]) + final_slicing.insert(0, slice_dim) + while len(final_slicing) < len(output_image.shape): + final_slicing.insert(0, slice(None)) + return output_image[final_slicing] + + +def _get_scan_interval( + image_size: Sequence[int], roi_size: Sequence[int], num_spatial_dims: int, overlap: float +) -> Tuple[int, ...]: + """ + Compute scan interval according to the image size, roi size and overlap. + Scan interval will be `int((1 - overlap) * roi_size)`, if interval is 0, + use 1 instead to make sure sliding window works. + + """ + if len(image_size) != num_spatial_dims: + raise ValueError("image coord different from spatial dims.") + if len(roi_size) != num_spatial_dims: + raise ValueError("roi coord different from spatial dims.") + + scan_interval = [] + for i in range(num_spatial_dims): + if roi_size[i] == image_size[i]: + scan_interval.append(int(roi_size[i])) + else: + interval = int(roi_size[i] * (1 - overlap)) + scan_interval.append(interval if interval > 0 else 1) + return tuple(scan_interval) \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/run_segmentation.py b/models/msk_smit_lung_gtv/src/run_segmentation.py new file mode 100644 index 00000000..84bd50f0 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/run_segmentation.py @@ -0,0 +1,252 @@ +import SimpleITK as sitk +import os +import torch +import numpy as np +from edit_inference_utils import sliding_window_inference +from torch.cuda.amp import GradScaler, autocast +import argparse +from monai import transforms, data +from monai.handlers.utils import from_engine +from monai.data import decollate_batch, load_decathlon_datalist +from monai.transforms import ( + AsDiscrete, + AsDiscreted, + EnsureChannelFirstd, + Compose, + CropForegroundd, + SpatialPadd, + LoadImaged, + Orientationd, + RandCropByPosNegLabeld, + ScaleIntensityRanged, + Spacingd, + EnsureTyped, + EnsureType, + Invertd, +) + +from smit_models.smit import CONFIGS as CONFIGS_SMIT +import smit_models.smit as smit +from smit_models import smit_plus + + +from skimage.measure import label +import scipy.ndimage.morphology as snm +import skimage + + +parser = argparse.ArgumentParser(description='UNETR segmentation pipeline') +parser.add_argument('--pretrained_dir', default='./pretrained_models/', type=str, + help='pretrained checkpoint directory') + +parser.add_argument('--data_dir', default='/scratch/input', type=str, + help='dataset directory') +parser.add_argument('--json_list', + default='/scratch/input/data.json', + type=str, help='dataset json file') + +parser.add_argument('--pretrained_model_name', default='model.pt', type=str, + help='pretrained model name') +parser.add_argument('--saved_checkpoint', default='ckpt', type=str, + help='Supports torchscript or ckpt pretrained checkpoint type') +parser.add_argument('--mlp_dim', default=3072, type=int, help='mlp dimention in ViT encoder') +parser.add_argument('--hidden_size', default=768, type=int, help='hidden size dimention in ViT encoder') +parser.add_argument('--feature_size', default=16, type=int, help='feature size dimention') +parser.add_argument('--infer_overlap', default=0.5, type=float, help='sliding window inference overlap') +parser.add_argument('--in_channels', default=1, type=int, help='number of input channels') +parser.add_argument('--out_channels', default=1 + 6, type=int, help='number of output channels') +parser.add_argument('--num_heads', default=12, type=int, help='number of attention heads in ViT encoder') +parser.add_argument('--res_block', action='store_true', help='use residual blocks') +parser.add_argument('--conv_block', action='store_true', help='use conv blocks') +parser.add_argument('--a_min', default=-140, type=float, help='a_min in ScaleIntensityRanged') +parser.add_argument('--a_max', default=260, type=float, help='a_max in ScaleIntensityRanged') +parser.add_argument('--b_min', default=0.0, type=float, help='b_min in ScaleIntensityRanged') +parser.add_argument('--b_max', default=1.0, type=float, help='b_max in ScaleIntensityRanged') +parser.add_argument('--space_x', default=1.0, type=float, help='spacing in x direction') +parser.add_argument('--space_y', default=1.0, type=float, help='spacing in y direction') +parser.add_argument('--space_z', default=1.0, type=float, help='spacing in z direction') +parser.add_argument('--roi_x', default=128, type=int, help='roi size in x direction') +parser.add_argument('--roi_y', default=128, type=int, help='roi size in y direction') +parser.add_argument('--roi_z', default=128, type=int, help='roi size in z direction') +parser.add_argument('--dropout_rate', default=0.0, type=float, help='dropout rate') +parser.add_argument('--distributed', action='store_true', help='start distributed training') +parser.add_argument('--workers', default=8, type=int, help='number of workers') +parser.add_argument('--RandFlipd_prob', default=0.8, type=float, help='RandFlipd aug probability') +parser.add_argument('--RandRotate90d_prob', default=0.2, type=float, help='RandRotate90d aug probability') +parser.add_argument('--RandScaleIntensityd_prob', default=0.1, type=float, help='RandScaleIntensityd aug probability') +parser.add_argument('--RandShiftIntensityd_prob', default=0.1, type=float, help='RandShiftIntensityd aug probability') +parser.add_argument('--pos_embed', default='perceptron', type=str, help='type of position embedding') +parser.add_argument('--norm_name', default='instance', type=str, help='normalization layer type in decoder') +parser.add_argument('--load_weight_name', default='a', type=str, help='trained_weight') +parser.add_argument('--save_folder', default='a', type=str, help='output_folder') +parser.add_argument('--model_feature', default=96, type=int, help='model_imbeding_feature size') +parser.add_argument('--scale_intensity', action='store_true', help='') +parser.add_argument('--use_smit', default=0, type=int, help='use smit model') + + +# copy spacing and orientation info between sitk objects +def copy_info(src, dst): + dst.SetSpacing(src.GetSpacing()) + dst.SetOrigin(src.GetOrigin()) + dst.SetDirection(src.GetDirection()) + + return dst + +#Additional functions to filter out the body + + +# thresholding the intensity values to get a binary mask of the patient +def fg_mask2d(img_2d, thresh): # + mask_map = np.float32(img_2d > thresh) + + def getLargestCC(segmentation): # largest connected components + labels = label(segmentation) + assert( labels.max() != 0 ) # assume at least 1 CC + largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1 + return largestCC + if mask_map.max() < 0.999: + return mask_map + else: + post_mask = getLargestCC(mask_map) + fill_mask = snm.binary_fill_holes(post_mask) + return fill_mask + + +def Get_body_wrapper(img, verbose = False, fg_thresh = 1e-4): + + fg_mask_vol = np.zeros(img.shape) + for ii in range(fg_mask_vol.shape[0]): + if verbose: + print("doing {} slice".format(ii)) + _fgm = fg_mask2d(img[ii, ...], fg_thresh ) + + + fg_mask_vol[ii] = _fgm + + return fg_mask_vol + + +def main(): + args = parser.parse_args() + + img_folder = args.data_dir + save_folder = args.save_folder + if not os.path.exists(save_folder): + os.makedirs(save_folder) + + data_dir = args.data_dir + datalist_json = os.path.join(args.data_dir, 'data.json') + + + val_org_transforms = Compose( + [ + LoadImaged(keys=["image"]), + EnsureChannelFirstd(keys=["image"]), + Spacingd(keys=["image"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode="bilinear"), + Orientationd(keys=["image"], axcodes="RAS"), + ScaleIntensityRanged( + keys=["image"], a_min=args.a_min, a_max=args.a_max, + b_min=0.0, b_max=1.0, clip=True, + ), + CropForegroundd(keys=["image"], source_key="image"), + SpatialPadd(keys=["image"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + EnsureTyped(keys=["image"]), + ] + ) + + + test_files = load_decathlon_datalist(datalist_json, + True, + "val", + base_dir=data_dir) + + val_org_ds = data.Dataset(data=test_files, transform=val_org_transforms) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, num_workers=4) + + print('val data size is ', len(val_org_loader)) + post_transforms = Compose([ + EnsureTyped(keys="pred"), + Invertd( + keys="pred", + transform=val_org_transforms, + orig_keys="image", + meta_keys="pred_meta_dict", + orig_meta_keys="image_meta_dict", + meta_key_postfix="meta_dict", + nearest_interp=True, + to_tensor=True, + ), + AsDiscreted(keys="pred", argmax=True), + + ]) + + args.test_mode = True + val_loader = val_org_loader + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + if args.use_smit == 1: + config = CONFIGS_SMIT['SMIT_config'] + model = smit.SMIT_3D_Seg(config, + out_channels=args.out_channels, + norm_name='instance') + else: + model = smit_plus.SMIT_Plus(out_channels=args.out_channels, + in_channels=args.in_channels, + norm_name='instance', + feature_size=args.model_feature) + + model_dict = torch.load(args.load_weight_name) + + print('info: started to load weight: ', args.load_weight_name) + print('info: model emb feature is : ', args.model_feature) + model.load_state_dict(model_dict['state_dict'], strict=True) + model.eval() + model.to(device) + print('info: Successfully loaded trained weights: ', args.load_weight_name) + + with torch.no_grad(): + for i, val_data in enumerate(val_loader): + val_inputs = val_data["image"].cuda() + + img_name = val_data['image_meta_dict']['filename_or_obj'][0].split('/')[-1] + + with autocast(enabled=True): + val_data["pred"] = sliding_window_inference(val_inputs, + (args.roi_x, args.roi_y, args.roi_z), + 4, + model, + overlap=args.infer_overlap) + + val_data = [post_transforms(i) for i in decollate_batch(val_data)] + + val_outputs = from_engine(["pred"])(val_data) + + val_outputs = val_outputs[0] + + seg_ori_size = val_outputs.numpy().astype(np.uint8) + seg_ori_size = np.squeeze(seg_ori_size) + + pred_sv_name = os.path.join(save_folder, os.path.split(args.load_weight_name)[-1].replace('.pt', '') + '_' + img_name) + + print('info: start get the info') + + #Start to filter the body + cur_rd_path = os.path.join(img_folder, img_name) + im_obj = sitk.ReadImage(cur_rd_path) + img_3d_data=sitk.GetArrayFromImage(im_obj) + threshold_= -150 + out_fg= Get_body_wrapper(img_3d_data, fg_thresh = threshold_) + out_fg=np.transpose(out_fg, (2, 1, 0)) + seg_ori_size[out_fg==0]=0 + seg_ori_size = np.transpose(seg_ori_size, (2, 1, 0)) + out_fg_o = sitk.GetImageFromArray(seg_ori_size) + seg_ori_size = copy_info(im_obj, out_fg_o) + sitk.WriteImage(seg_ori_size, pred_sv_name) + + +if __name__ == '__main__': + main() From 903d37ccbd51e6bff2cff22c35ae6a1a009f273b Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:39:57 -0400 Subject: [PATCH 06/20] add utils --- models/msk_smit_lung_gtv/src/utils/utils.py | 73 +++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 models/msk_smit_lung_gtv/src/utils/utils.py diff --git a/models/msk_smit_lung_gtv/src/utils/utils.py b/models/msk_smit_lung_gtv/src/utils/utils.py new file mode 100644 index 00000000..ee7de873 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/utils/utils.py @@ -0,0 +1,73 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np + +def dice(x, y): + intersect = np.sum(np.sum(np.sum(x * y))) + y_sum = np.sum(np.sum(np.sum(y))) + if y_sum == 0: + return 0.0 + x_sum = np.sum(np.sum(np.sum(x))) + return 2 * intersect / (x_sum + y_sum) + +class AverageMeter(object): + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = np.where(self.count > 0, + self.sum / self.count, + self.sum) + +def distributed_all_gather(tensor_list, + valid_batch_size=None, + out_numpy=False, + world_size=None, + no_barrier=False, + is_valid=None): + + if world_size is None: + world_size = torch.distributed.get_world_size() + if valid_batch_size is not None: + valid_batch_size = min(valid_batch_size, world_size) + elif is_valid is not None: + is_valid = torch.tensor(bool(is_valid), dtype=torch.bool, device=tensor_list[0].device) + if not no_barrier: + torch.distributed.barrier() + tensor_list_out = [] + with torch.no_grad(): + if is_valid is not None: + is_valid_list = [torch.zeros_like(is_valid) for _ in range(world_size)] + torch.distributed.all_gather(is_valid_list, is_valid) + is_valid = [x.item() for x in is_valid_list] + for tensor in tensor_list: + gather_list = [torch.zeros_like(tensor) for _ in range(world_size)] + torch.distributed.all_gather(gather_list, tensor) + if valid_batch_size is not None: + gather_list = gather_list[:valid_batch_size] + elif is_valid is not None: + gather_list = [g for g,v in zip(gather_list, is_valid_list) if v] + if out_numpy: + gather_list = [t.cpu().numpy() for t in gather_list] + tensor_list_out.append(gather_list) + return tensor_list_out From d38cfe4b69404db3ba70a2a93a8a48ff0224e4fc Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:40:18 -0400 Subject: [PATCH 07/20] Add files via upload --- .../msk_smit_lung_gtv/src/utils/data_utils.py | 825 ++++++++++++++++++ .../src/utils/gen_data_json.py | 28 + 2 files changed, 853 insertions(+) create mode 100644 models/msk_smit_lung_gtv/src/utils/data_utils.py create mode 100644 models/msk_smit_lung_gtv/src/utils/gen_data_json.py diff --git a/models/msk_smit_lung_gtv/src/utils/data_utils.py b/models/msk_smit_lung_gtv/src/utils/data_utils.py new file mode 100644 index 00000000..b39c6d41 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/utils/data_utils.py @@ -0,0 +1,825 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import math +import numpy as np +import torch +from monai import transforms, data +from monai.data import load_decathlon_datalist + +from monai.transforms import OneOf, RandCoarseDropoutd + + +class Sampler(torch.utils.data.Sampler): + def __init__(self, dataset, num_replicas=None, rank=None, + shuffle=True, make_even=True): + if num_replicas is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + num_replicas = torch.distributed.get_world_size() + if rank is None: + if not torch.distributed.is_available(): + raise RuntimeError("Requires distributed package to be available") + rank = torch.distributed.get_rank() + self.shuffle = shuffle + self.make_even = make_even + self.dataset = dataset + self.num_replicas = num_replicas + self.rank = rank + self.epoch = 0 + self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) + self.total_size = self.num_samples * self.num_replicas + indices = list(range(len(self.dataset))) + self.valid_length = len(indices[self.rank:self.total_size:self.num_replicas]) + + def __iter__(self): + if self.shuffle: + g = torch.Generator() + g.manual_seed(self.epoch) + indices = torch.randperm(len(self.dataset), generator=g).tolist() + else: + indices = list(range(len(self.dataset))) + if self.make_even: + if len(indices) < self.total_size: + if self.total_size - len(indices) < len(indices): + indices += indices[:(self.total_size - len(indices))] + else: + extra_ids = np.random.randint(low=0, high=len(indices), size=self.total_size - len(indices)) + indices += [indices[ids] for ids in extra_ids] + assert len(indices) == self.total_size + indices = indices[self.rank:self.total_size:self.num_replicas] + self.num_samples = len(indices) + return iter(indices) + + def __len__(self): + return self.num_samples + + def set_epoch(self, epoch): + self.epoch = epoch + + +def get_loader(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.ScaleIntensityRanged(keys=["image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_x), + pos=1, + neg=1, + num_samples=2, + image_key="image", + image_threshold=0, + ), + transforms.RandRotate90d( + keys=["image", "label"], + prob=args.RandRotate90d_prob, + max_k=3, + ), + transforms.RandScaleIntensityd(keys="image", + factors=0.1, + prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", + offsets=0.1, + prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.ScaleIntensityRanged(keys=["image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + if args.test_mode: + test_files = load_decathlon_datalist(datalist_json, + True, + "validation", + base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=val_transform) + test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True) + loader = test_loader + else: + datalist = load_decathlon_datalist(datalist_json, + True, + "training", + base_dir=data_dir) + if args.use_normal_dataset: + train_ds = data.Dataset(data=datalist, transform=train_transform) + else: + train_ds = data.CacheDataset( + data=datalist, + transform=train_transform, + cache_num=400, # 250, + cache_rate=1.0, + num_workers=args.workers, + ) + train_sampler = Sampler(train_ds) if args.distributed else None + train_loader = data.DataLoader(train_ds, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + pin_memory=True, + persistent_workers=True) + val_files = load_decathlon_datalist(datalist_json, + True, + "val", + base_dir=data_dir) + val_ds = data.Dataset(data=val_files, transform=val_transform) + val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None + val_loader = data.DataLoader(val_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=val_sampler, + pin_memory=True, + persistent_workers=True) + loader = [train_loader, val_loader] + + return loader + + + +def get_loader_multi_modality(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["primary_image", "secondary_image","label"]), + transforms.AddChanneld(keys=["primary_image", "secondary_image","label"]), + transforms.Orientationd(keys=["primary_image", "secondary_image","label"], + axcodes="RAS"), + + transforms.ScaleIntensityRanged(keys=["primary_image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + + transforms.CropForegroundd(keys=["primary_image", "secondary_image","label"], source_key="primary_image"), + + OneOf(transforms=[ + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x + 0.2, args.space_y + 0.2, args.space_z + 0.2), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x + 0.4, args.space_y + 0.4, args.space_z + 0.4), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x + 0.6, args.space_y + 0.6, args.space_z + 0.6), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x + 0.8, args.space_y + 0.8, args.space_z + 0.8), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x - 0.2, args.space_y - 0.2, args.space_z - 0.2), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x - 0.4, args.space_y - 0.4, args.space_z - 0.4), + mode=("bilinear","bilinear", "nearest")), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x - 0.6, args.space_y - 0.6, args.space_z - 0.6), + mode=("bilinear","bilinear", "nearest")), + ]), + + transforms.SpatialPadd(keys=["primary_image", "secondary_image","label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + transforms.RandCropByPosNegLabeld( + keys=["primary_image", "secondary_image","label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_x), + pos=1, + neg=1, + num_samples=1, + image_key="primary_image", + image_threshold=0, + ), + + #transforms.RandFlipd(keys=["image", "label"], + # prob=args.RandFlipd_prob, + # spatial_axis=0), + #transforms.RandFlipd(keys=["image", "label"], + # prob=args.RandFlipd_prob, + # spatial_axis=1), + #transforms.RandFlipd(keys=["image", "label"], + # prob=args.RandFlipd_prob, + # spatial_axis=2), + transforms.RandRotate90d( + keys=["primary_image", "secondary_image","label"], + prob=args.RandRotate90d_prob, + max_k=3, + ), + transforms.RandScaleIntensityd(keys=["primary_image", "secondary_image"], + factors=0.2, + prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys=["primary_image", "secondary_image"], + offsets=0.2, + prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["primary_image", "secondary_image","label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["primary_image", "secondary_image","label"]), + transforms.AddChanneld(keys=["primary_image", "secondary_image","label"]), + transforms.Orientationd(keys=["primary_image", "secondary_image","label"], + axcodes="RAS"), + + transforms.ScaleIntensityRanged(keys=["primary_image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + transforms.CropForegroundd(keys=["primary_image", "secondary_image","label"], source_key="primary_image"), + transforms.Spacingd(keys=["primary_image", "secondary_image","label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear","bilinear", "nearest")), + transforms.SpatialPadd(keys=["primary_image", "secondary_image","label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + #transforms.RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=(args.roi_x, args.roi_y, args.roi_x), + # pos=1, + # neg=1, + # num_samples=1, + # image_key="image", + # image_threshold=0, + #), + transforms.ToTensord(keys=["primary_image", "secondary_image","label"]), + ] + ) + + if args.test_mode: + test_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=val_transform) + test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True) + loader = test_loader + else: + datalist = load_decathlon_datalist(datalist_json, + True, + "train_data", + base_dir=data_dir) + #datalist = datalist[0:600] + if args.use_normal_dataset: + train_ds = data.Dataset(data=datalist, transform=train_transform) + else: + train_ds = data.CacheDataset( + data=datalist, + transform=train_transform, + cache_num=args.cache_num, # 200,#1200,#400,#200,#args.cache_num,# 250, + cache_rate=1.0, + num_workers=args.workers, + ) + train_sampler = Sampler(train_ds) if args.distributed else None + train_loader = data.DataLoader(train_ds, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + pin_memory=True, + persistent_workers=True) + val_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + #val_files=val_files[0] + val_ds = data.Dataset(data=val_files, transform=val_transform) + #val_ds=val_ds[0] + val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None + val_loader = data.DataLoader(val_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=val_sampler, + pin_memory=True, + persistent_workers=True) + loader = [train_loader, val_loader] + + return loader + + +def get_loader_ct(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + + transforms.ScaleIntensityRanged(keys=["image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + + OneOf(transforms=[ + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.2, args.space_y + 0.2, args.space_z + 0.2), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.4, args.space_y + 0.4, args.space_z + 0.4), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.6, args.space_y + 0.6, args.space_z + 0.6), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.8, args.space_y + 0.8, args.space_z + 0.8), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.2, args.space_y - 0.2, args.space_z - 0.2), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.4, args.space_y - 0.4, args.space_z - 0.4), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.6, args.space_y - 0.6, args.space_z - 0.6), + mode=("bilinear", "nearest")), + ]), + + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_x), + pos=1, + neg=1, + num_samples=1, + image_key="image", + image_threshold=0, + ), + + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=0), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=1), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=2), + transforms.RandRotate90d( + keys=["image", "label"], + prob=args.RandRotate90d_prob, + max_k=3, + ), + transforms.RandScaleIntensityd(keys="image", + factors=0.2, + prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", + offsets=0.2, + prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.ScaleIntensityRanged(keys=["image"], + a_min=args.a_min, + a_max=args.a_max, + b_min=args.b_min, + b_max=args.b_max, + clip=True), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + #transforms.RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=(args.roi_x, args.roi_y, args.roi_x), + # pos=1, + # neg=1, + # num_samples=1, + # image_key="image", + # image_threshold=0, + #), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + if args.test_mode: + test_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=val_transform) + test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True) + loader = test_loader + else: + datalist = load_decathlon_datalist(datalist_json, + True, + "train_data", + base_dir=data_dir) + #datalist = datalist[0:600] + if args.use_normal_dataset: + train_ds = data.Dataset(data=datalist, transform=train_transform) + else: + train_ds = data.CacheDataset( + data=datalist, + transform=train_transform, + cache_num=args.cache_num, # 200,#1200,#400,#200,#args.cache_num,# 250, + cache_rate=1.0, + num_workers=args.workers, + ) + train_sampler = Sampler(train_ds) if args.distributed else None + train_loader = data.DataLoader(train_ds, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + pin_memory=True, + persistent_workers=True) + val_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + #val_files=val_files[0] + val_ds = data.Dataset(data=val_files, transform=val_transform) + #val_ds=val_ds[0] + val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None + val_loader = data.DataLoader(val_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=val_sampler, + pin_memory=True, + persistent_workers=True) + loader = [train_loader, val_loader] + + return loader + + +def get_loader_mr(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + + OneOf(transforms=[ + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.2, args.space_y + 0.2, args.space_z + 0.2), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.4, args.space_y + 0.4, args.space_z + 0.4), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.6, args.space_y + 0.6, args.space_z + 0.6), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x + 0.8, args.space_y + 0.8, args.space_z + 0.8), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.2, args.space_y - 0.2, args.space_z - 0.2), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.4, args.space_y - 0.4, args.space_z - 0.4), + mode=("bilinear", "nearest")), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x - 0.6, args.space_y - 0.6, args.space_z - 0.6), + mode=("bilinear", "nearest")), + ]), + + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_x), + pos=1, + neg=1, + num_samples=1, + image_key="image", + image_threshold=0, + ), + + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=0), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=1), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=2), + transforms.RandRotate90d( + keys=["image", "label"], + prob=args.RandRotate90d_prob, + max_k=3, + ), + transforms.RandScaleIntensityd(keys="image", + factors=0.2, + prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", + offsets=0.2, + prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + # transforms.RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=(args.roi_x, args.roi_y, args.roi_x), + # pos=1, + # neg=1, + # num_samples=1, + # image_key="image", + # image_threshold=0, + # ), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + if args.test_mode: + test_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=val_transform) + test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True) + loader = test_loader + else: + datalist = load_decathlon_datalist(datalist_json, + True, + "train_data", + base_dir=data_dir) + #datalist = datalist[0:600] + if args.use_normal_dataset: + train_ds = data.Dataset(data=datalist, transform=train_transform) + else: + train_ds = data.CacheDataset( + data=datalist, + transform=train_transform, + cache_num=args.cache_num, # 200,#1200,#400,#200,#args.cache_num,# 250, + cache_rate=1.0, + num_workers=args.workers, + ) + train_sampler = Sampler(train_ds) if args.distributed else None + train_loader = data.DataLoader(train_ds, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + pin_memory=True, + persistent_workers=True) + val_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + val_ds = data.Dataset(data=val_files, transform=val_transform) + val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None + val_loader = data.DataLoader(val_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=val_sampler, + pin_memory=True, + persistent_workers=True) + loader = [train_loader, val_loader] + + return loader + + +def get_loader_no_spatial_aug(args): + data_dir = args.data_dir + datalist_json = os.path.join(data_dir, args.json_list) + train_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + + OneOf(transforms=[ + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + ]), + + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + + transforms.RandCropByPosNegLabeld( + keys=["image", "label"], + label_key="label", + spatial_size=(args.roi_x, args.roi_y, args.roi_x), + pos=1, + neg=1, + num_samples=1, + image_key="image", + image_threshold=0, + ), + + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=0), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=1), + transforms.RandFlipd(keys=["image", "label"], + prob=args.RandFlipd_prob, + spatial_axis=2), + transforms.RandRotate90d( + keys=["image", "label"], + prob=args.RandRotate90d_prob, + max_k=3, + ), + transforms.RandScaleIntensityd(keys="image", + factors=0.2, + prob=args.RandScaleIntensityd_prob), + transforms.RandShiftIntensityd(keys="image", + offsets=0.2, + prob=args.RandShiftIntensityd_prob), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + val_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label"]), + transforms.AddChanneld(keys=["image", "label"]), + transforms.Orientationd(keys=["image", "label"], + axcodes="RAS"), + transforms.CropForegroundd(keys=["image", "label"], source_key="image"), + transforms.Spacingd(keys=["image", "label"], + pixdim=(args.space_x, args.space_y, args.space_z), + mode=("bilinear", "nearest")), + transforms.SpatialPadd(keys=["image", "label"], spatial_size=(args.roi_x, args.roi_y, args.roi_z)), + # transforms.RandCropByPosNegLabeld( + # keys=["image", "label"], + # label_key="label", + # spatial_size=(args.roi_x, args.roi_y, args.roi_x), + # pos=1, + # neg=1, + # num_samples=1, + # image_key="image", + # image_threshold=0, + # ), + transforms.ToTensord(keys=["image", "label"]), + ] + ) + + if args.test_mode: + test_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + test_ds = data.Dataset(data=test_files, transform=val_transform) + test_sampler = Sampler(test_ds, shuffle=False) if args.distributed else None + test_loader = data.DataLoader(test_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=test_sampler, + pin_memory=True, + persistent_workers=True) + loader = test_loader + else: + datalist = load_decathlon_datalist(datalist_json, + True, + "train_data", + base_dir=data_dir) + #datalist = datalist[0:600] + if args.use_normal_dataset: + train_ds = data.Dataset(data=datalist, transform=train_transform) + else: + train_ds = data.CacheDataset( + data=datalist, + transform=train_transform, + cache_num=args.cache_num,#10, # 200,#1200,#400,#200,#args.cache_num,# 250, + cache_rate=1.0, + num_workers=args.workers, + ) + train_sampler = Sampler(train_ds) if args.distributed else None + train_loader = data.DataLoader(train_ds, + batch_size=args.batch_size, + shuffle=(train_sampler is None), + num_workers=args.workers, + sampler=train_sampler, + pin_memory=True, + persistent_workers=True) + val_files = load_decathlon_datalist(datalist_json, + True, + "val_data", + base_dir=data_dir) + val_ds = data.Dataset(data=val_files, transform=val_transform) + val_sampler = Sampler(val_ds, shuffle=False) if args.distributed else None + val_loader = data.DataLoader(val_ds, + batch_size=1, + shuffle=False, + num_workers=args.workers, + sampler=val_sampler, + pin_memory=True, + persistent_workers=True) + loader = [train_loader, val_loader] + + return loader \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/utils/gen_data_json.py b/models/msk_smit_lung_gtv/src/utils/gen_data_json.py new file mode 100644 index 00000000..9579a2a3 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/utils/gen_data_json.py @@ -0,0 +1,28 @@ +import os, json, sys + + +data_file_path = sys.argv[1] + +data_dir = os.path.dirname(data_file_path) +nii_file = os.path.basename(data_file_path) + + +out_json = os.path.join(data_dir,'data.json') + +data_json = { + "val": + [ + { + "image": nii_file + } + + ] +} + + + +json_object = json.dumps(data_json, indent=4) + +# Writing to sample.json +with open(out_json, "w") as outfile: + outfile.write(json_object) From 1802f70fb567b74d732b25925447bade51baf890 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:43:29 -0400 Subject: [PATCH 08/20] Add files via upload --- .../src/smit_models/__init__.py | 1 + .../src/smit_models/_features.py | 367 + .../src/smit_models/_features_fx.py | 127 + .../src/smit_models/configs_smit.py | 63 + .../cross_swin_networks/SwinTransModels.py | 8184 ++++++++++++++++ .../cross_swin_networks/__init__.py | 0 .../smit_models/cross_swin_networks/unetr.py | 222 + .../src/smit_models/format.py | 59 + .../msk_smit_lung_gtv/src/smit_models/smit.py | 1160 +++ .../src/smit_models/smit_cross_attention.py | 8466 +++++++++++++++++ .../src/smit_models/smit_plus.py | 1938 ++++ 11 files changed, 20587 insertions(+) create mode 100644 models/msk_smit_lung_gtv/src/smit_models/__init__.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/_features.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/_features_fx.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/configs_smit.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/SwinTransModels.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/__init__.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/unetr.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/format.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/smit.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/smit_cross_attention.py create mode 100644 models/msk_smit_lung_gtv/src/smit_models/smit_plus.py diff --git a/models/msk_smit_lung_gtv/src/smit_models/__init__.py b/models/msk_smit_lung_gtv/src/smit_models/__init__.py new file mode 100644 index 00000000..2247c678 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/__init__.py @@ -0,0 +1 @@ +import numpy as np \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/smit_models/_features.py b/models/msk_smit_lung_gtv/src/smit_models/_features.py new file mode 100644 index 00000000..2ca9295b --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/_features.py @@ -0,0 +1,367 @@ +""" PyTorch Feature Extraction Helpers + +A collection of classes, functions, modules to help extract features from models +and provide a common interface for describing them. + +The return_layers, module re-writing idea inspired by torchvision IntermediateLayerGetter +https://github.com/pytorch/vision/blob/d88d8961ae51507d0cb680329d985b1488b1b76b/torchvision/models/_utils.py + +Hacked together by / Copyright 2020 Ross Wightman +""" +from collections import OrderedDict, defaultdict +from copy import deepcopy +from functools import partial +from typing import Dict, List, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from timm.models.layers import Format + + +__all__ = ['FeatureInfo', 'FeatureHooks', 'FeatureDictNet', 'FeatureListNet', 'FeatureHookNet'] + + +class FeatureInfo: + + def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]): + prev_reduction = 1 + for fi in feature_info: + # sanity check the mandatory fields, there may be additional fields depending on the model + assert 'num_chs' in fi and fi['num_chs'] > 0 + assert 'reduction' in fi and fi['reduction'] >= prev_reduction + prev_reduction = fi['reduction'] + assert 'module' in fi + self.out_indices = out_indices + self.info = feature_info + + def from_other(self, out_indices: Tuple[int]): + return FeatureInfo(deepcopy(self.info), out_indices) + + def get(self, key, idx=None): + """ Get value by key at specified index (indices) + if idx == None, returns value for key at each output index + if idx is an integer, return value for that feature module index (ignoring output indices) + if idx is a list/tupple, return value for each module index (ignoring output indices) + """ + if idx is None: + return [self.info[i][key] for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i][key] for i in idx] + else: + return self.info[idx][key] + + def get_dicts(self, keys=None, idx=None): + """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None) + """ + if idx is None: + if keys is None: + return [self.info[i] for i in self.out_indices] + else: + return [{k: self.info[i][k] for k in keys} for i in self.out_indices] + if isinstance(idx, (tuple, list)): + return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx] + else: + return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys} + + def channels(self, idx=None): + """ feature channels accessor + """ + return self.get('num_chs', idx) + + def reduction(self, idx=None): + """ feature reduction (output stride) accessor + """ + return self.get('reduction', idx) + + def module_name(self, idx=None): + """ feature module name accessor + """ + return self.get('module', idx) + + def __getitem__(self, item): + return self.info[item] + + def __len__(self): + return len(self.info) + + +class FeatureHooks: + """ Feature Hook Helper + + This module helps with the setup and extraction of hooks for extracting features from + internal nodes in a model by node name. + + FIXME This works well in eager Python but needs redesign for torchscript. + """ + + def __init__( + self, + hooks: Sequence[str], + named_modules: dict, + out_map: Sequence[Union[int, str]] = None, + default_hook_type: str = 'forward', + ): + # setup feature hooks + self._feature_outputs = defaultdict(OrderedDict) + modules = {k: v for k, v in named_modules} + for i, h in enumerate(hooks): + hook_name = h['module'] + m = modules[hook_name] + hook_id = out_map[i] if out_map else hook_name + hook_fn = partial(self._collect_output_hook, hook_id) + hook_type = h.get('hook_type', default_hook_type) + if hook_type == 'forward_pre': + m.register_forward_pre_hook(hook_fn) + elif hook_type == 'forward': + m.register_forward_hook(hook_fn) + else: + assert False, "Unsupported hook type" + + def _collect_output_hook(self, hook_id, *args): + x = args[-1] # tensor we want is last argument, output for fwd, input for fwd_pre + if isinstance(x, tuple): + x = x[0] # unwrap input tuple + self._feature_outputs[x.device][hook_id] = x + + def get_output(self, device) -> Dict[str, torch.tensor]: + output = self._feature_outputs[device] + self._feature_outputs[device] = OrderedDict() # clear after reading + return output + + +def _module_list(module, flatten_sequential=False): + # a yield/iter would be better for this but wouldn't be compatible with torchscript + ml = [] + for name, module in module.named_children(): + if flatten_sequential and isinstance(module, nn.Sequential): + # first level of Sequential containers is flattened into containing model + for child_name, child_module in module.named_children(): + combined = [name, child_name] + ml.append(('_'.join(combined), '.'.join(combined), child_module)) + else: + ml.append((name, name, module)) + return ml + + +def _get_feature_info(net, out_indices): + feature_info = getattr(net, 'feature_info') + if isinstance(feature_info, FeatureInfo): + return feature_info.from_other(out_indices) + elif isinstance(feature_info, (list, tuple)): + return FeatureInfo(net.feature_info, out_indices) + else: + assert False, "Provided feature_info is not valid" + + +def _get_return_layers(feature_info, out_map): + module_names = feature_info.module_name() + return_layers = {} + for i, name in enumerate(module_names): + return_layers[name] = out_map[i] if out_map is not None else feature_info.out_indices[i] + return return_layers + + +class FeatureDictNet(nn.ModuleDict): + """ Feature extractor with OrderedDict return + + Wrap a model and extract features as specified by the out indices, the network is + partially re-built from contained modules. + + There is a strong assumption that the modules have been registered into the model in the same + order as they are used. There should be no reuse of the same nn.Module more than once, including + trivial modules like `self.relu = nn.ReLU`. + + Only submodules that are directly assigned to the model class (`model.feature1`) or at most + one Sequential container deep (`model.features.1`, with flatten_sequent=True) can be captured. + All Sequential containers that are directly assigned to the original model will have their + modules assigned to this module with the name `model.features.1` being changed to `model.features_1` + """ + def __init__( + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_map: Sequence[Union[int, str]] = None, + output_fmt: str = 'NCHW', + feature_concat: bool = False, + flatten_sequential: bool = False, + ): + """ + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + out_map: Return id mapping for each output index, otherwise str(index) is used. + feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting + first element e.g. `x[0]` + flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) + """ + super(FeatureDictNet, self).__init__() + self.feature_info = _get_feature_info(model, out_indices) + self.output_fmt = Format(output_fmt) + self.concat = feature_concat + self.grad_checkpointing = False + self.return_layers = {} + + return_layers = _get_return_layers(self.feature_info, out_map) + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = set(return_layers.keys()) + layers = OrderedDict() + for new_name, old_name, module in modules: + layers[new_name] = module + if old_name in remaining: + # return id has to be consistently str type for torchscript + self.return_layers[new_name] = str(return_layers[old_name]) + remaining.remove(old_name) + if not remaining: + break + assert not remaining and len(self.return_layers) == len(return_layers), \ + f'Return layers ({remaining}) are not present in model' + self.update(layers) + + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + def _collect(self, x) -> (Dict[str, torch.Tensor]): + out = OrderedDict() + for i, (name, module) in enumerate(self.items()): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # Skipping checkpoint of first module because need a gradient at input + # Skipping last because networks with in-place ops might fail w/ checkpointing enabled + # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues + first_or_last_module = i == 0 or i == max(len(self) - 1, 0) + x = module(x) if first_or_last_module else checkpoint(module, x) + else: + x = module(x) + + if name in self.return_layers: + out_id = self.return_layers[name] + if isinstance(x, (tuple, list)): + # If model tap is a tuple or list, concat or select first element + # FIXME this may need to be more generic / flexible for some nets + out[out_id] = torch.cat(x, 1) if self.concat else x[0] + else: + out[out_id] = x + return out + + def forward(self, x) -> Dict[str, torch.Tensor]: + return self._collect(x) + + +class FeatureListNet(FeatureDictNet): + """ Feature extractor with list return + + A specialization of FeatureDictNet that always returns features as a list (values() of dict). + """ + def __init__( + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + output_fmt: str = 'NCHW', + feature_concat: bool = False, + flatten_sequential: bool = False, + ): + """ + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + feature_concat: Concatenate intermediate features that are lists or tuples instead of selecting + first element e.g. `x[0]` + flatten_sequential: Flatten first two-levels of sequential modules in model (re-writes model modules) + """ + super().__init__( + model, + out_indices=out_indices, + output_fmt=output_fmt, + feature_concat=feature_concat, + flatten_sequential=flatten_sequential, + ) + + def forward(self, x) -> (List[torch.Tensor]): + return list(self._collect(x).values()) + + +class FeatureHookNet(nn.ModuleDict): + """ FeatureHookNet + + Wrap a model and extract features specified by the out indices using forward/forward-pre hooks. + + If `no_rewrite` is True, features are extracted via hooks without modifying the underlying + network in any way. + + If `no_rewrite` is False, the model will be re-written as in the + FeatureList/FeatureDict case by folding first to second (Sequential only) level modules into this one. + + FIXME this does not currently work with Torchscript, see FeatureHooks class + """ + def __init__( + self, + model: nn.Module, + out_indices: Tuple[int, ...] = (0, 1, 2, 3, 4), + out_map: Sequence[Union[int, str]] = None, + return_dict: bool = False, + output_fmt: str = 'NCHW', + no_rewrite: bool = False, + flatten_sequential: bool = False, + default_hook_type: str = 'forward', + ): + """ + + Args: + model: Model from which to extract features. + out_indices: Output indices of the model features to extract. + out_map: Return id mapping for each output index, otherwise str(index) is used. + return_dict: Output features as a dict. + no_rewrite: Enforce that model is not re-written if True, ie no modules are removed / changed. + flatten_sequential arg must also be False if this is set True. + flatten_sequential: Re-write modules by flattening first two levels of nn.Sequential containers. + default_hook_type: The default hook type to use if not specified in model.feature_info. + """ + super().__init__() + assert not torch.jit.is_scripting() + self.feature_info = _get_feature_info(model, out_indices) + self.return_dict = return_dict + self.output_fmt = Format(output_fmt) + self.grad_checkpointing = False + + layers = OrderedDict() + hooks = [] + if no_rewrite: + assert not flatten_sequential + if hasattr(model, 'reset_classifier'): # make sure classifier is removed? + model.reset_classifier(0) + layers['body'] = model + hooks.extend(self.feature_info.get_dicts()) + else: + modules = _module_list(model, flatten_sequential=flatten_sequential) + remaining = { + f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type + for f in self.feature_info.get_dicts() + } + for new_name, old_name, module in modules: + layers[new_name] = module + for fn, fm in module.named_modules(prefix=old_name): + if fn in remaining: + hooks.append(dict(module=fn, hook_type=remaining[fn])) + del remaining[fn] + if not remaining: + break + assert not remaining, f'Return layers ({remaining}) are not present in model' + self.update(layers) + self.hooks = FeatureHooks(hooks, model.named_modules(), out_map=out_map) + + def set_grad_checkpointing(self, enable: bool = True): + self.grad_checkpointing = enable + + def forward(self, x): + for i, (name, module) in enumerate(self.items()): + if self.grad_checkpointing and not torch.jit.is_scripting(): + # Skipping checkpoint of first module because need a gradient at input + # Skipping last because networks with in-place ops might fail w/ checkpointing enabled + # NOTE: first_or_last module could be static, but recalc in is_scripting guard to avoid jit issues + first_or_last_module = i == 0 or i == max(len(self) - 1, 0) + x = module(x) if first_or_last_module else checkpoint(module, x) + else: + x = module(x) + out = self.hooks.get_output(x.device) + return out if self.return_dict else list(out.values()) \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/smit_models/_features_fx.py b/models/msk_smit_lung_gtv/src/smit_models/_features_fx.py new file mode 100644 index 00000000..bfd5c9f4 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/_features_fx.py @@ -0,0 +1,127 @@ +""" PyTorch FX Based Feature Extraction Helpers +Using https://pytorch.org/vision/stable/feature_extraction.html +""" +from typing import Callable, List, Dict, Union, Type + +import torch +from torch import nn + +from ._features import _get_feature_info + +try: + from torchvision.models.feature_extraction import create_feature_extractor as _create_feature_extractor + has_fx_feature_extraction = True +except ImportError: + has_fx_feature_extraction = False + +# Layers we went to treat as leaf modules +from timm.models.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame +from timm.models.layers.non_local_attn import BilinearAttnTransform +from timm.models.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame + +__all__ = ['register_notrace_module', 'is_notrace_module', 'get_notrace_modules', + 'register_notrace_function', 'is_notrace_function', 'get_notrace_functions', + 'create_feature_extractor', 'FeatureGraphNet', 'GraphExtractNet'] + + +# NOTE: By default, any modules from timm.models.layers that we want to treat as leaf modules go here +# BUT modules from timm.models should use the registration mechanism below +_leaf_modules = { + BilinearAttnTransform, # reason: flow control t <= 1 + # Reason: get_same_padding has a max which raises a control flow error + Conv2dSame, MaxPool2dSame, ScaledStdConv2dSame, StdConv2dSame, AvgPool2dSame, + CondConv2d, # reason: TypeError: F.conv2d received Proxy in groups=self.groups * B (because B = x.shape[0]) +} + +try: + from timm.models.layers import InplaceAbn + _leaf_modules.add(InplaceAbn) +except ImportError: + pass + + +def register_notrace_module(module: Type[nn.Module]): + """ + Any module not under timm.models.layers should get this decorator if we don't want to trace through it. + """ + _leaf_modules.add(module) + return module + + +def is_notrace_module(module: Type[nn.Module]): + return module in _leaf_modules + + +def get_notrace_modules(): + return list(_leaf_modules) + + +# Functions we want to autowrap (treat them as leaves) +_autowrap_functions = set() + + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +def is_notrace_function(func: Callable): + return func in _autowrap_functions + + +def get_notrace_functions(): + return list(_autowrap_functions) + + +def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str], List[str]]): + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + return _create_feature_extractor( + model, return_nodes, + tracer_kwargs={'leaf_modules': list(_leaf_modules), 'autowrap_functions': list(_autowrap_functions)} + ) + + +class FeatureGraphNet(nn.Module): + """ A FX Graph based feature extractor that works with the model feature_info metadata + """ + def __init__(self, model, out_indices, out_map=None): + super().__init__() + assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction' + self.feature_info = _get_feature_info(model, out_indices) + if out_map is not None: + assert len(out_map) == len(out_indices) + return_nodes = { + info['module']: out_map[i] if out_map is not None else info['module'] + for i, info in enumerate(self.feature_info) if i in out_indices} + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x): + return list(self.graph_module(x).values()) + + +class GraphExtractNet(nn.Module): + """ A standalone feature extraction wrapper that maps dict -> list or single tensor + NOTE: + * one can use feature_extractor directly if dictionary output is desired + * unlike FeatureGraphNet, this is intended to be used standalone and not with model feature_info + metadata for builtin feature extraction mode + * create_feature_extractor can be used directly if dictionary output is acceptable + + Args: + model: model to extract features from + return_nodes: node names to return features from (dict or list) + squeeze_out: if only one output, and output in list format, flatten to single tensor + """ + def __init__(self, model, return_nodes: Union[Dict[str, str], List[str]], squeeze_out: bool = True): + super().__init__() + self.squeeze_out = squeeze_out + self.graph_module = create_feature_extractor(model, return_nodes) + + def forward(self, x) -> Union[List[torch.Tensor], torch.Tensor]: + out = list(self.graph_module(x).values()) + if self.squeeze_out and len(out) == 1: + return out[0] + return out \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/smit_models/configs_smit.py b/models/msk_smit_lung_gtv/src/smit_models/configs_smit.py new file mode 100644 index 00000000..4922e6c6 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/configs_smit.py @@ -0,0 +1,63 @@ + + +import ml_collections + + +def get_SMIT_128_bias_True(): + ''' + A Large TransMorph Network + ''' + config = ml_collections.ConfigDict() + config.if_transskip = True + config.if_convskip = True + config.patch_size = 2 + config.in_chans = 1 + config.embed_dim = 128 + config.embed_dim = 48 + config.depths = (2, 2, 8, 2) + config.num_heads = (4, 4, 8, 16) + + config.window_size = (4, 4, 4) + + config.mlp_ratio = 4 + config.pat_merg_rf = 4 + config.qkv_bias = True + config.drop_rate = 0 + config.drop_path_rate = 0.3 + config.ape = False + config.spe = False + config.patch_norm = True + config.use_checkpoint = False + config.out_indices = (0, 1, 2, 3) + config.reg_head_chan = 16 + config.img_size = (128, 128, 128) + return config + + +def get_SMIT_128_bias_True_Cross(): + + config = ml_collections.ConfigDict() + config.if_transskip = True + config.if_convskip = True + config.patch_size = 2 + config.in_chans = 2 + config.embed_dim = 48 # change 128 or 192 + config.depths = (2, 2, 8, 2) # change 4 to 6,10 + config.num_heads = (4, 4, 8, 16) + config.window_size = (4, 4, 4) + config.mlp_ratio = 4 + config.pat_merg_rf = 4 + config.qkv_bias = True + config.drop_rate = 0 + config.drop_path_rate = 0.3 + config.ape = False + config.spe = False + config.patch_norm = True + config.use_checkpoint = False + config.out_indices = (0, 1, 2, 3) + config.seg_head_chan = config.embed_dim // 2 + config.img_size = (128, 128, 128) + config.pos_embed_method = 'relative' + return config + + diff --git a/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/SwinTransModels.py b/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/SwinTransModels.py new file mode 100644 index 00000000..d2836e13 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/SwinTransModels.py @@ -0,0 +1,8184 @@ +''' +Swin-Transformer with UNet + +Swin-Transformer code retrieved from: +https://github.com/SwinTransformer/Swin-Transformer-Semantic-Segmentation + +Original paper: +Liu, Z., Lin, Y., Cao, Y., Hu, H., Wei, Y., Zhang, Z., ... & Guo, B. (2021). +Swin transformer: Hierarchical vision transformer using shifted windows. +arXiv preprint arXiv:2103.14030. + +Modified and tested by: +Junyu Chen +jchen245@jhmi.edu +Johns Hopkins University +''' +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_, to_3tuple +from torch.distributions.normal import Normal +import torch.nn.functional as nnf +import numpy as np +import configs_sw as configs +import sys +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock,UnetrBasicBlock_No_DownSampling#,UnetrUpOnlyBlock +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, L, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], L // window_size[2], window_size[2], C) + + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C) + return windows + + +def window_reverse(windows, window_size, H, W, L): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W * L / window_size[0] / window_size[1] / window_size[2])) + x = windows.view(B, H // window_size[0], W // window_size[1], L // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, H, W, L, -1) + return x + + +class RelativeSinPosEmbed(nn.Module): + ''' + Rotary Position Embedding + ''' + def __init__(self,): + super(RelativeSinPosEmbed, self).__init__() + + def forward(self, attn): + batch_sz, _, n_patches, emb_dim = attn.shape + position_ids = torch.arange(0, n_patches).float().cuda() + indices = torch.arange(0, emb_dim//2).float().cuda() + indices = torch.pow(10000.0, -2 * indices / emb_dim) + embeddings = torch.einsum('b,d->bd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = torch.reshape(embeddings.view(n_patches, emb_dim), (1, 1, n_patches, emb_dim)) + #embeddings = embeddings.permute(0, 3, 1, 2) + return embeddings + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + + +class WindowAttention_crossModality(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + q_mod1 = q_mod1 * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + attn_mod1 = (q_mod1 @ k_mod1.transpose(-2, -1)) + + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + attn_mod1 = attn_mod1 + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn_mod1 = attn_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_mod1 = attn_mod1.view(-1, self.num_heads, N, N) + attn_mod1 = self.softmax(attn_mod1) + else: + attn = self.softmax(attn) + attn_mod1 = self.softmax(attn_mod1) + + + attn = self.attn_drop(attn) + attn_mod1 = self.attn_drop(attn_mod1) + + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) #mod 1 dot mod2 * mod2 + x_1 = (attn_mod1 @ v_mod1).transpose(1, 2).reshape(B_, N, C) #mod 2 dot mod 1 * mod1 + + x = self.proj(x) + x_1 = self.proj(x_1) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + + return x,x_1 + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttention_crossModality_4attns(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + q_mod1 = q_mod1 * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + attn_mod1 = (q_mod1 @ k_mod1.transpose(-2, -1)) + + #self_attn + attn_self = (q @ k_mod1.transpose(-2, -1)) + attn_self_mod1 = (q_mod1 @ k.transpose(-2, -1)) + + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + attn_mod1 = attn_mod1 + relative_position_bias.unsqueeze(0) + # self_attn + attn_self = attn_self + relative_position_bias.unsqueeze(0) + attn_self_mod1 = attn_self_mod1 + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn_mod1 = attn_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_mod1 = attn_mod1.view(-1, self.num_heads, N, N) + attn_mod1 = self.softmax(attn_mod1) + # self_attn + attn_self = attn_self.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_self = attn_self.view(-1, self.num_heads, N, N) + attn_self = self.softmax(attn_self) + attn_self_mod1 = attn_self_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_self_mod1 = attn_self_mod1.view(-1, self.num_heads, N, N) + attn_self_mod1 = self.softmax(attn_self_mod1) + + + else: + attn = self.softmax(attn) + attn_mod1 = self.softmax(attn_mod1) + # self_attn + attn_self = self.softmax(attn_self) + attn_self_mod1 = self.softmax(attn_self_mod1) + + + attn = self.attn_drop(attn) + attn_mod1 = self.attn_drop(attn_mod1) + # self_attn + attn_self = self.attn_drop(attn_self) + attn_self_mod1 = self.attn_drop(attn_self_mod1) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x_1 = (attn_mod1 @ v_mod1).transpose(1, 2).reshape(B_, N, C) + # self_attn + x_self = (attn_self @ v_mod1).transpose(1, 2).reshape(B_, N, C) + x_1_self = (attn_self_mod1 @ v).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x_1 = self.proj(x_1) + # self_attn + x_self = self.proj(x_self) + x_1_self = self.proj(x_1_self) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + # self_attn + x_self = self.proj_drop(x_self) + x_1_self = self.proj_drop(x_1_self) + + return x+x_self,x_1+x_1_self + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttention_dualModality(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x_1 = (attn @ v_mod1).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x_1 = self.proj(x_1) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + + return x,x_1 + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, T, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + + x = x.view(B, H * W * T, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class SwinTransformerBlock_dualModality(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention_dualModality( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x,x_1, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x_1 = self.norm1(x_1) + + x = x.view(B, H, W, T, C) + x_1 = x_1.view(B, H, W, T, C) + + + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + x_1 = nnf.pad(x_1, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + shifted_x_1 = torch.roll(x_1, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + shifted_x_1 = x_1 + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + x_1_windows = window_partition(shifted_x_1, self.window_size) # nW*B, window_size, window_size, C + x_1_windows = x_1_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], + C) # nW*B, window_size*window_size, C + + + # W-MSA/SW-MSA + attn_windows,attn_windows_x_1 = self.attn(x_windows,x_1_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + attn_windows_x_1 = attn_windows_x_1.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x_1 = window_reverse(attn_windows_x_1, self.window_size, Hp, Wp, Tp) # B H' W' C + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + x_1 = torch.roll(shifted_x_1, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + + else: + x = shifted_x + x_1 = shifted_x_1 + + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + x_1 = x_1[:, :H, :W, :T, :].contiguous() + + + x = x.view(B, H * W * T, C) + x_1 = x_1.view(B, H * W * T, C) + + + # FFN + x = shortcut + self.drop_path(x) + x_1 = shortcut + self.drop_path(x_1) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + x_1 = x_1 + self.drop_path(self.mlp(self.norm2(x_1))) + + return x,x_1 + + + +class SwinTransformerBlock_crossModality(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention_crossModality( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x,x_1, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x_1 = self.norm1(x_1) + + x = x.view(B, H, W, T, C) + x_1 = x_1.view(B, H, W, T, C) + + + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + x_1 = nnf.pad(x_1, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + shifted_x_1 = torch.roll(x_1, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + shifted_x_1 = x_1 + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + x_1_windows = window_partition(shifted_x_1, self.window_size) # nW*B, window_size, window_size, C + x_1_windows = x_1_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], + C) # nW*B, window_size*window_size, C + + + # W-MSA/SW-MSA + attn_windows,attn_windows_x_1 = self.attn(x_windows,x_1_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + attn_windows_x_1 = attn_windows_x_1.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x_1 = window_reverse(attn_windows_x_1, self.window_size, Hp, Wp, Tp) # B H' W' C + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + x_1 = torch.roll(shifted_x_1, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + + else: + x = shifted_x + x_1 = shifted_x_1 + + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + x_1 = x_1[:, :H, :W, :T, :].contiguous() + + + x = x.view(B, H * W * T, C) + x_1 = x_1.view(B, H * W * T, C) + + + # FFN + x = shortcut + self.drop_path(x) + x_1 = shortcut + self.drop_path(x_1) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + x_1 = x_1 + self.drop_path(self.mlp(self.norm2(x_1))) + + return x,x_1 + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2, concatenated_input=False): + super().__init__() + if concatenated_input: + self.dim = dim * 2 + else: + self.dim = dim + self.reduction = nn.Linear(8 * self.dim, (8//reduce_factor) * self.dim, bias=False) + self.norm = norm_layer(8 * self.dim) + + + def forward(self, x, H, W, T): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, T, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, T % 2, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 C + x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 C + x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 C + x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 8 * C) # B H/2*W/2*T/2 8*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class PatchConvPool(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2, concatenated_input=False): + super().__init__() + if concatenated_input: + self.dim = dim * 2 + else: + self.dim = dim + #self.reduction = nn.Linear(8 * self.dim, (8//reduce_factor) * self.dim, bias=False) + #self.norm = norm_layer(8 * self.dim) + + self.conv_du = nn.Sequential( + nn.Conv3d(self.dim, 2 * self.dim, 1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.BatchNorm3d(2 * self.dim), + nn.Upsample(scale_factor=0.5, mode='trilinear', align_corners=False) + ) + + def forward(self, x, H, W, T): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, C, H, W, T) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, T % 2, 0, W % 2, 0, H % 2)) + x = self.conv_du(x) + # x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 C + # x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 C + # x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 C + # x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 C + # x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 C + # x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 C + # x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 C + # x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 C + # x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 2 * C) # B H/2*W/2*T/2 8*C + + #x = self.norm(x) + #x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x, H, W, T, x_down, Wh, Ww, Wt + else: + return x, H, W, T, x, H, W, T + + + +class BasicLayer_dualModality(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_dualModality( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x,x_1, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x,x_1 = blk(x, x_1, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + x_1_down = self.downsample(x_1, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x,x_1, H, W, T, x_down,x_1_down, Wh, Ww, Wt + else: + return x, x_1, H, W, T, x,x_1, H, W, T + + + +class BasicLayer_crossModality(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_crossModality( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x, x_1, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x,x_1 = blk(x, x_1, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + x_1_down = self.downsample(x_1, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x,x_1, H, W, T, x_down,x_1_down, Wh, Ww, Wt + else: + return x, x_1, H, W, T, x,x_1, H, W, T + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W, T = x.size() + if T % self.patch_size[2] != 0: + x = nnf.pad(x, (0, self.patch_size[2] - T % self.patch_size[2])) + if W % self.patch_size[1] != 0: + x = nnf.pad(x, (0, 0, 0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww Wt + if self.norm is not None: + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + + return x + +class SinusoidalPositionEmbedding(nn.Module): + ''' + Rotary Position Embedding + ''' + def __init__(self,): + super(SinusoidalPositionEmbedding, self).__init__() + + def forward(self, x): + batch_sz, n_patches, hidden = x.shape + position_ids = torch.arange(0, n_patches).float().cuda() + indices = torch.arange(0, hidden//2).float().cuda() + indices = torch.pow(10000.0, -2 * indices / hidden) + embeddings = torch.einsum('b,d->bd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = torch.reshape(embeddings, (1, n_patches, hidden)) + return embeddings +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=96, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + # self.patch_embedding = PatchEmbeddingBlock( + # in_channels=in_chans, + # img_size=pretrain_img_size, + # patch_size=patch_size, + # hidden_size=embed_dim, + # num_heads=4, + # pos_embed='perceptron', + # dropout_rate=drop_path_rate, + # spatial_dims=3, + # ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(1,self.num_layers+1)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + #x = self.patch_embedding(x).transpose(1, 2) + x = self.patch_embed(x) + #x = self.norm(x) + #x = x.transpose(1, 2).view(-1, self.embed_dim, 48, 48, 48) + outs.append(x) + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x = x.flatten(2).transpose(1, 2) + x += self.pos_embd(x) + else: + x = x.flatten(2).transpose(1, 2) + + x = self.pos_drop(x) + + + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x = norm_layer(x) + + out = x.view(-1, Wh, Ww, Wt, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + #print(out.shape) + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_dense(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=96, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + # self.patch_embedding = PatchEmbeddingBlock( + # in_channels=in_chans, + # img_size=pretrain_img_size, + # patch_size=patch_size, + # hidden_size=embed_dim, + # num_heads=4, + # pos_embed='perceptron', + # dropout_rate=drop_path_rate, + # spatial_dims=3, + # ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + self.layers.append(layer) + patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + num_features = [int(embed_dim * 2 ** i) for i in range(1,self.num_layers+1)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + #x = self.patch_embedding(x).transpose(1, 2) + x = self.patch_embed(x) + #x = self.norm(x) + #x = x.transpose(1, 2).view(-1, self.embed_dim, 48, 48, 48) + outs.append(x) + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x = x.flatten(2).transpose(1, 2) + x += self.pos_embd(x) + else: + x = x.flatten(2).transpose(1, 2) + + x = self.pos_drop(x) + + for i in range(self.num_layers): + layer = self.layers[i] + #x_pre = x + x_pre_down = self.patch_merging_layers[i](x, Wh, Ww, Wt) + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + x = x_pre_down + x + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x = norm_layer(x) + out = x.view(-1, Wh, Ww, Wt, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + #print(out.shape) + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_dense, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + out_x0_x1_l0 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_1[:,:,0:self.embed_dim*2] + x_0_down + x_1 = x_0_1[:,:,self.embed_dim*2:] + x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_1[:, :, 0:self.embed_dim * 4] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 4:] + x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_1[:, :, 0:self.embed_dim * 8] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 8:] + x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk_concat(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_1[:,:,0:self.embed_dim*2] + x_0_down + x_1 = x_0_1[:,:,self.embed_dim*2:] + x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_1[:, :, 0:self.embed_dim * 4] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 4:] + x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_1[:, :, 0:self.embed_dim * 8] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 8:] + x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_down + x_1 = x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating, self).train(mode) + self._freeze_stages() + + + + + +class SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int((embed_dim*2) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + # patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + outs.append((x_0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_out_x0_l0 = x_out_x0_1_l0[:, :, 0:self.embed_dim *2] + x_out_x1_l0 = x_out_x0_1_l0[:, :, self.embed_dim*2:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 # do not update x_0 + x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l0 = torch.concat((x_0_out, x_0), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, + Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l1 # updated x_0 + x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_0 + x_1 + x_1_out#should I use the sum or concat for decoder? + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1 = self.patch_merging_layers[2](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l2 # updated x_0 + x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1 = self.patch_merging_layers[3](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + x_0_out = x_out_x0_l3 # updated x_0 + x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts, self).train(mode) + self._freeze_stages() + + + + + +class SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_dualModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*4) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + out = torch.cat((x_0,x_1),dim=2) + #out = x_0+x_1 + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 32).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts, self).train(mode) + self._freeze_stages() + + + +class SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_dualModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + #out = torch.cat((x_0,x_1),dim=2) + out = x_0 + x_1 + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_1_out + #out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + +from monai.utils import optional_import +rearrange, _ = optional_import("einops", name="rearrange") +import torch.nn.functional as F + +class SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_crossModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def proj_out(self, x, normalize=False): + if normalize: + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") + return x + + def forward(self, x,normalize=True): + """Forward function.""" + outs = [] + + #print ('info,',x.shape) + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + #x_0 = self.proj_out(x_0, normalize) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #x_1 = self.proj_out(x_1, normalize) + + #print(x_0.size()) + + #out = torch.cat((x_0,x_1),dim=2) + out = x_0 + x_1 + out = self.proj_out(out, normalize) + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + out_x0_x1_l0 = x_0_out + x_1_out + x_out_l0 = self.proj_out(out_x0_x1_l0, normalize) + + #out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{0}') + #x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + x_out_l1 = self.proj_out(out_x0_x1_l1, normalize) + + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{1}') + #x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + x_out_l2 = self.proj_out(out_x0_x1_l2, normalize) + + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{2}') + #x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = self.proj_out(out_x0_x1_l3, normalize) + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + + + +class SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + self.res_fusionBlock = depthwise_separable_conv3d( + nin=2, + kernels_per_layer=48, + nout=48, + ) + # split image into non-overlapping patches + + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_crossModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + # patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim * 2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + # print(x_0.size()) + + # out = torch.cat((x_0,x_1),dim=2) + #out = x_0 + x_1 + out = self.res_fusionBlock(x) + outs.append(out) + #outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer( + x_0, x_1, int(Wh_x0 / 2), int(Ww_x0 / 2), + int(Wt_x0 / 2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + # x_0_out = x_out_x0_l0 # do not update x_0 + # x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_1_out + # out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0 / 2), int(Ww_x0 / 2), int(Wt_x0 / 2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1 / 2), int(Ww_x1 / 2), int(Wt_x1 / 2)) + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1, x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer( + x_0, x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + # x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + # x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + # x_0_out = x_out_x0_l1 # updated x_0 + # x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out # should I use the sum or concat for decoder? + # out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer( + x_0, x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + # x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + # x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + # x_0_out = x_out_x0_l2 # updated x_0 + # x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + # out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer( + x_0, x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + # x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + # x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + # x_0_out = x_out_x0_l3 # updated x_0 + # x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + # out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*4) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def complement_idx(idx, dim): + """ + Compute the complement: set(range(dim)) - set(idx). + idx is a multi-dimensional tensor, find the complement for its trailing dimension, + all other dimension is considered batched. + Args: + idx: input index, shape: [N, *, K] + dim: the max index for complement + """ + a = torch.arange(dim, device=idx.device) + ndim = idx.ndim + dims = idx.shape + n_idx = dims[-1] + dims = dims[:-1] + (-1,) + for i in range(1, ndim): + a = a.unsqueeze(0) + a = a.expand(*dims) + masked = torch.scatter(a, -1, idx, 0) + compl, _ = torch.sort(masked, dim=-1, descending=False) + compl = compl.permute(-1, *tuple(range(ndim - 1))) + compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,))) + return compl + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + out = torch.cat((x_0,x_1),dim=2) + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + + x_0_top50,x_0_top50_idx = torch.topk(x_0,int(x_0.size(dim=1)/2),dim=1) + x_1_top50 = torch.gather(x_1, 1, x_0_top50_idx) + + x_0_1_top50 = torch.cat((x_0_top50, x_1_top50), + dim=1) + + + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, int(Wh_x0), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_1_l0[:, :int(Wh_x0*Ww_x0/2*Wt_x0/2/2), :] + x_1_out = x_out_x0_1_l0[:, int(Wh_x0*Ww_x0/2*Wt_x0/2/2):, :] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_1_l1[:, :int(Wh_x0_1_l0 * Ww_x0_1_l0 * Wt_x0_1_l0 / 2), :] + x_1_out = x_out_x0_1_l1[:, int(Wh_x0_1_l0 * Ww_x0_1_l0 * Wt_x0_1_l0 / 2):, :] + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0/2), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0/2), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_1_l2[:, :int(Wh_x0_1_l1 * Ww_x0_1_l1 * Wt_x0_1_l1 / 2), :] + x_1_out = x_out_x0_1_l2[:, int(Wh_x0_1_l1 * Ww_x0_1_l1 * Wt_x0_1_l1 / 2):, :] + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1/2), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1/2), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_1_l3[:, :int(Wh_x0_1_l2 * Ww_x0_1_l2 * Wt_x0_1_l2 / 2), :] + x_1_out = x_out_x0_1_l3[:, int(Wh_x0_1_l2 * Ww_x0_1_l2 * Wt_x0_1_l2 / 2):, :] + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 32).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_downsampling_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_downsampling_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_downsampling_layers.append(patch_downsampling_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer; use conv and pool and view, to downsample input of size of (1,64000,128) to (1,8000,256) + x_0_down = self.patch_downsampling_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_downsampling_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_downsampling_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_downsampling_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_down + x_1 = x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_downsampling_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_downsampling_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling, self).train(mode) + self._freeze_stages() + + + + +# feature 96 +class TransMorph_Unetr_CT_Lung_Tumor_Batch_Norm_Correction_Official_No_Unused_Parameters_Cross_Attention(nn.Module): + def __init__( + self, + config, + out_channels: int = 2, + feature_size: int = 48, + hidden_size: int = 768, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + #out_channels: int, + ) -> None: + ''' + TransMorph Model + ''' + + #super(TransMorph_Unetr, self).__init__() + super().__init__() + self.hidden_size = hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + + embed_dim = 96#config.embed_dim + + #SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts + #SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + pretrain_img_size=config.img_size[0], + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + ) + #below is the decoder from UnetR + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=config.in_chans, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + #x, out_feats = self.transformer(x_in) + + out_feats = self.transformer(x_in) + + #for item in out_feats: + # print ('info: size is ',item.shape) + + #info: size is torch.Size([6, 48, 64, 64, 64]) + #info: size is torch.Size([6, 96, 32, 32, 32]) + #info: size is torch.Size([6, 192, 16, 16, 16]) + #info: size is torch.Size([6, 384, 8, 8, 8]) + #info: size is torch.Size([6, 768, 4, 4, 4]) + + enc44 = out_feats[3] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[1] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[0] # torch.Size([4, 48, 64, 64, 64]) + #x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + x=out_feats[4] + + #print ('encoder x after projection size is ',x.size()) + + #print ('input enc0 size ',x_in.size()) + enc0 = self.encoder1(x_in) + #print ('out enc0 size ',enc0.size()) + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + #print ('enc1 size ',enc1.size()) + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + #print ('enc2 size ',enc2.size()) + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + #print ('enc3 size ',enc3.size()) + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits + +class Conv3dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + relu = nn.LeakyReLU(inplace=True) + if not use_batchnorm: + nm = nn.InstanceNorm3d(out_channels) + else: + nm = nn.BatchNorm3d(out_channels) + + super(Conv3dReLU, self).__init__(conv, nm, relu) + + +# Residual block +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, downsample=None): + super(ResidualBlock, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, stride), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, stride), + nn.BatchNorm3d(out_channels) + ) + self.conv_skip = nn.Sequential( + nn.Conv3d(in_channels, out_channels, stride), + nn.BatchNorm3d(out_channels), + ) + + def forward(self, x): + # residual = self.conv_skip(x) + # out = self.conv1(x) + # out = self.bn1(out) + # out = self.relu(out) + # out = self.conv2(out) + # out = self.bn2(out) + # + # out += residual + # out = self.relu(out) + return self.conv_block(x) + self.conv_skip(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + # self.conv1 = Conv3dReLU( + # out_channels+skip_channels, + # out_channels, + # kernel_size=3, + # padding=1, + # use_batchnorm=use_batchnorm, + # ) + # self.conv2 = Conv3dReLU( + # out_channels, + # out_channels, + # kernel_size=3, + # padding=1, + # use_batchnorm=use_batchnorm, + # ) + self.up = nn.ConvTranspose3d(in_channels,out_channels,kernel_size=2,stride=2) + + def forward(self, x, skip=None): + x = self.up(x) + #if skip is not None: + # x = torch.cat([x, skip], dim=1) + #x = self.conv1(x) + #x = self.conv2(x) + return x + +class RegistrationHead(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super().__init__(conv3d) + + +class SegmentationHead(nn.Sequential): + def __init__(self, in_channels, num_classes, image_size=(128,128,48), kernel_size=3, upsampling=1): + #conv3d = nn.Conv3d(in_channels, num_classes, kernel_size=1) + conv3d = nn.Conv3d(in_channels, num_classes, 1,1,0,1,1,False) + softmax = nn.Softmax(dim=1) + #Reshape = torch.reshape([np.prod(image_size),num_classes]) + #softmax = torch.nn.functional.softmax() + #conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + #conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super(SegmentationHead, self).__init__(conv3d,softmax) + +class SegmentationHead_new(nn.Sequential): + def __init__(self, in_channels, num_classes, kernel_size=1, upsampling=1): + #conv3d = nn.Conv3d(in_channels, num_classes, kernel_size=1) + conv3d = nn.Conv3d(in_channels, num_classes, 1,1,0,1,1, False) + sigmoid = nn.Sigmoid() + #conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + #conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super(SegmentationHead_new, self).__init__(conv3d,sigmoid) + +class SpatialTransformer(nn.Module): + """ + N-D Spatial Transformer + + Obtained from https://github.com/voxelmorph/voxelmorph + """ + + def __init__(self, size, mode='bilinear'): + super().__init__() + + self.mode = mode + + # create sampling grid + vectors = [torch.arange(0, s) for s in size] + grids = torch.meshgrid(vectors) + grid = torch.stack(grids) + grid = torch.unsqueeze(grid, 0) + grid = grid.type(torch.FloatTensor) + + # registering the grid as a buffer cleanly moves it to the GPU, but it also + # adds it to the state dict. this is annoying since everything in the state dict + # is included when saving weights to disk, so the model files are way bigger + # than they need to be. so far, there does not appear to be an elegant solution. + # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict + self.register_buffer('grid', grid) + + def forward(self, src, flow): + # new locations + new_locs = self.grid + flow + shape = flow.shape[2:] + + # need to normalize grid values to [-1, 1] for resampler + for i in range(len(shape)): + new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) + + # move channels dim to last position + # also not sure why, but the channels need to be reversed + if len(shape) == 2: + new_locs = new_locs.permute(0, 2, 3, 1) + new_locs = new_locs[..., [1, 0]] + elif len(shape) == 3: + new_locs = new_locs.permute(0, 2, 3, 4, 1) + new_locs = new_locs[..., [2, 1, 0]] + + return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode) + +class SwinVNetSkip(nn.Module): + def __init__(self, config): + super(SwinVNetSkip, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=config.in_chans, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + self.up0 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim, embed_dim//2, skip_channels=embed_dim//2 if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim//2, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim//2, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead_new( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.transformer(x) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + +from monai.networks.blocks import UnetrBasicBlock,UnetResBlock,UnetrUpBlock,UnetrPrUpBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock, get_conv_layer, UnetBasicBlock + +from typing import Sequence, Tuple, Union + +class SWINUnetrUpBlock(nn.Module): + """ + An upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>" + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = torch.cat((inp, skip), dim=1) + out = self.conv_block(out) + out = self.transp_conv(out) + + return out + + +class SWINUnetrBlock(nn.Module): + """ + An upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>" + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = torch.cat((inp, skip), dim=1) + out = self.conv_block(out) + #out = self.transp_conv(out) + + return out + +class SwinUNETR_self(nn.Module): + def __init__(self, config): + super(SwinUNETR_self, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=config.in_chans, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + #print(out[-1].size()) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_inputsFusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_inputsFusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=1, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + # self.res_fusionBlock = UnetResBlock( + # spatial_dims=3, + # in_channels=config.in_chans, + # out_channels=1, + # kernel_size=3, + # stride=1, + # norm_name='instance', + # ) + + self.res_fusionBlock = depthwise_separable_conv3d( + nin=config.in_chans, + kernels_per_layer=48, + nout=1, + ) + self.encoder0 = depthwise_separable_conv3d( + nin=1, + kernels_per_layer=48, + nout=embed_dim, + ) + + + # UnetrBasicBlock( + # spatial_dims=3, + # in_channels=1, + # out_channels=embed_dim, + # kernel_size=3, + # stride=1, + # norm_name='instance', + # res_block=True, + # ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + x = self.res_fusionBlock(x) + + out = self.transformer(x) # (B, n_patch, hidden) + # print(out[-1].size()) + + # stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + enc3 = self.encoder3(out[-3]) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + enc2 = self.encoder2(out[-4]) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class depthwise_separable_conv3d(nn.Module): + def __init__(self, nin, kernels_per_layer, nout): + super(depthwise_separable_conv3d, self).__init__() + self.depthwise = nn.Conv3d(nin, nin * kernels_per_layer, kernel_size=3, padding=1, stride=1, groups=nin) + self.pointwise = nn.Conv3d(nin * kernels_per_layer, nout, kernel_size=1) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class SwinUNETR_dense(nn.Module): + def __init__(self, config): + super(SwinUNETR_dense, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_dense(patch_size=config.patch_size, + in_chans=1, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + self.res_fusionBlock = UnetResBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=1, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + # self.res_fusionBlock = depthwise_separable_conv3d( + # nin=config.in_chans, + # kernels_per_layer=48, + # nout=1, + # ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=1, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + x = self.res_fusionBlock(x) + print(x.size()) + out = self.transformer(x) # (B, n_patch, hidden) + # print(out[-1].size()) + + # stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + enc3 = self.encoder3(out[-3]) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + enc2 = self.encoder2(out[-4]) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + +class SwinVNetSkip_transfuser(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim, embed_dim//2, skip_channels=embed_dim//2 if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim//2, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim//2, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + + +class SwinVNetSkip_transfuser_concat(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + + +class SwinVNetSkip_transfuser_concat_noCrossModalUpdating(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat_noCrossModalUpdating, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + +class SwinVNetSkip_transfuser_concat_noCrossModalUpdating_ConvPoolDownsampling(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat_noCrossModalUpdating_ConvPoolDownsampling, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead_new( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + +class SwinUNETR_fusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_fusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + + + +class SwinUNETR_dualModalityFusion_OutConcat(nn.Module): + def __init__(self, config): + super(SwinUNETR_dualModalityFusion_OutConcat, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *2, + out_channels=embed_dim *2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*32, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*16, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim *2 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_inputFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_inputFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + self.res_fusionBlock = depthwise_separable_conv3d( + nin=config.in_chans, + kernels_per_layer=48, + nout=1, + ) + + def forward(self, x): + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +########################################################################## +def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): + return nn.Conv3d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias, stride = stride) + +## Channel Attention Block (CAB) + +class CAB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction=4, bias=False, act = nn.PReLU()): + super(CAB, self).__init__() + modules_body = [] + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + modules_body.append(act) + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + + self.CA = CALayer(n_feat, reduction, bias=bias) #n_feat = channel, noiseLevel_dim + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) #x.shape=[4,80,32,32,32] and res.shape=[4,80,32,32,32] + res = self.CA(res) + res += x + return res +## Channel Attention Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16, bias=False): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool3d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv3d(channel, channel // reduction, 1, padding=0, bias=bias), + nn.ReLU(inplace=True), + nn.Conv3d(channel // reduction, channel, 1, padding=0, bias=bias), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class SwinUNETR_CrossModalityFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_OutSum_6stageOuts(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum_6stageOuts, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + # self.encoder0 = depthwise_separable_conv3d( + # nin=2, + # kernels_per_layer=96, + # nout=embed_dim, + # ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + #print(1) + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_OutSum_wChAttn(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum_wChAttn, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int( + config.in_chans / 2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim * 1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.CAB1 = CAB( + n_feat=embed_dim * 1, + kernel_size=3, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.CAB2 = CAB( + n_feat=embed_dim * 2, + kernel_size=3, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB3 = CAB( + n_feat=embed_dim * 4, + kernel_size=3, + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB4 = CAB( + n_feat=embed_dim * 8, + kernel_size=3, + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB5 = CAB( + n_feat=embed_dim * 16, + kernel_size=3, + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + out = self.transformer(x) # (B, n_patch, hidden) + + # stage 4 features + cab5 = self.CAB5(out[-1]) + enc5 = self.res_botneck(cab5) # B, 5,5,5,2048 + + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + cab4 = self.CAB4(out[-2]) + enc4 = self.encoder4(cab4) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + cab3 = self.CAB3(out[-3]) + enc3 = self.encoder3(cab3) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + cab2 = self.CAB2(out[-4]) + enc2 = self.encoder2(cab2) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + cab1 = self.CAB1(out[-5]) + enc1 = self.encoder1(cab1) # skip features + + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + +class SwinUNETR_dualModalityFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_dualModalityFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_RandomSpatialFusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_RandomSpatialFusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *2, + out_channels=embed_dim *2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*32, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*16, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim *2 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + +CONFIGS = { + 'Swin-Net-v0': configs.get_3DSwinNetV0_config(), + #'Swin-Net-v01': configs.get_3DSwinNetV01_config(), + 'Swin-Net-v02': configs.get_3DSwinNetV02_config(), + 'Swin-Net-v03': configs.get_3DSwinNetV03_config(), + 'Swin-Net-v04': configs.get_3DSwinNetV04_config(), + 'Swin-Net-v05': configs.get_3DSwinNetV05_config(), + 'Swin-Net-v06': configs.get_3DSwinNetV06_config(), + 'Swin-Net-hecktor-v01': configs.get_3DSwinNet_hecktor2021_V01_config(), + 'Swin-Net-hecktor-v02': configs.get_3DSwinNet_hecktor2021_V02_config(), + 'Swin-Net-hecktor-v03': configs.get_3DSwinNet_hecktor2021_V03_config(), + 'Swin-Net-hecktor-v01-ape': configs.get_3DSwinNetNoPosEmbd_config(), + 'Swin-Net-MGHHNData-v01-ape': configs.get_3DSwinNetV01_NoPosEmd_config(), + 'SwinUNETR-hecktor-v01': configs.get_3DSwinUNETR_hecktor2021_V01_config(), + 'SwinUNETR-hecktor-v02': configs.get_3DSwinUNETR_hecktor2021_V02_config(), + 'SwinUNETR_CMFF-hecktor-v01': configs.get_3DSwinUNETR_CMFF_hecktor2021_V01_config(), + 'SwinUNETR_CMFF-hecktor-v02': configs.get_3DSwinUNETR_CMFF_hecktor2021_V02_config(), + 'SwinUNETR_CMFF-hecktor-v03': configs.get_3DSwinUNETR_CMFF_hecktor2021_V03_config(), + 'SwinUNETR_CMFF-hecktor-v04': configs.get_3DSwinUNETR_CMFF_hecktor2021_V04_config(), + 'SwinUNETR_CMFF-hecktor-v05': configs.get_3DSwinUNETR_CMFF_hecktor2021_V05_config(), + 'SwinUNETR_CMFF-hecktor-v06': configs.get_3DSwinUNETR_CMFF_hecktor2021_V06_config() + +} \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/__init__.py b/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/unetr.py b/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/unetr.py new file mode 100644 index 00000000..4631f9c1 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/cross_swin_networks/unetr.py @@ -0,0 +1,222 @@ +# Copyright 2020 - 2021 MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple, Union +import torch +import torch.nn as nn + +from monai.networks.blocks.dynunet_block import UnetOutBlock +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.nets import ViT + + +class UNETR(nn.Module): + """ + UNETR based on: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>" + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + img_size: Tuple[int, int, int], + feature_size: int = 16, + hidden_size: int = 768, + mlp_dim: int = 3072, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "instance", + conv_block: bool = False, + res_block: bool = True, + dropout_rate: float = 0.0, + ) -> None: + """ + Args: + in_channels: dimension of input channels. + out_channels: dimension of output channels. + img_size: dimension of input image. + feature_size: dimension of network feature size. + hidden_size: dimension of hidden layer. + mlp_dim: dimension of feedforward layer. + num_heads: number of attention heads. + pos_embed: position embedding layer type. + norm_name: feature normalization type and arguments. + conv_block: bool argument to determine if convolutional block is used. + res_block: bool argument to determine if residual block is used. + dropout_rate: faction of the input units to drop. + + Examples:: + + # for single channel input 4-channel output with patch size of (96,96,96), feature size of 32 and batch norm + >>> net = UNETR(in_channels=1, out_channels=4, img_size=(96,96,96), feature_size=32, norm_name='batch') + + # for 4-channel input 3-channel output with patch size of (128,128,128), conv position embedding and instance norm + >>> net = UNETR(in_channels=4, out_channels=3, img_size=(128,128,128), pos_embed='conv', norm_name='instance') + + """ + + super().__init__() + + if not (0 <= dropout_rate <= 1): + raise AssertionError("dropout_rate should be between 0 and 1.") + + if hidden_size % num_heads != 0: + raise AssertionError("hidden size should be divisible by num_heads.") + + if pos_embed not in ["conv", "perceptron"]: + raise KeyError(f"Position embedding layer of type {pos_embed} is not supported.") + + self.num_layers = 12 + self.patch_size = (8, 8, 8) + self.feat_size = ( + img_size[0] // self.patch_size[0], + img_size[1] // self.patch_size[1], + img_size[2] // self.patch_size[2], + ) + self.hidden_size = hidden_size + self.classification = False + self.vit = ViT( + in_channels=in_channels, + img_size=img_size, + patch_size=self.patch_size, + hidden_size=hidden_size, + mlp_dim=mlp_dim, + num_layers=self.num_layers, + num_heads=num_heads, + pos_embed=pos_embed, + classification=self.classification, + dropout_rate=dropout_rate, + ) + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=res_block, + ) + self.encoder2 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 2, + num_layer=2, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder3 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 4, + num_layer=1, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.encoder4 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name=norm_name, + conv_block=conv_block, + res_block=res_block, + ) + self.decoder5 = UnetrUpBlock( + spatial_dims=3, + in_channels=hidden_size, + out_channels=feature_size * 8, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder4 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder3 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=3, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=res_block, + ) + self.out = UnetOutBlock(spatial_dims=3, in_channels=feature_size*2, out_channels=out_channels) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def load_from(self, weights): + with torch.no_grad(): + res_weight = weights + # copy weights from patch embedding + for i in weights['state_dict']: + print(i) + self.vit.patch_embedding.position_embeddings.copy_(weights['state_dict']['module.transformer.patch_embedding.position_embeddings_3d']) + self.vit.patch_embedding.cls_token.copy_(weights['state_dict']['module.transformer.patch_embedding.cls_token']) + self.vit.patch_embedding.patch_embeddings[1].weight.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.weight']) + self.vit.patch_embedding.patch_embeddings[1].bias.copy_(weights['state_dict']['module.transformer.patch_embedding.patch_embeddings.1.bias']) + + # copy weights from encoding blocks (default: num of blocks: 12) + for bname, block in self.vit.blocks.named_children(): + print(block) + block.loadFrom(weights, n_block=bname) + # last norm layer of transformer + self.vit.norm.weight.copy_(weights['state_dict']['module.transformer.norm.weight']) + self.vit.norm.bias.copy_(weights['state_dict']['module.transformer.norm.bias']) + + + def forward(self, x_in): + x, hidden_states_out = self.vit(x_in) + enc1 = self.encoder1(x_in) + x2 = hidden_states_out[3] + enc2 = self.encoder2(self.proj_feat(x2, self.hidden_size, self.feat_size)) + x3 = hidden_states_out[6] + enc3 = self.encoder3(self.proj_feat(x3, self.hidden_size, self.feat_size)) + x4 = hidden_states_out[9] + enc4 = self.encoder4(self.proj_feat(x4, self.hidden_size, self.feat_size)) + dec4 = self.proj_feat(x, self.hidden_size, self.feat_size) + dec3 = self.decoder5(dec4, enc4) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + #out = self.decoder2(dec1, enc1) + logits = self.out(dec1) + return logits diff --git a/models/msk_smit_lung_gtv/src/smit_models/format.py b/models/msk_smit_lung_gtv/src/smit_models/format.py new file mode 100644 index 00000000..09239338 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/format.py @@ -0,0 +1,59 @@ +from enum import Enum +from typing import Union + +import torch + + +class Format(str, Enum): + NCHW = 'NCHW' + NHWC = 'NHWC' + NCL = 'NCL' + NLC = 'NLC' + + +FormatT = Union[str, Format] + + +def get_spatial_dim(fmt: FormatT): + fmt = Format(fmt) + if fmt is Format.NLC: + dim = (1,) + elif fmt is Format.NCL: + dim = (2,) + elif fmt is Format.NHWC: + dim = (1, 2) + else: + dim = (2, 3) + return dim + + +def get_channel_dim(fmt: FormatT): + fmt = Format(fmt) + if fmt is Format.NHWC: + dim = 3 + elif fmt is Format.NLC: + dim = 2 + else: + dim = 1 + return dim + + +def nchw_to(x: torch.Tensor, fmt: Format): + if fmt == Format.NHWC: + #print ('x size ',x.shape) + x = x.permute(0, 2, 3, 4,1) + elif fmt == Format.NLC: + x = x.flatten(2).transpose(1, 2) + elif fmt == Format.NCL: + x = x.flatten(2) + return x + + +def nhwc_to(x: torch.Tensor, fmt: Format): + if fmt == Format.NCHW: + x = x.permute(0, 3, 1, 2) + elif fmt == Format.NLC: + x = x.flatten(1, 2) + elif fmt == Format.NCL: + x = x.flatten(1, 2).transpose(1, 2) + return x \ No newline at end of file diff --git a/models/msk_smit_lung_gtv/src/smit_models/smit.py b/models/msk_smit_lung_gtv/src/smit_models/smit.py new file mode 100644 index 00000000..aadbc7a6 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/smit.py @@ -0,0 +1,1160 @@ + +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_, to_3tuple +from torch.distributions.normal import Normal +import torch.nn.functional as nnf +import numpy as np +import smit_models.configs_smit as configs + +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock#,UnetrUpOnlyBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock + +from functools import partial + +from typing import Sequence, Tuple, Union + +from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer + +import numpy as np +import torch +import torch.nn as nn + + +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from typing import Optional + +class UnetrUpOnlyBlock(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + + + super(UnetrUpOnlyBlock, self).__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + out_channels,# + out_channels, + out_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp): + + out = self.transp_conv(inp) + + out = self.conv_block(out) + return out + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + + B, H, W, L, C = x.shape + + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], L // window_size[2], window_size[2], C) + + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C) + return windows + + +def window_reverse(windows, window_size, H, W, L): + + B = int(windows.shape[0] / (H * W * L / window_size[0] / window_size[1] / window_size[2])) + x = windows.view(B, H // window_size[0], W // window_size[1], L // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(B, H, W, L, -1) + return x + +class WindowAttention(nn.Module): + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + #print ('info: window_size sizie ',window_size) + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + + B_, N, C = x.shape #(num_windows*B, Wh*Ww*Wt, C) + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = (q @ k.transpose(-2, -1)) + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class SwinTransformerBlock(nn.Module): + + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size, shift_sz: {}, win_size: {}".format(self.shift_size, self.window_size) + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, T, C) + #print ('x size is ',x.size()) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' L' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + + x = x.view(B, H * W * T, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class PatchMerging(nn.Module): + + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(8 * dim, (8//reduce_factor) * dim, bias=False) + self.norm = norm_layer(8 * dim) + + + def forward(self, x, H, W, T): + """ + x: B, H*W*T, C + """ + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, T, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, W % 2, 0, H % 2, 0, T % 2)) + + x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 T/2 C + x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 T/2 C + x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 T/2 C + x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 T/2 C + x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 T/2 C + x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 T/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 T/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 T/2 C + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 8 * C) # B H/2*W/2*T/2 8*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class BasicLayer(nn.Module): + + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2,): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer,) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf) + else: + self.downsample = None + + def forward(self, x, H, W, T): + + + + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x, H, W, T, x_down, Wh, Ww, Wt + else: + return x, H, W, T, x, H, W, T + + def forward_with_features(self, x, H, W, T): + + + + + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + fea=[] + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + fea.append(x) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x, H, W, T, x_down, Wh, Ww, Wt + else: + return x, H, W, T, x, H, W, T,fea + +class PatchEmbed(nn.Module): + + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W, T = x.size() + if W % self.patch_size[1] != 0: + x = nnf.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if T % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww Wt + if self.norm is not None: + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + + return x + +class SinusoidalPositionEmbedding(nn.Module): + ''' + Rotary Position Embedding + ''' + def __init__(self,): + super(SinusoidalPositionEmbedding, self).__init__() + + def forward(self, x): + batch_sz, n_patches, hidden = x.shape + position_ids = torch.arange(0, n_patches).float().cuda() + indices = torch.arange(0, hidden//2).float().cuda() + indices = torch.pow(10000.0, -2 * indices / hidden) + embeddings = torch.einsum('b,d->bd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = torch.reshape(embeddings, (1, n_patches, hidden)) + return embeddings + +class SinPositionalEncoding3D(nn.Module): + def __init__(self, channels): + """ + :param channels: The last dimension of the tensor you want to apply pos emb to. + """ + super(SinPositionalEncoding3D, self).__init__() + channels = int(np.ceil(channels/6)*2) + if channels % 2: + channels += 1 + self.channels = channels + self.inv_freq = 1. / (10000 ** (torch.arange(0, channels, 2).float() / channels)) + #self.register_buffer('inv_freq', inv_freq) + + def forward(self, tensor): + """ + :param tensor: A 5d tensor of size (batch_size, x, y, z, ch) + :return: Positional Encoding Matrix of size (batch_size, x, y, z, ch) + """ + tensor = tensor.permute(0, 2, 3, 4, 1) + if len(tensor.shape) != 5: + raise RuntimeError("The input tensor has to be 5d!") + batch_size, x, y, z, orig_ch = tensor.shape + pos_x = torch.arange(x, device=tensor.device).type(self.inv_freq.type()) + pos_y = torch.arange(y, device=tensor.device).type(self.inv_freq.type()) + pos_z = torch.arange(z, device=tensor.device).type(self.inv_freq.type()) + sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) + sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) + sin_inp_z = torch.einsum("i,j->ij", pos_z, self.inv_freq) + emb_x = torch.cat((sin_inp_x.sin(), sin_inp_x.cos()), dim=-1).unsqueeze(1).unsqueeze(1) + emb_y = torch.cat((sin_inp_y.sin(), sin_inp_y.cos()), dim=-1).unsqueeze(1) + emb_z = torch.cat((sin_inp_z.sin(), sin_inp_z.cos()), dim=-1) + emb = torch.zeros((x,y,z,self.channels*3),device=tensor.device).type(tensor.type()) + emb[:,:,:,:self.channels] = emb_x + emb[:,:,:,self.channels:2*self.channels] = emb_y + emb[:,:,:,2*self.channels:] = emb_z + emb = emb[None,:,:,:,:orig_ch].repeat(batch_size, 1, 1, 1, 1) + return emb.permute(0, 4, 1, 2, 3) + +class SwinTransformer(nn.Module): + + + def __init__(self, pretrain_img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2,): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + #self.pos_embd = SinPositionalEncoding3D(96).cuda()#SinusoidalPositionEmbedding().cuda() + elif self.spe: + self.pos_embd = SinPositionalEncoding3D(embed_dim).cuda() + #self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf,) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + #print ('before patch embd x size is ',x.size()) + x = self.patch_embed(x) + + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + x = (x + self.pos_embd(x)).flatten(2).transpose(1, 2) + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + #print ('there are steps numbers ',self.num_layers) + #print ('x size is ',x.size()) + for i in range(self.num_layers): + #print ('stage ',i) + layer = self.layers[i] + #print ('before X size is ', x.size()) + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, T, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + + #print ('after out size is ', out.size()) + outs.append(out) + #print ('after X size is ', x.size()) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + + +class UnetResBlock_No_Downsampleing(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.downsample = in_channels != out_channels + stride_np = np.atleast_1d(stride) + if not np.all(stride_np == 1): + self.downsample = True + + def forward(self, inp): + residual = inp + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + out += residual + out = self.lrelu(out) + return out + + +class UnetrBasicBlock_No_DownSampling(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + + + super().__init__() + + if res_block: + self.layer = UnetResBlock_No_Downsampleing( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + else: + self.layer = UnetBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + + def forward(self, inp): + return self.layer(inp) + + +class SwinTransformer_(nn.Module): + + + def __init__(self, pretrain_img_size=128, + patch_size=2, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2,): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + #self.pos_embd = SinPositionalEncoding3D(96).cuda()#SinusoidalPositionEmbedding().cuda() + elif self.spe: + self.pos_embd = SinPositionalEncoding3D(embed_dim).cuda() + #self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf,) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + #print ('before patch embd x size is ',x.size()) + x = self.patch_embed(x) + #print ('after patch embd x size is ',x.size()) + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + x = (x + self.pos_embd(x)).flatten(2).transpose(1, 2) + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + outs = [] + #print ('there are steps numbers ',self.num_layers) + #print ('x size is ',x.size()) + for i in range(self.num_layers): + #print ('stage ',i) + layer = self.layers[i] + #print ('before X size is ', x.size()) + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x_out = norm_layer(x_out) + + out = x_out.view(-1, H, W, T, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + + #print ('after out size is ', out.size()) + outs.append(out) + #print ('bottle net X size is ', x.size()) + + return x,outs + + def train(self, mode=True): + + super(SwinTransformer_, self).train(mode) + self._freeze_stages() + +class SMIT_3D_Seg(nn.Module): + def __init__( + self, + config, + out_channels: int , + feature_size: int = 48, + hidden_size: int = 768, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + + ) -> None: + + super().__init__() + self.hidden_size = hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + + embed_dim = 96#config.embed_dim + self.transformer = SwinTransformer_(patch_size=config.patch_size, + pretrain_img_size=config.img_size[0], + in_chans=config.in_chans, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + ) + + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + x, out_feats = self.transformer(x_in) + + + + + + + + + enc44 = out_feats[-1] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[-2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[-3] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[-4] # torch.Size([4, 48, 64, 64, 64]) + x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + + + enc0 = self.encoder1(x_in) + + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits + + + + +CONFIGS = { + 'SMIT_config':configs.get_SMIT_128_bias_True(), + 'SMIT_config_cross_attention':configs.get_SMIT_128_bias_True_Cross(), + +} diff --git a/models/msk_smit_lung_gtv/src/smit_models/smit_cross_attention.py b/models/msk_smit_lung_gtv/src/smit_models/smit_cross_attention.py new file mode 100644 index 00000000..1a2aefc3 --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/smit_cross_attention.py @@ -0,0 +1,8466 @@ + +from typing import Tuple, Union +import torch +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, trunc_normal_, to_3tuple +from torch.distributions.normal import Normal +import torch.nn.functional as nnf +import numpy as np + +import sys +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock#,#UnetrBasicBlock_No_DownSampling#,UnetrUpOnlyBlock + +from monai.networks.layers.utils import get_act_layer, get_norm_layer +from typing import Optional +from typing import Sequence, Tuple, Union + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, L, C = x.shape + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], L // window_size[2], window_size[2], C) + + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C) + return windows + + +def window_reverse(windows, window_size, H, W, L): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W * L / window_size[0] / window_size[1] / window_size[2])) + x = windows.view(B, H // window_size[0], W // window_size[1], L // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 4, 2, 5, 3, 6, 7).contiguous().view(B, H, W, L, -1) + return x + + +class RelativeSinPosEmbed(nn.Module): + ''' + Rotary Position Embedding + ''' + def __init__(self,): + super(RelativeSinPosEmbed, self).__init__() + + def forward(self, attn): + batch_sz, _, n_patches, emb_dim = attn.shape + position_ids = torch.arange(0, n_patches).float().cuda() + indices = torch.arange(0, emb_dim//2).float().cuda() + indices = torch.pow(10000.0, -2 * indices / emb_dim) + embeddings = torch.einsum('b,d->bd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = torch.reshape(embeddings.view(n_patches, emb_dim), (1, 1, n_patches, emb_dim)) + #embeddings = embeddings.permute(0, 3, 1, 2) + return embeddings + +class WindowAttention(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + + + +class UnetResBlock_No_Downsampleing(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.downsample = in_channels != out_channels + stride_np = np.atleast_1d(stride) + if not np.all(stride_np == 1): + self.downsample = True + + def forward(self, inp): + residual = inp + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + out += residual + out = self.lrelu(out) + return out + + +class UnetrBasicBlock_No_DownSampling(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + + + super().__init__() + + if res_block: + self.layer = UnetResBlock_No_Downsampleing( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + else: + self.layer = UnetBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + + def forward(self, inp): + return self.layer(inp) + + +class WindowAttention_crossModality(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + q_mod1 = q_mod1 * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + attn_mod1 = (q_mod1 @ k_mod1.transpose(-2, -1)) + + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + attn_mod1 = attn_mod1 + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn_mod1 = attn_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_mod1 = attn_mod1.view(-1, self.num_heads, N, N) + attn_mod1 = self.softmax(attn_mod1) + else: + attn = self.softmax(attn) + attn_mod1 = self.softmax(attn_mod1) + + + attn = self.attn_drop(attn) + attn_mod1 = self.attn_drop(attn_mod1) + + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) #mod 1 dot mod2 * mod2 + x_1 = (attn_mod1 @ v_mod1).transpose(1, 2).reshape(B_, N, C) #mod 2 dot mod 1 * mod1 + + x = self.proj(x) + x_1 = self.proj(x_1) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + + return x,x_1 + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttention_crossModality_4attns(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + q_mod1 = q_mod1 * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + attn_mod1 = (q_mod1 @ k_mod1.transpose(-2, -1)) + + #self_attn + attn_self = (q @ k_mod1.transpose(-2, -1)) + attn_self_mod1 = (q_mod1 @ k.transpose(-2, -1)) + + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + attn_mod1 = attn_mod1 + relative_position_bias.unsqueeze(0) + # self_attn + attn_self = attn_self + relative_position_bias.unsqueeze(0) + attn_self_mod1 = attn_self_mod1 + relative_position_bias.unsqueeze(0) + + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + attn_mod1 = attn_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_mod1 = attn_mod1.view(-1, self.num_heads, N, N) + attn_mod1 = self.softmax(attn_mod1) + # self_attn + attn_self = attn_self.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_self = attn_self.view(-1, self.num_heads, N, N) + attn_self = self.softmax(attn_self) + attn_self_mod1 = attn_self_mod1.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn_self_mod1 = attn_self_mod1.view(-1, self.num_heads, N, N) + attn_self_mod1 = self.softmax(attn_self_mod1) + + + else: + attn = self.softmax(attn) + attn_mod1 = self.softmax(attn_mod1) + # self_attn + attn_self = self.softmax(attn_self) + attn_self_mod1 = self.softmax(attn_self_mod1) + + + attn = self.attn_drop(attn) + attn_mod1 = self.attn_drop(attn_mod1) + # self_attn + attn_self = self.attn_drop(attn_self) + attn_self_mod1 = self.attn_drop(attn_self_mod1) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x_1 = (attn_mod1 @ v_mod1).transpose(1, 2).reshape(B_, N, C) + # self_attn + x_self = (attn_self @ v_mod1).transpose(1, 2).reshape(B_, N, C) + x_1_self = (attn_self_mod1 @ v).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x_1 = self.proj(x_1) + # self_attn + x_self = self.proj(x_self) + x_1_self = self.proj(x_1_self) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + # self_attn + x_self = self.proj_drop(x_self) + x_1_self = self.proj_drop(x_1_self) + + return x+x_self,x_1+x_1_self + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class WindowAttention_dualModality(nn.Module): + """ Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pos_embed_method='relative'): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1) * (2 * window_size[2] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1 * 2*Wt-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.pos_embed_method = pos_embed_method + trunc_normal_(self.relative_position_bias_table, std=.02) + self.softmax = nn.Softmax(dim=-1) + self.sinposembed = RelativeSinPosEmbed() + + def forward(self, x, x_1, mask=None): + """ Forward function. + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + qkv_mod1 = self.qkv(x_1).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) + q_mod1, k_mod1, v_mod1 = qkv_mod1[0], qkv_mod1[1], qkv_mod1[2] # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + if self.pos_embed_method == 'rotary': + pos_embed = self.sinposembed(q) + cos_pos = pos_embed[..., 1::2].repeat(1, 1, 1, 2).cuda() + sin_pos = pos_embed[..., ::2].repeat(1, 1, 1, 2).cuda() + qw2 = torch.stack([-q[..., 1::2], q[..., ::2]], 4) + qw2 = torch.reshape(qw2, q.shape) + q = q * cos_pos + qw2 * sin_pos + kw2 = torch.stack([-k[..., 1::2], k[..., ::2]], 4) + kw2 = torch.reshape(kw2, k.shape) + k = k * cos_pos + kw2 * sin_pos + + attn = (q @ k.transpose(-2, -1)) + if self.pos_embed_method == 'relative': + relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], + self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww*Wt,Wh*Ww*Wt,nH + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww*Wt, Wh*Ww*Wt + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x_1 = (attn @ v_mod1).transpose(1, 2).reshape(B_, N, C) + + x = self.proj(x) + x_1 = self.proj(x_1) + + x = self.proj_drop(x) + x_1 = self.proj_drop(x_1) + + return x,x_1 + + def extra_repr(self) -> str: + return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}' + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, T, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + + x = x.view(B, H * W * T, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + +class SwinTransformerBlock_dualModality(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention_dualModality( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x,x_1, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x_1 = self.norm1(x_1) + + x = x.view(B, H, W, T, C) + x_1 = x_1.view(B, H, W, T, C) + + + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + x_1 = nnf.pad(x_1, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + shifted_x_1 = torch.roll(x_1, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + shifted_x_1 = x_1 + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + x_1_windows = window_partition(shifted_x_1, self.window_size) # nW*B, window_size, window_size, C + x_1_windows = x_1_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], + C) # nW*B, window_size*window_size, C + + + # W-MSA/SW-MSA + attn_windows,attn_windows_x_1 = self.attn(x_windows,x_1_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + attn_windows_x_1 = attn_windows_x_1.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x_1 = window_reverse(attn_windows_x_1, self.window_size, Hp, Wp, Tp) # B H' W' C + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + x_1 = torch.roll(shifted_x_1, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + + else: + x = shifted_x + x_1 = shifted_x_1 + + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + x_1 = x_1[:, :H, :W, :T, :].contiguous() + + + x = x.view(B, H * W * T, C) + x_1 = x_1.view(B, H * W * T, C) + + + # FFN + x = shortcut + self.drop_path(x) + x_1 = shortcut + self.drop_path(x_1) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + x_1 = x_1 + self.drop_path(self.mlp(self.norm2(x_1))) + + return x,x_1 + + + +class SwinTransformerBlock_crossModality(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm, pos_embed_method='relative', concatenated_input=True): + super().__init__() + if concatenated_input: + self.dim = dim *2 + else: + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(self.dim) + self.attn = WindowAttention_crossModality( + self.dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pos_embed_method=pos_embed_method) + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = norm_layer(self.dim) + mlp_hidden_dim = int(self.dim * mlp_ratio) + self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x,x_1, mask_matrix): + H, W, T = self.H, self.W, self.T + B, L, C = x.shape + #C = C * 2 + assert L == H * W * T, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x_1 = self.norm1(x_1) + + x = x.view(B, H, W, T, C) + x_1 = x_1.view(B, H, W, T, C) + + + + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + x_1 = nnf.pad(x_1, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + shifted_x_1 = torch.roll(x_1, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + shifted_x_1 = x_1 + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size, C + + x_1_windows = window_partition(shifted_x_1, self.window_size) # nW*B, window_size, window_size, C + x_1_windows = x_1_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], + C) # nW*B, window_size*window_size, C + + + # W-MSA/SW-MSA + attn_windows,attn_windows_x_1 = self.attn(x_windows,x_1_windows, mask=attn_mask) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' C + attn_windows_x_1 = attn_windows_x_1.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x_1 = window_reverse(attn_windows_x_1, self.window_size, Hp, Wp, Tp) # B H' W' C + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + x_1 = torch.roll(shifted_x_1, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + + else: + x = shifted_x + x_1 = shifted_x_1 + + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + x_1 = x_1[:, :H, :W, :T, :].contiguous() + + + x = x.view(B, H * W * T, C) + x_1 = x_1.view(B, H * W * T, C) + + + # FFN + x = shortcut + self.drop_path(x) + x_1 = shortcut + self.drop_path(x_1) + + x = x + self.drop_path(self.mlp(self.norm2(x))) + x_1 = x_1 + self.drop_path(self.mlp(self.norm2(x_1))) + + return x,x_1 + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2, concatenated_input=False): + super().__init__() + if concatenated_input: + self.dim = dim * 2 + else: + self.dim = dim + self.reduction = nn.Linear(8 * self.dim, (8//reduce_factor) * self.dim, bias=False) + self.norm = norm_layer(8 * self.dim) + + + def forward(self, x, H, W, T): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, T, C) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, T % 2, 0, W % 2, 0, H % 2)) + + x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 C + x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 C + x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 C + x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 8 * C) # B H/2*W/2*T/2 8*C + + x = self.norm(x) + x = self.reduction(x) + + return x + +class PatchConvPool(nn.Module): + r""" Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2, concatenated_input=False): + super().__init__() + if concatenated_input: + self.dim = dim * 2 + else: + self.dim = dim + #self.reduction = nn.Linear(8 * self.dim, (8//reduce_factor) * self.dim, bias=False) + #self.norm = norm_layer(8 * self.dim) + + self.conv_du = nn.Sequential( + nn.Conv3d(self.dim, 2 * self.dim, 1, stride=1, padding=0), + nn.ReLU(inplace=True), + nn.BatchNorm3d(2 * self.dim), + nn.Upsample(scale_factor=0.5, mode='trilinear', align_corners=False) + ) + + def forward(self, x, H, W, T): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W * T, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, C, H, W, T) + + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, T % 2, 0, W % 2, 0, H % 2)) + x = self.conv_du(x) + # x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 C + # x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 C + # x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 C + # x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 C + # x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 C + # x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 C + # x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 C + # x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 C + # x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 2 * C) # B H/2*W/2*T/2 8*C + + #x = self.norm(x) + #x = self.reduction(x) + + return x + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x, H, W, T, x_down, Wh, Ww, Wt + else: + return x, H, W, T, x, H, W, T + + + +class BasicLayer_dualModality(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_dualModality( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x,x_1, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x,x_1 = blk(x, x_1, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + x_1_down = self.downsample(x_1, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x,x_1, H, W, T, x_down,x_1_down, Wh, Ww, Wt + else: + return x, x_1, H, W, T, x,x_1, H, W, T + + + +class BasicLayer_crossModality(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock_crossModality( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + for i in range(depth)]) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf,concatenated_input=concatenated_input) + else: + self.downsample = None + + def forward(self, x, x_1, H, W, T): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x,x_1 = blk(x, x_1, attn_mask) + if self.downsample is not None: + x_down = self.downsample(x, H, W, T) + x_1_down = self.downsample(x_1, H, W, T) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x,x_1, H, W, T, x_down,x_1_down, Wh, Ww, Wt + else: + return x, x_1, H, W, T, x,x_1, H, W, T + + +class PatchEmbed(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W, T = x.size() + if T % self.patch_size[2] != 0: + x = nnf.pad(x, (0, self.patch_size[2] - T % self.patch_size[2])) + if W % self.patch_size[1] != 0: + x = nnf.pad(x, (0, 0, 0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww Wt + if self.norm is not None: + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + + return x + +class SinusoidalPositionEmbedding(nn.Module): + ''' + Rotary Position Embedding + ''' + def __init__(self,): + super(SinusoidalPositionEmbedding, self).__init__() + + def forward(self, x): + batch_sz, n_patches, hidden = x.shape + position_ids = torch.arange(0, n_patches).float().cuda() + indices = torch.arange(0, hidden//2).float().cuda() + indices = torch.pow(10000.0, -2 * indices / hidden) + embeddings = torch.einsum('b,d->bd', position_ids, indices) + embeddings = torch.stack([torch.sin(embeddings), torch.cos(embeddings)], dim=-1) + embeddings = torch.reshape(embeddings, (1, n_patches, hidden)) + return embeddings +from monai.networks.blocks.patchembedding import PatchEmbeddingBlock + +class SwinTransformer(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=96, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + # self.patch_embedding = PatchEmbeddingBlock( + # in_channels=in_chans, + # img_size=pretrain_img_size, + # patch_size=patch_size, + # hidden_size=embed_dim, + # num_heads=4, + # pos_embed='perceptron', + # dropout_rate=drop_path_rate, + # spatial_dims=3, + # ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + self.layers.append(layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(1,self.num_layers+1)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + #x = self.patch_embedding(x).transpose(1, 2) + x = self.patch_embed(x) + #x = self.norm(x) + #x = x.transpose(1, 2).view(-1, self.embed_dim, 48, 48, 48) + outs.append(x) + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x = x.flatten(2).transpose(1, 2) + x += self.pos_embd(x) + else: + x = x.flatten(2).transpose(1, 2) + + x = self.pos_drop(x) + + + for i in range(self.num_layers): + layer = self.layers[i] + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x = norm_layer(x) + + out = x.view(-1, Wh, Ww, Wt, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + #print(out.shape) + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_dense(nn.Module): + r""" Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=96, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3, 4), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative', + concatenated_input=True): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + # self.patch_embedding = PatchEmbeddingBlock( + # in_channels=in_chans, + # img_size=pretrain_img_size, + # patch_size=patch_size, + # hidden_size=embed_dim, + # num_heads=4, + # pos_embed='perceptron', + # dropout_rate=drop_path_rate, + # spatial_dims=3, + # ) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method, + concatenated_input=concatenated_input) + self.layers.append(layer) + patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + num_features = [int(embed_dim * 2 ** i) for i in range(1,self.num_layers+1)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + #x = self.patch_embedding(x).transpose(1, 2) + x = self.patch_embed(x) + #x = self.norm(x) + #x = x.transpose(1, 2).view(-1, self.embed_dim, 48, 48, 48) + outs.append(x) + + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = nnf.interpolate(self.absolute_pos_embed, size=(Wh, Ww, Wt), mode='trilinear') + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x = x.flatten(2).transpose(1, 2) + x += self.pos_embd(x) + else: + x = x.flatten(2).transpose(1, 2) + + x = self.pos_drop(x) + + for i in range(self.num_layers): + layer = self.layers[i] + #x_pre = x + x_pre_down = self.patch_merging_layers[i](x, Wh, Ww, Wt) + x_out, H, W, T, x, Wh, Ww, Wt = layer(x, Wh, Ww, Wt) + x = x_pre_down + x + if i in self.out_indices: + norm_layer = getattr(self, f'norm{i}') + x = norm_layer(x) + out = x.view(-1, Wh, Ww, Wt, self.num_features[i]).permute(0, 4, 1, 2, 3).contiguous() + #print(out.shape) + outs.append(out) + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_dense, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + out_x0_x1_l0 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_1[:,:,0:self.embed_dim*2] + x_0_down + x_1 = x_0_1[:,:,self.embed_dim*2:] + x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_1[:, :, 0:self.embed_dim * 4] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 4:] + x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_1[:, :, 0:self.embed_dim * 8] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 8:] + x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk_concat(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_1[:,:,0:self.embed_dim*2] + x_0_down + x_1 = x_0_1[:,:,self.embed_dim*2:] + x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_1[:, :, 0:self.embed_dim * 4] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 4:] + x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_1[:, :, 0:self.embed_dim * 8] + x_0_down + x_1 = x_0_1[:, :, self.embed_dim * 8:] + x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat, self).train(mode) + self._freeze_stages() + +class SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_merging_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_merging_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_down + x_1 = x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_merging_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_merging_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating, self).train(mode) + self._freeze_stages() + + + + + +class SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int((embed_dim*2) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + # patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + outs.append((x_0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_out_x0_l0 = x_out_x0_1_l0[:, :, 0:self.embed_dim *2] + x_out_x1_l0 = x_out_x0_1_l0[:, :, self.embed_dim*2:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 # do not update x_0 + x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l0 = torch.concat((x_0_out, x_0), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, + Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l1 # updated x_0 + x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_0 + x_1 + x_1_out#should I use the sum or concat for decoder? + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1 = self.patch_merging_layers[2](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l2 # updated x_0 + x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1 = self.patch_merging_layers[3](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + x_0_out = x_out_x0_l3 # updated x_0 + x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_0 + x_1 + x_1_out + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, H_x0_1, W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts, self).train(mode) + self._freeze_stages() + + + + + +class SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_dualModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*4) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + out = torch.cat((x_0,x_1),dim=2) + #out = x_0+x_1 + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 32).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts, self).train(mode) + self._freeze_stages() + + + +class SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(nn.Module): + + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_dualModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + #downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + downsample=PatchMerging if (i_layer < self.num_layers) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # Prim Image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # Second Image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + #out = torch.cat((x_0,x_1),dim=2) + out = x_0 + x_1 + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + #print (' info: after patch merge layer size ',out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).shape) + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_1_out + #out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + #print (' info: after Stage 0 layer size ',out.shape) + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2,3).contiguous() + #print (' info: after Stage 1 layer size ',out.shape) + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, 3).contiguous() + + #print (' info: after Stage 2 layer size ',out.shape) + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2,3).contiguous() + + + #print (' info: after Stage 3 layer size ',out.shape) + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + +from monai.utils import optional_import +rearrange, _ = optional_import("einops", name="rearrange") +import torch.nn.functional as F + +class SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_crossModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def proj_out(self, x, normalize=False): + if normalize: + x_shape = x.size() + if len(x_shape) == 5: + n, ch, d, h, w = x_shape + x = rearrange(x, "n c d h w -> n d h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n d h w c -> n c d h w") + elif len(x_shape) == 4: + n, ch, h, w = x_shape + x = rearrange(x, "n c h w -> n h w c") + x = F.layer_norm(x, [ch]) + x = rearrange(x, "n h w c -> n c h w") + return x + + def forward(self, x,normalize=True): + """Forward function.""" + outs = [] + + #print ('info,',x.shape) + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + #x_0 = self.proj_out(x_0, normalize) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #x_1 = self.proj_out(x_1, normalize) + + #print(x_0.size()) + + #out = torch.cat((x_0,x_1),dim=2) + out = x_0 + x_1 + out = self.proj_out(out, normalize) + + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0,x_1,int(Wh_x0/2), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + out_x0_x1_l0 = x_0_out + x_1_out + x_out_l0 = self.proj_out(out_x0_x1_l0, normalize) + + #out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{0}') + #x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1,x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0,x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + x_out_l1 = self.proj_out(out_x0_x1_l1, normalize) + + #out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{1}') + #x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0,x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + x_out_l2 = self.proj_out(out_x0_x1_l2, normalize) + + #out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + #norm_layer = getattr(self, f'norm{2}') + #x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + #x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3,H_x0_1, W_x0_1, T_x0_1, x_0_small,x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0,x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + #x_0_out = self.proj_out(x_0_out, normalize) + #x_1_out = self.proj_out(x_1_out, normalize) + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = self.proj_out(out_x0_x1_l3, normalize) + #out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + + + +class SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + self.res_fusionBlock = depthwise_separable_conv3d( + nin=2, + kernels_per_layer=48, + nout=48, + ) + # split image into non-overlapping patches + + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer_crossModality(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + # patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim * 2) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) # B C, W, H ,D + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + # print(x_0.size()) + + # out = torch.cat((x_0,x_1),dim=2) + #out = x_0 + x_1 + out = self.res_fusionBlock(x) + outs.append(out) + #outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_l0, x_out_x1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer( + x_0, x_1, int(Wh_x0 / 2), int(Ww_x0 / 2), + int(Wt_x0 / 2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_l0 + x_1_out = x_out_x1_l0 + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + # x_0_out = x_out_x0_l0 # do not update x_0 + # x_1_out = x_out_x1_l0 # do not update x_1 + out_x0_x1_l0 = x_0_out + x_1_out + # out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0 / 2), int(Ww_x0 / 2), int(Wt_x0 / 2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1 / 2), int(Ww_x1 / 2), int(Wt_x1 / 2)) + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l1, x_out_x1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer( + x_0, x_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_l1 + x_1_out = x_out_x1_l1 + # x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + # x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + # x_0_out = x_out_x0_l1 # updated x_0 + # x_1_out = x_out_x1_l1 # updated x_1 + out_x0_x1_l1 = x_0_out + x_1_out # should I use the sum or concat for decoder? + # out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l2, x_out_x1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer( + x_0, x_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_l2 + x_1_out = x_out_x1_l2 + # x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + # x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + # x_0_out = x_out_x0_l2 # updated x_0 + # x_1_out = x_out_x1_l2 # updated x_1 + out_x0_x1_l2 = x_0_out + x_1_out + # out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + # x_0_1 = torch.cat((x_0, x_1), + # dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_l3, x_out_x1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_small, x_1_small, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer( + x_0, x_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + # x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_l3 + x_1_out = x_out_x1_l3 + # x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + # x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + # x_0_out = x_out_x0_l3 # updated x_0 + # x_1_out = x_out_x1_l3 # updated x_1 + out_x0_x1_l3 = x_0_out + x_1_out + # out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed_mod0 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + self.patch_embed_mod1 = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], + pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_merging_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int((embed_dim) * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers -1 ) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_merging_layer = PatchMerging(int(embed_dim * 2 ** i_layer), reduce_factor=4, concatenated_input=False) + # self.patch_merging_layers.append(patch_merging_layer) + #patch_merging_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_merging_layers.append(patch_merging_layer) + + num_features = [int((embed_dim*4) * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer] * 1) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def complement_idx(idx, dim): + """ + Compute the complement: set(range(dim)) - set(idx). + idx is a multi-dimensional tensor, find the complement for its trailing dimension, + all other dimension is considered batched. + Args: + idx: input index, shape: [N, *, K] + dim: the max index for complement + """ + a = torch.arange(dim, device=idx.device) + ndim = idx.ndim + dims = idx.shape + n_idx = dims[-1] + dims = dims[:-1] + (-1,) + for i in range(1, ndim): + a = a.unsqueeze(0) + a = a.expand(*dims) + masked = torch.scatter(a, -1, idx, 0) + compl, _ = torch.sort(masked, dim=-1, descending=False) + compl = compl.permute(-1, *tuple(range(ndim - 1))) + compl = compl[n_idx:].permute(*(tuple(range(1, ndim)) + (0,))) + return compl + + def forward(self, x): + """Forward function.""" + outs = [] + x_0 = torch.unsqueeze(x[:, 0, :, :, :], 1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :], 1) + # PET image + x_0 = self.patch_embed_mod0(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), + mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + # CT image + x_1 = self.patch_embed_mod1(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + #print(x_0.size()) + + out = torch.cat((x_0,x_1),dim=2) + outs.append((out.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous())) + + x_0 = self.patch_merging_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1 = self.patch_merging_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + + x_0_top50,x_0_top50_idx = torch.topk(x_0,int(x_0.size(dim=1)/2),dim=1) + x_1_top50 = torch.gather(x_1, 1, x_0_top50_idx) + + x_0_1_top50 = torch.cat((x_0_top50, x_1_top50), + dim=1) + + + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the spatial dimension so that the SWINTR is looking at the correlations spatially + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, int(Wh_x0), int(Ww_x0/2), + int(Wt_x0/2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l0 = x_0_1 + x_0_out = x_out_x0_1_l0[:, :int(Wh_x0*Ww_x0/2*Wt_x0/2/2), :] + x_1_out = x_out_x0_1_l0[:, int(Wh_x0*Ww_x0/2*Wt_x0/2/2):, :] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + #x_0_out = x_out_x0_l0 # do not update x_0 + #x_1_out = x_out_x1_l0 # do not update x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 0 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[1](x_0, int(Wh_x0/2), int(Ww_x0/2), int(Wt_x0/2)) + x_1 = self.patch_merging_layers[1](x_1, int(Wh_x1/2), int(Ww_x1/2), int(Wt_x1/2)) + + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, int(Wh_x0_1_l0), + int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l1 = x_0_1 + x_0_out = x_out_x0_1_l1[:, :int(Wh_x0_1_l0 * Ww_x0_1_l0 * Wt_x0_1_l0 / 2), :] + x_1_out = x_out_x0_1_l1[:, int(Wh_x0_1_l0 * Ww_x0_1_l0 * Wt_x0_1_l0 / 2):, :] + #x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim * 4] + #x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim * 4:] + #x_0_out = x_out_x0_l1 # updated x_0 + #x_1_out = x_out_x1_l1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 8).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 1 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[2](x_0, int(Wh_x0_1_l0/2), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + x_1 = self.patch_merging_layers[2](x_1, int(Wh_x0_1_l0/2), int(Ww_x0_1_l0), int(Wt_x0_1_l0)) + + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, int(Wh_x0_1_l1), + int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l2 = x_0_1 + x_0_out = x_out_x0_1_l2[:, :int(Wh_x0_1_l1 * Ww_x0_1_l1 * Wt_x0_1_l1 / 2), :] + x_1_out = x_out_x0_1_l2[:, int(Wh_x0_1_l1 * Ww_x0_1_l1 * Wt_x0_1_l1 / 2):, :] + #x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 8] + #x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 8:] + #x_0_out = x_out_x0_l2 # updated x_0 + #x_1_out = x_out_x1_l2 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 16).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + # add transformer encoded back to modality-specific branch + x_0 = x_0_out + x_0 + x_1 = x_1_out + x_1 + # construct the input for the next layer + x_0 = self.patch_merging_layers[3](x_0, int(Wh_x0_1_l1/2), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + x_1 = self.patch_merging_layers[3](x_1, int(Wh_x0_1_l1/2), int(Ww_x0_1_l1), int(Wt_x0_1_l1)) + + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=1) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, int(Wh_x0_1_l2), + int(Ww_x0_1_l2), int(Wt_x0_1_l2)) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + #x_out_x0_1_l3 = x_0_1 + x_0_out = x_out_x0_1_l3[:, :int(Wh_x0_1_l2 * Ww_x0_1_l2 * Wt_x0_1_l2 / 2), :] + x_1_out = x_out_x0_1_l3[:, int(Wh_x0_1_l2 * Ww_x0_1_l2 * Wt_x0_1_l2 / 2):, :] + #x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 16] + #x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 16:] + #x_0_out = x_out_x0_l3 # updated x_0 + #x_1_out = x_out_x1_l3 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out, x_1_out), dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, int(H_x0_1/2), W_x0_1, T_x0_1, self.embed_dim * 32).permute(0, 4, 1, 2, + 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts, self).train(mode) + self._freeze_stages() + + +class SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling(nn.Module): + r""" Swin Transformer modified to process images from two modalities; feature talks between two images are introduced in encoder + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (tuple): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + """ + + def __init__(self, pretrain_img_size=160, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.2, + norm_layer=nn.LayerNorm, + ape=False, + spe=False, + patch_norm=True, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + use_checkpoint=False, + pat_merg_rf=2, + pos_embed_method='relative'): + super().__init__() + self.pretrain_img_size = pretrain_img_size + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.ape = ape + self.spe = spe + self.patch_norm = patch_norm + self.out_indices = out_indices + self.frozen_stages = frozen_stages + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None) + + # absolute position embedding + if self.ape: + pretrain_img_size = to_3tuple(self.pretrain_img_size) + patch_size = to_3tuple(patch_size) + patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1], pretrain_img_size[2] // patch_size[2]] + + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1], patches_resolution[2])) + trunc_normal_(self.absolute_pos_embed, std=.02) + elif self.spe: + self.pos_embd = SinusoidalPositionEmbedding().cuda() + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + self.patch_downsampling_layers = nn.ModuleList() + + for i_layer in range(self.num_layers): + layer = BasicLayer(dim=int(embed_dim * 2 ** i_layer), + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pat_merg_rf=pat_merg_rf, + pos_embed_method=pos_embed_method) + self.layers.append(layer) + + patch_downsampling_layer = PatchConvPool(int(embed_dim * 2 ** i_layer), concatenated_input=False) + self.patch_downsampling_layers.append(patch_downsampling_layer) + + num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + self.num_features = num_features + + + + # add a norm layer for each output + for i_layer in out_indices: + layer = norm_layer(num_features[i_layer]*2) + layer_name = f'norm{i_layer}' + self.add_module(layer_name, layer) + + self._freeze_stages() + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + + if self.frozen_stages >= 1 and self.ape: + self.absolute_pos_embed.requires_grad = False + + if self.frozen_stages >= 2: + self.pos_drop.eval() + for i in range(0, self.frozen_stages - 1): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + """Initialize the weights in backbone. + Args: + pretrained (str, optional): Path to pre-trained weights. + Defaults to None. + """ + + def _init_weights(m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + if isinstance(pretrained, str): + self.apply(_init_weights) + elif pretrained is None: + self.apply(_init_weights) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x_0,x_1): + """Forward function.""" + #PET image + x_0 = self.patch_embed(x_0) + Wh_x0, Ww_x0, Wt_x0 = x_0.size(2), x_0.size(3), x_0.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x0 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x0, Ww_x0, Wt_x0), mode='trilinear') + x_0 = (x_0 + absolute_pos_embed_x0).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 += self.pos_embd(x_0) + else: + x_0 = x_0.flatten(2).transpose(1, 2) + x_0 = self.pos_drop(x_0) + + #CT image + x_1 = self.patch_embed(x_1) + Wh_x1, Ww_x1, Wt_x1 = x_1.size(2), x_1.size(3), x_1.size(4) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed_x1 = nnf.interpolate(self.absolute_pos_embed, size=(Wh_x1, Ww_x1, Wt_x1), + mode='trilinear') + x_1 = (x_1 + absolute_pos_embed_x1).flatten(2).transpose(1, 2) # B Wh*Ww*Wt C + elif self.spe: + print(self.spe) + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 += self.pos_embd(x_1) + else: + x_1 = x_1.flatten(2).transpose(1, 2) + x_1 = self.pos_drop(x_1) + + + outs = [] + + #############layer0################ + layer = self.layers[0] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0,x_1),dim=2) #concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l0, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0 = layer(x_0_1, Wh_x0, Ww_x0, Wt_x0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l0 = x_out_x0_1_l0[:,:,0:self.embed_dim] + x_out_x1_l0 = x_out_x0_1_l0[:,:,self.embed_dim:] + # add x_1_process and x_0_processed to x_1 and x_0 and start layer 1 + x_0_out = x_out_x0_l0 + x_0 # updated x_0 + x_1_out = x_out_x1_l0 + x_1 # updated x_1 + #out_x0_x1_l0 = x_0_out + x_1_out + out_x0_x1_l0 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{0}') + x_out_l0 = norm_layer(out_x0_x1_l0) + out = x_out_l0.view(-1, Wh_x0, Ww_x0, Wt_x0, self.embed_dim*2).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) #layer 0 output + + #construct the input for the next layer; use conv and pool and view, to downsample input of size of (1,64000,128) to (1,8000,256) + x_0_down = self.patch_downsampling_layers[0](x_0, Wh_x0, Ww_x0, Wt_x0) + x_1_down = self.patch_downsampling_layers[0](x_1, Wh_x1, Ww_x1, Wt_x1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer1################ + layer = self.layers[1] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l1, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1 = layer(x_0_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l1 = x_out_x0_1_l1[:, :, 0:self.embed_dim*2] + x_out_x1_l1 = x_out_x0_1_l1[:, :, self.embed_dim*2:] + x_0_out = x_out_x0_l1 + x_0 # updated x_0 + x_1_out = x_out_x1_l1 + x_1 # updated x_1 + #out_x0_x1_l1 = x_0_out + x_1_out #should I use the sum or concat for decoder? + out_x0_x1_l1 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{1}') + x_out_l1 = norm_layer(out_x0_x1_l1) + out = x_out_l1.view(-1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0, self.embed_dim*4).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 1 output + + #construct the input for the next layer + x_0_down = self.patch_downsampling_layers[1](x_0, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_1_down = self.patch_downsampling_layers[1](x_1, Wh_x0_1_l0, Ww_x0_1_l0, Wt_x0_1_l0) + x_0 = x_0_down + x_1 = x_1_down + + #############layer2################ + layer = self.layers[2] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l2, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2 = layer(x_0_1, Wh_x0_1_l1, + Ww_x0_1_l1, Wt_x0_1_l1) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l2 = x_out_x0_1_l2[:, :, 0:self.embed_dim * 4] + x_out_x1_l2 = x_out_x0_1_l2[:, :, self.embed_dim * 4:] + x_0_out = x_out_x0_l2 + x_0 # updated x_0 + x_1_out = x_out_x1_l2 + x_1 # updated x_1 + #out_x0_x1_l2 = x_0_out + x_1_out + out_x0_x1_l2 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{2}') + x_out_l2 = norm_layer(out_x0_x1_l2) + out = x_out_l2.view(-1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1, self.embed_dim*8).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + # construct the input for the next layer + x_0_down = self.patch_downsampling_layers[2](x_0, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_1_down = self.patch_downsampling_layers[2](x_1, Wh_x0_1_l1, Ww_x0_1_l1, Wt_x0_1_l1) + x_0 = x_0_down + x_1 = x_1_down + + #############layer3################ + layer = self.layers[3] + # concatenate x_0 and x_1 in dimension 1 + x_0_1 = torch.cat((x_0, x_1), + dim=2) # concatenate in the channel dimension so that the image dimension is still Wh_x0, Ww_x0, Wt_x0 which is needed to be compatible for SWINTR + # send the concatenated feature vector to transformer + x_out_x0_1_l3, H_x0_1, W_x0_1, T_x0_1, x_0_1, Wh_x0_1_l3, Ww_x0_1_l3, Wt_x0_1_l3 = layer(x_0_1, Wh_x0_1_l2, + Ww_x0_1_l2, Wt_x0_1_l2) + # split the resulting feature vector in dimension 1 to get processed x_1_processed, x_0_processed + x_out_x0_l3 = x_out_x0_1_l3[:, :, 0:self.embed_dim * 8] + x_out_x1_l3 = x_out_x0_1_l3[:, :, self.embed_dim * 8:] + x_0_out = x_out_x0_l3 + x_0 # updated x_0 + x_1_out = x_out_x1_l3 + x_1 # updated x_1 + #out_x0_x1_l3 = x_0_out + x_1_out + out_x0_x1_l3 = torch.concat((x_0_out , x_1_out),dim=2) + + norm_layer = getattr(self, f'norm{3}') + x_out_l3 = norm_layer(out_x0_x1_l3) + out = x_out_l3.view(-1, Wh_x0_1_l2, Ww_x0_1_l2, Wt_x0_1_l2, self.embed_dim*16).permute(0, 4, 1, 2, 3).contiguous() + outs.append(out) # layer 2 output + + return outs + + def train(self, mode=True): + """Convert the model into training mode while keep layers freezed.""" + super(SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling, self).train(mode) + self._freeze_stages() + + + + + +class SMIT_Cross_Attention(nn.Module): + def __init__( + self, + config, + out_channels: int = 2, + feature_size: int = 48, + hidden_size: int = 768, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + #out_channels: int, + ) -> None: + ''' + TransMorph Model + ''' + + #super(TransMorph_Unetr, self).__init__() + super().__init__() + self.hidden_size = hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + + + + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + pretrain_img_size=config.img_size[0], + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + ) + #below is the decoder from UnetR + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=config.in_chans, + #in_channels=1, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + #x, out_feats = self.transformer(x_in) + + out_feats = self.transformer(x_in) + + #for item in out_feats: + # print ('info: size is ',item.shape) + + #info: size is torch.Size([6, 48, 64, 64, 64]) + #info: size is torch.Size([6, 96, 32, 32, 32]) + #info: size is torch.Size([6, 192, 16, 16, 16]) + #info: size is torch.Size([6, 384, 8, 8, 8]) + #info: size is torch.Size([6, 768, 4, 4, 4]) + + enc44 = out_feats[3] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[1] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[0] # torch.Size([4, 48, 64, 64, 64]) + #x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + x=out_feats[4] + + print ('encoder x after projection size is ',x.size()) + print ('encoder enc11 size is ',enc11.size()) + print ('encoder enc22 size is ',enc22.size()) + print ('encoder enc33 size is ',enc33.size()) + print ('encoder enc44 size is ',enc44.size()) + + #encoder x after projection size is torch.Size([1, 768, 4, 4, 4]) + #encoder enc11 size is torch.Size([1, 48, 64, 64, 64]) + #encoder enc22 size is torch.Size([1, 96, 32, 32, 32]) + #encoder enc33 size is torch.Size([1, 192, 16, 16, 16]) + #encoder enc44 size is torch.Size([1, 384, 8, 8, 8]) + + + + #print ('input enc0 size ',x_in.size()) + + print ('x_in size ',x_in.shape) + x_prim = torch.unsqueeze(x_in[:, 0, :, :, :], 1) + x_second = torch.unsqueeze(x_in[:, 1, :, :, :], 1) + + print ('x_prim size ',x_prim.shape) + + enc0 = self.encoder1(x_in) + #print ('out enc0 size ',enc0.size()) + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + #print ('enc1 size ',enc1.size()) + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + #print ('enc2 size ',enc2.size()) + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + #print ('enc3 size ',enc3.size()) + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits + + + +class SMIT_Cross_Attention_24(nn.Module): + def __init__( + self, + config, + out_channels: int = 2, + feature_size: int = 24, + hidden_size: int = 768, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + #out_channels: int, + ) -> None: + ''' + TransMorph Model + ''' + + #super(TransMorph_Unetr, self).__init__() + super().__init__() + self.hidden_size = hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + + + + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + pretrain_img_size=config.img_size[0], + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + ) + #below is the decoder from UnetR + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=config.in_chans, + #in_channels=1, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + #x, out_feats = self.transformer(x_in) + + out_feats = self.transformer(x_in) + + #for item in out_feats: + # print ('info: size is ',item.shape) + + #info: size is torch.Size([6, 48, 64, 64, 64]) + #info: size is torch.Size([6, 96, 32, 32, 32]) + #info: size is torch.Size([6, 192, 16, 16, 16]) + #info: size is torch.Size([6, 384, 8, 8, 8]) + #info: size is torch.Size([6, 768, 4, 4, 4]) + + enc44 = out_feats[3] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[1] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[0] # torch.Size([4, 48, 64, 64, 64]) + #x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + x=out_feats[4] + + #print ('encoder x after projection size is ',x.size()) + #print ('encoder enc11 size is ',enc11.size()) + #print ('encoder enc22 size is ',enc22.size()) + #print ('encoder enc33 size is ',enc33.size()) + #print ('encoder enc44 size is ',enc44.size()) + + #encoder x after projection size is torch.Size([1, 768, 4, 4, 4]) + #encoder enc11 size is torch.Size([1, 48, 64, 64, 64]) + #encoder enc22 size is torch.Size([1, 96, 32, 32, 32]) + #encoder enc33 size is torch.Size([1, 192, 16, 16, 16]) + #encoder enc44 size is torch.Size([1, 384, 8, 8, 8]) + + + + #print ('input enc0 size ',x_in.size()) + + #print ('x_in size ',x_in.shape) + x_prim = torch.unsqueeze(x_in[:, 0, :, :, :], 1) + x_second = torch.unsqueeze(x_in[:, 1, :, :, :], 1) + + #print ('x_prim size ',x_prim.shape) + + enc0 = self.encoder1(x_in) + #print ('out enc0 size ',enc0.size()) + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + #print ('enc1 size ',enc1.size()) + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + #print ('enc2 size ',enc2.size()) + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + #print ('enc3 size ',enc3.size()) + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits +class Conv3dReLU(nn.Sequential): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + padding=0, + stride=1, + use_batchnorm=True, + ): + conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + relu = nn.LeakyReLU(inplace=True) + if not use_batchnorm: + nm = nn.InstanceNorm3d(out_channels) + else: + nm = nn.BatchNorm3d(out_channels) + + super(Conv3dReLU, self).__init__(conv, nm, relu) + + +# Residual block +class ResidualBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, downsample=None): + super(ResidualBlock, self).__init__() + self.conv_block = nn.Sequential( + nn.Conv3d(in_channels, out_channels, stride), + nn.BatchNorm3d(out_channels), + nn.ReLU(inplace=True), + nn.Conv3d(out_channels, out_channels, stride), + nn.BatchNorm3d(out_channels) + ) + self.conv_skip = nn.Sequential( + nn.Conv3d(in_channels, out_channels, stride), + nn.BatchNorm3d(out_channels), + ) + + def forward(self, x): + # residual = self.conv_skip(x) + # out = self.conv1(x) + # out = self.bn1(out) + # out = self.relu(out) + # out = self.conv2(out) + # out = self.bn2(out) + # + # out += residual + # out = self.relu(out) + return self.conv_block(x) + self.conv_skip(x) + + +class DecoderBlock(nn.Module): + def __init__( + self, + in_channels, + out_channels, + skip_channels=0, + use_batchnorm=True, + ): + super().__init__() + # self.conv1 = Conv3dReLU( + # out_channels+skip_channels, + # out_channels, + # kernel_size=3, + # padding=1, + # use_batchnorm=use_batchnorm, + # ) + # self.conv2 = Conv3dReLU( + # out_channels, + # out_channels, + # kernel_size=3, + # padding=1, + # use_batchnorm=use_batchnorm, + # ) + self.up = nn.ConvTranspose3d(in_channels,out_channels,kernel_size=2,stride=2) + + def forward(self, x, skip=None): + x = self.up(x) + #if skip is not None: + # x = torch.cat([x, skip], dim=1) + #x = self.conv1(x) + #x = self.conv2(x) + return x + +class RegistrationHead(nn.Sequential): + def __init__(self, in_channels, out_channels, kernel_size=3, upsampling=1): + conv3d = nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2) + conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super().__init__(conv3d) + + +class SegmentationHead(nn.Sequential): + def __init__(self, in_channels, num_classes, image_size=(128,128,48), kernel_size=3, upsampling=1): + #conv3d = nn.Conv3d(in_channels, num_classes, kernel_size=1) + conv3d = nn.Conv3d(in_channels, num_classes, 1,1,0,1,1,False) + softmax = nn.Softmax(dim=1) + #Reshape = torch.reshape([np.prod(image_size),num_classes]) + #softmax = torch.nn.functional.softmax() + #conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + #conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super(SegmentationHead, self).__init__(conv3d,softmax) + +class SegmentationHead_new(nn.Sequential): + def __init__(self, in_channels, num_classes, kernel_size=1, upsampling=1): + #conv3d = nn.Conv3d(in_channels, num_classes, kernel_size=1) + conv3d = nn.Conv3d(in_channels, num_classes, 1,1,0,1,1, False) + sigmoid = nn.Sigmoid() + #conv3d.weight = nn.Parameter(Normal(0, 1e-5).sample(conv3d.weight.shape)) + #conv3d.bias = nn.Parameter(torch.zeros(conv3d.bias.shape)) + super(SegmentationHead_new, self).__init__(conv3d,sigmoid) + +class SpatialTransformer(nn.Module): + """ + N-D Spatial Transformer + + Obtained from https://github.com/voxelmorph/voxelmorph + """ + + def __init__(self, size, mode='bilinear'): + super().__init__() + + self.mode = mode + + # create sampling grid + vectors = [torch.arange(0, s) for s in size] + grids = torch.meshgrid(vectors) + grid = torch.stack(grids) + grid = torch.unsqueeze(grid, 0) + grid = grid.type(torch.FloatTensor) + + # registering the grid as a buffer cleanly moves it to the GPU, but it also + # adds it to the state dict. this is annoying since everything in the state dict + # is included when saving weights to disk, so the model files are way bigger + # than they need to be. so far, there does not appear to be an elegant solution. + # see: https://discuss.pytorch.org/t/how-to-register-buffer-without-polluting-state-dict + self.register_buffer('grid', grid) + + def forward(self, src, flow): + # new locations + new_locs = self.grid + flow + shape = flow.shape[2:] + + # need to normalize grid values to [-1, 1] for resampler + for i in range(len(shape)): + new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) + + # move channels dim to last position + # also not sure why, but the channels need to be reversed + if len(shape) == 2: + new_locs = new_locs.permute(0, 2, 3, 1) + new_locs = new_locs[..., [1, 0]] + elif len(shape) == 3: + new_locs = new_locs.permute(0, 2, 3, 4, 1) + new_locs = new_locs[..., [2, 1, 0]] + + return nnf.grid_sample(src, new_locs, align_corners=True, mode=self.mode) + +class SwinVNetSkip(nn.Module): + def __init__(self, config): + super(SwinVNetSkip, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=config.in_chans, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + self.up0 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim, embed_dim//2, skip_channels=embed_dim//2 if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim//2, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim//2, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead_new( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.transformer(x) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + +from monai.networks.blocks import UnetrBasicBlock,UnetResBlock,UnetrUpBlock,UnetrPrUpBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock, get_conv_layer, UnetBasicBlock + +from typing import Sequence, Tuple, Union + +class SWINUnetrUpBlock(nn.Module): + """ + An upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>" + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = torch.cat((inp, skip), dim=1) + out = self.conv_block(out) + out = self.transp_conv(out) + + return out + + +class SWINUnetrBlock(nn.Module): + """ + An upsampling module that can be used for UNETR: "Hatamizadeh et al., + UNETR: Transformers for 3D Medical Image Segmentation <https://arxiv.org/abs/2103.10504>" + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + upsample_kernel_size: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + """ + Args: + spatial_dims: number of spatial dimensions. + in_channels: number of input channels. + out_channels: number of output channels. + kernel_size: convolution kernel size. + upsample_kernel_size: convolution kernel size for transposed convolution layers. + norm_name: feature normalization type and arguments. + res_block: bool argument to determine if residual block is used. + + """ + + super().__init__() + upsample_stride = upsample_kernel_size + self.transp_conv = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=upsample_kernel_size, + stride=upsample_stride, + conv_only=True, + is_transposed=True, + ) + + if res_block: + self.conv_block = UnetResBlock( + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + else: + self.conv_block = UnetBasicBlock( # type: ignore + spatial_dims, + in_channels + in_channels, + in_channels, + kernel_size=kernel_size, + stride=1, + norm_name=norm_name, + ) + + def forward(self, inp, skip): + # number of channels for skip should equals to out_channels + out = torch.cat((inp, skip), dim=1) + out = self.conv_block(out) + #out = self.transp_conv(out) + + return out + +class SwinUNETR_self(nn.Module): + def __init__(self, config): + super(SwinUNETR_self, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=config.in_chans, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + #print(out[-1].size()) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_inputsFusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_inputsFusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer(patch_size=config.patch_size, + in_chans=1, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + # self.res_fusionBlock = UnetResBlock( + # spatial_dims=3, + # in_channels=config.in_chans, + # out_channels=1, + # kernel_size=3, + # stride=1, + # norm_name='instance', + # ) + + self.res_fusionBlock = depthwise_separable_conv3d( + nin=config.in_chans, + kernels_per_layer=48, + nout=1, + ) + self.encoder0 = depthwise_separable_conv3d( + nin=1, + kernels_per_layer=48, + nout=embed_dim, + ) + + + # UnetrBasicBlock( + # spatial_dims=3, + # in_channels=1, + # out_channels=embed_dim, + # kernel_size=3, + # stride=1, + # norm_name='instance', + # res_block=True, + # ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + x = self.res_fusionBlock(x) + + out = self.transformer(x) # (B, n_patch, hidden) + # print(out[-1].size()) + + # stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + enc3 = self.encoder3(out[-3]) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + enc2 = self.encoder2(out[-4]) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class depthwise_separable_conv3d(nn.Module): + def __init__(self, nin, kernels_per_layer, nout): + super(depthwise_separable_conv3d, self).__init__() + self.depthwise = nn.Conv3d(nin, nin * kernels_per_layer, kernel_size=3, padding=1, stride=1, groups=nin) + self.pointwise = nn.Conv3d(nin * kernels_per_layer, nout, kernel_size=1) + + def forward(self, x): + out = self.depthwise(x) + out = self.pointwise(out) + return out + + +class SwinUNETR_dense(nn.Module): + def __init__(self, config): + super(SwinUNETR_dense, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_dense(patch_size=config.patch_size, + in_chans=1, + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + concatenated_input=False) + + self.res_fusionBlock = UnetResBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=1, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + # self.res_fusionBlock = depthwise_separable_conv3d( + # nin=config.in_chans, + # kernels_per_layer=48, + # nout=1, + # ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=1, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim, + out_channels=embed_dim, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + x = self.res_fusionBlock(x) + print(x.size()) + out = self.transformer(x) # (B, n_patch, hidden) + # print(out[-1].size()) + + # stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + enc3 = self.encoder3(out[-3]) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + enc2 = self.encoder2(out[-4]) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + +class SwinVNetSkip_transfuser(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim, embed_dim//2, skip_channels=embed_dim//2 if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim//2, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim//2, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + + +class SwinVNetSkip_transfuser_concat(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + + +class SwinVNetSkip_transfuser_concat_noCrossModalUpdating(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat_noCrossModalUpdating, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + +class SwinVNetSkip_transfuser_concat_noCrossModalUpdating_ConvPoolDownsampling(nn.Module): + def __init__(self, config): + super(SwinVNetSkip_transfuser_concat_noCrossModalUpdating_ConvPoolDownsampling, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.swinTransfuser = SwinTransformer_wFeatureTalk_concat_noCrossModalUpdating_ConvPoolDownsampling(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), # + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method) + self.up0 = DecoderBlock(embed_dim*16, embed_dim*8, skip_channels=embed_dim*8 if if_transskip else 0, use_batchnorm=False) + self.up1 = DecoderBlock(embed_dim*8, embed_dim*4, skip_channels=embed_dim*4 if if_transskip else 0, use_batchnorm=False) # 384, 20, 20, 64 + self.up2 = DecoderBlock(embed_dim*4, embed_dim*2, skip_channels=embed_dim*2 if if_transskip else 0, use_batchnorm=False) # 384, 40, 40, 64 + self.up3 = DecoderBlock(embed_dim*2, embed_dim, skip_channels=embed_dim if if_convskip else 0, use_batchnorm=False) # 384, 80, 80, 128 + self.up4 = DecoderBlock(embed_dim, config.seg_head_chan, skip_channels=config.seg_head_chan if if_convskip else 0, use_batchnorm=False) # 384, 160, 160, 256 + self.c1 = Conv3dReLU(2, embed_dim, 3, 1, use_batchnorm=False) + self.c2 = Conv3dReLU(2, config.seg_head_chan, 3, 1, use_batchnorm=False) + self.seg_head = SegmentationHead_new( + in_channels=config.seg_head_chan, + num_classes=2, + kernel_size=3, + ) + self.spatial_trans = SpatialTransformer(config.img_size) + self.avg_pool = nn.AvgPool3d(3, stride=2, padding=1) + + def forward(self, x): + #source = x[:, 0:1, :, :] + x_0 = torch.unsqueeze(x[:, 0, :, :, :],1) + x_1 = torch.unsqueeze(x[:, 1, :, :, :],1) + if self.if_convskip: + x_s0 = x.clone() + x_s1 = self.avg_pool(x) + f4 = self.c1(x_s1) + f5 = self.c2(x_s0) + else: + f4 = None + f5 = None + + out = self.swinTransfuser(x_0,x_1) # (B, n_patch, hidden) + + if self.if_transskip: + f1 = out[-2] + f2 = out[-3] + f3 = out[-4] + else: + f1 = None + f2 = None + f3 = None + x = self.up0(out[-1], f1) + x = self.up1(x, f2) + x = self.up2(x, f3) + x = self.up3(x, f4) + x = self.up4(x, f5) + out = self.seg_head(x) + #out = self.spatial_trans(source, flow) + return out + + +class SwinUNETR_fusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_fusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wFeatureTalk_concat_PETUpdatingOnly_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + + + +class SwinUNETR_dualModalityFusion_OutConcat(nn.Module): + def __init__(self, config): + super(SwinUNETR_dualModalityFusion_OutConcat, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutConcat_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *2, + out_channels=embed_dim *2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*32, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*16, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim *2 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_inputFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_inputFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_wInputFusion_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + self.res_fusionBlock = depthwise_separable_conv3d( + nin=config.in_chans, + kernels_per_layer=48, + nout=1, + ) + + def forward(self, x): + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +########################################################################## +def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): + return nn.Conv3d( + in_channels, out_channels, kernel_size, + padding=(kernel_size//2), bias=bias, stride = stride) + +## Channel Attention Block (CAB) + +class CAB(nn.Module): + def __init__(self, n_feat, kernel_size, reduction=4, bias=False, act = nn.PReLU()): + super(CAB, self).__init__() + modules_body = [] + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + modules_body.append(act) + modules_body.append(conv(n_feat, n_feat, kernel_size, bias=bias)) + + self.CA = CALayer(n_feat, reduction, bias=bias) #n_feat = channel, noiseLevel_dim + self.body = nn.Sequential(*modules_body) + + def forward(self, x): + res = self.body(x) #x.shape=[4,80,32,32,32] and res.shape=[4,80,32,32,32] + res = self.CA(res) + res += x + return res +## Channel Attention Layer +class CALayer(nn.Module): + def __init__(self, channel, reduction=16, bias=False): + super(CALayer, self).__init__() + # global average pooling: feature --> point + self.avg_pool = nn.AdaptiveAvgPool3d(1) + # feature channel downscale and upscale --> channel weight + self.conv_du = nn.Sequential( + nn.Conv3d(channel, channel // reduction, 1, padding=0, bias=bias), + nn.ReLU(inplace=True), + nn.Conv3d(channel // reduction, channel, 1, padding=0, bias=bias), + nn.Sigmoid() + ) + + def forward(self, x): + y = self.avg_pool(x) + y = self.conv_du(y) + return x * y + +class SwinUNETR_CrossModalityFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_OutSum_6stageOuts(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum_6stageOuts, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + # self.encoder0 = depthwise_separable_conv3d( + # nin=2, + # kernels_per_layer=96, + # nout=embed_dim, + # ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + #print(1) + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_CrossModalityFusion_OutSum_wChAttn(nn.Module): + def __init__(self, config): + super(SwinUNETR_CrossModalityFusion_OutSum_wChAttn, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wCrossModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int( + config.in_chans / 2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim * 1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.CAB1 = CAB( + n_feat=embed_dim * 1, + kernel_size=3, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.CAB2 = CAB( + n_feat=embed_dim * 2, + kernel_size=3, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB3 = CAB( + n_feat=embed_dim * 4, + kernel_size=3, + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB4 = CAB( + n_feat=embed_dim * 8, + kernel_size=3, + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.CAB5 = CAB( + n_feat=embed_dim * 16, + kernel_size=3, + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + def forward(self, x): + out = self.transformer(x) # (B, n_patch, hidden) + + # stage 4 features + cab5 = self.CAB5(out[-1]) + enc5 = self.res_botneck(cab5) # B, 5,5,5,2048 + + dec4 = self.decoder5(enc5) # B, 10,10,10,1024 + cab4 = self.CAB4(out[-2]) + enc4 = self.encoder4(cab4) # skip features should be twice the di + + # stage 3 features + dec3 = self.decoder4(dec4, enc4) + cab3 = self.CAB3(out[-3]) + enc3 = self.encoder3(cab3) # skip features + + # stage 2 features + dec2 = self.decoder3(dec3, enc3) + cab2 = self.CAB2(out[-4]) + enc2 = self.encoder2(cab2) # skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + cab1 = self.CAB1(out[-5]) + enc1 = self.encoder1(cab1) # skip features + + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + +class SwinUNETR_dualModalityFusion_OutSum(nn.Module): + def __init__(self, config): + super(SwinUNETR_dualModalityFusion_OutSum, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wDualModalityFeatureTalk_OutSum_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *1, + out_channels=embed_dim *1, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 2, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*8, + out_channels=embed_dim*4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim *1 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 1, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + + +class SwinUNETR_RandomSpatialFusion(nn.Module): + def __init__(self, config): + super(SwinUNETR_RandomSpatialFusion, self).__init__() + if_convskip = config.if_convskip + self.if_convskip = if_convskip + if_transskip = config.if_transskip + self.if_transskip = if_transskip + embed_dim = config.embed_dim + self.transformer = SwinTransformer_wRandomSpatialFeatureTalk_wCrossModalUpdating_5stageOuts(patch_size=config.patch_size, + in_chans=int(config.in_chans/2), + embed_dim=config.embed_dim, + depths=config.depths, + num_heads=config.num_heads, + window_size=config.window_size, + mlp_ratio=config.mlp_ratio, + qkv_bias=config.qkv_bias, + drop_rate=config.drop_rate, + drop_path_rate=config.drop_path_rate, + ape=config.ape, + spe=config.spe, + patch_norm=config.patch_norm, + use_checkpoint=config.use_checkpoint, + out_indices=config.out_indices, + pat_merg_rf=config.pat_merg_rf, + pos_embed_method=config.pos_embed_method, + ) + + self.encoder0 = UnetrBasicBlock( + spatial_dims=3, + in_channels=config.in_chans, + out_channels=embed_dim*2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim *2, + out_channels=embed_dim *2, + kernel_size=3, + stride=1, + norm_name='instance', + res_block=True, + ) + + self.encoder2 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 4, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder3 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 8, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.encoder4 = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim * 16, + out_channels=embed_dim * 16, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.res_botneck = UnetResBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*32, + kernel_size=3, + stride=1, + norm_name='instance', + ) + + self.decoder5 = UnetrPrUpBlock( + spatial_dims=3, + in_channels=embed_dim*32, + out_channels=embed_dim*16, + num_layer=0, + kernel_size=3, + stride=1, + upsample_kernel_size=2, + norm_name='instance', + conv_block=True, + res_block=True, + ) + + self.decoder4 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim*16, + out_channels=embed_dim*8, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder3 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 8, + out_channels=embed_dim * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + self.decoder2 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 4, + out_channels=embed_dim * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder1 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim *2 , + kernel_size=3, + upsample_kernel_size=2, + norm_name='instance', + res_block=True, + ) + + self.decoder0 = SWINUnetrUpBlock( + spatial_dims=3, + in_channels=embed_dim * 2, + out_channels=embed_dim * 1, + kernel_size=3, + upsample_kernel_size=1, + norm_name='instance', + res_block=True, + ) + + self.out = UnetOutBlock(spatial_dims=3, in_channels=embed_dim, out_channels=2) # type: ignore + + + def forward(self, x): + + + + out = self.transformer(x) # (B, n_patch, hidden) + + #stage 4 features + enc5 = self.res_botneck(out[-1]) # B, 5,5,5,2048 + dec4 = self.decoder5(enc5) #B, 10,10,10,1024 + enc4 = self.encoder4(out[-2]) #skip features should be twice the di + + #stage 3 features + dec3 = self.decoder4(dec4,enc4) + enc3 = self.encoder3(out[-3]) #skip features + + #stage 2 features + dec2 = self.decoder3(dec3,enc3) + enc2 = self.encoder2(out[-4]) #skip features + + # stage 1 features + dec1 = self.decoder2(dec2, enc2) + enc1 = self.encoder1(out[-5]) # skip features + + dec0 = self.decoder1(dec1, enc1) + enc0 = self.encoder0(x) + + head = self.decoder0(dec0, enc0) + + logits = self.out(head) + + return logits + diff --git a/models/msk_smit_lung_gtv/src/smit_models/smit_plus.py b/models/msk_smit_lung_gtv/src/smit_models/smit_plus.py new file mode 100644 index 00000000..333ad19b --- /dev/null +++ b/models/msk_smit_lung_gtv/src/smit_models/smit_plus.py @@ -0,0 +1,1938 @@ + +import math +from typing import Callable, Optional, Tuple, Union +import ml_collections +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from einops import rearrange + +from functools import partial +from timm.models.layers import Mlp, DropPath, to_2tuple, trunc_normal_, _assert, ClassifierHead,to_3tuple +from monai.networks.blocks import UnetrBasicBlock, UnetrPrUpBlock, UnetrUpBlock +from monai.networks.blocks.dynunet_block import UnetOutBlock + +#from ._features_fx import register_notrace_function + +import torch.nn.functional as nnf +#Jue added and changed +from timm.models.registry import register_model#,generate_default_cfgs +from timm.models.layers import Mlp, DropPath, trunc_normal_, lecun_normal_ +from timm.models.helpers import build_model_with_cfg, named_apply, adapt_input_conv +from .format import Format, nchw_to +import numpy as np + +from monai.networks.blocks.convolutions import Convolution +from monai.networks.layers.factories import Act, Norm +from monai.networks.layers.utils import get_act_layer, get_norm_layer + +__all__ = ['SwinTransformerV2'] # model_registry will add each entrypoint fn to this + +_int_or_tuple_2_t = Union[int, Tuple[int, int]] + + +_autowrap_functions = set() + +from typing import Optional, Sequence, Tuple, Union + +def get_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = (kernel_size_np - stride_np + 1) / 2 + if np.min(padding_np) < 0: + raise AssertionError("padding value should not be negative, please change the kernel size and/or stride.") + padding = tuple(int(p) for p in padding_np) + + return padding if len(padding) > 1 else padding[0] + + +def get_output_padding( + kernel_size: Union[Sequence[int], int], stride: Union[Sequence[int], int], padding: Union[Sequence[int], int] +) -> Union[Tuple[int, ...], int]: + kernel_size_np = np.atleast_1d(kernel_size) + stride_np = np.atleast_1d(stride) + padding_np = np.atleast_1d(padding) + + out_padding_np = 2 * padding_np + stride_np - kernel_size_np + if np.min(out_padding_np) < 0: + raise AssertionError("out_padding value should not be negative, please change the kernel size and/or stride.") + out_padding = tuple(int(p) for p in out_padding_np) + + return out_padding if len(out_padding) > 1 else out_padding[0] + +from torch import Tensor, Size +from typing import Union, List +import numbers + +from torch.nn.parameter import Parameter + + +_shape_t = Union[int, List[int], Size] + +class LayerNormWithForceFP32(nn.Module): + __constants__ = ['normalized_shape', 'eps', 'elementwise_affine'] + normalized_shape: _shape_t + eps: float + elementwise_affine: bool + + def __init__(self, normalized_shape: _shape_t, eps: float = 1e-5, elementwise_affine: bool = True) -> None: + super(LayerNormWithForceFP32, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = tuple(normalized_shape) + self.eps = eps + self.elementwise_affine = elementwise_affine + if self.elementwise_affine: + self.weight = Parameter(torch.Tensor(*normalized_shape)) + self.bias = Parameter(torch.Tensor(*normalized_shape)) + else: + self.register_parameter('weight', None) + self.register_parameter('bias', None) + self.reset_parameters() + + def reset_parameters(self) -> None: + if self.elementwise_affine: + nn.init.ones_(self.weight) + nn.init.zeros_(self.bias) + + def forward(self, input: Tensor) -> Tensor: + return F.layer_norm( + input.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps).type_as(input) + + def extra_repr(self) -> Tensor: + return '{normalized_shape}, eps={eps}, ' \ + 'elementwise_affine={elementwise_affine}'.format(**self.__dict__) + +def get_conv_layer( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int] = 3, + stride: Union[Sequence[int], int] = 1, + act: Optional[Union[Tuple, str]] = Act.PRELU, + norm: Union[Tuple, str] = Norm.INSTANCE, + dropout: Optional[Union[Tuple, str, float]] = None, + bias: bool = False, + conv_only: bool = True, + is_transposed: bool = False, +): + padding = get_padding(kernel_size, stride) + output_padding = None + if is_transposed: + output_padding = get_output_padding(kernel_size, stride, padding) + return Convolution( + spatial_dims, + in_channels, + out_channels, + strides=stride, + kernel_size=kernel_size, + act=act, + norm=norm, + dropout=dropout, + bias=bias, + conv_only=conv_only, + is_transposed=is_transposed, + padding=padding, + output_padding=output_padding, + ) + + +class UnetResBlock_No_Downsampleing(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}), + dropout: Optional[Union[Tuple, str, float]] = None, + ): + super().__init__() + self.conv1 = get_conv_layer( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + dropout=dropout, + conv_only=True, + ) + self.conv2 = get_conv_layer( + spatial_dims, out_channels, out_channels, kernel_size=kernel_size, stride=1, dropout=dropout, conv_only=True + ) + + self.lrelu = get_act_layer(name=act_name) + self.norm1 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.norm2 = get_norm_layer(name=norm_name, spatial_dims=spatial_dims, channels=out_channels) + self.downsample = in_channels != out_channels + stride_np = np.atleast_1d(stride) + if not np.all(stride_np == 1): + self.downsample = True + + def forward(self, inp): + residual = inp + out = self.conv1(inp) + out = self.norm1(out) + out = self.lrelu(out) + out = self.conv2(out) + out = self.norm2(out) + out += residual + out = self.lrelu(out) + return out + + +class UnetrBasicBlock_No_DownSampling(nn.Module): + + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: Union[Sequence[int], int], + stride: Union[Sequence[int], int], + norm_name: Union[Tuple, str], + res_block: bool = False, + ) -> None: + + + super().__init__() + + if res_block: + self.layer = UnetResBlock_No_Downsampleing( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + else: + self.layer = UnetBasicBlock( # type: ignore + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + norm_name=norm_name, + ) + + def forward(self, inp): + return self.layer(inp) + +def register_notrace_function(func: Callable): + """ + Decorator for functions which ought not to be traced through + """ + _autowrap_functions.add(func) + return func + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + """ + output_fmt: Format + + def __init__( + self, + img_size: int = 224, + patch_size: int = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten: bool = True, + output_fmt: Optional[str] = None, + bias: bool = True, + ): + super().__init__() + img_size = to_3tuple(img_size) + patch_size = to_3tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1],img_size[2] // patch_size[2]) + self.num_patches = self.grid_size[0] * self.grid_size[1]* self.grid_size[2] + if output_fmt is not None: + self.flatten = False + self.output_fmt = Format(output_fmt) + else: + # flatten spatial dim and transpose to channels last, kept for bwd compat + self.flatten = flatten + self.output_fmt = Format.NCHW + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x): + B, C, H, W,T = x.shape + _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # NCHW -> NLC + elif self.output_fmt != Format.NCHW: + x = nchw_to(x, self.output_fmt) + x = self.norm(x) + B,h,w,t,C=x.shape + x=x.view(B,-1,C) + #print ('info: after patch emb ',x.shape) + return x + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, L, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, window_size, C) + """ + B, H, W, L, C = x.shape + + x = x.view(B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], L // window_size[2], window_size[2], C) + + windows = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(-1, window_size[0], window_size[1], window_size[2], C) + #print ('info: after window_parti window size ',windows.shape) + return windows + + +@register_notrace_function # reason: int argument is a Proxy +def window_reverse(windows, window_size, H, W, L): + """ + Args: + windows: (num_windows*B, window_size, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + L (int): Length of image + Returns: + x: (B, H, W, L, C) + """ + B = int(windows.shape[0] / (H * W * L / window_size[0] / window_size[1] / window_size[2])) + x = windows.view(B, H // window_size[0], W // window_size[1], L // window_size[2], window_size[0], window_size[1], window_size[2], -1) + x = x.permute(0, 1, 3, 5, 2, 4, 6, 7).contiguous().view(B, H, W, L, -1) + return x + + + + +class WindowAttention(nn.Module): + r""" Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0., + proj_drop=0., + pretrained_window_size=[0, 0], + ): + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) + + # mlp to generate continuous relative position bias + self.cpb_mlp = nn.Sequential( + nn.Linear(3, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False) + ) + + # get relative_coords_table + relative_coords_h = torch.arange(-(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32) + relative_coords_w = torch.arange(-(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32) + relative_coords_t = torch.arange(-(self.window_size[2] - 1), self.window_size[2], dtype=torch.float32) + + + relative_coords_table = torch.stack(torch.meshgrid([ + relative_coords_h, + relative_coords_w, + relative_coords_t])).permute(1, 2, 3, 0).contiguous().unsqueeze(0) # 1, 2*Wh-1, 2*Ww-1,2*Wt-1, 3 + + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= (pretrained_window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (pretrained_window_size[1] - 1) + relative_coords_table[:, :, :, 2] /= (pretrained_window_size[2] - 1) + else: + relative_coords_table[:, :, :, 0] /= (self.window_size[0] - 1) + relative_coords_table[:, :, :, 1] /= (self.window_size[1] - 1) + relative_coords_table[:, :, :, 2] /= (self.window_size[2] - 1) + + relative_coords_table *= 8 # normalize to -8, 8 + + relative_coords_table = torch.sign(relative_coords_table) * torch.log2( + torch.abs(relative_coords_table) + 1.0) / math.log2(8) + + self.register_buffer("relative_coords_table", relative_coords_table, persistent=False) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords_t = torch.arange(self.window_size[2]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w, coords_t])) # 3, Wh, Ww, Wt + + coords_flatten = torch.flatten(coords, 1) # 3, Wh*Ww*Wt + relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wh*Ww*Wt, Wh*Ww*Wt + relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww*Wt, Wh*Ww*Wt, 3 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 2] += self.window_size[2] - 1 + relative_coords[:, :, 0] *= (2 * self.window_size[1] - 1) * (2 * self.window_size[2] - 1) + relative_coords[:, :, 1] *= 2 * self.window_size[2] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww*Wt, Wh*Ww*Wt + self.register_buffer("relative_position_index", relative_position_index) + + # ('info: relative_position_index ',relative_position_index.shape) + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.register_buffer('k_bias', torch.zeros(dim), persistent=False) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.k_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask: Optional[torch.Tensor] = None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + + q=q.float() + k=k.float() + attn = (F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)) + + logit_scale = torch.clamp(self.logit_scale, max=math.log(1. / 0.01)).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(-1, self.num_heads) + relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( + self.window_size[0] * self.window_size[1] * self.window_size[2], self.window_size[0] * self.window_size[1] * self.window_size[2], -1) # Wh*Ww,Wh*Ww,nH + + relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + num_win = mask.shape[0] + attn = attn.view(-1, num_win, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + + attn=attn.half() # only use when don't use deep seepd + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + + +class SwinTransformerBlock(nn.Module): + r""" Swin Transformer Block. + Args: + dim (int): Number of input channels. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, num_heads, window_size=(7, 7, 7), shift_size=(0, 0, 0), + mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., + act_layer=nn.GELU, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + + + + assert 0 <= min(self.shift_size) < min(self.window_size), "shift_size must in 0-window_size, shift_sz: {}, win_size: {}".format(self.shift_size, self.window_size) + + norm_layer1=partial(LayerNormWithForceFP32, eps=1e-6) + self.norm1 = norm_layer1(dim) + + + + self.attn = WindowAttention( + dim, window_size=self.window_size, num_heads=num_heads, + qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) + + + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + norm_layer2=partial(LayerNormWithForceFP32, eps=1e-6) + self.norm2 = norm_layer2(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) + + self.H = None + self.W = None + self.T = None + + + def forward(self, x, mask_matrix=None): + H, W, T = self.H, self.W, self.T + + + B, L, C = x.shape + + + + + H=round(math.pow(L,1/3.)) + W=H + T=H + shortcut = x + + x = x.view(B, H, W, T, C) + #print ('x size is ',x.size()) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' L' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + x = self.norm1(x) + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + + x = x.view(B, H * W * T, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + + +class SwinTransformerV2Block(nn.Module): + """ Swin Transformer Block. + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + num_heads: Number of attention heads. + window_size: Window size. + shift_size: Shift size for SW-MSA. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate. + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + act_layer: Activation layer. + norm_layer: Normalization layer. + pretrained_window_size: Window size in pretraining. + """ + super().__init__() + self.dim = dim + self.input_resolution = to_3tuple(input_resolution) + self.num_heads = num_heads + + + ws, ss = self._calc_window_shift(window_size, shift_size) + + self.window_size: Tuple[int, int,int] = ws + self.shift_size: Tuple[int, int,int] = ss + + + self.window_area = self.window_size[0] * self.window_size[1]* self.window_size[2] + self.mlp_ratio = mlp_ratio + + self.attn = WindowAttention( + dim, + window_size=to_3tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=to_3tuple(pretrained_window_size), + ) + + + norm_layer1=partial(LayerNormWithForceFP32, eps=1e-6) + self.norm1 = norm_layer(dim) + self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + self.mlp = Mlp( + in_features=dim, + hidden_features=int(dim * mlp_ratio), + act_layer=act_layer, + drop=drop, + ) + norm_layer2=partial(LayerNormWithForceFP32, eps=1e-6) + self.norm2 = norm_layer2(dim) + + self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() + + if any(self.shift_size): + # calculate attention mask for SW-MSA + + + H, W,T = self.input_resolution + + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + + img_mask = torch.zeros((1, Hp, Wp, Tp, 1)) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_area) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + + self.register_buffer("attn_mask", attn_mask) + + def _calc_window_shift(self, target_window_size, target_shift_size) -> Tuple[Tuple[int, int], Tuple[int, int]]: + target_window_size = to_3tuple(target_window_size) + target_shift_size = to_3tuple(target_shift_size) + + + window_size = [r if r <= w else w for r, w in zip(self.input_resolution, target_window_size)] + shift_size = [0 if r <= w else s for r, w, s in zip(self.input_resolution, window_size, target_shift_size)] + + + return tuple(window_size), tuple(shift_size) + + def _attn(self, x): + + + B, H, W,T, C = x.shape + #print ('x size is ',x.size()) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + #attn_mask = mask_matrix + else: + shifted_x = x + #attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=self.attn_mask) # nW*B, window_size*window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' L' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + return x + + def forward(self, x): + + + B, L, C = x.shape + + B_,L_,C_=x.shape + H_=round(math.pow(L_,1/3.)) + W_=H_ + T_=H_ + + #assert L == H * W * T, "input feature has wrong size" + + H=round(math.pow(L,1/3.)) + W=H + T=H + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, T, C) + #print ('x size is ',x.size()) + # pad feature maps to multiples of window size + pad_l = pad_t = pad_f = 0 + pad_r = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + pad_b = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_h = (self.window_size[2] - T % self.window_size[2]) % self.window_size[2] + x = nnf.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b, pad_f, pad_h)) + _, Hp, Wp, Tp, _ = x.shape + + # cyclic shift + if min(self.shift_size) > 0: + shifted_x = torch.roll(x, shifts=(-self.shift_size[0], -self.shift_size[1], -self.shift_size[2]), dims=(1, 2, 3)) + attn_mask = self.attn_mask + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, window_size, C + x_windows = x_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2], C) # nW*B, window_size*window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size[0], self.window_size[1], self.window_size[2], C) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp, Tp) # B H' W' L' C + + # reverse cyclic shift + if min(self.shift_size) > 0: + x = torch.roll(shifted_x, shifts=(self.shift_size[0], self.shift_size[1], self.shift_size[2]), dims=(1, 2, 3)) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :T, :].contiguous() + + x = x.view(B, H * W * T, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x + + +class PatchMerging(nn.Module): + r""" Patch Merging Layer. + Args: + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm, reduce_factor=2): + super().__init__() + self.dim = dim + #self.reduction = nn.Linear(8 * dim, (4//reduce_factor) * dim, bias=False) + #self.reduction = nn.Linear(8 * dim, (4//reduce_factor) * dim, bias=False) + self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False) + self.norm = norm_layer( 2*dim) + + + def forward(self, x): + """ + x: B, H*W*T, C + """ + # ('x before merge ',x.shape) + B, L, C = x.shape + + H=round(math.pow(L,1/3.)) + W=H + T=H + #assert L == H * W * T, "input feature has wrong size" + #assert H % 2 == 0 and W % 2 == 0 and T % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, T, C) + #print ('x reshape ',x.shape) + # padding + pad_input = (H % 2 == 1) or (W % 2 == 1) or (T % 2 == 1) + if pad_input: + x = nnf.pad(x, (0, 0, 0, W % 2, 0, H % 2, 0, T % 2)) + + x0 = x[:, 0::2, 0::2, 0::2, :] # B H/2 W/2 T/2 C + x1 = x[:, 1::2, 0::2, 0::2, :] # B H/2 W/2 T/2 C + x2 = x[:, 0::2, 1::2, 0::2, :] # B H/2 W/2 T/2 C + x3 = x[:, 0::2, 0::2, 1::2, :] # B H/2 W/2 T/2 C + x4 = x[:, 1::2, 1::2, 0::2, :] # B H/2 W/2 T/2 C + x5 = x[:, 0::2, 1::2, 1::2, :] # B H/2 W/2 T/2 C + x6 = x[:, 1::2, 0::2, 1::2, :] # B H/2 W/2 T/2 C + x7 = x[:, 1::2, 1::2, 1::2, :] # B H/2 W/2 T/2 C + x = torch.cat([x0, x1, x2, x3, x4, x5, x6, x7], -1) # B H/2 W/2 T/2 8*C + x = x.view(B, -1, 8 * C) # B H/2*W/2*T/2 8*C + #print ('self.dim is ',self.dim) + #print ('error x2 before reduction resahpe ',x.shape) + x = self.reduction(x) + #print ('error x3 after reduction resahpe ',x.shape) + x = self.norm(x) + #print ('x after merge ',x.shape) #B,L,C + + return x + + +#V2 use BHWC input +class PatchMerging_V2(nn.Module): + """ Patch Merging Layer. + """ + + def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm): + """ + Args: + dim (int): Number of input channels. + out_dim (int): Number of output channels (or 2 * dim if None) + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + super().__init__() + self.dim = dim + self.out_dim = out_dim or 2 * dim + self.reduction = nn.Linear(8 * dim, self.out_dim, bias=False) + self.norm = norm_layer(self.out_dim) + + def forward(self, x): + + B, H, W, T,C = x.shape + print ('error: before patch merging size ',x.shape) + _assert(H % 2 == 0, f"x height ({H}) is not even.") + _assert(W % 2 == 0, f"x width ({W}) is not even.") + _assert(T % 2 == 0, f"x width ({T}) is not even.") + x = x.reshape(B, H // 2, 2, W // 2, 2, T//2, 2, C).permute(0, 1, 3,5, 4, 2, 6,7).flatten(4) + #print ('error x afer resahpe ',x.shape) + x = self.reduction(x) + x = self.norm(x) + #print ('error: after patch merging size ',x.shape) + B, H, W, T,C = x.shape + x=x.view(B,-1,C) + #x_out=torch.zeros(B, H, W, T,C ).cuda().half() + return x#x_out + + +#V2 use BLC input +class PatchMerging_V2_2(nn.Module): + """ Patch Merging Layer. + """ + + def __init__(self, dim, out_dim=None, norm_layer=nn.LayerNorm): + """ + Args: + dim (int): Number of input channels. + out_dim (int): Number of output channels (or 2 * dim if None) + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + super().__init__() + self.dim = dim + self.out_dim = out_dim or 2 * dim + self.reduction = nn.Linear(4* dim, self.out_dim, bias=False) + self.norm = norm_layer(self.out_dim) + + def forward(self, x): + + B, L,C = x.shape + H=round(math.pow(L,1/3.)) + W=H + T=H + # ('error: before patch merging size ',x.shape) + + x = x.reshape(B, H // 2, 2, W // 2, 2, T//2, 2, C).permute(0, 1, 3,5, 4, 2, 6,7).flatten(4) + #print ('error x afer resahpe ',x.shape) + x = self.reduction(x) + x = self.norm(x) + #print ('error: after patch merging size ',x.shape) + B, H, W, T,C = x.shape + x=x.view(B,-1,C) + #x_out=torch.zeros(B, H, W, T,C ).cuda().half() + return x#x_out + + +class BasicLayer(nn.Module): + """ A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of feature channels + depth (int): Depths of this stage. + num_heads (int): Number of attention head. + window_size (int): Local window size. Default: 7. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__(self, + dim, + depth, + num_heads, + window_size=(7, 7, 7), + mlp_ratio=4., + qkv_bias=True, + qk_scale=None, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + downsample=None, + input_resolution=(1,1,1), + use_checkpoint=False, + pat_merg_rf=2,): + super().__init__() + self.window_size = window_size + self.shift_size = (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2) + self.depth = depth + self.use_checkpoint = use_checkpoint + self.pat_merg_rf = pat_merg_rf + + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + + # build blocks + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer,) + for i in range(depth)]) + + + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer, reduce_factor=self.pat_merg_rf) + else: + self.downsample = None + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + def forward(self, x): + """ Forward function. + Args: + x: Input feature, tensor size (B, H*W, C). + H, W: Spatial resolution of the input feature. + """ + + B,L,C=x.shape + #x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + #Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + #h=math.cbrt(L) + H=round(math.pow(L,1/3.)) + W=H + T=H + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + Tp = int(np.ceil(T / self.window_size[2])) * self.window_size[2] + img_mask = torch.zeros((1, Hp, Wp, Tp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = (slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None)) + w_slices = (slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None)) + t_slices = (slice(0, -self.window_size[2]), + slice(-self.window_size[2], -self.shift_size[2]), + slice(-self.shift_size[2], None)) + cnt = 0 + for h in h_slices: + for w in w_slices: + for t in t_slices: + img_mask[:, h, w, t, :] = cnt + cnt += 1 + + mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1] * self.window_size[2]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + + for blk in self.blocks: + blk.H, blk.W, blk.T = H, W, T + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + if self.downsample is not None: + #x_down = self.downsample(x, H, W, T) + x_down = self.downsample(x) + Wh, Ww, Wt = (H + 1) // 2, (W + 1) // 2, (T + 1) // 2 + return x_down#x, H, W, T, x_down, Wh, Ww, Wt + else: + return x#, H, W, T, x, H, W, T + +class SwinTransformerV2Stage(nn.Module): + """ A Swin Transformer V2 Stage. + """ + + def __init__( + self, + dim, + out_dim, + input_resolution, + depth, + num_heads, + window_size, + downsample=False, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + output_nchw=False, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + depth: Number of blocks. + num_heads: Number of attention heads. + window_size: Local window size. + downsample: Use downsample layer at start of the block. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + pretrained_window_size: Local window size in pretraining. + output_nchw: Output tensors on NCHW format instead of NHWC. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + self.depth = depth + self.output_nchw = output_nchw + self.grad_checkpointing = False + + + + # build blocks + qk_scale=None + window_size=(window_size,window_size,window_size) + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer,) + for i in range(depth)]) + + # patch merging / downsample layer + if downsample: + + self.downsample = PatchMerging(dim=dim, norm_layer=norm_layer) + + else: + assert dim == out_dim + self.downsample = nn.Identity() + + def forward(self, x): + # ('error: current dim is ',self.dim) + + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + + + B,L,C=x.shape + + h=round(math.pow(L,1/3.)) + + x=blk(x) + + x = self.downsample(x) + + return x + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class SwinTransformerV2Stage_out_before_downsample(nn.Module): + """ A Swin Transformer V2 Stage. + """ + + def __init__( + self, + dim, + out_dim, + input_resolution, + depth, + num_heads, + window_size, + downsample=False, + mlp_ratio=4., + qkv_bias=True, + drop=0., + attn_drop=0., + drop_path=0., + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + output_nchw=False, + ): + """ + Args: + dim: Number of input channels. + input_resolution: Input resolution. + depth: Number of blocks. + num_heads: Number of attention heads. + window_size: Local window size. + downsample: Use downsample layer at start of the block. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop: Dropout rate + attn_drop: Attention dropout rate. + drop_path: Stochastic depth rate. + norm_layer: Normalization layer. + pretrained_window_size: Local window size in pretraining. + output_nchw: Output tensors on NCHW format instead of NHWC. + """ + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.output_resolution = tuple(i // 2 for i in input_resolution) if downsample else input_resolution + self.depth = depth + self.output_nchw = output_nchw + self.grad_checkpointing = False + + + + + # build blocks + qk_scale=None + window_size=(window_size,window_size,window_size) + self.blocks = nn.ModuleList([ + SwinTransformerBlock( + dim=dim, + num_heads=num_heads, + window_size=window_size, + shift_size=(0, 0, 0) if (i % 2 == 0) else (window_size[0] // 2, window_size[1] // 2, window_size[2] // 2), + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, + norm_layer=norm_layer,) + for i in range(depth)]) + + # patch merging / downsample layer + if downsample: + #V2 use + #self.downsample = PatchMerging(dim=dim, out_dim=out_dim, norm_layer=norm_layer) + + #V1 use + self.downsample = PatchMerging(dim=dim, norm_layer=norm_layer) + + else: + assert dim == out_dim + self.downsample = nn.Identity() + + def forward(self, x): + # ('error: current dim is ',self.dim) + + + for blk in self.blocks: + if self.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint.checkpoint(blk, x) + else: + #print (' error x size before stage blk is ',x.shape) + + B,L,C=x.shape + #x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + #Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + #h=math.cbrt(L) + h=round(math.pow(L,1/3.)) + #x = blk(x,h,h,h) + x=blk(x) + x_bf=x + x = self.downsample(x) + + return x,x_bf + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class PatchEmbed_For_SSIM(nn.Module): + """ Image to Patch Embedding + Args: + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None): + super().__init__() + patch_size = to_3tuple(patch_size) + self.patch_size = patch_size + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W, T = x.size() + if W % self.patch_size[1] != 0: + x = nnf.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + if T % self.patch_size[0] != 0: + x = nnf.pad(x, (0, 0, 0, 0, 0, self.patch_size[0] - T % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww Wt + if self.norm is not None: + Wh, Ww, Wt = x.size(2), x.size(3), x.size(4) + x = x.flatten(2).transpose(1, 2) + x = self.norm(x) + x_reshape = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww, Wt) + + return x,x_reshape + + +class SwinTransformerV2_SMIT_For_FineTunning_Only(nn.Module): + """ Swin Transformer V2 with SMIT pretrained + + A PyTorch impl of : `Swin Transformer V2: Scaling Up Capacity and Resolution` + - https://arxiv.org/abs/2111.09883 + """ + + def __init__( + self, + img_size: _int_or_tuple_2_t = 128, + patch_size: int = 2, + in_chans: int = 1, + num_classes: int = 1000, + global_pool: str = 'avg', + embed_dim: int = 48, + depths: Tuple[int, ...] = (2, 2, 6, 2), + num_heads: Tuple[int, ...] = (3, 6, 12, 24), + window_size: _int_or_tuple_2_t = 5, + mlp_ratio: float = 4., + qkv_bias: bool = True, + drop_rate: float = 0., + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + norm_layer: Callable = nn.LayerNorm, + pretrained_window_sizes: Tuple[int, ...] = (0, 0, 0,0,0), + **kwargs, + ): + """ + Args: + img_size: Input image size. + patch_size: Patch size. + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + embed_dim: Patch embedding dimension. + depths: Depth of each Swin Transformer stage (layer). + num_heads: Number of attention heads in different layers. + window_size: Window size. + mlp_ratio: Ratio of mlp hidden dim to embedding dim. + qkv_bias: If True, add a learnable bias to query, key, value. + drop_rate: Dropout rate. + attn_drop_rate: Attention dropout rate. + drop_path_rate: Stochastic depth rate. + norm_layer: Normalization layer. + patch_norm: If True, add normalization after patch embedding. + pretrained_window_sizes: Pretrained window sizes of each layer. + output_fmt: Output tensor format if not None, otherwise output 'NHWC' by default. + """ + super().__init__() + + self.num_classes = num_classes + assert global_pool in ('', 'avg') + self.global_pool = global_pool + self.output_fmt = 'NHWC' + self.num_layers = len(depths) + self.embed_dim = embed_dim + self.num_features = int(embed_dim * 2 ** (self.num_layers)) + self.feature_info = [] + + + print ('error: self.num_features size is ',self.num_features) + if not isinstance(embed_dim, (tuple, list)): + + embed_dim = [int(embed_dim * 2 ** i) for i in range(self.num_layers)] + + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim[0], + norm_layer=norm_layer, + output_fmt='NHWC', + ) + + dpr = [x.tolist() for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)] + layers = [] + in_dim = embed_dim[0] + scale = 1 + + + for i in range(self.num_layers): + out_dim = embed_dim[i] + + + #V2 use + layers += [SwinTransformerV2Stage_out_before_downsample( + dim=int(embed_dim[0] * 2 ** i),#in_dim,#out_dim,#in_dim,#, + out_dim=out_dim, + input_resolution=( + self.patch_embed.grid_size[0] // scale, + self.patch_embed.grid_size[1] // scale, + self.patch_embed.grid_size[2] // scale,), + + + depth=depths[i], + downsample= True,#i > 0, + num_heads=num_heads[i], + window_size=window_size, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, attn_drop=attn_drop_rate, + drop_path=dpr[i], + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_sizes[i], + )] + in_dim = out_dim + if i > 0: + scale *= 2 + + + + self.feature_info += [dict(num_chs=out_dim, reduction=8 * scale, module=f'layers.{i}')] + + #only V2 use + self.layers = nn.Sequential(*layers) + + + + self.norm = norm_layer(self.num_features) + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + + + self.encoder_stride=32 + + print ('info patch_size size ',patch_size) + + self.hidden_size=embed_dim[-1]*2 + self.pt_size=self.encoder_stride + self.feat_size= [int(img_size/self.encoder_stride),int(img_size/self.encoder_stride),int(img_size/self.encoder_stride)] + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + @torch.jit.ignore + def no_weight_decay(self): + nod = set() + for n, m in self.named_modules(): + if any([kw in n for kw in ("cpb_mlp", "logit_scale", 'relative_position_bias_table')]): + nod.add(n) + return nod + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r'^absolute_pos_embed|patch_embed', # stem and embed + blocks=r'^layers\.(\d+)' if coarse else [ + (r'^layers\.(\d+).downsample', (0,)), + (r'^layers\.(\d+)\.\w+\.(\d+)', None), + (r'^norm', (99999,)), + ] + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for l in self.layers: + l.grad_checkpointing = enable + + @torch.jit.ignore + def get_classifier(self): + return self.head.fc + + def reset_classifier(self, num_classes, global_pool=None): + self.num_classes = num_classes + self.head.reset(num_classes, global_pool) + + def proj_feat(self, x, hidden_size, feat_size): + + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + + def forward_features(self, x_in): + B, nc, w, h, t = x_in.shape + #print ('info: x_in size ',x_in.shape) # 2,1,96,96,96 + x_reshape = self.patch_embed(x_in) + + x_out=[] + x=x_reshape + for layer in self.layers: + x,x_bf = layer(x) + B,L_,C_=x_bf.shape + + sq_3=round(math.pow(L_,1/3.)) + xout=x_bf.view(B, sq_3, sq_3, sq_3, -1).permute(0, 4, 1, 2, 3).contiguous() + + #print ('x shape ',xout.shape) + x_out.append(xout) + + + x = self.norm(x) + return x,x_out + + def forward_head(self, x, pre_logits: bool = False): + return self.head(x, pre_logits=True) if pre_logits else self.head(x) + + def forward(self, x_in): + x,x_feature = self.forward_features(x_in) + + + + + + return x,x_feature + + +def checkpoint_filter_fn(state_dict, model): + state_dict = state_dict.get('model', state_dict) + state_dict = state_dict.get('state_dict', state_dict) + if 'head.fc.weight' in state_dict: + return state_dict + out_dict = {} + import re + for k, v in state_dict.items(): + if any([n in k for n in ('relative_position_index', 'relative_coords_table')]): + continue # skip buffers that should not be persistent + k = re.sub(r'layers.(\d+).downsample', lambda x: f'layers.{int(x.group(1)) + 1}.downsample', k) + k = k.replace('head.', 'head.fc.') + out_dict[k] = v + return out_dict + + +# This is SMIT plus with feature 96 channel +class SMIT_Plus_feature_96(nn.Module): + def __init__( + self, + #config, + out_channels: int , + feature_size: int = 96, + hidden_size: int = 1536, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + #out_channels: int, + ) -> None: + ''' + TransMorph Model + ''' + + #super(TransMorph_Unetr, self).__init__() + super().__init__() + self.hidden_size = hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + self.ft_size=feature_size + embed_dim = 96#config.embed_dim + self.transformer = SwinTransformerV2_SMIT_For_FineTunning_Only(img_size=128,window_size=4, embed_dim=self.ft_size,patch_size=2, depths=(2,2,40,4), num_heads=(4,8,16,32),qkv_bias=True) + + + + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + x, out_feats = self.transformer(x_in) + + + + enc44 = out_feats[-1] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[-2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[-3] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[-4] # torch.Size([4, 48, 64, 64, 64]) + x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + + + + + enc0 = self.encoder1(x_in) + + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits + + +# feature customized of SMIT Plus +class SMIT_Plus(nn.Module): + def __init__( + self, + #config, + out_channels: int , + feature_size: int = 96, + hidden_size: int = 1536, + mlp_dim: int = 3072, + img_size: int = 128, + num_heads: int = 12, + pos_embed: str = "perceptron", + norm_name: Union[Tuple, str] = "batch", + conv_block: bool = False, + res_block: bool = True, + spatial_dims: int = 3, + in_channels: int=1, + testing: bool = False, + #out_channels: int, + ) -> None: + + self.ft_size=feature_size + #self. + super().__init__() + self.hidden_size =self.ft_size*16 #hidden_size + self.feat_size=(img_size//32,img_size//32,img_size//32) + + + + # Default one use 40 depth + if testing: + self.transformer = SwinTransformerV2_SMIT_For_FineTunning_Only(img_size=128, + in_chans=in_channels, + window_size=4, + embed_dim=self.ft_size, + patch_size=2, + depths=(2, 2, 2, 2), + num_heads=(4, 8, 16, 32), + qkv_bias=True) + else: + self.transformer = SwinTransformerV2_SMIT_For_FineTunning_Only(img_size=128, + in_chans=in_channels, + window_size=4, + embed_dim=self.ft_size,patch_size=2, + depths=(2,2,40,4), + num_heads=(4,8,16,32), + qkv_bias=True) + + self.encoder1 = UnetrBasicBlock( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder2 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder3 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=2 * feature_size, + out_channels=2 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder4 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=4 * feature_size, + out_channels=4 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.encoder10 = UnetrBasicBlock_No_DownSampling( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=16 * feature_size, + kernel_size=3, + stride=1, + norm_name=norm_name, + res_block=True, + ) + + self.decoder5 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=16 * feature_size, + out_channels=8 * feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder4 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 8, + out_channels=feature_size * 4, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder3 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 4, + out_channels=feature_size * 2, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + self.decoder2 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size * 2, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.decoder1 = UnetrUpBlock( + spatial_dims=spatial_dims, + in_channels=feature_size, + out_channels=feature_size, + kernel_size=3, + upsample_kernel_size=2, + norm_name=norm_name, + res_block=True, + ) + + self.out = UnetOutBlock( + spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels + ) # type: ignore + + def proj_feat(self, x, hidden_size, feat_size): + x = x.view(x.size(0), feat_size[0], feat_size[1], feat_size[2], hidden_size) + x = x.permute(0, 4, 1, 2, 3).contiguous() + return x + + def forward(self, x_in): + + x, out_feats = self.transformer(x_in) + + + + enc44 = out_feats[-1] # torch.Size([4, 384, 8, 8, 8]) + enc33 = out_feats[-2] # torch.Size([4, 192, 16, 16, 16]) + enc22 = out_feats[-3] # torch.Size([4, 96, 32, 32, 32]) + enc11 = out_feats[-4] # torch.Size([4, 48, 64, 64, 64]) + x=self.proj_feat(x, self.hidden_size, self.feat_size) # torch.Size([4, 768, 4, 4, 4]) + + + + + enc0 = self.encoder1(x_in) + + enc1 = self.encoder2(enc11) #input size torch.Size([4, 96, 64, 64, 64]) + + enc2 = self.encoder3(enc22) #input size torch.Size([4, 192, 32, 32, 32]) + + enc3 = self.encoder4(enc33) #torch.Size([4, 384, 16, 16, 16]) + + + dec4 = self.encoder10(x) + + dec3 = self.decoder5(dec4, enc44) + dec2 = self.decoder4(dec3, enc3) + dec1 = self.decoder3(dec2, enc2) + dec0 = self.decoder2(dec1, enc1) + out = self.decoder1(dec0, enc0) + logits = self.out(out) + + + + return logits \ No newline at end of file From 8aa99897de2bd9135b79e9de4b29787be1be3062 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 12:46:59 -0400 Subject: [PATCH 09/20] Update Dockerfile --- models/msk_smit_lung_gtv/dockerfiles/Dockerfile | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile index ff18a288..cb58af28 100644 --- a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile +++ b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile @@ -4,13 +4,19 @@ FROM mhubai/base:latest LABEL authors="aptea@mskcc.org,deasyj@mskcc.org,iyera@mskcc.org,locastre@mskcc.org" RUN apt update -RUN mkdir -p /app/models/msk_smit_lung_gtv -RUN cd /app/models/msk_smit_lung_gtv && git clone https://github.com/cerr/model_installer.git && cd model_installer && source installer.sh -m 4 -d /app/models/msk_smit_lung_gtv -p C -RUN chmod -R 755 /app/models/msk_smit_lung_gtv/CT_Lung_SMIT -ENV PYTHONPATH="/app:/app/models/msk_smit_lung_gtv/CT_Lung_SMIT/conda-pack" +#ARG MHUB_MODELS_REPO= +ENV MHUB_MODELS_REPO=https://github.com/locastre/models.git +RUN buildutils/import_mhub_model.sh msk_smit_lung_gtv ${MHUB_MODELS_REPO} + +ENV WORK_DIR=/app/models/msk_smit_lung_gtv/src + +WORKDIR ${WORK_DIR} +ENV WEIGHTS_URL=https://mskcc.box.com/shared/static/sf7jic4m2dk67413cipbbq6hddvhpj61.gz +ENV CONDA_URL=https://mskcc.box.com/shared/static/d580gfjzzmt26v8klwp8pivb6wafomag.gz +RUN wget ${WEIGHTS_URL} -O weights.tar.gz && tar xvf weights.tar.gz && rm weights.tar.gz +RUN wget ${CONDA_URL} -O conda.tar.gz && tar xvf conda.tar.gz && rm conda.tar.gz -RUN source /app/models/msk_smit_lung_gtv/CT_Lung_SMIT/conda-pack/bin/activate ENTRYPOINT ["mhub.run"] CMD ["--config", "/app/models/msk_smit_lung_gtv/config/default.yml"] From ec9993033e78bd33161cef9eea26da8aca5a34cc Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 13:52:18 -0400 Subject: [PATCH 10/20] mkdir for conda-pack --- models/msk_smit_lung_gtv/dockerfiles/Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile index cb58af28..f5e2c918 100644 --- a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile +++ b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile @@ -15,7 +15,8 @@ WORKDIR ${WORK_DIR} ENV WEIGHTS_URL=https://mskcc.box.com/shared/static/sf7jic4m2dk67413cipbbq6hddvhpj61.gz ENV CONDA_URL=https://mskcc.box.com/shared/static/d580gfjzzmt26v8klwp8pivb6wafomag.gz RUN wget ${WEIGHTS_URL} -O weights.tar.gz && tar xvf weights.tar.gz && rm weights.tar.gz -RUN wget ${CONDA_URL} -O conda.tar.gz && tar xvf conda.tar.gz && rm conda.tar.gz +RUN mkdir conda-pack && chmod -R 777 conda-pack +RUN cd conda-pack && wget ${CONDA_URL} -O conda.tar.gz && tar xvf conda.tar.gz && rm conda.tar.gz ENTRYPOINT ["mhub.run"] From 483c92a4078cd7ef802463e9bd22f49cc6fbebfc Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 14:11:48 -0400 Subject: [PATCH 11/20] add default.yml --- models/msk_smit_lung_gtv/config/default.yml | 34 +++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 models/msk_smit_lung_gtv/config/default.yml diff --git a/models/msk_smit_lung_gtv/config/default.yml b/models/msk_smit_lung_gtv/config/default.yml new file mode 100644 index 00000000..58fb19b7 --- /dev/null +++ b/models/msk_smit_lung_gtv/config/default.yml @@ -0,0 +1,34 @@ +general: + data_base_dir: /app/data + version: 1.0.0 + description: Default configuration for SMIT model (dicom to dicom) + +execute: +- DicomImporter +- NiftiConverter +- SMITRunner +- DsegConverter +- DataOrganizer + +modules: + DicomImporter: + source_dir: input_data + import_dir: sorted_data + sort_data: true + meta: + mod: '%Modality' + + SMITRunner: + a_min: -500 + a_max: 500 + # Can add other config paremeters here + + DsegConverter: + model_name: SMIT + body_part_examined: CHEST + source_segs: nifti:mod=seg + skip_empty_slices: true + + DataOrganizer: + targets: + - dicomseg:mod=seg-->[i:sid]/smit.seg.dcm% From 5aa95319faba9672af20c696ed2c2db5683ae429 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 14:28:22 -0400 Subject: [PATCH 12/20] add config.json --- models/msk_smit_lung_gtv/config.json | 103 +++++++++++++++++++++++++++ 1 file changed, 103 insertions(+) create mode 100644 models/msk_smit_lung_gtv/config.json diff --git a/models/msk_smit_lung_gtv/config.json b/models/msk_smit_lung_gtv/config.json new file mode 100644 index 00000000..3c2c015e --- /dev/null +++ b/models/msk_smit_lung_gtv/config.json @@ -0,0 +1,103 @@ +{ + "id": "", + "name": "msk_smit_lung_gtv_seg", + "title": "CT Lung GTV SMIT Segmentation", + "summary": { + "description": "GTV segmentation from CT scan", + "inputs": [ + { + "label": "Input Image", + "description": "The CT scan of a patient.", + "format": "NIFTI", + "modality": "CT", + "bodypartexamined": "Chest", + "slicethickness": "5mm", + "contrast": true, + "noncontrast": true + } + ], + "outputs": [ + { + "label": "Segmentation of the lung GTV", + "description": "Segmentation of the lung GTV from NIfTI CT images.", + "type": "Segmentation", + "classes": [ + "GTV" + ] + } + ], + "model": { + "architecture": "Swin Transformer based segmentation, self-supervised pretrained with 10k CT data", + "training": "supervised", + "cmpapproach": "3D" + }, + "data": { + "training": { + "vol_samples": 377 + }, + "evaluation": { + "vol_samples": 139 + }, + "public": true, + "external": false + } + }, + "details": { + "name": "SMIT", + "version": "1.0.0", + "devteam": "", + "authors": ["Jue Jiang, Harini Veeraraghavan"], + "type": "it is a 3D Swin transformer based segmentation net", + "date": { + "code": "11.03.2025", + "weights": "11.03.2025", + "pub": "15.July.2024" + }, + "cite": "Jiang, Jue, and Harini Veeraraghavan. Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation. Proceedings of machine learning research 250 (2024): 708.", + "license": { + "code": "GNU General Public License", + "weights": "GNU General Public License" + }, + "publications": [ + { + "title": "Self-supervised pretraining in the wild imparts image acquisition robustness to medical image transformers: an application to lung cancer segmentation", + "url": "https://openreview.net/pdf?id=G9Te2IevNm" + }, + { + "title":"Self-supervised 3D anatomy segmentation using self-distilled masked image transformer (SMIT)", + "url":"https://link.springer.com/chapter/10.1007/978-3-031-16440-8_53" + } + ], + "github": "https://github.com/The-Veeraraghavan-Lab/CTRobust_Transformers.git" + }, + "info": { + "use": { + "title": "Intended use", + "text": "This model is intended to be used on CT images (with or without contrast)" + "references": [], + "tables": [] + + }, + "evaluation": { + "title": "Evaluation data", + "text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.", + "references": [], + "tables": ["validation_data_id and DSC value":{ + "Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/" + "AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598, + R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136, + R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" + }], + "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" + }, + "training": { + "title": "Training data", + "text": "Training data was from 377 data in the TCIA NSCLC-Radiomics data, references: Aerts, H. J. W. L., Wee, L., Rios Velazquez, E., Leijenaar, R. T. H., Parmar, C., Grossmann, P., Carvalho, S., Bussink, J., Monshouwer, R., Haibe-Kains, B., Rietveld, D., Hoebers, F., Rietbergen, M. M., Leemans, C. R., Dekker, A., Quackenbush, J., Gillies, R. J., Lambin, P. (2014). Data From NSCLC-Radiomics (version 4) [Data set]. The Cancer Imaging Archive." + + }, + "analyses": { + "title": "Quantitative Analyses", + "text": "DSC was used to compute the accuracy of the model" + }, + + } From 874da7e29e920f765cc13338e4563ae9dc1cd827 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Sat, 15 Mar 2025 14:28:50 -0400 Subject: [PATCH 13/20] Rename config.json to meta.json --- models/msk_smit_lung_gtv/{config.json => meta.json} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename models/msk_smit_lung_gtv/{config.json => meta.json} (100%) diff --git a/models/msk_smit_lung_gtv/config.json b/models/msk_smit_lung_gtv/meta.json similarity index 100% rename from models/msk_smit_lung_gtv/config.json rename to models/msk_smit_lung_gtv/meta.json From 5ccb70c91c4a0ae1ed2bf79bb9b8f197f18dc9d5 Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:21:57 -0400 Subject: [PATCH 14/20] Update meta.json --- models/msk_smit_lung_gtv/meta.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/msk_smit_lung_gtv/meta.json b/models/msk_smit_lung_gtv/meta.json index 3c2c015e..6d5d0023 100644 --- a/models/msk_smit_lung_gtv/meta.json +++ b/models/msk_smit_lung_gtv/meta.json @@ -1,6 +1,6 @@ { "id": "", - "name": "msk_smit_lung_gtv_seg", + "name": "msk_smit_lung_gtv", "title": "CT Lung GTV SMIT Segmentation", "summary": { "description": "GTV segmentation from CT scan", From 20b0960d3454bbedda673576b4270524f8bd46be Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:51:14 -0400 Subject: [PATCH 15/20] Update meta.json --- models/msk_smit_lung_gtv/meta.json | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/models/msk_smit_lung_gtv/meta.json b/models/msk_smit_lung_gtv/meta.json index 6d5d0023..746e3b59 100644 --- a/models/msk_smit_lung_gtv/meta.json +++ b/models/msk_smit_lung_gtv/meta.json @@ -1,6 +1,6 @@ { "id": "", - "name": "msk_smit_lung_gtv", + "name": "msk_smit_lung_gtv_seg", "title": "CT Lung GTV SMIT Segmentation", "summary": { "description": "GTV segmentation from CT scan", @@ -73,7 +73,7 @@ "info": { "use": { "title": "Intended use", - "text": "This model is intended to be used on CT images (with or without contrast)" + "text": "This model is intended to be used on CT images (with or without contrast)", "references": [], "tables": [] @@ -82,12 +82,8 @@ "title": "Evaluation data", "text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.", "references": [], - "tables": ["validation_data_id and DSC value":{ - "Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/" - "AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598, - R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136, - R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" - }], + "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" + ], "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" }, "training": { @@ -95,9 +91,10 @@ "text": "Training data was from 377 data in the TCIA NSCLC-Radiomics data, references: Aerts, H. J. W. L., Wee, L., Rios Velazquez, E., Leijenaar, R. T. H., Parmar, C., Grossmann, P., Carvalho, S., Bussink, J., Monshouwer, R., Haibe-Kains, B., Rietveld, D., Hoebers, F., Rietbergen, M. M., Leemans, C. R., Dekker, A., Quackenbush, J., Gillies, R. J., Lambin, P. (2014). Data From NSCLC-Radiomics (version 4) [Data set]. The Cancer Imaging Archive." }, - "analyses": { - "title": "Quantitative Analyses", - "text": "DSC was used to compute the accuracy of the model" - }, - + "analyses": { + "title": "Quantitative Analyses", + "text": "DSC was used to compute the accuracy of the model" + }, + "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" } +} From d5733f636126c412a7ac2bba9b2fc29bda01e5ba Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 10:52:07 -0400 Subject: [PATCH 16/20] Update meta.json --- models/msk_smit_lung_gtv/meta.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/msk_smit_lung_gtv/meta.json b/models/msk_smit_lung_gtv/meta.json index 746e3b59..7d912891 100644 --- a/models/msk_smit_lung_gtv/meta.json +++ b/models/msk_smit_lung_gtv/meta.json @@ -82,8 +82,7 @@ "title": "Evaluation data", "text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.", "references": [], - "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" - ], + "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088"], "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" }, "training": { From d346ec231999f86c7a6ce5f17e6dd3ff027d57eb Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:15:40 -0400 Subject: [PATCH 17/20] Update meta.json --- models/msk_smit_lung_gtv/meta.json | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/models/msk_smit_lung_gtv/meta.json b/models/msk_smit_lung_gtv/meta.json index 7d912891..1fcbb9fc 100644 --- a/models/msk_smit_lung_gtv/meta.json +++ b/models/msk_smit_lung_gtv/meta.json @@ -1,6 +1,6 @@ { "id": "", - "name": "msk_smit_lung_gtv_seg", + "name": "msk_smit_lung_gtv", "title": "CT Lung GTV SMIT Segmentation", "summary": { "description": "GTV segmentation from CT scan", @@ -82,7 +82,8 @@ "title": "Evaluation data", "text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.", "references": [], - "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088"], + "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" + ], "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" }, "training": { @@ -94,6 +95,9 @@ "title": "Quantitative Analyses", "text": "DSC was used to compute the accuracy of the model" }, - "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" + "limitations": { + "title": "Limitations", + "text": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" + } } } From 28a2adf6d14d8c1ef9772dc0756c57089faaffca Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:16:02 -0400 Subject: [PATCH 18/20] Update meta.json --- models/msk_smit_lung_gtv/meta.json | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/models/msk_smit_lung_gtv/meta.json b/models/msk_smit_lung_gtv/meta.json index 1fcbb9fc..a334d497 100644 --- a/models/msk_smit_lung_gtv/meta.json +++ b/models/msk_smit_lung_gtv/meta.json @@ -82,8 +82,7 @@ "title": "Evaluation data", "text": "To assess the model's segmentation performance in the NSCLC Radiogenomics dataset, we considered that the original input data is a full 3D volume. The model segmented not only the labeled tumor but also tumors that were not manually annotated. Therefore, we evaluated the model based on the manually labeled tumors. After applying the segmentation model, we extracted a 128*128*128 cubic region containing the manual segmentation to assess the model’s performance.", "references": [], - "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088" - ], + "tables": ["validation_data_id and DSC value, Validation data is 139 data in the NSCLC Radiogenomics data:https://www.cancerimagingarchive.net/collection/nsclc-radiogenomics/, AMC-001:0.023977216,AMC-005:0.84385232,AMC-006:0.844950109,AMC-011:0.885911774,AMC-013:0.786724403,AMC-014:0.628335342,AMC-016:0.708633094,AMC-019:0.791600435,AMC-020:0.882119609,AMC-021:0.834135707,AMC-022:0.688767807,AMC-026:0.801595536,R01-001:0.738330143,R01-002:0.826459454,R01-003:0.724166437,R01-004:0.643794147,R01-005:0.8740986,R01-006:0.816578249,R01-007:0.736460458,R01-008:0.570397112,R01-010:0.901700554,R01-011:0.836905321,R01-012:0.26011073,R01-013:0.760693274,R01-014:0.605606001,R01-015:0.921568729,R01-016:0.748842593,R01-018:0.899090049,R01-019:0.777296896,R01-020:0.858735841,R01-021:0.674536904,R01-022:0.773468955,R01-023:0.851143174,R01-024:0.63791364,R01-025:0.667036976,R01-026:0.867828559,R01-027:0.849266954,R01-028:0.914362163,R01-029:0.796479193,R01-030:0.742501087,R01-031:0.771934798,R01-032:0.546395241,R01-033:0.668465959,R01-034:0.491623711,R01-035:0.861957664,R01-036:0.834929738,R01-039:0.640360767,R01-040:0.843040538,R01-041:0.255910987,R01-042:0.827863856,R01-043:0.358487119,R01-045:0.556983182,R01-046:0.798674399,R01-047:0.875100294,R01-048:0.86953796,R01-049:0.831395349,R01-050:0.736791014,R01-051:0.863763708,R01-052:0.853056081,R01-054:0.890185037,R01-055:0.721171698,R01-056:0.646278311,R01-057:0.819531018,R01-060:0.755168662,R01-061:0.831325301,R01-062:0.621616202,R01-063:0.887817849,R01-064:0.503693754,R01-065:0.900957261,R01-066:0.863084304,R01-067:0.793478908,R01-068:0.706467662,R01-069:0.652887756,R01-070:0.156561781,R01-071:0.794301598,R01-072:0.71873941,R01-073:0.656626506,R01-074:0.686797136,R01-075:0.769153952,R01-076:0.658746901,R01-077:0.515673556,R01-078:0.805609871,R01-079:0.768960982,R01-080:0.465984655,R01-082:0.764202063,R01-083:0.420652174,R01-084:0.679731288,R01-085:0.768992248,R01-086:0.493431042,R01-087:0.488001239,R01-088:0.593974567,R01-089:0.933253651,R01-090:0.891955114,R01-091:0.726296959,R01-092:0.557369092,R01-093:0.827921054,R01-094:0.809129332,R01-095:0.713630679,R01-096:0.728150443,R01-097:0.445709849,R01-099:0.786909219,R01-101:0.826549971,R01-103:0.818544249,R01-105:0.800283429,R01-107:0.77209806,R01-109:0.526077667,R01-110:0.497560976,R01-111:0.511410894,R01-112:0.907062065,R01-113:0.44661508,R01-114:0.902224058,R01-115:0.78721174,R01-116:0.561519405,R01-117:0.570513745,R01-118:0.594700407,R01-119:0.61825917,R01-121:0.839111393,R01-122:0.519057377,R01-124:0.594308036,R01-125:0.734829593,R01-126:0.426915017,R01-127:0.191945712,R01-128:0.781319407,R01-129:0.538877476,R01-131:0.544844598,R01-132:0.557804821,R01-133:0.491422557,R01-134:0.431908166,R01-135:0.554446119,R01-136:0.407775136,R01-137:0.248216534,R01-138:0.835014493,R01-139:0.680349407,R01-140:0.858731552,R01-141:0.081384615,R01-142:0.703421009,R01-144:0.657289694,R01-145:0.787378659,R01-146:0.850732088"], "limitations": "The model might produce minor false positives but this could be easilily removed by post-processing such as constrain the tumor segmentation only in lung slices" }, "training": { From c20d4890a6e7af233822b1c9db28f4f6190e7fdb Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:17:04 -0400 Subject: [PATCH 19/20] Update Dockerfile --- models/msk_smit_lung_gtv/dockerfiles/Dockerfile | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile index f5e2c918..be88c57f 100644 --- a/models/msk_smit_lung_gtv/dockerfiles/Dockerfile +++ b/models/msk_smit_lung_gtv/dockerfiles/Dockerfile @@ -5,13 +5,14 @@ LABEL authors="aptea@mskcc.org,deasyj@mskcc.org,iyera@mskcc.org,locastre@mskcc.o RUN apt update -#ARG MHUB_MODELS_REPO= -ENV MHUB_MODELS_REPO=https://github.com/locastre/models.git +ARG MHUB_MODELS_REPO +#ENV MHUB_MODELS_REPO=https://github.com/locastre/models.git RUN buildutils/import_mhub_model.sh msk_smit_lung_gtv ${MHUB_MODELS_REPO} -ENV WORK_DIR=/app/models/msk_smit_lung_gtv/src +#ENV WORK_DIR=/app/models/msk_smit_lung_gtv/src +ENV WORK_DIR=/app -WORKDIR ${WORK_DIR} +WORKDIR ${WORK_DIR}/msk_smit_lung_gtv/src ENV WEIGHTS_URL=https://mskcc.box.com/shared/static/sf7jic4m2dk67413cipbbq6hddvhpj61.gz ENV CONDA_URL=https://mskcc.box.com/shared/static/d580gfjzzmt26v8klwp8pivb6wafomag.gz RUN wget ${WEIGHTS_URL} -O weights.tar.gz && tar xvf weights.tar.gz && rm weights.tar.gz From 523ea46f45ac8f2961c4d6375553f9705fc3072b Mon Sep 17 00:00:00 2001 From: locastre <34068528+locastre@users.noreply.github.com> Date: Mon, 17 Mar 2025 11:19:05 -0400 Subject: [PATCH 20/20] Update SMITrunner.py --- models/msk_smit_lung_gtv/utils/SMITrunner.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/msk_smit_lung_gtv/utils/SMITrunner.py b/models/msk_smit_lung_gtv/utils/SMITrunner.py index 8acb6eaf..531a26ac 100644 --- a/models/msk_smit_lung_gtv/utils/SMITrunner.py +++ b/models/msk_smit_lung_gtv/utils/SMITrunner.py @@ -27,7 +27,7 @@ class SMITRunner(ModelRunner): @IO.Output('gtv_mask', 'gtv_mask.nii.gz', 'nifti:mod=seg:model=SMIT:roi=GTV',data='scan', the='predicted lung gtv') def task(self, instance: Instance, scan: InstanceData, gtv_mask: InstanceData) -> None: - workDir = os.environ['WORK_DIR'] # Needs to be defined in docker file as ENV WORK_DIR=path_to_dir e.g. /app/models/SMIT/workDir + workDir = os.path.join(os.environ['WORK_DIR'],'models','msk_smit_lung_gtv','src') # Needs to be defined in docker file as ENV WORK_DIR=path_to_dir e.g. /app/models/SMIT/workDir #wrapperInstallDir = os.path.join(workDir,'CT_Lung_SMIT') #condaEnvDir = os.path.join(wrapperInstallDir,'conda-pack') #condaEnvActivateScript = os.path.join(condaEnvDir, 'bin', 'activate')