Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PARSeq Model #2089

Draft
wants to merge 46 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
528d3a4
Base for parseq model
sineeli Jan 31, 2025
3bf11cd
make it vit compatiable with diff height and width sizes
sineeli Jan 31, 2025
a8fb177
correct vit conv scripts
sineeli Jan 31, 2025
6f4363a
make class token optional in backbone by default its included
sineeli Jan 31, 2025
d1cece0
add flags to adjust vit network
sineeli Jan 31, 2025
92b2745
add test case for without class_token
sineeli Jan 31, 2025
ed00b73
Merge branch 'master' into parseq
sineeli Feb 3, 2025
25f661c
decoder file
sineeli Feb 6, 2025
f97fab1
parseq tokenizer base
sineeli Feb 10, 2025
d424210
add api for parseq tokenizer
sineeli Feb 10, 2025
3f3ad0d
Add missing arg max_label_length.
sineeli Feb 10, 2025
bb4457e
nit
sineeli Feb 10, 2025
68829f8
Merge branch 'master' into parseq
sineeli Feb 10, 2025
1bde466
add missing normalization step using tf_text
sineeli Feb 11, 2025
e6c5379
add missing config for preprocessor
sineeli Feb 12, 2025
5b08c93
add default start, pad and end tokens
sineeli Feb 12, 2025
49260ef
nit
sineeli Feb 12, 2025
b4150ed
correct special token order
sineeli Feb 12, 2025
ed8b9d7
return padding mask as well
sineeli Feb 18, 2025
4e4511c
use proper keras ops
sineeli Feb 18, 2025
9222331
nit
sineeli Feb 18, 2025
78a07a0
add decoder for parseq
sineeli Mar 3, 2025
decc12c
Build unbuilt layers for model validation
sineeli Mar 14, 2025
7aa2b67
fix forward pass and decoder
sineeli Mar 14, 2025
82be527
add missing mlp forward pass
sineeli Mar 25, 2025
c0bf528
add generate prprocess and generate step
sineeli Mar 29, 2025
3a862bb
Merge remote-tracking branch 'origin/master' into parseq
sineeli Mar 29, 2025
b6991be
nit
sineeli Mar 29, 2025
40df2ea
add generate_step to parseq causal lm
sineeli Mar 30, 2025
9ce7c62
minor fixes for jax backend and config fix
sineeli Apr 1, 2025
b1cb2ca
update decoder layer with caching mechanism which is used for generat…
sineeli Apr 7, 2025
3cd87cd
modify generate step including cache
sineeli Apr 7, 2025
57a5054
re structure code to make jax backend compatiable
sineeli Apr 8, 2025
3adad55
add postprocess step into preprocessor
sineeli Apr 8, 2025
b7be4dd
test only forward pass
sineeli Apr 8, 2025
103ee5c
nit
sineeli Apr 8, 2025
c9487ae
test build cache
sineeli Apr 8, 2025
d0b3906
test generate step only build cache
sineeli Apr 8, 2025
9dfecc1
correct class name
sineeli Apr 8, 2025
a7619c6
correct dropout
sineeli Apr 8, 2025
4cb3c65
remove slicing in forward pass
sineeli Apr 8, 2025
dd4f8aa
nit
sineeli Apr 8, 2025
c473f6d
use python style slicing
sineeli Apr 8, 2025
456ba1d
support jax for generate step
sineeli Apr 8, 2025
78f319a
Merge branch 'master' into parseq
sineeli Apr 10, 2025
ac30b4b
Merge branch 'master' into parseq
sineeli Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras_hub/api/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@
from keras_hub.src.models.pali_gemma.pali_gemma_image_converter import (
PaliGemmaImageConverter,
)
from keras_hub.src.models.parseq.parseq_image_converter import (
PARSeqImageConverter,
)
from keras_hub.src.models.resnet.resnet_image_converter import (
ResNetImageConverter,
)
Expand Down
8 changes: 8 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,10 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.parseq.parseq_causal_lm import ParSeqCausalLM
from keras_hub.src.models.parseq.parseq_preprocessor import PARSeqPreprocessor
from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
from keras_hub.src.models.phi3.phi3_backbone import Phi3Backbone
from keras_hub.src.models.phi3.phi3_causal_lm import Phi3CausalLM
from keras_hub.src.models.phi3.phi3_causal_lm_preprocessor import (
Expand Down Expand Up @@ -389,6 +393,10 @@
from keras_hub.src.models.text_classifier_preprocessor import (
TextClassifierPreprocessor,
)
from keras_hub.src.models.text_recognition import TextRecognition
from keras_hub.src.models.text_recognition_preprocessor import (
TextRecognitionPreprocessor,
)
from keras_hub.src.models.text_to_image import TextToImage
from keras_hub.src.models.text_to_image_preprocessor import (
TextToImagePreprocessor,
Expand Down
1 change: 1 addition & 0 deletions keras_hub/api/tokenizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from keras_hub.src.models.pali_gemma.pali_gemma_tokenizer import (
PaliGemmaTokenizer,
)
from keras_hub.src.models.parseq.parseq_tokenizer import PARSeqTokenizer
from keras_hub.src.models.phi3.phi3_tokenizer import Phi3Tokenizer
from keras_hub.src.models.qwen.qwen_tokenizer import QwenTokenizer
from keras_hub.src.models.qwen.qwen_tokenizer import (
Expand Down
Empty file.
102 changes: 102 additions & 0 deletions keras_hub/src/models/parseq/parseq_backbone.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import keras

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.parseq.parseq_decoder import PARSeqDecoder


@keras_hub_export("keras_hub.models.PARSeqBackbone")
class PARSeqBackbone(Backbone):
"""Scene Text Detection with PARSeq.

Performs OCR in natural scenes using the PARSeq model described in [Scene
Text Recognition with Permuted Autoregressive Sequence Models](
https://arxiv.org/abs/2207.06966). PARSeq is a ViT-based model that allows
iterative decoding by performing an autoregressive decoding phase, followed
by a refinement phase.
"""

def __init__(
self,
image_encoder,
vocabulary_size,
max_label_length,
decoder_hidden_dim,
num_decoder_layers,
num_decoder_heads,
decoder_mlp_dim,
dropout_rate=0.1,
attention_dropout=0.1,
dtype=None,
**kwargs,
):
# === Layers ===
self.image_encoder = image_encoder
self.decoder = PARSeqDecoder(
vocabulary_size=vocabulary_size,
max_label_length=max_label_length,
num_layers=num_decoder_layers,
num_heads=num_decoder_heads,
hidden_dim=decoder_hidden_dim,
mlp_dim=decoder_mlp_dim,
dropout_rate=dropout_rate,
attention_dropout=attention_dropout,
name="decoder",
)
self.head = keras.layers.Dense(
vocabulary_size - 2, # We don't predict <bos> nor <pad>
)

# === Functional Model ===
image_input = self.image_encoder.input

token_id_input = keras.Input(
shape=(None,), dtype="int32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="int32", name="padding_mask"
)

memory = self.image_encoder(image_input)
target_out = self.decoder(
token_id_input, memory, padding_mask=padding_mask_input
)
logits = self.head(target_out)

# === Config ===
self.vocabulary_size = vocabulary_size
self.max_label_length = max_label_length
self.decoder_hidden_dim = decoder_hidden_dim
self.num_decoder_layers = num_decoder_layers
self.num_decoder_heads = num_decoder_heads
self.deocder_head_dim = decoder_hidden_dim // num_decoder_heads
self.decoder_mlp_dim = decoder_mlp_dim
self.dropout_rate = dropout_rate
self.attention_dropout = attention_dropout

super().__init__(
inputs={
"images": image_input,
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=logits,
dtype=dtype,
**kwargs,
)

def get_config(self):
config = super().get_config()
config.update(
{
"encoder": keras.layers.serialize(self.image_encoder),
"vocabulary_size": self.vocabulary_size,
"max_label_length": self.max_label_length,
"decoder_hidden_dim": self.decoder_hidden_dim,
"num_decoder_layers": self.num_decoder_layers,
"num_decoder_heads": self.num_decoder_heads,
"deocder_head_dim": self.deocder_head_dim,
"dropout_rate": self.dropout_rate,
"attention_dropout": self.attention_dropout,
}
)
216 changes: 216 additions & 0 deletions keras_hub/src/models/parseq/parseq_causal_lm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from keras import ops

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.causal_lm import CausalLM
from keras_hub.src.models.parseq.parseq_backbone import PARSeqBackbone
from keras_hub.src.models.parseq.parseq_preprocessor import PARSeqPreprocessor
from keras_hub.src.utils.tensor_utils import any_equal


@keras_hub_export("keras_hub.models.ParSeqCausalLM")
class ParSeqCausalLM(CausalLM):
backbone_cls = PARSeqBackbone
preprocessor_cls = PARSeqPreprocessor

def __init__(
self,
preprocessor,
backbone,
**kwargs,
):
# === Layers ===
self.preprocessor = preprocessor
self.backbone = backbone

# === Functional Model ===
# This must be "backbone.input" i.e. the full input structure,
# rather than "backbone.inputs" which is the flattened list of inputs.
inputs = backbone.input
outputs = backbone(inputs=inputs)
super().__init__(
inputs=inputs,
outputs=outputs,
**kwargs,
)

def compile(
self,
optimizer="auto",
loss="auto",
*,
weighted_metrics="auto",
sampler="greedy",
**kwargs,
):
super().compile(
optimizer=optimizer,
loss=loss,
weighted_metrics=weighted_metrics,
sampler=sampler,
**kwargs,
)

def call_with_cache(
self,
token_ids,
cache,
cache_update_index,
img_embeddings,
padding_mask=None,
):
bs = ops.shape(token_ids)[0]
# <bos> stands for the null context. We only supply position information
# for characters after <bos>.
content = ops.where(
cache_update_index == 0,
self.backbone.decoder_hidden_dim**0.5
* self.backbone.decoder.token_embedding(token_ids),
ops.expand_dims(
self.backbone.decoder.pos_query_embeddings[
:, cache_update_index - 1, :
],
axis=0,
)
+ self.backbone.decoder_hidden_dim**0.5
* self.backbone.decoder.token_embedding(token_ids),
)
content = self.backbone.decoder.dropout(content)

query = ops.ones((bs, 1, 1)) * ops.expand_dims(
self.backbone.decoder.pos_query_embeddings[
:, cache_update_index, :
],
axis=0,
)
query = self.backbone.decoder.dropout(query)

query_cache = []
content_cache = []
for i, decoder_layer in enumerate(self.backbone.decoder.decoder_layers):
last = i == self.backbone.num_decoder_layers - 1
current_query_cache = cache[:, i, 0, ...]
current_content_cache = cache[:, i, 1, ...]
(
query,
content,
query_self_attention_new_cache,
content_self_attention_cache,
) = decoder_layer(
query=query,
content=content,
memory=img_embeddings,
padding_mask=padding_mask,
update_content=not last,
query_self_attention_cache=current_query_cache,
query_self_attention_cache_update_index=cache_update_index,
content_self_attention_cache=current_content_cache,
content_self_attention_cache_update_index=cache_update_index,
)
query_cache.append(query_self_attention_new_cache)
content_cache.append(content_self_attention_cache)

query_cache = ops.stack(query_cache, axis=1)
content_cache = ops.stack(content_cache, axis=1)
cache = ops.stack([query_cache, content_cache], axis=2)
hidden_states = self.backbone.decoder.layer_norm(query)
logits = self.backbone.head(hidden_states)
return logits, hidden_states, cache

def _build_cache(self, token_ids, img_embeddings, padding_mask):
batch_size = ops.shape(token_ids)[0]
max_length = ops.shape(token_ids)[1]
num_layers = self.backbone.num_decoder_layers
head_dim = self.backbone.deocder_head_dim
num_heads = self.backbone.num_decoder_heads
shape = [batch_size, num_layers, 2, 2, max_length, num_heads, head_dim]
cache = ops.zeros(shape)

# Seed the cache.
logits, hidden_states, cache = self.call_with_cache(
token_ids=token_ids,
img_embeddings=img_embeddings,
cache=cache,
cache_update_index=0,
padding_mask=padding_mask,
)
return hidden_states, cache

def generate_step(self, inputs, stop_token_ids=None):
token_ids, padding_mask, images = (
inputs["token_ids"],
inputs["padding_mask"],
inputs["images"],
)
images_shape = ops.shape(images)
if len(images_shape) == 3:
# Handle an unbatched image. Unlike `token_ids` and `padding_mask`
# this will not automatically be upranked.
images = ops.expand_dims(images, axis=0)

img_embeddings = self.backbone.image_encoder(images)
# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(
token_ids=token_ids,
img_embeddings=img_embeddings,
padding_mask=padding_mask,
)

# Create and seed cache with a single forward pass.
hidden_states, cache = self._build_cache(
token_ids, img_embeddings, padding_mask
)
# Compute the lengths of all user inputted tokens ids.
row_lengths = ops.sum(ops.cast(padding_mask, "int32"), axis=-1)
# Start at the first index that has no user inputted id.
index = ops.min(row_lengths)

def next(prompt, cache, index):
# The cache index is the index of our previous token.
cache_update_index = index - 1
batch_size = ops.shape(prompt)[0]
prompt = ops.slice(prompt, [0, index - 1], [batch_size, 1])
logits, hidden_states, cache = self.call_with_cache(
token_ids=prompt,
cache=cache,
cache_update_index=cache_update_index,
img_embeddings=img_embeddings,
)
return (
ops.squeeze(logits, axis=1),
ops.squeeze(hidden_states, axis=1),
cache,
)

token_ids = self.sampler(
next=next,
prompt=token_ids,
cache=cache,
index=index,
mask=padding_mask,
stop_token_ids=stop_token_ids,
hidden_states=hidden_states,
model=self,
)

# Compute an output padding mask with the token ids we updated.
if stop_token_ids is not None:
# Build a mask of `stop_token_ids` locations not in the original
# prompt (not in locations where `padding_mask` is True).
end_locations = any_equal(
token_ids, stop_token_ids, ops.logical_not(padding_mask)
)

end_locations = ops.cast(end_locations, "int32")
# Use cumsum to get ones in all locations after end_locations.
cumsum = ops.cast(ops.cumsum(end_locations, axis=-1), "int32")
overflow = cumsum - end_locations
# Our padding mask is the inverse of these overflow locations.
padding_mask = ops.logical_not(ops.cast(overflow, "bool"))
else:
# Without early stopping, all locations will have been updated.
padding_mask = ops.ones_like(token_ids, dtype="bool")
return {
"token_ids": token_ids,
"padding_mask": padding_mask,
"images": images,
}
Loading
Loading