Skip to content

Commit b851874

Browse files
remove clamp_keypoints from transforms (#9236)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 7a13ad0 commit b851874

File tree

2 files changed

+13
-20
lines changed

2 files changed

+13
-20
lines changed

test/test_transforms_v2.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ def affine_rotated_bounding_boxes(bounding_boxes):
631631
)
632632

633633

634-
def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, clamp=True):
634+
def reference_affine_keypoints_helper(keypoints, *, affine_matrix, new_canvas_size=None, cast=True):
635635
canvas_size = new_canvas_size or keypoints.canvas_size
636636

637637
def affine_keypoints(keypoints):
@@ -650,10 +650,7 @@ def affine_keypoints(keypoints):
650650
float(transformed_points[0, 1]),
651651
]
652652
)
653-
654-
if clamp:
655-
output = F.clamp_keypoints(output, canvas_size=canvas_size)
656-
else:
653+
if not cast:
657654
dtype = output.dtype
658655

659656
return output.to(dtype=dtype, device=device)
@@ -2293,10 +2290,10 @@ def _reference_rotate_keypoints(self, keypoints, *, angle, expand, center):
22932290
keypoints,
22942291
affine_matrix=affine_matrix,
22952292
new_canvas_size=new_canvas_size,
2296-
clamp=False,
2293+
cast=False,
22972294
)
22982295

2299-
return F.clamp_keypoints(self._recenter_keypoints_after_expand(output, recenter_xy=recenter_xy)).to(keypoints)
2296+
return self._recenter_keypoints_after_expand(output, recenter_xy=recenter_xy).to(keypoints)
23002297

23012298
@pytest.mark.parametrize("angle", _CORRECTNESS_AFFINE_KWARGS["angle"])
23022299
@pytest.mark.parametrize("expand", [False, True])
@@ -5360,11 +5357,7 @@ def perspective_keypoints(keypoints):
53605357
]
53615358
)
53625359

5363-
# It is important to clamp before casting, especially for CXCYWH format, dtype=int64
5364-
return F.clamp_keypoints(
5365-
output,
5366-
canvas_size=canvas_size,
5367-
).to(dtype=dtype, device=device)
5360+
return output.to(dtype=dtype, device=device)
53685361

53695362
return tv_tensors.KeyPoints(
53705363
torch.cat([perspective_keypoints(k) for k in keypoints.reshape(-1, 2).unbind()], dim=0).reshape(

torchvision/transforms/v2/functional/_geometry.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
from torchvision.utils import _log_api_usage_once
2626

27-
from ._meta import _get_size_image_pil, clamp_bounding_boxes, clamp_keypoints, convert_bounding_box_format
27+
from ._meta import _get_size_image_pil, clamp_bounding_boxes, convert_bounding_box_format
2828

2929
from ._utils import _FillTypeJIT, _get_kernel, _register_five_ten_crop_kernel_internal, _register_kernel_internal
3030

@@ -71,7 +71,7 @@ def horizontal_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, i
7171
shape = keypoints.shape
7272
keypoints = keypoints.clone().reshape(-1, 2)
7373
keypoints[..., 0] = keypoints[..., 0].sub_(canvas_size[1] - 1).neg_()
74-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
74+
return keypoints.reshape(shape)
7575

7676

7777
@_register_kernel_internal(horizontal_flip, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -159,7 +159,7 @@ def vertical_flip_keypoints(keypoints: torch.Tensor, canvas_size: tuple[int, int
159159
shape = keypoints.shape
160160
keypoints = keypoints.clone().reshape(-1, 2)
161161
keypoints[..., 1] = keypoints[..., 1].sub_(canvas_size[0] - 1).neg_()
162-
return clamp_keypoints(keypoints.reshape(shape), canvas_size=canvas_size)
162+
return keypoints.reshape(shape)
163163

164164

165165
def vertical_flip_bounding_boxes(
@@ -1026,7 +1026,7 @@ def _affine_keypoints_with_expand(
10261026
new_width, new_height = _compute_affine_output_size(affine_vector, width, height)
10271027
canvas_size = (new_height, new_width)
10281028

1029-
out_keypoints = clamp_keypoints(transformed_points, canvas_size=canvas_size).reshape(original_shape)
1029+
out_keypoints = transformed_points.reshape(original_shape)
10301030
out_keypoints = out_keypoints.to(original_dtype)
10311031

10321032
return out_keypoints, canvas_size
@@ -1695,7 +1695,7 @@ def pad_keypoints(
16951695
left, right, top, bottom = _parse_pad_padding(padding)
16961696
pad = torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
16971697
canvas_size = (canvas_size[0] + top + bottom, canvas_size[1] + left + right)
1698-
return clamp_keypoints(keypoints + pad, canvas_size), canvas_size
1698+
return keypoints + pad, canvas_size
16991699

17001700

17011701
@_register_kernel_internal(pad, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -1817,7 +1817,7 @@ def crop_keypoints(
18171817
keypoints = keypoints - torch.tensor([left, top], dtype=keypoints.dtype, device=keypoints.device)
18181818
canvas_size = (height, width)
18191819

1820-
return clamp_keypoints(keypoints, canvas_size=canvas_size), canvas_size
1820+
return keypoints, canvas_size
18211821

18221822

18231823
@_register_kernel_internal(crop, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2047,7 +2047,7 @@ def perspective_keypoints(
20472047
numer_points = torch.matmul(points, theta1.T)
20482048
denom_points = torch.matmul(points, theta2.T)
20492049
transformed_points = numer_points.div_(denom_points)
2050-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size).reshape(original_shape)
2050+
return transformed_points.to(keypoints.dtype).reshape(original_shape)
20512051

20522052

20532053
@_register_kernel_internal(perspective, tv_tensors.KeyPoints, tv_tensor_wrapper=False)
@@ -2376,7 +2376,7 @@ def elastic_keypoints(
23762376
t_size = torch.tensor(canvas_size[::-1], device=displacement.device, dtype=displacement.dtype)
23772377
transformed_points = inv_grid[0, index_y, index_x, :].add_(1).mul_(0.5 * t_size).sub_(0.5)
23782378

2379-
return clamp_keypoints(transformed_points.to(keypoints.dtype), canvas_size=canvas_size).reshape(original_shape)
2379+
return transformed_points.to(keypoints.dtype).reshape(original_shape)
23802380

23812381

23822382
@_register_kernel_internal(elastic, tv_tensors.KeyPoints, tv_tensor_wrapper=False)

0 commit comments

Comments
 (0)