Skip to content

Commit cfc062d

Browse files
committed
Use conditional keras_nlp imports
1 parent e21312a commit cfc062d

File tree

4 files changed

+97
-92
lines changed

4 files changed

+97
-92
lines changed

keras_cv/models/feature_extractor/clip/clip_model.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
CLIPTextEncoder,
2727
)
2828
from keras_cv.models.task import Task
29+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
2930
from keras_cv.utils.python_utils import classproperty
3031

3132
try:
@@ -98,11 +99,7 @@ def __init__(
9899
**kwargs,
99100
):
100101
super().__init__(**kwargs)
101-
if keras_nlp is None:
102-
raise ValueError(
103-
"ClipTokenizer requires keras-nlp. Please install "
104-
"using pip `pip install -U keras-nlp && pip install -U keras`"
105-
)
102+
assert_keras_nlp_installed("CLIP")
106103
self.embed_dim = embed_dim
107104
self.image_resolution = image_resolution
108105
self.vision_layers = vision_layers

keras_cv/models/feature_extractor/clip/clip_processor.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616
from keras_cv.backend import keras
1717
from keras_cv.backend import ops
1818
from keras_cv.models.feature_extractor.clip.clip_tokenizer import CLIPTokenizer
19+
from keras_cv.utils.conditional_imports import assert_keras_nlp_installed
1920

2021
try:
2122
import keras_nlp
22-
from keras_nlp.layers import StartEndPacker
2323
except ImportError:
2424
keras_nlp = None
2525

@@ -50,11 +50,7 @@ class CLIPProcessor:
5050
"""
5151

5252
def __init__(self, input_resolution, vocabulary, merges, **kwargs):
53-
if keras_nlp is None:
54-
raise ValueError(
55-
"ClipTokenizer requires keras-nlp. Please install "
56-
"using pip `pip install -U keras-nlp && pip install -U keras`"
57-
)
53+
assert_keras_nlp_installed("CLIPProcessor")
5854
self.input_resolution = input_resolution
5955
self.vocabulary = vocabulary
6056
self.merges = merges
@@ -64,7 +60,7 @@ def __init__(self, input_resolution, vocabulary, merges, **kwargs):
6460
merges=self.merges,
6561
unsplittable_tokens=["</w>"],
6662
)
67-
self.packer = StartEndPacker(
63+
self.packer = keras_nlp.layers.StartEndPacker(
6864
start_value=self.tokenizer.token_to_id("<|startoftext|>"),
6965
end_value=self.tokenizer.token_to_id("<|endoftext|>"),
7066
pad_value=None,

keras_cv/models/feature_extractor/clip/clip_tokenizer.py

Lines changed: 78 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,9 @@
1616
import tensorflow_text as tf_text
1717

1818
try:
19-
import keras_nlp
2019
from keras_nlp.tokenizers import BytePairTokenizer
2120
except ImportError:
22-
keras_nlp = None
21+
BytePairTokenizer = None
2322

2423
# As python and TF handles special spaces differently, we need to
2524
# manually handle special spaces during string split.
@@ -104,83 +103,82 @@ def remove_strings_from_inputs(tensor, string_to_remove):
104103
return result
105104

106105

107-
class CLIPTokenizer(BytePairTokenizer):
108-
def __init__(self, **kwargs):
109-
super().__init__(**kwargs)
110-
if keras_nlp is None:
111-
raise ValueError(
112-
"ClipTokenizer requires keras-nlp. Please install "
113-
"using pip `pip install -U keras-nlp && pip install -U keras`"
106+
if BytePairTokenizer:
107+
class CLIPTokenizer(BytePairTokenizer):
108+
def __init__(self, **kwargs):
109+
super().__init__(**kwargs)
110+
111+
def _bpe_merge_and_update_cache(self, tokens):
112+
"""Process unseen tokens and add to cache."""
113+
words = self._transform_bytes(tokens)
114+
tokenized_words = self._bpe_merge(words)
115+
116+
# For each word, join all its token by a whitespace,
117+
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
118+
tokenized_words = tf.strings.reduce_join(
119+
tokenized_words,
120+
axis=1,
121+
)
122+
self.cache.insert(tokens, tokenized_words)
123+
124+
def tokenize(self, inputs):
125+
self._check_vocabulary()
126+
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
127+
inputs = tf.convert_to_tensor(inputs)
128+
129+
if self.add_prefix_space:
130+
inputs = tf.strings.join([" ", inputs])
131+
132+
scalar_input = inputs.shape.rank == 0
133+
if scalar_input:
134+
inputs = tf.expand_dims(inputs, 0)
135+
136+
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
137+
token_row_splits = raw_tokens.row_splits
138+
flat_tokens = raw_tokens.flat_values
139+
# Check cache.
140+
cache_lookup = self.cache.lookup(flat_tokens)
141+
cache_mask = cache_lookup == ""
142+
143+
has_unseen_words = tf.math.reduce_any(
144+
(cache_lookup == "") & (flat_tokens != "")
145+
)
146+
147+
def process_unseen_tokens():
148+
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
149+
self._bpe_merge_and_update_cache(unseen_tokens)
150+
return self.cache.lookup(flat_tokens)
151+
152+
# If `has_unseen_words == True`, it means not all tokens are in cache,
153+
# we will process the unseen tokens. Otherwise return the cache lookup.
154+
tokenized_words = tf.cond(
155+
has_unseen_words,
156+
process_unseen_tokens,
157+
lambda: cache_lookup,
158+
)
159+
tokens = tf.strings.split(tokenized_words, sep=" ")
160+
if self.compute_dtype != tf.string:
161+
# Encode merged tokens.
162+
tokens = self.token_to_id_map.lookup(tokens)
163+
164+
# Unflatten to match input.
165+
tokens = tf.RaggedTensor.from_row_splits(
166+
tokens.flat_values,
167+
tf.gather(tokens.row_splits, token_row_splits),
114168
)
115169

116-
def _bpe_merge_and_update_cache(self, tokens):
117-
"""Process unseen tokens and add to cache."""
118-
words = self._transform_bytes(tokens)
119-
tokenized_words = self._bpe_merge(words)
120-
121-
# For each word, join all its token by a whitespace,
122-
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
123-
tokenized_words = tf.strings.reduce_join(
124-
tokenized_words,
125-
axis=1,
126-
)
127-
self.cache.insert(tokens, tokenized_words)
128-
129-
def tokenize(self, inputs):
130-
self._check_vocabulary()
131-
if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
132-
inputs = tf.convert_to_tensor(inputs)
133-
134-
if self.add_prefix_space:
135-
inputs = tf.strings.join([" ", inputs])
136-
137-
scalar_input = inputs.shape.rank == 0
138-
if scalar_input:
139-
inputs = tf.expand_dims(inputs, 0)
140-
141-
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
142-
token_row_splits = raw_tokens.row_splits
143-
flat_tokens = raw_tokens.flat_values
144-
# Check cache.
145-
cache_lookup = self.cache.lookup(flat_tokens)
146-
cache_mask = cache_lookup == ""
147-
148-
has_unseen_words = tf.math.reduce_any(
149-
(cache_lookup == "") & (flat_tokens != "")
150-
)
151-
152-
def process_unseen_tokens():
153-
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
154-
self._bpe_merge_and_update_cache(unseen_tokens)
155-
return self.cache.lookup(flat_tokens)
156-
157-
# If `has_unseen_words == True`, it means not all tokens are in cache,
158-
# we will process the unseen tokens. Otherwise return the cache lookup.
159-
tokenized_words = tf.cond(
160-
has_unseen_words,
161-
process_unseen_tokens,
162-
lambda: cache_lookup,
163-
)
164-
tokens = tf.strings.split(tokenized_words, sep=" ")
165-
if self.compute_dtype != tf.string:
166-
# Encode merged tokens.
167-
tokens = self.token_to_id_map.lookup(tokens)
168-
169-
# Unflatten to match input.
170-
tokens = tf.RaggedTensor.from_row_splits(
171-
tokens.flat_values,
172-
tf.gather(tokens.row_splits, token_row_splits),
173-
)
174-
175-
# Convert to a dense output if `sequence_length` is set.
176-
if self.sequence_length:
177-
output_shape = tokens.shape.as_list()
178-
output_shape[-1] = self.sequence_length
179-
tokens = tokens.to_tensor(shape=output_shape)
180-
181-
# Convert to a dense output if input in scalar
182-
if scalar_input:
183-
tokens = tf.squeeze(tokens, 0)
184-
tf.ensure_shape(tokens, shape=[self.sequence_length])
185-
186-
return tokens
170+
# Convert to a dense output if `sequence_length` is set.
171+
if self.sequence_length:
172+
output_shape = tokens.shape.as_list()
173+
output_shape[-1] = self.sequence_length
174+
tokens = tokens.to_tensor(shape=output_shape)
175+
176+
# Convert to a dense output if input in scalar
177+
if scalar_input:
178+
tokens = tf.squeeze(tokens, 0)
179+
tf.ensure_shape(tokens, shape=[self.sequence_length])
180+
181+
return tokens
182+
183+
else:
184+
CLIPTokenizer = None

keras_cv/utils/conditional_imports.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,11 @@
3333
except ImportError:
3434
pycocotools = None
3535

36+
try:
37+
import keras_nlp
38+
except ImportError:
39+
keras_nlp = None
40+
3641

3742
def assert_cv2_installed(symbol_name):
3843
if cv2 is None:
@@ -70,3 +75,12 @@ def assert_pycocotools_installed(symbol_name):
7075
"Please install the package using "
7176
"`pip install pycocotools`."
7277
)
78+
79+
80+
def assert_keras_nlp_installed(symbol_name):
81+
if keras_nlp is None:
82+
raise ImportError(
83+
f"{symbol_name} requires the `keras_nlp` package. "
84+
"Please install the package using "
85+
"`pip install keras_nlp`."
86+
)

0 commit comments

Comments
 (0)