|
2 | 2 |
|
3 | 3 | import argparse |
4 | 4 | import os |
| 5 | +import shutil |
| 6 | +import tarfile |
| 7 | +import tempfile |
5 | 8 |
|
6 | 9 | import torch |
7 | 10 | import torchaudio |
@@ -188,6 +191,51 @@ def load_model(): |
188 | 191 | return model |
189 | 192 |
|
190 | 193 |
|
| 194 | +def extract_tokenizer(output_dir: str) -> str | None: |
| 195 | + """Extract tokenizer.model from the cached .nemo file. |
| 196 | +
|
| 197 | + Args: |
| 198 | + output_dir: Directory to save the tokenizer.model file. |
| 199 | +
|
| 200 | + Returns: |
| 201 | + Path to the extracted tokenizer.model, or None if extraction failed. |
| 202 | + """ |
| 203 | + from huggingface_hub import hf_hub_download |
| 204 | + |
| 205 | + # Download/get cached .nemo file path |
| 206 | + nemo_path = hf_hub_download( |
| 207 | + repo_id="nvidia/parakeet-tdt-0.6b-v3", |
| 208 | + filename="parakeet-tdt-0.6b-v3.nemo", |
| 209 | + ) |
| 210 | + |
| 211 | + # .nemo files are tar archives - extract tokenizer.model |
| 212 | + tokenizer_filename = "tokenizer.model" |
| 213 | + output_path = os.path.join(output_dir, tokenizer_filename) |
| 214 | + |
| 215 | + with tempfile.TemporaryDirectory() as tmpdir: |
| 216 | + with tarfile.open(nemo_path, "r") as tar: |
| 217 | + # Find tokenizer.model in the archive (may be in root or subdirectory) |
| 218 | + tokenizer_member = None |
| 219 | + for member in tar.getmembers(): |
| 220 | + if member.name.endswith(tokenizer_filename): |
| 221 | + tokenizer_member = member |
| 222 | + break |
| 223 | + |
| 224 | + if tokenizer_member is None: |
| 225 | + print(f"Warning: {tokenizer_filename} not found in .nemo archive") |
| 226 | + return None |
| 227 | + |
| 228 | + # Extract to temp directory |
| 229 | + tar.extract(tokenizer_member, tmpdir) |
| 230 | + extracted_path = os.path.join(tmpdir, tokenizer_member.name) |
| 231 | + |
| 232 | + # Copy to output directory |
| 233 | + shutil.copy2(extracted_path, output_path) |
| 234 | + |
| 235 | + print(f"Extracted tokenizer to: {output_path}") |
| 236 | + return output_path |
| 237 | + |
| 238 | + |
191 | 239 | class JointAfterProjection(torch.nn.Module): |
192 | 240 | def __init__(self, joint): |
193 | 241 | super().__init__() |
@@ -401,6 +449,9 @@ def main(): |
401 | 449 |
|
402 | 450 | os.makedirs(args.output_dir, exist_ok=True) |
403 | 451 |
|
| 452 | + print("Extracting tokenizer...") |
| 453 | + extract_tokenizer(args.output_dir) |
| 454 | + |
404 | 455 | print("Loading model...") |
405 | 456 | model = load_model() |
406 | 457 |
|
|
0 commit comments