Skip to content

Commit f0d2812

Browse files
author
Lars Haalck
committed
fix: dtype might change during resize
1 parent 014047e commit f0d2812

File tree

2 files changed

+7
-0
lines changed

2 files changed

+7
-0
lines changed

src/transformers/image_transforms.py

+2
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,7 @@ def resize(
365365
# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
366366
# the pillow library to resize the image and then convert back to numpy
367367
do_rescale = False
368+
original_type = image.dtype
368369
if not isinstance(image, PIL.Image.Image):
369370
do_rescale = _rescale_for_pil_conversion(image)
370371
image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format)
@@ -384,6 +385,7 @@ def resize(
384385
# If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to
385386
# rescale it back to the original range.
386387
resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image
388+
resized_image = resized_image.astype(original_type)
387389
return resized_image
388390

389391

tests/test_image_transforms.py

+5
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,11 @@ def test_resize(self):
277277
self.assertIsInstance(resized_image, np.ndarray)
278278
self.assertEqual(resized_image.shape, (4, 30, 40))
279279

280+
# check that resize keeps dtype
281+
image = np.zeros((1, 128, 128), dtype=np.float32)
282+
resized_image = resize(image, size=(64, 64))
283+
self.assertEqual(image.dtype, resized_image.dtype)
284+
280285
def test_normalize(self):
281286
image = np.random.randint(0, 256, (224, 224, 3)) / 255
282287

0 commit comments

Comments
 (0)