|
14 | 14 | import regex as re
|
15 | 15 | import tensorflow as tf
|
16 | 16 | import tensorflow_text as tf_text
|
| 17 | +from keras_cv.utils.conditional_imports import assert_keras_nlp_installed |
17 | 18 |
|
18 | 19 | try:
|
19 | 20 | from keras_nlp.tokenizers import BytePairTokenizer
|
20 | 21 | except ImportError:
|
21 |
| - BytePairTokenizer = None |
| 22 | + BytePairTokenizer = object |
22 | 23 |
|
23 | 24 | # As python and TF handles special spaces differently, we need to
|
24 | 25 | # manually handle special spaces during string split.
|
@@ -103,83 +104,80 @@ def remove_strings_from_inputs(tensor, string_to_remove):
|
103 | 104 | return result
|
104 | 105 |
|
105 | 106 |
|
106 |
| -if BytePairTokenizer: |
107 |
| - class CLIPTokenizer(BytePairTokenizer): |
108 |
| - def __init__(self, **kwargs): |
109 |
| - super().__init__(**kwargs) |
110 |
| - |
111 |
| - def _bpe_merge_and_update_cache(self, tokens): |
112 |
| - """Process unseen tokens and add to cache.""" |
113 |
| - words = self._transform_bytes(tokens) |
114 |
| - tokenized_words = self._bpe_merge(words) |
115 |
| - |
116 |
| - # For each word, join all its token by a whitespace, |
117 |
| - # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
118 |
| - tokenized_words = tf.strings.reduce_join( |
119 |
| - tokenized_words, |
120 |
| - axis=1, |
121 |
| - ) |
122 |
| - self.cache.insert(tokens, tokenized_words) |
123 |
| - |
124 |
| - def tokenize(self, inputs): |
125 |
| - self._check_vocabulary() |
126 |
| - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
127 |
| - inputs = tf.convert_to_tensor(inputs) |
128 |
| - |
129 |
| - if self.add_prefix_space: |
130 |
| - inputs = tf.strings.join([" ", inputs]) |
131 |
| - |
132 |
| - scalar_input = inputs.shape.rank == 0 |
133 |
| - if scalar_input: |
134 |
| - inputs = tf.expand_dims(inputs, 0) |
135 |
| - |
136 |
| - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
137 |
| - token_row_splits = raw_tokens.row_splits |
138 |
| - flat_tokens = raw_tokens.flat_values |
139 |
| - # Check cache. |
140 |
| - cache_lookup = self.cache.lookup(flat_tokens) |
141 |
| - cache_mask = cache_lookup == "" |
142 |
| - |
143 |
| - has_unseen_words = tf.math.reduce_any( |
144 |
| - (cache_lookup == "") & (flat_tokens != "") |
145 |
| - ) |
146 |
| - |
147 |
| - def process_unseen_tokens(): |
148 |
| - unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
149 |
| - self._bpe_merge_and_update_cache(unseen_tokens) |
150 |
| - return self.cache.lookup(flat_tokens) |
151 |
| - |
152 |
| - # If `has_unseen_words == True`, it means not all tokens are, |
153 |
| - # in cache we will process the unseen tokens. Otherwise |
154 |
| - # return the cache lookup. |
155 |
| - tokenized_words = tf.cond( |
156 |
| - has_unseen_words, |
157 |
| - process_unseen_tokens, |
158 |
| - lambda: cache_lookup, |
159 |
| - ) |
160 |
| - tokens = tf.strings.split(tokenized_words, sep=" ") |
161 |
| - if self.compute_dtype != tf.string: |
162 |
| - # Encode merged tokens. |
163 |
| - tokens = self.token_to_id_map.lookup(tokens) |
164 |
| - |
165 |
| - # Unflatten to match input. |
166 |
| - tokens = tf.RaggedTensor.from_row_splits( |
167 |
| - tokens.flat_values, |
168 |
| - tf.gather(tokens.row_splits, token_row_splits), |
169 |
| - ) |
170 |
| - |
171 |
| - # Convert to a dense output if `sequence_length` is set. |
172 |
| - if self.sequence_length: |
173 |
| - output_shape = tokens.shape.as_list() |
174 |
| - output_shape[-1] = self.sequence_length |
175 |
| - tokens = tokens.to_tensor(shape=output_shape) |
176 |
| - |
177 |
| - # Convert to a dense output if input in scalar |
178 |
| - if scalar_input: |
179 |
| - tokens = tf.squeeze(tokens, 0) |
180 |
| - tf.ensure_shape(tokens, shape=[self.sequence_length]) |
181 |
| - |
182 |
| - return tokens |
183 |
| - |
184 |
| -else: |
185 |
| - CLIPTokenizer = None |
| 107 | +class CLIPTokenizer(BytePairTokenizer): |
| 108 | + def __init__(self, **kwargs): |
| 109 | + assert_keras_nlp_installed("CLIPTokenizer") |
| 110 | + super().__init__(**kwargs) |
| 111 | + |
| 112 | + def _bpe_merge_and_update_cache(self, tokens): |
| 113 | + """Process unseen tokens and add to cache.""" |
| 114 | + words = self._transform_bytes(tokens) |
| 115 | + tokenized_words = self._bpe_merge(words) |
| 116 | + |
| 117 | + # For each word, join all its token by a whitespace, |
| 118 | + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. |
| 119 | + tokenized_words = tf.strings.reduce_join( |
| 120 | + tokenized_words, |
| 121 | + axis=1, |
| 122 | + ) |
| 123 | + self.cache.insert(tokens, tokenized_words) |
| 124 | + |
| 125 | + def tokenize(self, inputs): |
| 126 | + self._check_vocabulary() |
| 127 | + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): |
| 128 | + inputs = tf.convert_to_tensor(inputs) |
| 129 | + |
| 130 | + if self.add_prefix_space: |
| 131 | + inputs = tf.strings.join([" ", inputs]) |
| 132 | + |
| 133 | + scalar_input = inputs.shape.rank == 0 |
| 134 | + if scalar_input: |
| 135 | + inputs = tf.expand_dims(inputs, 0) |
| 136 | + |
| 137 | + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) |
| 138 | + token_row_splits = raw_tokens.row_splits |
| 139 | + flat_tokens = raw_tokens.flat_values |
| 140 | + # Check cache. |
| 141 | + cache_lookup = self.cache.lookup(flat_tokens) |
| 142 | + cache_mask = cache_lookup == "" |
| 143 | + |
| 144 | + has_unseen_words = tf.math.reduce_any( |
| 145 | + (cache_lookup == "") & (flat_tokens != "") |
| 146 | + ) |
| 147 | + |
| 148 | + def process_unseen_tokens(): |
| 149 | + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) |
| 150 | + self._bpe_merge_and_update_cache(unseen_tokens) |
| 151 | + return self.cache.lookup(flat_tokens) |
| 152 | + |
| 153 | + # If `has_unseen_words == True`, it means not all tokens are, |
| 154 | + # in cache we will process the unseen tokens. Otherwise |
| 155 | + # return the cache lookup. |
| 156 | + tokenized_words = tf.cond( |
| 157 | + has_unseen_words, |
| 158 | + process_unseen_tokens, |
| 159 | + lambda: cache_lookup, |
| 160 | + ) |
| 161 | + tokens = tf.strings.split(tokenized_words, sep=" ") |
| 162 | + if self.compute_dtype != tf.string: |
| 163 | + # Encode merged tokens. |
| 164 | + tokens = self.token_to_id_map.lookup(tokens) |
| 165 | + |
| 166 | + # Unflatten to match input. |
| 167 | + tokens = tf.RaggedTensor.from_row_splits( |
| 168 | + tokens.flat_values, |
| 169 | + tf.gather(tokens.row_splits, token_row_splits), |
| 170 | + ) |
| 171 | + |
| 172 | + # Convert to a dense output if `sequence_length` is set. |
| 173 | + if self.sequence_length: |
| 174 | + output_shape = tokens.shape.as_list() |
| 175 | + output_shape[-1] = self.sequence_length |
| 176 | + tokens = tokens.to_tensor(shape=output_shape) |
| 177 | + |
| 178 | + # Convert to a dense output if input in scalar |
| 179 | + if scalar_input: |
| 180 | + tokens = tf.squeeze(tokens, 0) |
| 181 | + tf.ensure_shape(tokens, shape=[self.sequence_length]) |
| 182 | + |
| 183 | + return tokens |
0 commit comments