Skip to content

Commit 6919b83

Browse files
committed
Fix BytePairTokenizer
1 parent 658ded0 commit 6919b83

File tree

1 file changed

+79
-81
lines changed

1 file changed

+79
-81
lines changed

keras_cv/models/feature_extractor/clip/clip_tokenizer.py

Lines changed: 79 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
import regex as re
1515
import tensorflow as tf
1616
import tensorflow_text as tf_text
17+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
1718

1819
try:
1920
from keras_nlp.tokenizers import BytePairTokenizer
2021
except ImportError:
21-
BytePairTokenizer = None
22+
BytePairTokenizer = object
2223

2324
# As python and TF handles special spaces differently, we need to
2425
# manually handle special spaces during string split.
@@ -103,83 +104,80 @@ def remove_strings_from_inputs(tensor, string_to_remove):
103104
return result
104105

105106

106-
if BytePairTokenizer:
107-
class CLIPTokenizer(BytePairTokenizer):
108-
def __init__(self, **kwargs):
109-
super().__init__(**kwargs)
110-
111-
def _bpe_merge_and_update_cache(self, tokens):
112-
"""Process unseen tokens and add to cache."""
113-
words = self._transform_bytes(tokens)
114-
tokenized_words = self._bpe_merge(words)
115-
116-
# For each word, join all its token by a whitespace,
117-
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
118-
tokenized_words = tf.strings.reduce_join(
119-
tokenized_words,
120-
axis=1,
121-
)
122-
self.cache.insert(tokens, tokenized_words)
123-
124-
def tokenize(self, inputs):
125-
self._check_vocabulary()
126-
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
127-
inputs = tf.convert_to_tensor(inputs)
128-
129-
if self.add_prefix_space:
130-
inputs = tf.strings.join([" ", inputs])
131-
132-
scalar_input = inputs.shape.rank == 0
133-
if scalar_input:
134-
inputs = tf.expand_dims(inputs, 0)
135-
136-
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
137-
token_row_splits = raw_tokens.row_splits
138-
flat_tokens = raw_tokens.flat_values
139-
# Check cache.
140-
cache_lookup = self.cache.lookup(flat_tokens)
141-
cache_mask = cache_lookup == ""
142-
143-
has_unseen_words = tf.math.reduce_any(
144-
(cache_lookup == "") & (flat_tokens != "")
145-
)
146-
147-
def process_unseen_tokens():
148-
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
149-
self._bpe_merge_and_update_cache(unseen_tokens)
150-
return self.cache.lookup(flat_tokens)
151-
152-
# If `has_unseen_words == True`, it means not all tokens are,
153-
# in cache we will process the unseen tokens. Otherwise
154-
# return the cache lookup.
155-
tokenized_words = tf.cond(
156-
has_unseen_words,
157-
process_unseen_tokens,
158-
lambda: cache_lookup,
159-
)
160-
tokens = tf.strings.split(tokenized_words, sep=" ")
161-
if self.compute_dtype != tf.string:
162-
# Encode merged tokens.
163-
tokens = self.token_to_id_map.lookup(tokens)
164-
165-
# Unflatten to match input.
166-
tokens = tf.RaggedTensor.from_row_splits(
167-
tokens.flat_values,
168-
tf.gather(tokens.row_splits, token_row_splits),
169-
)
170-
171-
# Convert to a dense output if `sequence_length` is set.
172-
if self.sequence_length:
173-
output_shape = tokens.shape.as_list()
174-
output_shape[-1] = self.sequence_length
175-
tokens = tokens.to_tensor(shape=output_shape)
176-
177-
# Convert to a dense output if input in scalar
178-
if scalar_input:
179-
tokens = tf.squeeze(tokens, 0)
180-
tf.ensure_shape(tokens, shape=[self.sequence_length])
181-
182-
return tokens
183-
184-
else:
185-
CLIPTokenizer = None
107+
class CLIPTokenizer(BytePairTokenizer):
108+
def __init__(self, **kwargs):
109+
assert_keras_nlp_installed("CLIPTokenizer")
110+
super().__init__(**kwargs)
111+
112+
def _bpe_merge_and_update_cache(self, tokens):
113+
"""Process unseen tokens and add to cache."""
114+
words = self._transform_bytes(tokens)
115+
tokenized_words = self._bpe_merge(words)
116+
117+
# For each word, join all its token by a whitespace,
118+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
119+
tokenized_words = tf.strings.reduce_join(
120+
tokenized_words,
121+
axis=1,
122+
)
123+
self.cache.insert(tokens, tokenized_words)
124+
125+
def tokenize(self, inputs):
126+
self._check_vocabulary()
127+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
128+
inputs = tf.convert_to_tensor(inputs)
129+
130+
if self.add_prefix_space:
131+
inputs = tf.strings.join([" ", inputs])
132+
133+
scalar_input = inputs.shape.rank == 0
134+
if scalar_input:
135+
inputs = tf.expand_dims(inputs, 0)
136+
137+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
138+
token_row_splits = raw_tokens.row_splits
139+
flat_tokens = raw_tokens.flat_values
140+
# Check cache.
141+
cache_lookup = self.cache.lookup(flat_tokens)
142+
cache_mask = cache_lookup == ""
143+
144+
has_unseen_words = tf.math.reduce_any(
145+
(cache_lookup == "") & (flat_tokens != "")
146+
)
147+
148+
def process_unseen_tokens():
149+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
150+
self._bpe_merge_and_update_cache(unseen_tokens)
151+
return self.cache.lookup(flat_tokens)
152+
153+
# If `has_unseen_words == True`, it means not all tokens are,
154+
# in cache we will process the unseen tokens. Otherwise
155+
# return the cache lookup.
156+
tokenized_words = tf.cond(
157+
has_unseen_words,
158+
process_unseen_tokens,
159+
lambda: cache_lookup,
160+
)
161+
tokens = tf.strings.split(tokenized_words, sep=" ")
162+
if self.compute_dtype != tf.string:
163+
# Encode merged tokens.
164+
tokens = self.token_to_id_map.lookup(tokens)
165+
166+
# Unflatten to match input.
167+
tokens = tf.RaggedTensor.from_row_splits(
168+
tokens.flat_values,
169+
tf.gather(tokens.row_splits, token_row_splits),
170+
)
171+
172+
# Convert to a dense output if `sequence_length` is set.
173+
if self.sequence_length:
174+
output_shape = tokens.shape.as_list()
175+
output_shape[-1] = self.sequence_length
176+
tokens = tokens.to_tensor(shape=output_shape)
177+
178+
# Convert to a dense output if input in scalar
179+
if scalar_input:
180+
tokens = tf.squeeze(tokens, 0)
181+
tf.ensure_shape(tokens, shape=[self.sequence_length])
182+
183+
return tokens

0 commit comments

Comments
 (0)