From cfc062d7ea3fb146a3c4c33ca4c5505a70d94b77 Mon Sep 17 00:00:00 2001 From: Smit Lunagariya Date: Tue, 19 Mar 2024 06:50:43 +0000 Subject: [PATCH 1/4] Use conditional keras_nlp imports --- .../feature_extractor/clip/clip_model.py | 7 +- .../feature_extractor/clip/clip_processor.py | 10 +- .../feature_extractor/clip/clip_tokenizer.py | 158 +++++++++--------- keras_cv/utils/conditional_imports.py | 14 ++ 4 files changed, 97 insertions(+), 92 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_model.py b/keras_cv/models/feature_extractor/clip/clip_model.py index 860739388e..67868f2d71 100644 --- a/keras_cv/models/feature_extractor/clip/clip_model.py +++ b/keras_cv/models/feature_extractor/clip/clip_model.py @@ -26,6 +26,7 @@ CLIPTextEncoder, ) from keras_cv.models.task import Task +from keras_cv.utils.conditional_imports import assert_keras_nlp_installed from keras_cv.utils.python_utils import classproperty try: @@ -98,11 +99,7 @@ def __init__( **kwargs, ): super().__init__(**kwargs) - if keras_nlp is None: - raise ValueError( - "ClipTokenizer requires keras-nlp. Please install " - "using pip `pip install -U keras-nlp && pip install -U keras`" - ) + assert_keras_nlp_installed("CLIP") self.embed_dim = embed_dim self.image_resolution = image_resolution self.vision_layers = vision_layers diff --git a/keras_cv/models/feature_extractor/clip/clip_processor.py b/keras_cv/models/feature_extractor/clip/clip_processor.py index 16d8d24222..1a07a3d710 100644 --- a/keras_cv/models/feature_extractor/clip/clip_processor.py +++ b/keras_cv/models/feature_extractor/clip/clip_processor.py @@ -16,10 +16,10 @@ from keras_cv.backend import keras from keras_cv.backend import ops from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer +from keras_cv.utils.conditional_imports import assert_keras_nlp_installed try: import keras_nlp - from keras_nlp.layers import StartEndPacker except ImportError: keras_nlp = None @@ -50,11 +50,7 @@ class CLIPProcessor: """ def __init__(self, input_resolution, vocabulary, merges, **kwargs): - if keras_nlp is None: - raise ValueError( - "ClipTokenizer requires keras-nlp. Please install " - "using pip `pip install -U keras-nlp && pip install -U keras`" - ) + assert_keras_nlp_installed("CLIPProcessor") self.input_resolution = input_resolution self.vocabulary = vocabulary self.merges = merges @@ -64,7 +60,7 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs): merges=self.merges, unsplittable_tokens=[""], ) - self.packer = StartEndPacker( + self.packer = keras_nlp.layers.StartEndPacker( start_value=self.tokenizer.token_to_id("<|startoftext|>"), end_value=self.tokenizer.token_to_id("<|endoftext|>"), pad_value=None, diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index 66b4d7cef6..9219d7cf15 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -16,10 +16,9 @@ import tensorflow_text as tf_text try: - import keras_nlp from keras_nlp.tokenizers import BytePairTokenizer except ImportError: - keras_nlp = None + BytePairTokenizer = None # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. @@ -104,83 +103,82 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -class CLIPTokenizer(BytePairTokenizer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - if keras_nlp is None: - raise ValueError( - "ClipTokenizer requires keras-nlp. Please install " - "using pip `pip install -U keras-nlp && pip install -U keras`" +if BytePairTokenizer: + class CLIPTokenizer(BytePairTokenizer): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, + axis=1, + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are in cache, + # we will process the unseen tokens. Otherwise return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), ) - def _bpe_merge_and_update_cache(self, tokens): - """Process unseen tokens and add to cache.""" - words = self._transform_bytes(tokens) - tokenized_words = self._bpe_merge(words) - - # For each word, join all its token by a whitespace, - # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. - tokenized_words = tf.strings.reduce_join( - tokenized_words, - axis=1, - ) - self.cache.insert(tokens, tokenized_words) - - def tokenize(self, inputs): - self._check_vocabulary() - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): - inputs = tf.convert_to_tensor(inputs) - - if self.add_prefix_space: - inputs = tf.strings.join([" ", inputs]) - - scalar_input = inputs.shape.rank == 0 - if scalar_input: - inputs = tf.expand_dims(inputs, 0) - - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) - token_row_splits = raw_tokens.row_splits - flat_tokens = raw_tokens.flat_values - # Check cache. - cache_lookup = self.cache.lookup(flat_tokens) - cache_mask = cache_lookup == "" - - has_unseen_words = tf.math.reduce_any( - (cache_lookup == "") & (flat_tokens != "") - ) - - def process_unseen_tokens(): - unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) - self._bpe_merge_and_update_cache(unseen_tokens) - return self.cache.lookup(flat_tokens) - - # If `has_unseen_words == True`, it means not all tokens are in cache, - # we will process the unseen tokens. Otherwise return the cache lookup. - tokenized_words = tf.cond( - has_unseen_words, - process_unseen_tokens, - lambda: cache_lookup, - ) - tokens = tf.strings.split(tokenized_words, sep=" ") - if self.compute_dtype != tf.string: - # Encode merged tokens. - tokens = self.token_to_id_map.lookup(tokens) - - # Unflatten to match input. - tokens = tf.RaggedTensor.from_row_splits( - tokens.flat_values, - tf.gather(tokens.row_splits, token_row_splits), - ) - - # Convert to a dense output if `sequence_length` is set. - if self.sequence_length: - output_shape = tokens.shape.as_list() - output_shape[-1] = self.sequence_length - tokens = tokens.to_tensor(shape=output_shape) - - # Convert to a dense output if input in scalar - if scalar_input: - tokens = tf.squeeze(tokens, 0) - tf.ensure_shape(tokens, shape=[self.sequence_length]) - - return tokens + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens + +else: + CLIPTokenizer = None diff --git a/keras_cv/utils/conditional_imports.py b/keras_cv/utils/conditional_imports.py index fc9cc32810..0f0ea2b890 100644 --- a/keras_cv/utils/conditional_imports.py +++ b/keras_cv/utils/conditional_imports.py @@ -33,6 +33,11 @@ except ImportError: pycocotools = None +try: + import keras_nlp +except ImportError: + keras_nlp = None + def assert_cv2_installed(symbol_name): if cv2 is None: @@ -70,3 +75,12 @@ def assert_pycocotools_installed(symbol_name): "Please install the package using " "`pip install pycocotools`." ) + + +def assert_keras_nlp_installed(symbol_name): + if keras_nlp is None: + raise ImportError( + f"{symbol_name} requires the `keras_nlp` package. " + "Please install the package using " + "`pip install keras_nlp`." + ) \ No newline at end of file From 658ded056bd82046302ea3ca0baf58076e555401 Mon Sep 17 00:00:00 2001 From: Smit Lunagariya Date: Tue, 19 Mar 2024 06:58:46 +0000 Subject: [PATCH 2/4] Linting fixes --- keras_cv/models/feature_extractor/clip/clip_tokenizer.py | 5 +++-- keras_cv/utils/conditional_imports.py | 2 +- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index 9219d7cf15..c0aab3cfa0 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -149,8 +149,9 @@ def process_unseen_tokens(): self._bpe_merge_and_update_cache(unseen_tokens) return self.cache.lookup(flat_tokens) - # If `has_unseen_words == True`, it means not all tokens are in cache, - # we will process the unseen tokens. Otherwise return the cache lookup. + # If `has_unseen_words == True`, it means not all tokens are, + # in cache we will process the unseen tokens. Otherwise + # return the cache lookup. tokenized_words = tf.cond( has_unseen_words, process_unseen_tokens, diff --git a/keras_cv/utils/conditional_imports.py b/keras_cv/utils/conditional_imports.py index 0f0ea2b890..2ae1ec88b0 100644 --- a/keras_cv/utils/conditional_imports.py +++ b/keras_cv/utils/conditional_imports.py @@ -83,4 +83,4 @@ def assert_keras_nlp_installed(symbol_name): f"{symbol_name} requires the `keras_nlp` package. " "Please install the package using " "`pip install keras_nlp`." - ) \ No newline at end of file + ) From 95e5ee514a8a659a370ecffd5e7dc255ac59f36d Mon Sep 17 00:00:00 2001 From: Smit Lunagariya Date: Tue, 19 Mar 2024 19:31:48 +0000 Subject: [PATCH 3/4] Fix BytePairTokenizer --- .../feature_extractor/clip/clip_tokenizer.py | 161 +++++++++--------- 1 file changed, 80 insertions(+), 81 deletions(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index c0aab3cfa0..a6774ab13c 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -15,10 +15,12 @@ import tensorflow as tf import tensorflow_text as tf_text +from keras_cv.utils.conditional_imports import assert_keras_nlp_installed + try: from keras_nlp.tokenizers import BytePairTokenizer except ImportError: - BytePairTokenizer = None + BytePairTokenizer = object # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. @@ -103,83 +105,80 @@ def remove_strings_from_inputs(tensor, string_to_remove): return result -if BytePairTokenizer: - class CLIPTokenizer(BytePairTokenizer): - def __init__(self, **kwargs): - super().__init__(**kwargs) - - def _bpe_merge_and_update_cache(self, tokens): - """Process unseen tokens and add to cache.""" - words = self._transform_bytes(tokens) - tokenized_words = self._bpe_merge(words) - - # For each word, join all its token by a whitespace, - # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. - tokenized_words = tf.strings.reduce_join( - tokenized_words, - axis=1, - ) - self.cache.insert(tokens, tokenized_words) - - def tokenize(self, inputs): - self._check_vocabulary() - if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): - inputs = tf.convert_to_tensor(inputs) - - if self.add_prefix_space: - inputs = tf.strings.join([" ", inputs]) - - scalar_input = inputs.shape.rank == 0 - if scalar_input: - inputs = tf.expand_dims(inputs, 0) - - raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) - token_row_splits = raw_tokens.row_splits - flat_tokens = raw_tokens.flat_values - # Check cache. - cache_lookup = self.cache.lookup(flat_tokens) - cache_mask = cache_lookup == "" - - has_unseen_words = tf.math.reduce_any( - (cache_lookup == "") & (flat_tokens != "") - ) - - def process_unseen_tokens(): - unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) - self._bpe_merge_and_update_cache(unseen_tokens) - return self.cache.lookup(flat_tokens) - - # If `has_unseen_words == True`, it means not all tokens are, - # in cache we will process the unseen tokens. Otherwise - # return the cache lookup. - tokenized_words = tf.cond( - has_unseen_words, - process_unseen_tokens, - lambda: cache_lookup, - ) - tokens = tf.strings.split(tokenized_words, sep=" ") - if self.compute_dtype != tf.string: - # Encode merged tokens. - tokens = self.token_to_id_map.lookup(tokens) - - # Unflatten to match input. - tokens = tf.RaggedTensor.from_row_splits( - tokens.flat_values, - tf.gather(tokens.row_splits, token_row_splits), - ) - - # Convert to a dense output if `sequence_length` is set. - if self.sequence_length: - output_shape = tokens.shape.as_list() - output_shape[-1] = self.sequence_length - tokens = tokens.to_tensor(shape=output_shape) - - # Convert to a dense output if input in scalar - if scalar_input: - tokens = tf.squeeze(tokens, 0) - tf.ensure_shape(tokens, shape=[self.sequence_length]) - - return tokens - -else: - CLIPTokenizer = None +class CLIPTokenizer(BytePairTokenizer): + def __init__(self, **kwargs): + assert_keras_nlp_installed("CLIPTokenizer") + super().__init__(**kwargs) + + def _bpe_merge_and_update_cache(self, tokens): + """Process unseen tokens and add to cache.""" + words = self._transform_bytes(tokens) + tokenized_words = self._bpe_merge(words) + + # For each word, join all its token by a whitespace, + # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose. + tokenized_words = tf.strings.reduce_join( + tokenized_words, + axis=1, + ) + self.cache.insert(tokens, tokenized_words) + + def tokenize(self, inputs): + self._check_vocabulary() + if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)): + inputs = tf.convert_to_tensor(inputs) + + if self.add_prefix_space: + inputs = tf.strings.join([" ", inputs]) + + scalar_input = inputs.shape.rank == 0 + if scalar_input: + inputs = tf.expand_dims(inputs, 0) + + raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens) + token_row_splits = raw_tokens.row_splits + flat_tokens = raw_tokens.flat_values + # Check cache. + cache_lookup = self.cache.lookup(flat_tokens) + cache_mask = cache_lookup == "" + + has_unseen_words = tf.math.reduce_any( + (cache_lookup == "") & (flat_tokens != "") + ) + + def process_unseen_tokens(): + unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask) + self._bpe_merge_and_update_cache(unseen_tokens) + return self.cache.lookup(flat_tokens) + + # If `has_unseen_words == True`, it means not all tokens are, + # in cache we will process the unseen tokens. Otherwise + # return the cache lookup. + tokenized_words = tf.cond( + has_unseen_words, + process_unseen_tokens, + lambda: cache_lookup, + ) + tokens = tf.strings.split(tokenized_words, sep=" ") + if self.compute_dtype != tf.string: + # Encode merged tokens. + tokens = self.token_to_id_map.lookup(tokens) + + # Unflatten to match input. + tokens = tf.RaggedTensor.from_row_splits( + tokens.flat_values, + tf.gather(tokens.row_splits, token_row_splits), + ) + + # Convert to a dense output if `sequence_length` is set. + if self.sequence_length: + output_shape = tokens.shape.as_list() + output_shape[-1] = self.sequence_length + tokens = tokens.to_tensor(shape=output_shape) + + # Convert to a dense output if input in scalar + if scalar_input: + tokens = tf.squeeze(tokens, 0) + tf.ensure_shape(tokens, shape=[self.sequence_length]) + + return tokens From f48e9ba550c88a1a77e5175c276fbe798d68d770 Mon Sep 17 00:00:00 2001 From: Smit Lunagariya Date: Thu, 21 Mar 2024 08:51:05 +0000 Subject: [PATCH 4/4] Add conditional import for tensorflow_text --- .../feature_extractor/clip/clip_tokenizer.py | 10 +++++++++- keras_cv/utils/conditional_imports.py | 14 ++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py index a6774ab13c..6758296818 100644 --- a/keras_cv/models/feature_extractor/clip/clip_tokenizer.py +++ b/keras_cv/models/feature_extractor/clip/clip_tokenizer.py @@ -13,15 +13,20 @@ # limitations under the License. import regex as re import tensorflow as tf -import tensorflow_text as tf_text from keras_cv.utils.conditional_imports import assert_keras_nlp_installed +from keras_cv.utils.conditional_imports import assert_tf_text_installed try: from keras_nlp.tokenizers import BytePairTokenizer except ImportError: BytePairTokenizer = object +try: + import tensorflow_text as tf_text +except ImportError: + tf_text = None + # As python and TF handles special spaces differently, we need to # manually handle special spaces during string split. SPECIAL_WHITESPACES = r"\x{a0}\x{2009}\x{202f}\x{3000}" @@ -42,6 +47,9 @@ def split_strings_for_bpe(inputs, unsplittable_tokens=None): # support lookahead match, we are using an alternative insert a special # token "६" before leading space of non-space characters and after the # trailing space, e.g., " keras" will be "६ keras". + + assert_tf_text_installed("split_strings_for_bpe") + inputs = tf.strings.regex_replace( inputs, rf"( )([^\s{SPECIAL_WHITESPACES}])", r"६\1\2" ) diff --git a/keras_cv/utils/conditional_imports.py b/keras_cv/utils/conditional_imports.py index 2ae1ec88b0..d6eaf64299 100644 --- a/keras_cv/utils/conditional_imports.py +++ b/keras_cv/utils/conditional_imports.py @@ -38,6 +38,11 @@ except ImportError: keras_nlp = None +try: + import tensorflow_text +except ImportError: + tensorflow_text = None + def assert_cv2_installed(symbol_name): if cv2 is None: @@ -84,3 +89,12 @@ def assert_keras_nlp_installed(symbol_name): "Please install the package using " "`pip install keras_nlp`." ) + + +def assert_tf_text_installed(symbol_name): + if tensorflow_text is None: + raise ImportError( + f"{symbol_name} requires the `tensorflow_text` package. " + "Please install the package using " + "`pip install tensorflow_text`." + )