diff --git a/src/transformers/image_transforms.py b/src/transformers/image_transforms.py index aaadcb44586a..25448f950d53 100644 --- a/src/transformers/image_transforms.py +++ b/src/transformers/image_transforms.py @@ -366,9 +366,12 @@ def resize( # To maintain backwards compatibility with the resizing done in previous image feature extractors, we use # the pillow library to resize the image and then convert back to numpy do_rescale = False + original_type = None if not isinstance(image, PIL.Image.Image): + original_type = image.dtype do_rescale = _rescale_for_pil_conversion(image) image = to_pil_image(image, do_rescale=do_rescale, input_data_format=input_data_format) + height, width = size # PIL images are in the format (width, height) resized_image = image.resize((width, height), resample=resample, reducing_gap=reducing_gap) @@ -385,6 +388,9 @@ def resize( # If an image was rescaled to be in the range [0, 255] before converting to a PIL image, then we need to # rescale it back to the original range. resized_image = rescale(resized_image, 1 / 255) if do_rescale else resized_image + # convert back to original type if original image was np.ndarray + if original_type is not None: + resized_image = resized_image.astype(original_type) return resized_image diff --git a/tests/test_image_transforms.py b/tests/test_image_transforms.py index 560ea6a36b40..403f811ef089 100644 --- a/tests/test_image_transforms.py +++ b/tests/test_image_transforms.py @@ -277,6 +277,11 @@ def test_resize(self): self.assertIsInstance(resized_image, np.ndarray) self.assertEqual(resized_image.shape, (4, 30, 40)) + # check that resize keeps dtype + image = np.zeros((1, 128, 128), dtype=np.float32) + resized_image = resize(image, size=(64, 64)) + self.assertEqual(image.dtype, resized_image.dtype) + def test_normalize(self): image = np.random.randint(0, 256, (224, 224, 3)) / 255