Skip to content

Conversation

Callidior
Copy link

Context

This PR proposes to add a SanitizeKeyPoints transform, similar to the existing SanitizeBoundingBoxes (#7246). This transform removes keypoints lying outside of the valid image area, which can happen after geometric transformations with the new default clamping_mode proposed in #9234, which allows for disabling automatic clamping of keypoints.

This implementation follows the proposal in issue #9223 as a solution for the issue that the previous default clamping of keypoints to the image edges modifies their position and creates a misalignment with the actual locations in the transformed image.

This PR hence only makes sense in combination with new clamping modes such as the one proposed in #9234.

Implementation details

To understand the behavior of the proposed SanitizeKeyPoints transform, we need to distinguish two cases of keypoint formats:

  • tv_tensors.KeyPoints contains a set of keypoints of shape [N_points, 2] or [N_points, 1, 2]. In this case, the transform will remove all keypoints lying outside of the valid image region.
  • tv_tensors.KeyPoints contains groups of keypoints, i.e., several objects, each consisting of a certain number of keypoints (e.g., polygons, polygonal chains, skeletons etc.). It is a tensor of shape [N_objects, N_points, 2] or, in general, [N_objects, ..., 2]. In this case, the transform will remove all objects (first dimension) that have at least a certain number of keypoints lying outside of the valid image region.

The behavior of the transform can be controlled with the following arguments:

  • min_valid_edge_distance (int): The minimum distance that keypoints need to be away from the closest image edge along any axis in order to be considered valid. For example, setting this to 0 will only invalidate/remove keypoints outside of the image area, while a value of 1 will also remove keypoints lying exactly on the edge. Default is 0.
  • min_invalid_points (int or float): Minimum number or fraction of invalid keypoints required for a group of keypoints to be removed. For example, setting this to 1 will remove a group of keypoints if any of its keypoints is invalid, while setting it to 2 will only remove groups with at least 2 invalid keypoints. If a float in (0.0, 1.0] is passed, it represents a fraction of the total number of keypoints in the group. For example, setting this to 0.3 will remove groups of keypoints with at least 30% invalid keypoints. Note that a value of 1 (integer) is very different from 1.0 (float). The former will remove groups with any invalid keypoint, while the latter will only remove groups where all keypoints are invalid. Default is 1 (int).

In addition, the transform can also remove labels associated with the keypoints (or elements from any other tensors with the same first dimension as the keypoints). This can be achieved by setting the labels_getter argument, which follows the same logic as the homonymous argument of SanitizeBoundingBoxes. The only difference is, that the default for SanitizeKeyPoints is None, in order to avoid accidental conflicts with any additionally present bounding box labels.

Illustration of the changes

The following example additionally requires PR #9234.

orig_pts = KeyPoints(
    [
        [[445, 700]],  # nose
        [[320, 660]],
        [[370, 660]],
        [[420, 660]],  # left eye
        [[300, 620]],
        [[420, 620]],  # left eyebrow
        [[475, 665]],
        [[515, 665]],
        [[555, 655]],  # right eye
        [[460, 625]],
        [[560, 600]],  # right eyebrow
        [[370, 780]],
        [[450, 760]],
        [[540, 780]],
        [[450, 820]],  # mouth
    ],
    canvas_size=(orig_img.size[1], orig_img.size[0]),
    clamping_mode="soft",
)
cropper = v2.RandomCrop(size=(256, 256))
crops = [cropper((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + crops)
sanitize-keypoints-example

Unsanitized keypoint coordinates:

for _, pts in crops:
    print(pts)
KeyPoints([[[   1, -109]],
           [[-124, -149]],
           [[ -74, -149]],
           [[ -24, -149]],
           [[-144, -189]],
           [[ -24, -189]],
           [[  31, -144]],
           [[  71, -144]],
           [[ 111, -154]],
           [[  16, -184]],
           [[ 116, -209]],
           [[ -74,  -29]],
           [[   6,  -49]],
           [[  96,  -29]],
           [[   6,   11]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[ -65, -238]],
           [[-190, -278]],
           [[-140, -278]],
           [[ -90, -278]],
           [[-210, -318]],
           [[ -90, -318]],
           [[ -35, -273]],
           [[   5, -273]],
           [[  45, -283]],
           [[ -50, -313]],
           [[  50, -338]],
           [[-140, -158]],
           [[ -60, -178]],
           [[  30, -158]],
           [[ -60, -118]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[301,  27]],
           [[176, -13]],
           [[226, -13]],
           [[276, -13]],
           [[156, -53]],
           [[276, -53]],
           [[331,  -8]],
           [[371,  -8]],
           [[411, -18]],
           [[316, -48]],
           [[416, -73]],
           [[226, 107]],
           [[306,  87]],
           [[396, 107]],
           [[306, 147]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[ 372,  -27]],
           [[ 247,  -67]],
           [[ 297,  -67]],
           [[ 347,  -67]],
           [[ 227, -107]],
           [[ 347, -107]],
           [[ 402,  -62]],
           [[ 442,  -62]],
           [[ 482,  -72]],
           [[ 387, -102]],
           [[ 487, -127]],
           [[ 297,   53]],
           [[ 377,   33]],
           [[ 467,   53]],
           [[ 377,   93]]], canvas_size=(256, 256), clamping_mode=soft)

Sanitization:

sanitizer = v2.SanitizeKeyPoints()
sanitized_pts = [sanitizer(pts) for _, pts in crops]

for pts in sanitized_pts:
    print(pts)
KeyPoints([[[ 6, 11]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([], size=(0, 1, 2), dtype=torch.int64, canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([[[226, 107]]], canvas_size=(256, 256), clamping_mode=soft)
KeyPoints([], size=(0, 1, 2), dtype=torch.int64, canvas_size=(256, 256), clamping_mode=soft)

Testing

Please run the following unit tests:

pytest test/test_transforms_v2.py -vvv -k "SanitizeKeyPoints"
...
219 passed, 9718 deselected, 8 xfailed in 1.27s

Copy link

pytorch-bot bot commented Oct 7, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/vision/9235

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 1a3ac36 with merge base b851874 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

Summary:
Fixing error `TypeError: unsupported operand type(s) for |: 'type' and 'type'`
@AntoineSimoulin
Copy link
Member

AntoineSimoulin commented Oct 8, 2025

@Callidior, a huge thanks for the high quality PR. After careful consideration, I did several changes, let let know what you think and if you are aligned with the current version. Here is a list of my modifications:

  • I simplified the _get_sanitize_keypoints_mask function by removing the min_valid_edge_distance and min_invalid_points and using the default value in the code. I feel we can still add those parameters in the future. But I would like to validate first the default behavior to be sure this makes sense.
  • Besides those parameters, I kept the behavior you proposed. If we pass KeyPoints as a group (i.e. with at least 3 dimensions) we will only keep the group if all the elements are within the canvas.
  • I only simplified the _get_sanitize_keypoints_mask implementation for better clarity and avoid reshaping operations
  • I modified the tests to reflect the restricted set of arguments.

Illustration of the changes

I apply below this logic to the example used for #9236.

orig_pts = tv_tensors.KeyPoints(
    [
        [
            [445, 700],  # nose
            [320, 660],
            [370, 660],
            [420, 660],  # left eye
            [300, 620],
            [420, 620],  # left eyebrow
            [475, 665],
            [515, 665],
            [555, 655],  # right eye
            [460, 625],
            [560, 600],  # right eyebrow
            [370, 780],
            [450, 760],
            [540, 780],
            [450, 820],  # mouth
        ],
    ],
    canvas_size=(orig_img.size[1], orig_img.size[0]),
    # clamping_mode="soft"
)
cropper = v2.RandomCrop(size=(128, 128))
crops = [cropper((orig_img, orig_pts)) for _ in range(4)]
plot([(orig_img, orig_pts)] + crops)
image

If we keep the points grouped

sanitizer = v2.SanitizeKeyPoints()
sanitized_pts = [sanitizer(pts) for _, pts in crops]

for pts in sanitized_pts:
    print(pts)
# KeyPoints([], size=(0, 15, 2), dtype=torch.int64, canvas_size=(128, 128))
# KeyPoints([], size=(0, 15, 2), dtype=torch.int64, canvas_size=(128, 128))
# KeyPoints([], size=(0, 15, 2), dtype=torch.int64, canvas_size=(128, 128))
# KeyPoints([], size=(0, 15, 2), dtype=torch.int64, canvas_size=(128, 128))

If we ungroup them:

sanitizer = v2.SanitizeKeyPoints()
sanitized_pts = [sanitizer(tv_tensors.wrap(pts.squeeze(0), like=pts)) for _, pts in crops]

for pts in sanitized_pts:
    print(pts)

# KeyPoints([], size=(0, 2), dtype=torch.int64, canvas_size=(128, 128))
# KeyPoints([], size=(0, 2), dtype=torch.int64, canvas_size=(128, 128))
# KeyPoints([[  6,  76],
#            [ 56,  76],
#            [106,  76],
#            [106,  36]], canvas_size=(128, 128))
# KeyPoints([[ 16, 106],
#            [ 56, 106],
#            [ 96,  96],
#            [  1,  66],
#            [101,  41]], canvas_size=(128, 128))

Testing

pytest test/test_transforms_v2.py -vvv -k "SanitizeKeyPoints"
52 passed, 9718 deselected in 1.14s

@Callidior
Copy link
Author

@AntoineSimoulin Thanks for your help with the PR!

I would also be fine with including the reduced version of the transform you proposed. It does simplify the code a lot, indeed.

I initially included those additional arguments because we have min_size and min_area in SanitizeBoundingBoxes and I thought, why not give similar controls for keypoints. But it's probably a small fraction of use cases that could benefit from this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants