Skip to content
3 changes: 3 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,7 @@ Miscellaneous
v2.RandomErasing
v2.Lambda
v2.SanitizeBoundingBoxes
v2.SanitizeKeyPoints
v2.ClampBoundingBoxes
v2.ClampKeyPoints
v2.UniformTemporalSubsample
Expand All @@ -427,6 +428,7 @@ Functionals
v2.functional.normalize
v2.functional.erase
v2.functional.sanitize_bounding_boxes
v2.functional.sanitize_keypoints
v2.functional.clamp_bounding_boxes
v2.functional.clamp_keypoints
v2.functional.uniform_temporal_subsample
Expand Down Expand Up @@ -530,6 +532,7 @@ Developer tools
v2.query_size
v2.query_chw
v2.get_bounding_boxes
v2.get_keypoints


V1 API Reference
Expand Down
320 changes: 320 additions & 0 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7397,6 +7397,326 @@ def test_errors_functional(self):
F.sanitize_bounding_boxes(good_bbox.tolist())


class TestSanitizeKeyPoints:
def _make_keypoints_with_validity(
self,
canvas_size=(100, 100),
shape="2d", # "2d", "3d", "4d" for different keypoint shapes
):
"""Create keypoints with known validity for testing."""
canvas_h, canvas_w = canvas_size

if shape == "2d": # [N_points, 2]
keypoints_data = [
([5, 5], True), # Valid point inside image
([canvas_w - 6, canvas_h - 6], True), # Valid point near corner
([canvas_w // 2, canvas_h // 2], True), # Valid point in center
([-1, canvas_h // 2], False), # Invalid: x < 0
([canvas_w // 2, -1], False), # Invalid: y < 0
([canvas_w, canvas_h // 2], False), # Invalid: x >= canvas_w
([canvas_w // 2, canvas_h], False), # Invalid: y >= canvas_h
([0, 0], True), # Edge case: exactly on edge
([canvas_w - 1, canvas_h - 1], True), # Edge case: exactly on edge
]
points, validity = zip(*keypoints_data)
keypoints = torch.tensor(points, dtype=torch.float32)

elif shape == "3d": # [N_objects, N_points, 2]
# Create groups of keypoints with different validity patterns
keypoints_data = [
# Group 1: All points valid
([[10, 10], [20, 20], [30, 30]], True),
# Group 2: One invalid point (should be removed if min_invalid_points=1)
([[10, 10], [20, 20], [-5, 30]], False),
# Group 3: All points invalid
([[-1, -1], [-2, -2], [-3, -3]], False),
# Group 4: Mix of valid and invalid (depends on min_invalid_points)
([[10, 10], [-1, 20], [-2, 30]], False),
]
groups, validity = zip(*keypoints_data)
keypoints = torch.tensor(groups, dtype=torch.float32)

elif shape == "4d": # [N_objects, N_bones, 2, 2]
# Create bone-like structures (pairs of points)
keypoints_data = [
# Object 1: All bones valid
([[[10, 10], [15, 15]], [[20, 20], [25, 25]]], True),
# Object 2: One bone with invalid point
([[[10, 10], [15, 15]], [[-1, 20], [25, 25]]], False),
# Object 3: All bones invalid
([[[-1, -1], [-2, -2]], [[-3, -3], [-4, -4]]], False),
]
objects, validity = zip(*keypoints_data)
keypoints = torch.tensor(objects, dtype=torch.float32)

else:
raise ValueError(f"Unsupported shape: {shape}")

return keypoints, validity

@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
@pytest.mark.parametrize("input_type", [torch.Tensor, tv_tensors.KeyPoints])
def test_functional(self, shape, input_type):
"""Test the sanitize_keypoints functional interface."""

# Create inputs
canvas_size = (50, 50)
keypoints, expected_validity = self._make_keypoints_with_validity(
canvas_size=canvas_size,
shape=shape,
)

if input_type is tv_tensors.KeyPoints:
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
canvas_size_arg = None
else:
canvas_size_arg = canvas_size

# Apply function to be tested
result_keypoints, valid_mask = F.sanitize_keypoints(
keypoints,
canvas_size=canvas_size_arg,
)

# Check return types
assert isinstance(result_keypoints, input_type)
assert isinstance(valid_mask, torch.Tensor)
assert valid_mask.dtype == torch.bool

# Check that valid mask matches expected validity
assert_equal(valid_mask, torch.tensor(expected_validity))

# Check that result has correct number of valid keypoints
assert result_keypoints.shape[0] == valid_mask.sum().item()

# Check that remaining keypoints shape is preserved
assert result_keypoints.shape[1:] == keypoints.shape[1:]

@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
def test_kernel(self, shape):
"""Test kernel functionality."""
canvas_size = (30, 30)
keypoints, _ = self._make_keypoints_with_validity(canvas_size=canvas_size, shape=shape)

check_kernel(
F.sanitize_keypoints,
input=keypoints,
canvas_size=canvas_size,
check_batched_vs_unbatched=False, # This function doesn't support batching
)

@pytest.mark.parametrize("shape", ["2d", "3d", "4d"])
@pytest.mark.parametrize(
"labels_getter",
(
"default",
lambda inputs: inputs["labels"],
lambda inputs: (inputs["labels"], inputs["other_labels"]),
lambda inputs: [inputs["labels"], inputs["other_labels"]],
None,
lambda inputs: None,
),
)
@pytest.mark.parametrize("sample_type", (tuple, dict))
def test_transform(self, shape, labels_getter, sample_type):
"""Test the SanitizeKeyPoints transform class."""
if sample_type is tuple and not isinstance(labels_getter, str):
# Lambda-based labels_getter doesn't work with tuple input
return

canvas_size = (40, 40)
keypoints, expected_validity = self._make_keypoints_with_validity(
canvas_size=canvas_size,
shape=shape,
)

keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)
num_keypoints = keypoints.shape[0]

# Create associated labels and other data
labels = torch.arange(num_keypoints)
other_labels = torch.arange(num_keypoints) * 2
masks = tv_tensors.Mask(torch.randint(0, 2, size=(num_keypoints, *canvas_size)))
whatever = torch.rand(10)
input_img = torch.randint(0, 256, size=(1, 3, *canvas_size), dtype=torch.uint8)

sample = {
"image": input_img,
"labels": labels,
"keypoints": keypoints,
"other_labels": other_labels,
"whatever": whatever,
"None": None,
"masks": masks,
}

if sample_type is tuple:
img = sample.pop("image")
sample = (img, sample)

# Apply transform
transform = transforms.SanitizeKeyPoints(
labels_getter=labels_getter,
)
out = transform(sample)

# Extract outputs
if sample_type is tuple:
out_image = out[0]
out_labels = out[1]["labels"]
out_other_labels = out[1]["other_labels"]
out_keypoints = out[1]["keypoints"]
out_masks = out[1]["masks"]
out_whatever = out[1]["whatever"]
else:
out_image = out["image"]
out_labels = out["labels"]
out_other_labels = out["other_labels"]
out_keypoints = out["keypoints"]
out_masks = out["masks"]
out_whatever = out["whatever"]

# Verify unchanged elements
assert_equal(out_image, input_img)
assert_equal(out_whatever, whatever)
assert_equal(out_masks, masks)

# Verify types
assert isinstance(out_keypoints, tv_tensors.KeyPoints)
assert isinstance(out_masks, tv_tensors.Mask)

# Calculate expected valid indices
valid_indices = [i for i, is_valid in enumerate(expected_validity) if is_valid]

# Test label handling
if labels_getter is None or (callable(labels_getter) and labels_getter(sample) is None):
# Labels should be unchanged
assert out_labels is labels
assert out_other_labels is other_labels
else:
# Labels should be filtered
assert isinstance(out_labels, torch.Tensor)
assert out_keypoints.shape[0] == out_labels.shape[0]
assert out_labels.tolist() == valid_indices

if callable(labels_getter) and isinstance(labels_getter(sample), (tuple, list)):
# other_labels should also be filtered
assert_equal(out_other_labels, out_labels * 2) # Since other_labels = labels * 2
else:
# other_labels and masks should be unchanged
assert_equal(out_other_labels, other_labels)

def test_edge_cases(self):
"""Test edge cases and boundary conditions."""
canvas_size = (10, 10)

# Test empty keypoints
empty_keypoints = tv_tensors.KeyPoints(torch.empty(0, 2), canvas_size=canvas_size)
result, valid_mask = F.sanitize_keypoints(empty_keypoints)
print(empty_keypoints, result, valid_mask)
assert tuple(result.shape) == (0, 2)
assert valid_mask.shape[0] == 0

# Test single valid keypoint
single_valid = tv_tensors.KeyPoints([[5, 5]], canvas_size=canvas_size)
result, valid_mask = F.sanitize_keypoints(single_valid)
assert tuple(result.shape) == (1, 2)
assert valid_mask.all()

# Test single invalid keypoint
single_invalid = tv_tensors.KeyPoints([[-1, -1]], canvas_size=canvas_size)
result, valid_mask = F.sanitize_keypoints(single_invalid)
assert tuple(result.shape) == (0, 2)
assert not valid_mask.any()

def test_errors_functional(self):
"""Test error conditions for the functional interface."""
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))

# Test missing canvas_size for pure tensor
with pytest.raises(ValueError, match="canvas_size cannot be None"):
F.sanitize_keypoints(good_keypoints.as_subclass(torch.Tensor), canvas_size=None)

# Test canvas_size provided for tv_tensor
with pytest.raises(ValueError, match="canvas_size must be None"):
F.sanitize_keypoints(good_keypoints, canvas_size=(10, 10))

def test_errors_transform(self):
"""Test error conditions for the transform class."""
good_keypoints = tv_tensors.KeyPoints([[5, 5]], canvas_size=(10, 10))

# Test invalid labels_getter
with pytest.raises(ValueError, match="labels_getter should either be"):
transforms.SanitizeKeyPoints(labels_getter="invalid_type") # type: ignore

# Test missing labels key
with pytest.raises(ValueError, match="Could not infer where the labels are"):
bad_sample = {"keypoints": good_keypoints, "BAD_KEY": torch.tensor([0])}
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)

# Test labels not a tensor
with pytest.raises(ValueError, match="must be a tensor"):
bad_sample = {"keypoints": good_keypoints, "labels": [0]}
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)

# Test mismatched sizes
with pytest.raises(ValueError, match="Number of"):
bad_sample = {"keypoints": good_keypoints, "labels": torch.tensor([0, 1, 2])}
transforms.SanitizeKeyPoints(labels_getter="default")(bad_sample)

def test_no_label(self):
"""Test transform without labels."""
img = make_image()
keypoints = make_keypoints()

# Should raise error without labels_getter=None
with pytest.raises(ValueError, match="or a two-tuple whose second item is a dict"):
transforms.SanitizeKeyPoints(labels_getter="default")(img, keypoints)

# Should work with labels_getter=None
out_img, out_keypoints = transforms.SanitizeKeyPoints(labels_getter=None)(img, keypoints)
assert isinstance(out_img, tv_tensors.Image)
assert isinstance(out_keypoints, tv_tensors.KeyPoints)

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_device_and_dtype_consistency(self, device):
"""Test that device and dtype are preserved."""
canvas_size = (20, 20)
keypoints = torch.tensor([[5, 5], [15, 15], [-1, -1]], dtype=torch.float32, device=device)
keypoints = tv_tensors.KeyPoints(keypoints, canvas_size=canvas_size)

result, valid_mask = F.sanitize_keypoints(keypoints)

assert result.device == keypoints.device
assert result.dtype == keypoints.dtype
assert valid_mask.device == keypoints.device

def test_keypoint_shapes_consistency(self):
"""Test that different keypoint shapes are handled correctly."""
canvas_size = (50, 50)

# Test 2D shape [N_points, 2]
kp_2d = torch.tensor([[10, 10], [20, 20], [-1, -1]], dtype=torch.float32)
kp_2d = tv_tensors.KeyPoints(kp_2d, canvas_size=canvas_size)
result_2d, valid_2d = F.sanitize_keypoints(kp_2d)
assert result_2d.ndim == 2
assert result_2d.shape[1:] == kp_2d.shape[1:]

# Test 3D shape [N_objects, N_points, 2]
kp_3d = torch.tensor([[[10, 10], [20, 20]], [[-1, -1], [30, 30]]], dtype=torch.float32)
kp_3d = tv_tensors.KeyPoints(kp_3d, canvas_size=canvas_size)
result_3d, valid_3d = F.sanitize_keypoints(kp_3d)
assert result_3d.ndim == 3
assert result_3d.shape[1:] == kp_3d.shape[1:]

# Test 4D shape [N_objects, N_bones, 2, 2]
kp_4d = torch.tensor([[[[10, 10], [20, 20]]], [[[-1, -1], [30, 30]]]], dtype=torch.float32)
kp_4d = tv_tensors.KeyPoints(kp_4d, canvas_size=canvas_size)
result_4d, valid_4d = F.sanitize_keypoints(kp_4d)
assert result_4d.ndim == 4
assert result_4d.shape[1:] == kp_4d.shape[1:]


class TestJPEG:
@pytest.mark.parametrize("quality", [5, 75])
@pytest.mark.parametrize("color_space", ["RGB", "GRAY"])
Expand Down
3 changes: 2 additions & 1 deletion torchvision/transforms/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@
LinearTransformation,
Normalize,
SanitizeBoundingBoxes,
SanitizeKeyPoints,
ToDtype,
)
from ._temporal import UniformTemporalSubsample
from ._type_conversion import PILToTensor, ToImage, ToPILImage, ToPureTensor
from ._utils import check_type, get_bounding_boxes, has_all, has_any, query_chw, query_size
from ._utils import check_type, get_bounding_boxes, get_keypoints, has_all, has_any, query_chw, query_size

from ._deprecated import ToTensor # usort: skip
Loading