Skip to content

add sharpness #1452

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 56 additions & 2 deletions tensorflow_addons/image/color_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# ==============================================================================
"""Color operations.
equalize: Equalizes image histogram
sharpness: Sharpen image
"""

import tensorflow as tf

from tensorflow_addons.utils.types import TensorLike
from tensorflow_addons.utils.types import TensorLike, Number
from tensorflow_addons.image.utils import to_4D_image, from_4D_image
from tensorflow_addons.image.compose_ops import blend

from typing import Optional
from functools import partial
Expand Down Expand Up @@ -84,7 +86,7 @@ def equalize(
(num_images, num_rows, num_columns, num_channels) (NHWC), or
(num_images, num_channels, num_rows, num_columns) (NCHW), or
(num_rows, num_columns, num_channels) (HWC), or
(num_channels, num_rows, num_columns) (HWC), or
(num_channels, num_rows, num_columns) (CHW), or
(num_rows, num_columns) (HW). The rank must be statically known (the
shape is not `TensorShape(None)`).
data_format: Either 'channels_first' or 'channels_last'
Expand All @@ -98,3 +100,55 @@ def equalize(
fn = partial(equalize_image, data_format=data_format)
image = tf.map_fn(fn, image)
return from_4D_image(image, image_dims)


def sharpness_image(image: TensorLike, factor: Number) -> tf.Tensor:
"""Implements Sharpness function from PIL using TF ops."""
orig_image = image
image_dtype = image.dtype
# Make image 4D for conv operation.
image = tf.expand_dims(image, 0)
# SMOOTH PIL Kernel.
image = tf.cast(image, tf.float32)
kernel = (
tf.constant(
[[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, shape=[3, 3, 1, 1]
)
/ 13.0
)
# Tile across channel dimension.
kernel = tf.tile(kernel, [1, 1, 3, 1])
strides = [1, 1, 1, 1]
degenerate = tf.nn.depthwise_conv2d(
image, kernel, strides, padding="VALID", dilations=[1, 1]
)
degenerate = tf.clip_by_value(degenerate, 0.0, 255.0)
degenerate = tf.squeeze(tf.cast(degenerate, image_dtype), [0])

# For the borders of the resulting image, fill in the values of the
# original image.
mask = tf.ones_like(degenerate)
padded_mask = tf.pad(mask, [[1, 1], [1, 1], [0, 0]])
padded_degenerate = tf.pad(degenerate, [[1, 1], [1, 1], [0, 0]])
result = tf.where(tf.equal(padded_mask, 1), padded_degenerate, orig_image)
# Blend the final result.
blended = blend(result, orig_image, factor)
return tf.cast(blended, image_dtype)


def sharpness(image: TensorLike, factor: Number) -> tf.Tensor:
"""Change sharpness of image(s)

Args:
images: A tensor of shape
(num_images, num_rows, num_columns, num_channels) (NHWC), or
(num_rows, num_columns, num_channels) (HWC)
factor: A floating point value or Tensor above 0.0.
Returns:
Image(s) with the same type and shape as `images`, sharper.
"""
image_dims = tf.rank(image)
image = to_4D_image(image)
fn = partial(sharpness_image, factor=factor)
image = tf.map_fn(fn, image)
return from_4D_image(image, image_dims)
23 changes: 22 additions & 1 deletion tensorflow_addons/image/tests/color_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np

from tensorflow_addons.image import color_ops
from PIL import Image, ImageOps
from PIL import Image, ImageOps, ImageEnhance

_DTYPES = {
np.uint8,
Expand Down Expand Up @@ -53,3 +53,24 @@ def test_equalize_channel_first(shape):
image = tf.ones(shape=shape, dtype=tf.uint8)
equalized = color_ops.equalize(image, "channels_first")
np.testing.assert_equal(equalized.numpy(), image.numpy())


@pytest.mark.parametrize("dtype", _DTYPES)
@pytest.mark.parametrize("shape", [(5, 5, 3), (10, 5, 5, 3)])
def test_sharpness_dtype_shape(dtype, shape):
image = np.ones(shape=shape, dtype=dtype)
sharp = color_ops.sharpness(tf.constant(image), 0).numpy()
np.testing.assert_equal(sharp, image)
assert sharp.dtype == image.dtype


@pytest.mark.parametrize("factor", [0, 0.25, 0.5, 0.75, 1])
def test_sharpness_with_PIL(factor):
np.random.seed(0)
image = np.random.randint(low=0, high=255, size=(10, 5, 5, 3), dtype=np.uint8)
sharpened = np.stack(
[ImageEnhance.Sharpness(Image.fromarray(i)).enhance(factor) for i in image]
)
np.testing.assert_allclose(
color_ops.sharpness(tf.constant(image), factor).numpy(), sharpened, atol=1
)