Skip to content

Commit 7562ee0

Browse files
fix
1 parent c87748a commit 7562ee0

File tree

5 files changed

+17
-15
lines changed

5 files changed

+17
-15
lines changed

Utils/CelebaImageProcessor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from Utils.CImageProcessor import CImageProcessor
2+
3+
def CelebaImageProcessor(image_size, to_grayscale):
4+
return CImageProcessor(
5+
image_size=image_size,
6+
to_grayscale=to_grayscale,
7+
format='RGB', # in CelebA images are in RGB format
8+
range='0..255' # in CelebA images are in the 0..255 range
9+
)

Utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def ImageProcessor_from_config(config):
1717
config = dict(name='celeba', image_size=64, toGrayscale=True)
1818

1919
if isinstance(config, dict) and ('celeba' == config['name'].lower()):
20-
from Utils.celeba import CelebaImageProcessor
20+
from Utils.CelebaImageProcessor import CelebaImageProcessor
2121
return CelebaImageProcessor(
2222
image_size=config['image_size'],
2323
to_grayscale=config.get('toGrayscale', True),

Utils/celeba.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,6 @@
11
import tensorflow_datasets as tfds
22
from Utils.utils import masking_from_config
3-
from Utils.CImageProcessor import CImageProcessor
4-
5-
def CelebaImageProcessor(image_size, to_grayscale):
6-
return CImageProcessor(
7-
image_size=image_size,
8-
to_grayscale=to_grayscale,
9-
format='RGB', # in CelebA images are in RGB format
10-
range='0..255' # in CelebA images are in the 0..255 range
11-
)
3+
from Utils.CelebaImageProcessor import CelebaImageProcessor
124

135
class CCelebADataset:
146
def __init__(self, batch_size=32, image_size=64, toGrayscale=True):

Utils/colors.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ def f(x):
1313
def _isRangeUInt8(x):
1414
tf.debugging.assert_integer(x)
1515
# x is in the 0..255 range
16-
tf.debugging.assert_less_equal(tf.reduce_min(x), tf.cast(0, x.dtype))
17-
tf.debugging.assert_greater_equal(tf.reduce_max(x), tf.cast(255, x.dtype))
16+
tf.debugging.assert_greater_equal(tf.reduce_min(x), tf.cast(0, x.dtype))
17+
tf.debugging.assert_less_equal(tf.reduce_max(x), tf.cast(255, x.dtype))
1818
return
1919

2020
def _from01range(to_):
@@ -38,14 +38,14 @@ def _fromUInt8Range(to_):
3838
if ('-1..1' == to_):
3939
return CFakeObject(
4040
convert=lambda x: (tf.cast(x, tf.float32) / 127.5) - 1.0,
41-
convertBack=lambda x: (x + 1.0) * 127.5,
41+
convertBack=lambda x: tf.cast((x + 1.0) * 127.5, tf.uint8),
4242
check=_isRangeUInt8
4343
)
4444

4545
if ('0..1' == to_):
4646
return CFakeObject(
4747
convert=lambda x: tf.cast(x, tf.float32) / 255.0,
48-
convertBack=lambda x: x * 255.0,
48+
convertBack=lambda x: tf.cast(x * 255.0, tf.uint8),
4949
check=_isRangeUInt8
5050
)
5151

huggingface/app.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def _processImage(modelName, inputImage, **kwargs):
2929
assert 3 == inputImage.shape[-1], f'Invalid image channels: {inputImage.shape[-1]}'
3030
# should be 64x64, because of preprocessing
3131
assert (64, 64) == inputImage.shape[:2], f'Invalid image shape: {inputImage.shape}'
32-
input = image_processor.range.convert(inputImage[None])
32+
# inputImage has 3 channels, but its grayscale
33+
input = image_processor.range.convert(inputImage[None, ..., :1])
3334

3435
model = models.get(modelName, None)
3536
assert model is not None, f'Invalid model name: {modelName}'

0 commit comments

Comments
 (0)