diff --git a/tensorflow_addons/image/BUILD b/tensorflow_addons/image/BUILD index a1a09d66bf..fd5577e213 100644 --- a/tensorflow_addons/image/BUILD +++ b/tensorflow_addons/image/BUILD @@ -18,6 +18,11 @@ py_library( "connected_components.py", "resampler_ops.py", "compose_ops.py", + "solarize_ops.py", + "posterize_ops.py", + "auto_contrast_ops.py", + "to_grayscale_ops.py", + "color_jitter_ops.py", ]), data = [ ":sparse_image_warp_test_data", @@ -177,3 +182,63 @@ py_test( ":image", ], ) + +py_test( + name = "solarize_ops_test", + size = "medium", + srcs = [ + "solarize_ops_test.py", + ], + main = "solarize_ops_test.py", + deps = [ + ":image", + ], +) + +py_test( + name = "posterize_ops_test", + size = "medium", + srcs = [ + "posterize_ops_test.py", + ], + main = "posterize_ops_test.py", + deps = [ + ":image", + ], +) + +py_test( + name = "auto_contrast_ops_test", + size = "medium", + srcs = [ + "auto_contrast_ops_test.py", + ], + main = "auto_contrast_ops_test.py", + deps = [ + ":image", + ], +) + +py_test( + name = "to_grayscale_ops_test", + size = "medium", + srcs = [ + "to_grayscale_ops_test.py", + ], + main = "to_grayscale_ops_test.py", + deps = [ + ":image", + ], +) + +py_test( + name = "color_jitter_ops", + size = "medium", + srcs = [ + "color_jitter_ops.py", + ], + main = "color_jitter_ops.py", + deps = [ + ":image", + ], +) diff --git a/tensorflow_addons/image/__init__.py b/tensorflow_addons/image/__init__.py index fbd5cda029..efa01bde3c 100644 --- a/tensorflow_addons/image/__init__.py +++ b/tensorflow_addons/image/__init__.py @@ -29,3 +29,9 @@ from tensorflow_addons.image.transform_ops import transform from tensorflow_addons.image.translate_ops import translate from tensorflow_addons.image.compose_ops import blend +from tensorflow_addons.image.solarize_ops import solarize +from tensorflow_addons.image.solarize_ops import solarize_add +from tensorflow_addons.image.posterize_ops import posterize +from tensorflow_addons.image.auto_contrast_ops import autocontrast +from tensorflow_addons.image.to_grayscale_ops import to_grayscale +from tensorflow_addons.image.color_jitter_ops import color_jitter diff --git a/tensorflow_addons/image/auto_contrast_ops.py b/tensorflow_addons/image/auto_contrast_ops.py new file mode 100644 index 0000000000..674d02e5a8 --- /dev/null +++ b/tensorflow_addons/image/auto_contrast_ops.py @@ -0,0 +1,60 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" Maximize (normalize) image contrast. +This function calculates a histogram of the input image, +removes cutoff percent of the lightest and darkest pixels + from the histogram, and remaps the image so +that the darkest pixel becomes black (0), +and the lightest becomes white (255). """ + +import tensorflow as tf +from tensorflow_addons.utils.types import TensorLike + + +def autocontrast(image: TensorLike) -> TensorLike: + """Implements Autocontrast function from PIL using TF ops. + Args: + image: A 3D uint8 tensor. + Returns: + The image after it has had autocontrast applied to it and will be of type + uint8. + """ + + def scale_channel(image: TensorLike) -> TensorLike: + """Scale the 2D image using the autocontrast rule.""" + # A possibly cheaper version can be done using cumsum/unique_with_counts + # over the histogram values, rather than iterating over the entire image. + # to compute mins and maxes. + lo = tf.cast(tf.reduce_min(image), dtype=tf.float32) + hi = tf.cast(tf.reduce_max(image), dtype=tf.float32) + + # Scale the image, making the lowest value 0 and the highest value 255. + def scale_values(im: TensorLike) -> TensorLike: + scale = 255.0 / (hi - lo) + offset = -lo * scale + im = tf.cast(im, dtype=tf.float32) * scale + offset + im = tf.clip_by_value(im, 0.0, 255.0) + return tf.cast(im, tf.uint8) + + result = tf.cond(hi > lo, lambda: scale_values(image), lambda: image) + return result + + # Assumes RGB for now. Scales each channel independently + # and then stacks the result. + s1 = scale_channel(image[:, :, 0]) + s2 = scale_channel(image[:, :, 1]) + s3 = scale_channel(image[:, :, 2]) + image = tf.stack([s1, s2, s3], 2) + return image diff --git a/tensorflow_addons/image/auto_contrast_ops_test.py b/tensorflow_addons/image/auto_contrast_ops_test.py new file mode 100644 index 0000000000..6ac2ec2f8f --- /dev/null +++ b/tensorflow_addons/image/auto_contrast_ops_test.py @@ -0,0 +1,39 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" Test auto_contrast_ops """ +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import auto_contrast_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class AutoContrastTest(tf.test.TestCase, parameterized.TestCase): + """AutoContrastTest class to test the working of + methods images""" + + def test_contrast(self): + """ Method to test the auto_contrast technique on images """ + if tf.executing_eagerly(): + image = tf.constant([[1, 1], [1, 1]], dtype=tf.uint8) + stacked_img = tf.stack([image] * 3, 2) + contrast_image = auto_contrast_ops.autocontrast(stacked_img) + self.assertAllEqual(tf.shape(contrast_image), tf.shape(stacked_img)) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/tensorflow_addons/image/color_jitter_ops.py b/tensorflow_addons/image/color_jitter_ops.py new file mode 100644 index 0000000000..9b4e67df7e --- /dev/null +++ b/tensorflow_addons/image/color_jitter_ops.py @@ -0,0 +1,142 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" This method is used to distorts the color of the image """ +import tensorflow as tf +from tensorflow_addons.utils.types import TensorLike + + +def color_jitter( + image: TensorLike, strength: float, random_order: bool = True +) -> TensorLike: + """Distorts the color of the image. + Args: + image: The input image tensor. + strength: the floating number for the strength of the color augmentation <= 2.0. + random_order: A bool, specifying whether to randomize the jittering order. + Returns: + The distorted image tensor. + """ + image = tf.cast(image, dtype=tf.dtypes.float32) + brightness = 0.8 * strength + contrast = 0.8 * strength + saturation = 0.8 * strength + hue = 0.2 * strength + if random_order: + return color_jitter_rand(image, brightness, contrast, saturation, hue) + else: + return color_jitter_nonrand(image, brightness, contrast, saturation, hue) + + +def color_jitter_nonrand( + image: TensorLike, + brightness: float = 0, + contrast: float = 0, + saturation: float = 0, + hue: float = 0, +) -> TensorLike: + """Distorts the color of the image (jittering order is fixed). + Args: + image: The input image tensor. + brightness: A float, specifying the brightness for color jitter. + contrast: A float, specifying the contrast for color jitter. + saturation: A float, specifying the saturation for color jitter. + hue: A float, specifying the hue for color jitter. + Returns: + The distorted image tensor. + """ + with tf.name_scope("distort_color"): + + def apply_transform(i, x, brightness, contrast, saturation, hue): + """Apply the i-th transformation.""" + if brightness != 0 and i == 0: + x = tf.image.random_brightness(x, max_delta=brightness) + elif contrast != 0 and i == 1: + x = tf.image.random_contrast(x, lower=1 - contrast, upper=1 + contrast) + elif saturation != 0 and i == 2: + x = tf.image.random_saturation( + x, lower=1 - saturation, upper=1 + saturation + ) + elif hue != 0: + x = tf.image.random_hue(x, max_delta=hue) + return x + + for i in range(4): + image = apply_transform(i, image, brightness, contrast, saturation, hue) + image = tf.clip_by_value(image, 0.0, 1.0) + return image + + +def color_jitter_rand( + image: TensorLike, + brightness: float = 0, + contrast: float = 0, + saturation: float = 0, + hue: float = 0, +) -> TensorLike: + """Distorts the color of the image (jittering order is random). + Args: + image: The input image tensor. + brightness: A float, specifying the brightness for color jitter. + contrast: A float, specifying the contrast for color jitter. + saturation: A float, specifying the saturation for color jitter. + hue: A float, specifying the hue for color jitter. + Returns: + The distorted image tensor. + """ + with tf.name_scope("distort_color"): + + def apply_transform(i, x): + """Apply the i-th transformation.""" + + def brightness_foo(): + if brightness == 0: + return x + else: + return tf.image.random_brightness(x, max_delta=brightness) + + def contrast_foo(): + if contrast == 0: + return x + else: + return tf.image.random_contrast( + x, lower=tf.math.abs(1 - contrast), upper=1 + contrast + ) + + def saturation_foo(): + if saturation == 0: + return x + else: + return tf.image.random_saturation( + x, lower=tf.math.abs(1 - saturation), upper=1 + saturation + ) + + def hue_foo(): + if hue == 0: + return x + else: + return tf.image.random_hue(x, max_delta=hue) + + x = tf.cond( + tf.less(i, 2), + lambda: tf.cond(tf.less(i, 1), brightness_foo, contrast_foo), + lambda: tf.cond(tf.less(i, 3), saturation_foo, hue_foo), + ) + return x + + perm = tf.random.shuffle(tf.range(4)) + for i in range(4): + image = apply_transform(perm[i], image) + image = tf.clip_by_value(image, 0.0, 1.0) + return tf.cast(image, dtype=tf.uint8) diff --git a/tensorflow_addons/image/color_jitter_ops_test.py b/tensorflow_addons/image/color_jitter_ops_test.py new file mode 100644 index 0000000000..68d6c77c31 --- /dev/null +++ b/tensorflow_addons/image/color_jitter_ops_test.py @@ -0,0 +1,41 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Test of color_jitter method""" + +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import color_jitter_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class ColorJitterTest(tf.test.TestCase, parameterized.TestCase): + """ColorJitterTest class to test the color distortion image operation""" + + def test_color_jitter(self): + """ Method to test the color distortion technique on images """ + if tf.executing_eagerly(): + image = tf.constant([[1, 2], [5, 3]], dtype=tf.uint8) + stacked_img = tf.stack([image] * 3, 2) + strength = 0.3 + jitter_image = color_jitter_ops.color_jitter(stacked_img, strength) + self.assertAllEqual(tf.shape(jitter_image), tf.shape(stacked_img)) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/tensorflow_addons/image/posterize_ops.py b/tensorflow_addons/image/posterize_ops.py new file mode 100644 index 0000000000..0b83df757f --- /dev/null +++ b/tensorflow_addons/image/posterize_ops.py @@ -0,0 +1,32 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" Posterize method used to reduce the + number of bits foreach color channel""" + +import tensorflow as tf +from tensorflow_addons.utils.types import TensorLike + + +def posterize(image: TensorLike, bits: int) -> TensorLike: + """Reduce the number of bits for each color channel. + Args: + image: The image to posterize + bits: The number of bits to keep for each channel(1-8) + + Returns: + An image + """ + shift = 8 - bits + return tf.bitwise.left_shift(tf.bitwise.right_shift(image, shift), shift) diff --git a/tensorflow_addons/image/posterize_ops_test.py b/tensorflow_addons/image/posterize_ops_test.py new file mode 100644 index 0000000000..18c9de62f4 --- /dev/null +++ b/tensorflow_addons/image/posterize_ops_test.py @@ -0,0 +1,40 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test of solarize_ops""" + +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import posterize_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class PosterizeOpsTest(tf.test.TestCase, parameterized.TestCase): + """PosterizeOpsTest class to test the working of + methods images""" + + def test_posterize(self): + """ Method to test the posterize technique on images """ + if tf.executing_eagerly(): + image = tf.constant(tf.ones([4, 4], dtype=tf.dtypes.uint8)) * 255 + bits = 2 + posterize_image = posterize_ops.posterize(image, bits) + self.assertAllEqual(tf.shape(image), tf.shape(posterize_image)) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/tensorflow_addons/image/solarize_ops.py b/tensorflow_addons/image/solarize_ops.py new file mode 100644 index 0000000000..7eaec699d6 --- /dev/null +++ b/tensorflow_addons/image/solarize_ops.py @@ -0,0 +1,54 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" This module is used to invert all pixel values above a threshold + which simply means segmentation. """ + +import tensorflow as tf +from tensorflow_addons.utils.types import TensorLike + + +def solarize(image: TensorLike, threshold: float = 128) -> TensorLike: + + """Method to solarize the image + image: input image + threshold: threshold value to solarize the image + + Returns: + A solarized image + """ + # For each pixel in the image, select the pixel + # if the value is less than the threshold. + # Otherwise, subtract 255 from the pixel. + return tf.where(image < threshold, image, 255 - image) + + +def solarize_add( + image: TensorLike, addition: int = 0, threshold: float = 128 +) -> TensorLike: + """Method to add solarize to the image + image: input image + addition: addition amount to add in image + threshold: threshold value to solarize the image + + Returns: + Solarized image with addition values + """ + # For each pixel in the image less than threshold + # we add 'addition' amount to it and then clip the + # pixel value to be between 0 and 255. The value + # of 'addition' is between -128 and 128. + added_image = tf.cast(image, tf.int64) + addition + added_image = tf.cast(tf.clip_by_value(added_image, 0, 255), tf.uint8) + return tf.where(image < threshold, added_image, image) diff --git a/tensorflow_addons/image/solarize_ops_test.py b/tensorflow_addons/image/solarize_ops_test.py new file mode 100644 index 0000000000..1c3c9ed230 --- /dev/null +++ b/tensorflow_addons/image/solarize_ops_test.py @@ -0,0 +1,38 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test of solarize_ops""" + +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import solarize_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class SolarizeOPSTest(tf.test.TestCase, parameterized.TestCase): + """SolarizeOPSTest class to test the solarize images""" + + def test_solarize(self): + if tf.executing_eagerly(): + image2 = tf.constant(tf.ones([4, 4], dtype=tf.uint8)) * 255 + threshold = 10 + solarize_img = solarize_ops.solarize(image2, threshold) + self.assertAllEqual(tf.shape(solarize_img), tf.shape(image2)) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__])) diff --git a/tensorflow_addons/image/to_grayscale_ops.py b/tensorflow_addons/image/to_grayscale_ops.py new file mode 100644 index 0000000000..cd70612dea --- /dev/null +++ b/tensorflow_addons/image/to_grayscale_ops.py @@ -0,0 +1,33 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" method to convert the color images into grayscale +by keeping the channel same""" +import tensorflow as tf +from tensorflow_addons.utils.types import TensorLike + + +def to_grayscale(image: TensorLike, keep_channels: bool = True) -> TensorLike: + """ Method to convert the color image into grayscale + by keeping the channels same. + + Args: + image: color image to convert into grayscale + keep_channels: boolean parameter for channels + Returns: + Image""" + image = tf.image.rgb_to_grayscale(image) + if keep_channels: + image = tf.tile(image, [1, 1, 3]) + return image diff --git a/tensorflow_addons/image/to_grayscale_ops_test.py b/tensorflow_addons/image/to_grayscale_ops_test.py new file mode 100644 index 0000000000..a82a3f71e2 --- /dev/null +++ b/tensorflow_addons/image/to_grayscale_ops_test.py @@ -0,0 +1,39 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Test of grayscale method""" + +import sys +import pytest +import tensorflow as tf +from tensorflow_addons.image import to_grayscale_ops +from tensorflow_addons.utils import test_utils +from absl.testing import parameterized + + +@test_utils.run_all_in_graph_and_eager_modes +class ToGrayScaleOpsTest(tf.test.TestCase, parameterized.TestCase): + """ToGrayScaleOpsTest class to test the grayscale image operation""" + + def test_grayscale(self): + """ Method to test the grayscale technique on images """ + if tf.executing_eagerly(): + image = tf.constant([[1, 2], [5, 3]], dtype=tf.uint8) + stacked_img = tf.stack([image] * 3, 2) + grayscale_image = to_grayscale_ops.to_grayscale(stacked_img) + self.assertAllEqual(tf.shape(grayscale_image), tf.shape(stacked_img)) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__]))