- 
                Notifications
    You must be signed in to change notification settings 
- Fork 1.3k
Developer Guide Lazy Transforms
Transforms are refactored into multiple layers:
- functional
- array-based a. deterministic b. random
- dictionary-based a. deterministic b. random
Functional transforms are the base of any transform implementation. They are stateless implementations of the actual transform operation to be carried out. They are all capable of operating either in immediate mode (the operation is defined and then immediately applied) or in lazy mode (the operation is added to the metatensor pending list).
Functional transforms have the following pattern:
def functional_operation(
    img: torch.Tensor,
    ..., # operation specific parameters
    shape_override: Optional[Sequence[int]] = None,
    lazy_evaluation: Optional[bool] = True
):
    img_ = convert_to_tensor(img, track_meta=get_track_meta())
    # the effective shape of the image can differ from the actual current shape,
    # when an image has one or more pending transforms to be applied. Transforms
    # typically need the shape that the image will have at the point this transform
    # is carried out rather than the shape of the image at the point this transform
    # is defined
    input_shape = img_.shape if shape_override is None else shape_override
    # this is typically needed to fully specify the transform
    input_ndim = len(input_shape) - 1
    transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)
    # this might be needed if the transform is known to change the shape of the
    # resulting image
    im_extents = extents_from_shape(input_shape)
    im_extents = [transform @ e for e in im_extents]
    output_shape = shape_from_extents(input_shape, im_extents)
    # everything required to specify the transform at the point that it is applied
    # note that shape_override should always be set as this is how chains of lazy
    # transforms pass the correct image shape on to the next transform
    metadata = {
        ...,
        "shape_override": output_shape
    }
    # either apply the operation immediately or just append it to the pending list
    return lazily_apply_op(img_, MetaMatrix(transform, metadata), lazy_evaluation)lazily_apply_op is defined as follows:
def lazily_apply_op(
        tensor, op, lazy_evaluation
) -> Union[MetaTensor, Tuple[torch.Tensor, Optional[MetaMatrix]]]:
    if isinstance(tensor, MetaTensor):
        tensor.push_pending_operation(op)
        if lazy_evaluation is False:
            result = apply(tensor)
            return result
        else:
            return tensor
    else:
        if lazy_evaluation is False:
            result = apply(tensor, [op])
            return result, None
        else:
            return tensor, opAs a rule, functional transform metadata should include: . the parameters specific to the transform (e.g. angles for rotation) . parameters influencing the operation, such as mode, padding_mode, etc. . shape_override, if the resulting shape differs from the shape passed in
Note: This section is a discussion of a design option. It is not currently the plan to implement this.
Instead of passing the overridden shape from functional transform to functional transform, it is possible to make all of the functional transforms functors that can be called at the point that the transform is actually applied. A functor transform would have the following implementation:
def functional_operation_functor(
    img: torch.Tensor,
    ..., # operation specific parameters
    lazy_evaluation: Optional[bool] = True
):
    def _inner(inner_img):
        img_ = convert_to_tensor(inner_img, track_meta=get_track_meta())
        # the effective shape of the image can differ from the actual current shape,
        # when an image has one or more pending transforms to be applied. Transforms
        # typically need the shape that the image will have at the point this transform
        # is carried out rather than the shape of the image at the point this transform
        # is defined
        input_shape = img_.shape if shape_override is None else shape_override
        # this is typically needed to fully specify the transform
        input_ndim = len(input_shape) - 1
        transform = get_a_specific_homogenous_matrix_or_grid_describing_the_operation(...)
        # this might be needed if the transform is known to change the shape of the
        # resulting image
        im_extents = extents_from_shape(input_shape)
        im_extents = [transform @ e for e in im_extents]
        shape_override_ = shape_from_extents(input_shape, im_extents)
        # everything required to specify the transform at the point that it is applied
        metadata = {
            ...
        }
    return lazily_apply_op(img, _inner, lazily_apply_op)class ADeterministicArrayTransform(InvertibleTransform, LazyTransform):
    backend = [TransformBackends.TORCH]
    def __init__(
            self,
            ..., # transform-specific arguments
            lazy_evaluation: Optional[bool] = False
    ):
        LazyTransform.__init__(self, lazy_evaluation)
        # set member variables for transform-specific arguments
        ...
    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific arguments
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:
        # determine transform-specific parameters to pass to function
        ...
        shape_override_ = shape_override
        if (shape_override_ is None and isinstance(img, MetaTensor) and
            img.has_pending_transforms):
            tx = img.peek_pending_transform()
            shape_override_ = tx.metadata.get("shape_override", None)
        img_t, _ = rotate(img, ..., shape_override_)
        return img_t
    def inverse(self, data):
        raise NotImplementedError()Random array transforms wrap the deterministic version of the transform:
class ARandomArrayTransform(RandomizableTransform, InvertibleTransform, LazyTrait):
    def __init__(
            self,
            ..., # transform-specific args
            lazy_evaluation: Optional[bool] = True
    ):
        RandomizableTransform.__init__(self, prob)
        self.op = AnArrayTransform(...)
        self.random_params = 0 # some default value
    def randomize(self, data: Optional[Any] = None) -> None:
        super().randomize(None)
        if self._do_transform:
            self.random_params = self.R.some_random_parameterized_value()
        else:
            self.random_params = 0 # the default value again
    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific args
            randomize: Optional[bool] = True,
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:
        if randomize:
            self.randomize(data=img)
        params = self.random_params
        return self.op(img, self.random_params, ..., shape_override)
    @property
    def lazy_evaluation(self):
        return self.op.lazy_evaluation
    @lazy_evaluation.setter
    def lazy_evaluation(self, value):
        self.op.lazy_evaluation = value
    def inverse(
            self,
            data: NdarrayOrTensor,
    ):
        raise NotImplementedError()class ARandomArrayTransform(InvertibleTransform, LazyTrait, RandomizableTrait):
    def __init__(
            self,
            ..., # transform-specific args
            lazy_evaluation: Optional[bool] = True
    ):
        self.randomizer = ARandomizer(..., prob)
        self.op = ADeterministicArrayTransform(0, ..., lazy_evaluation)
    def __call__(
            self,
            img: NdarrayOrTensor,
            ..., # call-time transform-specific args
            shape_override: Optional[Sequence] = None
    ) -> NdarrayOrTensor:
        angles = self.randomizer.sample(img)
        # TODO: the random transforms have been implemented to make use of Array ops,
        # which creates a problem if the operation name for "RandRotate" needs
        # to be "RandRotate" instead of "Rotate". This can be done via several
        # approaches:
        # 1. Use the functional op directly
        # 2. Pass an override to the array op for the name
        return self.op(img, angles, mode, padding_mode, align_corners, shape_override)
    @property
    def lazy_evaluation(self):
        return self.op.lazy_evaluation
    @lazy_evaluation.setter
    def lazy_evaluation(self, value):
        self.op.lazy_evaluation = value
    def inverse(
            self,
            data: NdarrayOrTensor,
    ):
        raise NotImplementedError()class ADeterministicDictionaryTransform(MapTransform, InvertibleTransform, LazyTrait):
    backend = ADeterministicDictionaryTransform.backend
    def __init__(
        self,
        keys: KeysCollection,
        lazy_evaluation: Optional[bool] = True
        ...,
        allow_missing_keys: bool = False,
    ) -> None:
        super().__init__(keys, allow_missing_keys)
        # operation-specific member variables
        ...
        self.op = ADeterministicArrayTransform(..., lazy_evaluation=lazy_evaluation)
    def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
        rd = dict(data)
        for key, mode, padding_mode, align_corners, dtype in self.key_iterator(
            rd, self.mode, self.padding_mode, self.align_corners, self.dtype
        ):
            rd[key] = self.op(rd[key], ...)
        return rd
    def inverse(self, data: Mapping[Hashable, torch.Tensor]) -> Dict[Hashable, torch.Tensor]:
        d = dict(data)
        for key in self.key_iterator(d):
            d[key] = self.op.inverse(d[key])
        return dclass ARandomDictionaryTransform(MapTransform, InvertibleTransform,
                                 LazyTrait, RandomizableTrait):
    def __init__(
            self,
            keys: KeysCollection,
            ..., # transform-specific args
            allow_missing_keys: Optional[bool] = False,
            lazy_evaluation: Optional[bool] = True,
    ):
        self.keys = keys
        self.allow_missing_keys = allow_missing_keys
        self.op = RandRotate2(
            range_x, range_y, range_z, prob,
            keep_size, mode, padding_mode, align_corners, dtype, lazy_evaluation
        )
    def __call__(self, data: Mapping[Hashable, torch.Tensor]):
        rd = dict(data)
        first_key = self.first_key(rd)
        if first_key == ():
            out = convert_to_tensor(rd, track_meta=get_track_meta())
            return out
        self.op.randomize(rd[first_key])
        it = self.key_iterator(rd, self.mode, self.padding_mode, self.align_corners)
        for key, mode, padding_mode, align_corners in it:
            rd[key] = self.op(rd[key], mode=mode, padding_mode=padding_mode,
                              align_corners=align_corners, randomize=False)
        return rd
    def inverse(self, data):
        raise NotImplementedError()