diff --git a/keras_hub/src/utils/transformers/export/gemma3.py b/keras_hub/src/utils/transformers/export/gemma3.py new file mode 100644 index 0000000000..8120de7f19 --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gemma3.py @@ -0,0 +1,181 @@ +import keras.ops as ops + + +def get_gemma3_config(backbone): + """Convert Keras Gemma3 config to Hugging Face config dictionary.""" + token_embedding_layer = backbone.get_layer("token_embedding") + hf_config = { + "architectures": ["Gemma3ForCausalLM"], + "model_type": "gemma3_text", + "vocab_size": backbone.vocabulary_size, + "num_hidden_layers": backbone.num_layers, + "num_attention_heads": backbone.num_query_heads, + "num_key_value_heads": backbone.num_key_value_heads, + "hidden_size": backbone.hidden_dim, + "intermediate_size": backbone.intermediate_dim, + "head_dim": backbone.head_dim, + "max_position_embeddings": 32768, + "tie_word_embeddings": token_embedding_layer.tie_weights, + "rms_norm_eps": 1e-6, + "rope_theta": 10000.0, + "attention_bias": False, + "attention_dropout": 0.0, + "hidden_activation": "gelu_pytorch_tanh", + } + return hf_config + + +def get_gemma3_weights_map(backbone, include_lm_head=False): + """Convert a Keras Gemma3 model to Hugging Face format. + + include_lm_head: If True, exports for CausalLM (with "model." prefix). + If False, exports for backbone only (without prefix). + """ + + def _convert_qkv_kernel(kernel, hidden_dim): + """Helper to convert Q/K/V projection kernels to HF format. + + Args: + kernel: The kernel weight tensor to convert. + hidden_dim: The hidden dimension size for reshaping. + + Returns: + Converted kernel in HF format. + """ + kernel = ops.transpose(kernel, axes=(1, 0, 2)) # permute(1, 0, 2) + kernel = ops.reshape(kernel, (hidden_dim, -1)) + kernel = ops.transpose(kernel) # .T + return kernel + + weights_dict = {} + + # For CausalLM export, use "model." prefix + # For backbone export, use no prefix + prefix = "model." if include_lm_head else "" + + # Token embeddings - use .weights[0] to get backend tensor + token_embedding_layer = backbone.get_layer("token_embedding") + token_embedding = token_embedding_layer.weights[0] + weights_dict[f"{prefix}embed_tokens.weight"] = token_embedding + + for i in range(backbone.num_layers): + block = backbone.get_layer(f"decoder_block_{i}") + + # Attention query projection + q_kernel = _convert_qkv_kernel( + block.attention.query_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.q_proj.weight"] = q_kernel + + # Attention key projection + k_kernel = _convert_qkv_kernel( + block.attention.key_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.k_proj.weight"] = k_kernel + + # Attention value projection + v_kernel = _convert_qkv_kernel( + block.attention.value_dense.weights[0], backbone.hidden_dim + ) + weights_dict[f"{prefix}layers.{i}.self_attn.v_proj.weight"] = v_kernel + + # Attention output projection + o_kernel = block.attention.output_dense.weights[0] + o_kernel = ops.transpose(o_kernel, axes=(2, 0, 1)) # permute(2, 0, 1) + o_kernel = ops.reshape(o_kernel, (backbone.hidden_dim, -1)) + weights_dict[f"{prefix}layers.{i}.self_attn.o_proj.weight"] = o_kernel + + # Query and key normalization + q_norm = block.attention.query_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.self_attn.q_norm.weight"] = q_norm + + k_norm = block.attention.key_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.self_attn.k_norm.weight"] = k_norm + + # MLP gate projection + gate_kernel = block.gating_ffw.weights[0] + gate_kernel = ops.transpose(gate_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.gate_proj.weight"] = gate_kernel + + # MLP up projection + up_kernel = block.gating_ffw_2.weights[0] + up_kernel = ops.transpose(up_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.up_proj.weight"] = up_kernel + + # MLP down projection + down_kernel = block.ffw_linear.weights[0] + down_kernel = ops.transpose(down_kernel) # .T + weights_dict[f"{prefix}layers.{i}.mlp.down_proj.weight"] = down_kernel + + # Pre-attention normalization + input_layer_norm = block.pre_attention_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.input_layernorm.weight"] = ( + input_layer_norm + ) + + # Post-attention normalization + if hasattr(block, "post_attention_norm"): + post_attn_norm = block.post_attention_norm.weights[0] + weights_dict[ + f"{prefix}layers.{i}.post_attention_layernorm.weight" + ] = post_attn_norm + # Pre-feedforward normalization + pre_feedforward_layernorm = block.pre_ffw_norm.weights[0] + weights_dict[f"{prefix}layers.{i}.pre_feedforward_layernorm.weight"] = ( + pre_feedforward_layernorm + ) + # Post-feedforward normalization + if hasattr(block, "post_ffw_norm"): + post_feedforward_layernorm = block.post_ffw_norm.weights[0] + weights_dict[ + f"{prefix}layers.{i}.post_feedforward_layernorm.weight" + ] = post_feedforward_layernorm + + # Final normalization + final_norm = backbone.get_layer("final_normalization").weights[0] + weights_dict[f"{prefix}norm.weight"] = final_norm + + if include_lm_head and not token_embedding_layer.tie_weights: + weights_dict["lm_head.weight"] = ops.transpose( + token_embedding_layer.reverse_embeddings + ) + + return weights_dict + + +def get_gemma3_tokenizer_config(tokenizer): + tokenizer_config = { + "tokenizer_class": "GemmaTokenizer", + "clean_up_tokenization_spaces": False, + "bos_token": "", + "eos_token": "", + "pad_token": "", + "unk_token": "", + "add_bos_token": True, + "add_eos_token": False, + "model_max_length": 32768, + } + # Add added_tokens_decoder + added_tokens_decoder = {} + special_tokens = [ + "", + "", + "", + "", + "", + "", + "", + ] + for token in special_tokens: + token_id = tokenizer.token_to_id(token) + if token_id is not None: + added_tokens_decoder[str(token_id)] = { + "content": token, + "special": True, + "single_word": False, + "lstrip": False, + "rstrip": False, + "normalized": False, + } + tokenizer_config["added_tokens_decoder"] = added_tokens_decoder + return tokenizer_config diff --git a/keras_hub/src/utils/transformers/export/gemma3_test.py b/keras_hub/src/utils/transformers/export/gemma3_test.py new file mode 100644 index 0000000000..926d54edfc --- /dev/null +++ b/keras_hub/src/utils/transformers/export/gemma3_test.py @@ -0,0 +1,161 @@ +import os + +import numpy as np +from transformers import AutoModel +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer + +from keras_hub.src.models.gemma3.gemma3_backbone import Gemma3Backbone +from keras_hub.src.models.gemma3.gemma3_causal_lm import Gemma3CausalLM +from keras_hub.src.models.gemma3.gemma3_causal_lm_preprocessor import ( + Gemma3CausalLMPreprocessor, +) +from keras_hub.src.models.gemma3.gemma3_tokenizer import Gemma3Tokenizer +from keras_hub.src.tests.test_case import TestCase + + +class TestGemma3Export(TestCase): + def test_export_to_hf(self): + proto = os.path.join(self.get_test_data_dir(), "gemma3_test_vocab.spm") + tokenizer = Gemma3Tokenizer(proto=proto) + + # Create a small backbone (text-only, no vision encoder) + backbone = Gemma3Backbone( + vocabulary_size=tokenizer.vocabulary_size(), + image_size=896, # Default value even for text-only + num_layers=2, + num_query_heads=4, + num_key_value_heads=1, + hidden_dim=512, + intermediate_dim=1028, + head_dim=128, + query_head_dim_normalize=True, + use_query_key_norm=True, + use_post_ffw_norm=True, # Real Gemma3 models have these + use_post_attention_norm=True, # Real Gemma3 models have these + attention_logit_soft_cap=None, + final_logit_soft_cap=None, + use_sliding_window_attention=False, + sliding_window_size=4096, + vision_encoder=None, # Text-only model for testing + layer_norm_epsilon=1e-6, + dropout=0, + ) + + # Create preprocessor + preprocessor = Gemma3CausalLMPreprocessor(tokenizer=tokenizer) + + # Create the causal LM model + keras_model = Gemma3CausalLM( + backbone=backbone, preprocessor=preprocessor + ) + + # Set all weights to random values + rng = np.random.default_rng(42) + weights = keras_model.get_weights() + for i in range(len(weights)): + weights[i] = rng.random(weights[i].shape).astype(weights[i].dtype) + keras_model.set_weights(weights) + + # Export to Hugging Face format using the new methods + export_path_backbone = os.path.join( + self.get_temp_dir(), "export_backbone" + ) + backbone.export_to_transformers(export_path_backbone) + + export_path_tokenizer = os.path.join( + self.get_temp_dir(), "export_tokenizer" + ) + preprocessor.tokenizer.export_to_transformers(export_path_tokenizer) + + export_path_task = os.path.join(self.get_temp_dir(), "export_task") + keras_model.export_to_transformers(export_path_task) + + # Load Hugging Face models and tokenizer + # Note: We only test the slow tokenizer because the test vocab file + # may not be compatible with fast tokenizer conversion + hf_backbone = AutoModel.from_pretrained(export_path_backbone) + hf_tokenizer_slow = AutoTokenizer.from_pretrained( + export_path_tokenizer, use_fast=False + ) + hf_full_model = AutoModelForCausalLM.from_pretrained(export_path_task) + + # Verify configuration + hf_config = hf_backbone.config + self.assertEqual( + hf_config.vocab_size, + backbone.vocabulary_size, + "Vocabulary sizes do not match", + ) + self.assertEqual( + hf_config.num_hidden_layers, + backbone.num_layers, + "Number of layers do not match", + ) + self.assertEqual( + hf_config.num_attention_heads, + backbone.num_query_heads, + "Number of query heads do not match", + ) + self.assertEqual( + hf_config.num_key_value_heads, + backbone.num_key_value_heads, + "Number of key value heads do not match", + ) + self.assertEqual( + hf_config.hidden_size, + backbone.hidden_dim, + "Hidden dimensions do not match", + ) + self.assertEqual( + hf_config.intermediate_size, + backbone.intermediate_dim, + "Intermediate sizes do not match", + ) + self.assertEqual( + hf_config.head_dim, + backbone.head_dim, + "Head dimensions do not match", + ) + self.assertEqual( + hf_config.max_position_embeddings, + 32768, + "Max position embeddings do not match", + ) + self.assertEqual( + hf_config.tie_word_embeddings, + backbone.token_embedding.tie_weights, + "Tie word embeddings do not match", + ) + + # Verify tokenizer compatibility (using slow tokenizer) + self.assertEqual( + hf_tokenizer_slow.vocab_size, + tokenizer.vocabulary_size(), + "Tokenizer vocabulary sizes do not match", + ) + + # Compare generated outputs using full model + prompt = "the quick" + + # Generate with Keras model + keras_output = keras_model.generate(prompt, max_length=20) + + # Generate with HuggingFace model using slow tokenizer + input_ids_slow = hf_tokenizer_slow.encode(prompt, return_tensors="pt") + output_ids_slow = hf_full_model.generate( + input_ids_slow, max_length=20, do_sample=False + ) + hf_slow_output = hf_tokenizer_slow.decode( + output_ids_slow[0], skip_special_tokens=True + ) + + # Debug print to see the actual outputs + print(f"Keras output: '{keras_output}'") + print(f"HF slow output: '{hf_slow_output}'") + + self.assertEqual( + keras_output, + hf_slow_output, + "Generated outputs do not match", + ) diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index 1593987ca9..b3a55fb27b 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -10,19 +10,29 @@ get_gemma_tokenizer_config, ) from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map +from keras_hub.src.utils.transformers.export.gemma3 import get_gemma3_config +from keras_hub.src.utils.transformers.export.gemma3 import ( + get_gemma3_tokenizer_config, +) +from keras_hub.src.utils.transformers.export.gemma3 import ( + get_gemma3_weights_map, +) MODEL_CONFIGS = { "GemmaBackbone": get_gemma_config, + "Gemma3Backbone": get_gemma3_config, # Add for future models, e.g., "MistralBackbone": get_mistral_config } MODEL_EXPORTERS = { "GemmaBackbone": get_gemma_weights_map, + "Gemma3Backbone": get_gemma3_weights_map, # Add for future models, e.g., "MistralBackbone": get_mistral_weights_map } MODEL_TOKENIZER_CONFIGS = { "GemmaTokenizer": get_gemma_tokenizer_config, + "Gemma3Tokenizer": get_gemma3_tokenizer_config, # Add for future models, e.g., "MistralTokenizer": # get_mistral_tokenizer_config } diff --git a/tools/checkpoint_conversion/convert_gemma3_checkpoints.py b/tools/checkpoint_conversion/convert_gemma3_checkpoints.py index 2105ae4ff1..c7f4c24603 100644 --- a/tools/checkpoint_conversion/convert_gemma3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_gemma3_checkpoints.py @@ -10,8 +10,10 @@ Usage: ```shell cd tools/checkpoint_conversion -python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b -python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b +python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b \ + --export_safetensors +python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b \ + --export_safetensors ``` """ @@ -21,16 +23,202 @@ # No GPU for conversion, makes memory management easier. os.environ["CUDA_VISIBLE_DEVICES"] = "-1" +import json +import shutil + import keras # noqa: E402 +import numpy as np import tensorflow_datasets as tfds # noqa: E402 +import torch +import transformers from absl import app # noqa: E402 from absl import flags # noqa: E402 from checkpoint_conversion_utils import download_gcs_file from gemma import gm # noqa: E402 from keras import ops # noqa: E402 +from safetensors.torch import save_file +from transformers import AutoModelForCausalLM +from transformers import AutoTokenizer import keras_hub # noqa: E402 + +def convert_to_hf_config(keras_config): + """Convert Keras Gemma config to Hugging Face GemmaConfig. + + Args: + keras_config: A Keras Gemma3 config object from the backbone. + + Returns: + A `transformers.Gemma3TextConfig` instance. + """ + hf_config = transformers.Gemma3TextConfig( + vocab_size=keras_config.vocabulary_size, + num_hidden_layers=keras_config.num_layers, + num_attention_heads=keras_config.num_query_heads, + num_key_value_heads=keras_config.num_key_value_heads, + hidden_size=keras_config.hidden_dim, + intermediate_size=keras_config.intermediate_dim, + head_dim=keras_config.head_dim, + max_position_embeddings=32768, + ) + return hf_config + + +def export_to_hf(backbone, keras_tokenizer, path): + """Convert a Keras Gemma model to Hugging Face format and save to path. + + Args: + backbone: A `keras_hub.models.Gemma3Backbone` instance. + keras_tokenizer: A `keras_hub.models.Gemma3Tokenizer` instance. + path: str. The path to save the Hugging Face model to. + """ + hf_config = convert_to_hf_config(backbone) + weights_dict = {} + + # Helper function to convert bfloat16 weights to torch tensors + def to_torch(weight): + # Convert array-like weights (e.g., from JAX) to a float32 NumPy + # array before creating a bfloat16 torch tensor for compatibility. + np_weight = np.array(weight, dtype=np.float32) + return torch.from_numpy(np_weight).to(torch.bfloat16) + + # Token embeddings + token_embedding = backbone.get_layer("token_embedding").get_weights()[0] + weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding) + + for i in range(backbone.num_layers): + block = backbone.get_layer(f"decoder_block_{i}") + q_kernel = block.attention.query_dense.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = ( + to_torch(q_kernel) + .permute(1, 0, 2) + .reshape(backbone.hidden_dim, -1) + .T + ) + + k_kernel = block.attention.key_dense.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = ( + to_torch(k_kernel) + .permute(1, 0, 2) + .reshape(backbone.hidden_dim, -1) + .T + ) + + v_kernel = block.attention.value_dense.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = ( + to_torch(v_kernel) + .permute(1, 0, 2) + .reshape(backbone.hidden_dim, -1) + .T + ) + + o_kernel = block.attention.output_dense.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = ( + to_torch(o_kernel).permute(2, 0, 1).reshape(backbone.hidden_dim, -1) + ) + + q_norm = block.attention.query_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch( + q_norm + ) + + k_norm = block.attention.key_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch( + k_norm + ) + + gate_kernel = block.gating_ffw.get_weights()[0] + weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = to_torch( + gate_kernel + ).T + + up_kernel = block.gating_ffw_2.get_weights()[0] + weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = to_torch( + up_kernel + ).T + + down_kernel = block.ffw_linear.get_weights()[0] + weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = to_torch( + down_kernel + ).T + + input_layer_norm = block.pre_attention_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch( + input_layer_norm + ) + + post_attn_norm = block.post_attention_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( + to_torch(post_attn_norm) + ) + + pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = ( + to_torch(pre_feedforward_layernorm_weight) + ) + + post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0] + weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = ( + to_torch(post_feedforward_layernorm_weight) + ) + + final_norm = backbone.get_layer("final_normalization").get_weights()[0] + weights_dict["model.norm.weight"] = to_torch(final_norm) + weights_dict["lm_head.weight"] = weights_dict[ + "model.embed_tokens.weight" + ].clone() + + os.makedirs(path, exist_ok=True) + with open(os.path.join(path, "config.json"), "w") as f: + json.dump(hf_config.to_dict(), f) + weights_dict = {k: v.contiguous() for k, v in weights_dict.items()} + save_file(weights_dict, os.path.join(path, "model.safetensors")) + keras_tokenizer.save_assets(path) + vocab_spm = os.path.join(path, "vocabulary.spm") + tokenizer_model = os.path.join(path, "tokenizer.model") + if os.path.exists(vocab_spm): + shutil.move(vocab_spm, tokenizer_model) + print("Export complete! Files saved in:", path) + + +def load_hf_model(model_name, device): + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained(model_name) + model.to(device) + model.eval() + return model, tokenizer + + +def infer( + model, + tokenizer, + prompt, + device, + max_new_tokens=30, + temperature=1.0, + top_k=50, + top_p=1.0, +): + # Tokenize inpu + inputs = tokenizer(prompt, return_tensors="pt").to(device) + + # Generate output + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + temperature=temperature, + top_k=top_k, + top_p=top_p, + do_sample=False, + ) + + # Decode generated tokens + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False) + return generated_text + + keras.utils.set_random_seed(42) FLAGS = flags.FLAGS @@ -126,6 +314,13 @@ required=True, ) +flags.DEFINE_bool( + "export_safetensors", + False, + "Export model to Safetensors format (HuggingFace-compatible). " + "Only for text-only models.", +) + def convert_model(flax_config, text_only): vision_encoder = None @@ -558,6 +753,35 @@ def main(_): keras_tokenizer.save_to_preset(preset) print(f"🏁 Preset saved to ./{preset}") + if FLAGS.export_safetensors and text_only: + export_dir = f"./{preset}_safetensors_export" + print( + f"🏃 Exporting to Safetensors (HuggingFace format) at {export_dir}" + ) + export_to_hf(keras_model, keras_tokenizer, export_dir) + print(f"🏁 Safetensors export complete: {export_dir}") + + local_hf_model, local_hf_tokenizer = load_hf_model( + export_dir, device="cpu" + ) + print("Local Hugging Face model loaded successfully!") + + print( + "🔶 Safetensors output:", + infer( + local_hf_model, + local_hf_tokenizer, + "What is Keras?", + "cpu", + max_new_tokens=100, + ), + ) + elif FLAGS.export_safetensors: + print( + "⚠️ Safetensors export is only supported for text-only models. \ + Skipping export." + ) + if __name__ == "__main__": app.run(main)