Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CropOrPadAtCenter transform class #1233

Open
themantalope opened this issue Nov 5, 2024 · 4 comments
Open

Add CropOrPadAtCenter transform class #1233

themantalope opened this issue Nov 5, 2024 · 4 comments
Labels
enhancement New feature or request

Comments

@themantalope
Copy link

🚀 Feature
A crop or pad function that allows the user to crop/pad an image but specify where the center of the new image should be based on the input image. For example, let's say I have an organ centered at location [45,62,101] in a volume that is [368, 512, 128] in size and I want my new image to be [128,128,128] in size, and the center of the new image (i.e. [64,64,64]) to map to the point [45,62,101] in the original image.

Motivation

Developing datasets with tumors, need to crop and pad around the tumors.

Pitch

Described above. Proposed code below.

Alternatives

N/A

Additional context
Here is my proposed code. I've tested it locally but probably need more robust testing.

import torchio as tio

class CropOrPadAtCenter(tio.Transform):
    
    def __init__(self, center , target_shape, image_or_mask_name=None, **kwargs):
        super().__init__(**kwargs)
        self.target_shape = target_shape
        self.center = center
        self.image_or_mask_name = image_or_mask_name
        

    def apply_transform(self, subject):
        # if image or mask name is None, use the first image as the base image to work with
        if self.image_or_mask_name is None:
            images = list(subject.keys())
            base_image = images[0]
            self.image_or_mask_name = base_image
        
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        
        # first assert that the center is in the image
        assert all([c >= 0 and c < s for c, s in zip(self.center, image.shape[-3:])])
        # next determine if we need to pad. if the bounds of the target shape are outside the image, we need to pad
        
        # compute how many pixels we need to pad for each dimension 
        pad = []
        for c, s, t in zip(self.center, non_channel_im_shape, self.target_shape):
            
            lower = 0
            upper = 0
            
            if c - t//2 < 0:
                lower = abs(c - t//2)
            
            if c + t//2 > s:
                upper = c + t//2 - s

            pad.extend([lower, upper])
        # pad the image
        p = tuple(pad)
        pad_xform = tio.Pad(p)
        subject = pad_xform(subject)

        # now crop the image
        # the crop function expects the start and dim_size - end of the crop (weird)
        image = subject[self.image_or_mask_name]
        non_channel_im_shape = image.shape[-3:]
        lower_bound_pads = [p for i, p in enumerate(pad) if i % 2 == 0]
        new_center = [c + l for c, l in zip(self.center, lower_bound_pads)]
        self.center = new_center
        crop = []
        width = self.target_shape[0]//2
        height = self.target_shape[1]//2
        depth = self.target_shape[2]//2
        im_width, im_height, im_depth = non_channel_im_shape
        for d, s, c in zip([width, height, depth], [im_width, im_height, im_depth], [self.center[0], self.center[1], self.center[2]]):
            start = c - d
            end = s - (c+d)
            crop.extend([start, end])

        ct = tuple(crop)
        crop_xform = tio.Crop(ct)
        subject = crop_xform(subject)
        return subject

        
@themantalope themantalope added the enhancement New feature or request label Nov 5, 2024
@romainVala
Copy link
Contributor

Hi
thanks for sharing.
Actually the same behavior can be achieve with torchio LabelSampler

here is a code exemple:

import torchio as tio, torch

center_patch = [132,100,170]
patch_size = 128
around_center = 0

suj = tio.datasets.Colin27()
label_proba =  torch.zeros_like(suj.t1.data)
label_proba[0,center_patch[0]-around_center:center_patch[0]+around_center+1,center_patch[1]-around_center:center_patch[1]+around_center+1,center_patch[2]-around_center:center_patch[2]+around_center +1] = 1

img = tio.LabelMap(tensor = label_proba, affine=suj.t1.affine)

suj.add_image(img, 'label_proba')

#Pad to handel worste case cenario where center is at the border of the image (so extend with patch_size/2 + around_center//2 )

t_pad = tio.Pad(patch_size//2 + around_center//2)
suj = t_pad(suj)

lab_s = tio.LabelSampler(128,'label_proba',{0:0,1:1})
generator = lab_s(suj, num_patches=10)

for patch in generator:
    locations = patch[tio.LOCATION]
    print(f'loc {locations} ')
#note here that locations are 3 indices start and 3 indice end in the padded suj indexes (not the center patch)
# so it is different from the original center_patch because we pad the volume, but the results is as expected

Does will suite your needs ?

@themantalope
Copy link
Author

I see, yes that would work for what I'm looking for. May need some additional code to ensure that I'm getting all voxels in the sampled image ,but this is basically what I want.

@romainVala
Copy link
Contributor

what do you mean by all voxel ? You will get only voxel that are less than patch_size/2 away from the center

@themantalope
Copy link
Author

to ensure that the sampled patch contains all positive voxels in the label.

i guess the difference here is that i'm specifying the center so there is more control over where the sample is getting drawn from, which is important for my use case

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants