|
5 | 5 | import traceback |
6 | 6 |
|
7 | 7 | import numpy as np |
8 | | -import requests |
9 | 8 | from absl import app |
10 | 9 | from absl import flags |
11 | 10 | from keras import ops |
12 | 11 | from transformers import AutoTokenizer |
13 | 12 | from transformers import MistralForCausalLM |
14 | 13 |
|
15 | 14 | from keras_hub.models import MistralBackbone |
| 15 | +from keras_hub.models import MistralCausalLM |
16 | 16 | from keras_hub.models import MistralCausalLMPreprocessor |
17 | 17 | from keras_hub.models import MistralTokenizer |
18 | | -from keras_hub.utils.preset_utils import save_to_preset |
19 | 18 |
|
20 | 19 | PRESET_MAP = { |
21 | 20 | "mistral_7b_en": "mistralai/Mistral-7B-v0.1", |
| 21 | + "mistral_0.3_7b_en": "mistralai/Mistral-7B-v0.3", |
22 | 22 | "mistral_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.1", |
23 | 23 | "mistral_0.2_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.2", |
| 24 | + "mistral_0.3_instruct_7b_en": "mistralai/Mistral-7B-Instruct-v0.3", |
24 | 25 | } |
25 | 26 |
|
26 | 27 | FLAGS = flags.FLAGS |
@@ -236,49 +237,43 @@ def main(_): |
236 | 237 | rope_max_wavelength=hf_model.config.rope_theta, |
237 | 238 | dtype="float32", |
238 | 239 | ) |
239 | | - keras_hub_model = MistralBackbone(**backbone_kwargs) |
| 240 | + keras_hub_backbone = MistralBackbone(**backbone_kwargs) |
240 | 241 |
|
241 | | - # === Download the tokenizer from Huggingface model card === |
242 | | - spm_path = ( |
243 | | - f"https://huggingface.co/{hf_preset}/resolve/main/tokenizer.model" |
244 | | - ) |
245 | | - response = requests.get(spm_path) |
246 | | - if not response.ok: |
247 | | - raise ValueError(f"Couldn't fetch {preset}'s tokenizer.") |
248 | | - tokenizer_path = os.path.join(temp_dir, "vocabulary.spm") |
249 | | - with open(tokenizer_path, "wb") as tokenizer_file: |
250 | | - tokenizer_file.write(response.content) |
251 | | - keras_hub_tokenizer = MistralTokenizer(tokenizer_path) |
| 242 | + keras_hub_tokenizer = MistralTokenizer.from_preset(f"hf://{hf_preset}") |
252 | 243 | print("\n-> Keras 3 model and tokenizer loaded.") |
253 | 244 |
|
254 | 245 | # === Port the weights === |
255 | | - convert_checkpoints(keras_hub_model, hf_model) |
| 246 | + convert_checkpoints(keras_hub_backbone, hf_model) |
256 | 247 | print("\n-> Weight transfer done.") |
257 | 248 |
|
258 | 249 | # === Check that the models and tokenizers outputs match === |
259 | 250 | test_tokenizer(keras_hub_tokenizer, hf_tokenizer) |
260 | | - test_model(keras_hub_model, keras_hub_tokenizer, hf_model, hf_tokenizer) |
| 251 | + test_model( |
| 252 | + keras_hub_backbone, keras_hub_tokenizer, hf_model, hf_tokenizer |
| 253 | + ) |
261 | 254 | print("\n-> Tests passed!") |
262 | 255 |
|
263 | 256 | # === Save the model weights in float32 format === |
264 | | - keras_hub_model.save_weights(os.path.join(temp_dir, "model.weights.h5")) |
| 257 | + keras_hub_backbone.save_weights( |
| 258 | + os.path.join(temp_dir, "model.weights.h5") |
| 259 | + ) |
265 | 260 | print("\n-> Saved the model weights in float32") |
266 | 261 |
|
267 | | - del keras_hub_model, hf_model |
| 262 | + del keras_hub_backbone, hf_model |
268 | 263 | gc.collect() |
269 | 264 |
|
270 | 265 | # === Save the weights again in float16 === |
271 | 266 | backbone_kwargs["dtype"] = "float16" |
272 | | - keras_hub_model = MistralBackbone(**backbone_kwargs) |
273 | | - keras_hub_model.load_weights(os.path.join(temp_dir, "model.weights.h5")) |
274 | | - save_to_preset(keras_hub_model, preset) |
| 267 | + keras_hub_backbone = MistralBackbone(**backbone_kwargs) |
| 268 | + keras_hub_backbone.load_weights( |
| 269 | + os.path.join(temp_dir, "model.weights.h5") |
| 270 | + ) |
| 271 | + |
| 272 | + preprocessor = MistralCausalLMPreprocessor(keras_hub_tokenizer) |
| 273 | + keras_hub_model = MistralCausalLM(keras_hub_backbone, preprocessor) |
| 274 | + keras_hub_model.save_to_preset(f"./{preset}") |
275 | 275 | print("\n-> Saved the model preset in float16") |
276 | 276 |
|
277 | | - # === Save the tokenizer === |
278 | | - save_to_preset( |
279 | | - keras_hub_tokenizer, preset, config_filename="tokenizer.json" |
280 | | - ) |
281 | | - print("\n-> Saved the tokenizer") |
282 | 277 | finally: |
283 | 278 | shutil.rmtree(temp_dir) |
284 | 279 |
|
|
0 commit comments