Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 233 additions & 2 deletions tools/checkpoint_conversion/convert_gemma3_checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
Usage:
```shell
cd tools/checkpoint_conversion
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b
python convert_gemma3_checkpoints.py --preset gemma3_instruct_1b \
--export_safetensors
python convert_gemma3_checkpoints.py --preset gemma3_instruct_4b \
--export_safetensors
```
"""

Expand All @@ -21,16 +23,209 @@
# No GPU for conversion, makes memory management easier.
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import json
import shutil

import keras # noqa: E402
import numpy as np
import tensorflow_datasets as tfds # noqa: E402
import torch
import transformers
from absl import app # noqa: E402
from absl import flags # noqa: E402
from checkpoint_conversion_utils import download_gcs_file
from gemma import gm # noqa: E402
from keras import ops # noqa: E402
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer

import keras_hub # noqa: E402


def convert_to_hf_config(keras_config):
"""Convert Keras Gemma config to Hugging Face GemmaConfig."""
hf_config = transformers.Gemma3TextConfig(
vocab_size=keras_config.vocabulary_size,
num_hidden_layers=keras_config.num_layers,
num_attention_heads=keras_config.num_query_heads,
num_key_value_heads=keras_config.num_key_value_heads,
hidden_size=keras_config.hidden_dim,
intermediate_size=keras_config.intermediate_dim,
head_dim=keras_config.head_dim,
max_position_embeddings=32768,
)
return hf_config


def export_to_hf(backbone, keras_tokenizer, path):
"""Convert a Keras Gemma model to Hugging Face format and save to path."""

hf_config = convert_to_hf_config(backbone)
weights_dict = {}

# Helper function to convert bfloat16 weights to torch tensors
def to_torch(weight):
# Convert bfloat16 to float32 first, then to torch, then to bfloat16
if hasattr(weight.dtype, "name") and "bfloat16" in str(weight.dtype):
weight = np.array(weight, dtype=np.float32)
return torch.from_numpy(weight).to(torch.bfloat16)

# Token embeddings
token_embedding = backbone.get_layer("token_embedding").get_weights()[0]
weights_dict["model.embed_tokens.weight"] = to_torch(token_embedding)

for i in range(backbone.num_layers):
block = backbone.get_layer(f"decoder_block_{i}")
q_kernel = block.attention.query_dense.get_weights()[0]
q_kernel = (
torch.from_numpy(np.array(q_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = q_kernel

k_kernel = block.attention.key_dense.get_weights()[0]
k_kernel = (
torch.from_numpy(np.array(k_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = k_kernel

v_kernel = block.attention.value_dense.get_weights()[0]
v_kernel = (
torch.from_numpy(np.array(v_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(1, 0, 2)
.reshape(backbone.hidden_dim, -1)
.T
)
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = v_kernel

o_kernel = block.attention.output_dense.get_weights()[0]
o_kernel = (
torch.from_numpy(np.array(o_kernel, dtype=np.float32))
.to(torch.bfloat16)
.permute(2, 0, 1)
.reshape(backbone.hidden_dim, -1)
)
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = o_kernel

q_norm = block.attention.query_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.q_norm.weight"] = to_torch(
q_norm
)

k_norm = block.attention.key_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.self_attn.k_norm.weight"] = to_torch(
k_norm
)

gate_kernel = block.gating_ffw.get_weights()[0]
gate_kernel = (
torch.from_numpy(np.array(gate_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = gate_kernel

up_kernel = block.gating_ffw_2.get_weights()[0]
up_kernel = (
torch.from_numpy(np.array(up_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = up_kernel

down_kernel = block.ffw_linear.get_weights()[0]
down_kernel = (
torch.from_numpy(np.array(down_kernel, dtype=np.float32))
.to(torch.bfloat16)
.T
)
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = down_kernel

input_layer_norm = block.pre_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = to_torch(
input_layer_norm
)

post_attn_norm = block.post_attention_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
to_torch(post_attn_norm)
)

pre_feedforward_layernorm_weight = block.pre_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.pre_feedforward_layernorm.weight"] = (
to_torch(pre_feedforward_layernorm_weight)
)

post_feedforward_layernorm_weight = block.post_ffw_norm.get_weights()[0]
weights_dict[f"model.layers.{i}.post_feedforward_layernorm.weight"] = (
to_torch(post_feedforward_layernorm_weight)
)

final_norm = backbone.get_layer("final_normalization").get_weights()[0]
weights_dict["model.norm.weight"] = to_torch(final_norm)
weights_dict["lm_head.weight"] = weights_dict[
"model.embed_tokens.weight"
].clone()

os.makedirs(path, exist_ok=True)
with open(os.path.join(path, "config.json"), "w") as f:
json.dump(hf_config.to_dict(), f)
weights_dict = {k: v.contiguous() for k, v in weights_dict.items()}
save_file(weights_dict, os.path.join(path, "model.safetensors"))
keras_tokenizer.save_assets(path)
vocab_spm = os.path.join(path, "vocabulary.spm")
tokenizer_model = os.path.join(path, "tokenizer.model")
if os.path.exists(vocab_spm):
shutil.move(vocab_spm, tokenizer_model)
print("Export complete! Files saved in:", path)


def load_hf_model(model_name, device):
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
model.to(device)
model.eval()
return model, tokenizer


def infer(
model,
tokenizer,
prompt,
device,
max_new_tokens=30,
temperature=1.0,
top_k=50,
top_p=1.0,
):
# Tokenize inpu
inputs = tokenizer(prompt, return_tensors="pt").to(device)

# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
)

# Decode generated tokens
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=False)
return generated_text


keras.utils.set_random_seed(42)

FLAGS = flags.FLAGS
Expand Down Expand Up @@ -126,6 +321,13 @@
required=True,
)

flags.DEFINE_bool(
"export_safetensors",
False,
"Export model to Safetensors format (HuggingFace-compatible). \
Only for text-only models.",
)


def convert_model(flax_config, text_only):
vision_encoder = None
Expand Down Expand Up @@ -558,6 +760,35 @@ def main(_):
keras_tokenizer.save_to_preset(preset)
print(f"🏁 Preset saved to ./{preset}")

if FLAGS.export_safetensors and text_only:
export_dir = f"./{preset}_safetensors_export"
print(
f"🏃 Exporting to Safetensors (HuggingFace format) at {export_dir}"
)
export_to_hf(keras_model, keras_tokenizer, export_dir)
print(f"🏁 Safetensors export complete: {export_dir}")

local_hf_model, local_hf_tokenizer = load_hf_model(
export_dir, device="cpu"
)
print("Local Hugging Face model loaded successfully!")

print(
"🔶 Safetensors output:",
infer(
local_hf_model,
local_hf_tokenizer,
"Hello, my name is",
"cpu",
max_new_tokens=100,
),
)
elif FLAGS.export_safetensors:
print(
"⚠️ Safetensors export is only supported for text-only models. \
Skipping export."
)


if __name__ == "__main__":
app.run(main)
Loading