Skip to content

Commit 386429d

Browse files
No public description
PiperOrigin-RevId: 904496953
1 parent b4f2098 commit 386429d

2 files changed

Lines changed: 20 additions & 6 deletions

File tree

official/nlp/tools/tokenization.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,16 +428,18 @@ def preprocess_text(inputs, remove_space=True, lower=False):
428428
The preprocessed text.
429429
430430
"""
431+
# Byte strings need to be explicitly decoded to unicode text,
432+
# typically using UTF-8. A latin-1 fallback is included for
433+
# backward compatibility with legacy sentence piece models.
434+
if isinstance(inputs, six.binary_type):
435+
try:
436+
inputs = six.ensure_text(inputs, "utf-8")
437+
except UnicodeDecodeError:
438+
inputs = six.ensure_text(inputs, "latin-1")
431439
outputs = inputs
432440
if remove_space:
433441
outputs = " ".join(inputs.strip().split())
434442

435-
if six.PY2 and isinstance(outputs, str):
436-
try:
437-
outputs = six.ensure_text(outputs, "utf-8")
438-
except UnicodeDecodeError:
439-
outputs = six.ensure_text(outputs, "latin-1")
440-
441443
outputs = unicodedata.normalize("NFKD", outputs)
442444
outputs = "".join([c for c in outputs if not unicodedata.combining(c)])
443445
if lower:

official/nlp/tools/tokenization_test.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,18 @@ def test_is_punctuation(self):
151151
self.assertFalse(tokenization._is_punctuation(u"A"))
152152
self.assertFalse(tokenization._is_punctuation(u" "))
153153

154+
def test_preprocess_text(self):
155+
self.assertEqual(tokenization.preprocess_text("hello world"), "hello world")
156+
self.assertEqual(tokenization.preprocess_text(b"hello \xc3\xa9"), "hello e")
157+
self.assertEqual(tokenization.preprocess_text(b"hello \xe9"), "hello e")
158+
self.assertEqual(
159+
tokenization.preprocess_text(b"hello world", remove_space=True),
160+
"hello world",
161+
)
162+
self.assertEqual(
163+
tokenization.preprocess_text("Hello World", lower=True), "hello world"
164+
)
165+
154166

155167
if __name__ == "__main__":
156168
tf.test.main()

0 commit comments

Comments
 (0)