Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions torchvision/tv_tensors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TypeVar

import torch

from ._bounding_boxes import BoundingBoxes, BoundingBoxFormat, is_rotated_bounding_format
Expand All @@ -6,34 +8,30 @@
from ._mask import Mask
from ._torch_function_helpers import set_return_type
from ._tv_tensor import TVTensor
from torchvision.tv_tensors._tv_tensor import TVTensor

from ._video import Video


TVTensorType = TypeVar("TVTensorType", bound=TVTensor)


# TODO: Fix this. We skip this method as it leads to
# RecursionError: maximum recursion depth exceeded while calling a Python object
# Until `disable` is removed, there will be graph breaks after all calls to functional transforms
@torch.compiler.disable
def wrap(wrappee, *, like, **kwargs):
def wrap(wrappee: torch.Tensor, *, like: TVTensorType, **kwargs) -> TVTensorType:
"""Convert a :class:`torch.Tensor` (``wrappee``) into the same :class:`~torchvision.tv_tensors.TVTensor` subclass as ``like``.

If ``like`` is a :class:`~torchvision.tv_tensors.BoundingBoxes`, the ``format`` and ``canvas_size`` of
``like`` are assigned to ``wrappee``, unless they are passed as ``kwargs``.

Args:
wrappee (Tensor): The tensor to convert.
like (:class:`~torchvision.tv_tensors.TVTensor`): The reference.
``wrappee`` will be converted into the same subclass as ``like``.
kwargs: Can contain "format", "canvas_size" and "clamping_mode" if ``like`` is a :class:`~torchvision.tv_tensor.BoundingBoxes`.
Ignored otherwise.
``wrappee`` will be converted into the same subclass as ``like``
maintaining the same metadata as ``like``.
kwargs: Optional overrides for metadata. For BoundingBoxes: ``format``, ``canvas_size``, ``clamping_mode``.
For KeyPoints: ``canvas_size``.
"""
if isinstance(like, BoundingBoxes):
return type(like)._wrap(
wrappee,
format=kwargs.get("format", like.format),
canvas_size=kwargs.get("canvas_size", like.canvas_size),
clamping_mode=kwargs.get("clamping_mode", like.clamping_mode),
)
elif isinstance(like, KeyPoints):
return type(like)._wrap(wrappee, canvas_size=kwargs.get("canvas_size", like.canvas_size))
else:
return wrappee.as_subclass(type(like))
if not hasattr(like, "__wrap__"):
raise TypeError(f"Expected `like` to have a `__wrap__` method, but got {type(like)}")

return like.__wrap__(wrappee, **kwargs)
29 changes: 19 additions & 10 deletions torchvision/tv_tensors/_bounding_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,23 @@ def _wrap(cls, tensor: torch.Tensor, *, format: BoundingBoxFormat | str, canvas_
bounding_boxes.clamping_mode = clamping_mode
return bounding_boxes

def __wrap__(
self,
tensor: torch.Tensor,
*,
format: BoundingBoxFormat | str | None = None,
canvas_size: tuple[int, int] | None = None,
clamping_mode: CLAMPING_MODE_TYPE = None,
check_dims: bool | None = None,
) -> BoundingBoxes:
return BoundingBoxes._wrap(
tensor,
format=format if format is not None else self.format,
canvas_size=canvas_size if canvas_size is not None else self.canvas_size,
clamping_mode=clamping_mode if clamping_mode is not None else self.clamping_mode,
check_dims=False,
)

def __new__(
cls,
data: Any,
Expand Down Expand Up @@ -153,17 +170,9 @@ def _wrap_output(
)

if isinstance(output, torch.Tensor) and not isinstance(output, BoundingBoxes):
output = BoundingBoxes._wrap(
output, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False
)
output = first_bbox_from_args.__wrap__(output)
elif isinstance(output, (tuple, list)):
# This branch exists for chunk() and unbind()
output = type(output)(
BoundingBoxes._wrap(
part, format=format, canvas_size=canvas_size, clamping_mode=clamping_mode, check_dims=False
)
for part in output
)
output = type(output)(first_bbox_from_args.__wrap__(part) for part in output)
return output

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
Expand Down
13 changes: 9 additions & 4 deletions torchvision/tv_tensors/_keypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,13 @@ def _wrap(cls, tensor: torch.Tensor, *, canvas_size: tuple[int, int], check_dims
points.canvas_size = canvas_size
return points

def __wrap__(self, tensor: torch.Tensor, *, canvas_size: tuple[int, int] | None = None) -> KeyPoints:
return KeyPoints._wrap(
tensor,
canvas_size=canvas_size if canvas_size is not None else self.canvas_size,
check_dims=False,
)

def __new__(
cls,
data: Any,
Expand All @@ -89,13 +96,11 @@ def _wrap_output(
# Similar to BoundingBoxes._wrap_output(), see comment there.
flat_params, _ = tree_flatten(args + (tuple(kwargs.values()) if kwargs else ())) # type: ignore[operator]
first_keypoints_from_args = next(x for x in flat_params if isinstance(x, KeyPoints))
canvas_size = first_keypoints_from_args.canvas_size

if isinstance(output, torch.Tensor) and not isinstance(output, KeyPoints):
output = KeyPoints._wrap(output, canvas_size=canvas_size, check_dims=False)
output = first_keypoints_from_args.__wrap__(output)
elif isinstance(output, (tuple, list)):
# This branch exists for chunk() and unbind()
output = type(output)(KeyPoints._wrap(part, canvas_size=canvas_size, check_dims=False) for part in output)
output = type(output)(first_keypoints_from_args.__wrap__(part) for part in output)
return output

def __repr__(self, *, tensor_contents: Any = None) -> str: # type: ignore[override]
Expand Down
4 changes: 4 additions & 0 deletions torchvision/tv_tensors/_tv_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from torch.types import _device, _dtype, _size

from torchvision.tv_tensors._torch_function_helpers import _FORCE_TORCHFUNCTION_SUBCLASS, _must_return_subclass
from typing_extensions import Self


D = TypeVar("D", bound="TVTensor")
Expand Down Expand Up @@ -49,6 +50,9 @@ def _wrap_output(
output = type(output)(cls._wrap_output(part, args, kwargs) for part in output)
return output

def __wrap__(self, tensor: torch.Tensor) -> Self:
return tensor.as_subclass(type(self))

@classmethod
def __torch_function__(
cls,
Expand Down
Loading