Skip to content

Commit 3090486

Browse files
authored
Export tokenizer.model for parakeet export (#16494)
1 parent 8e3db7f commit 3090486

File tree

1 file changed

+51
-0
lines changed

1 file changed

+51
-0
lines changed

examples/models/parakeet/export_parakeet_tdt.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
import argparse
44
import os
5+
import shutil
6+
import tarfile
7+
import tempfile
58

69
import torch
710
import torchaudio
@@ -188,6 +191,51 @@ def load_model():
188191
return model
189192

190193

194+
def extract_tokenizer(output_dir: str) -> str | None:
195+
"""Extract tokenizer.model from the cached .nemo file.
196+
197+
Args:
198+
output_dir: Directory to save the tokenizer.model file.
199+
200+
Returns:
201+
Path to the extracted tokenizer.model, or None if extraction failed.
202+
"""
203+
from huggingface_hub import hf_hub_download
204+
205+
# Download/get cached .nemo file path
206+
nemo_path = hf_hub_download(
207+
repo_id="nvidia/parakeet-tdt-0.6b-v3",
208+
filename="parakeet-tdt-0.6b-v3.nemo",
209+
)
210+
211+
# .nemo files are tar archives - extract tokenizer.model
212+
tokenizer_filename = "tokenizer.model"
213+
output_path = os.path.join(output_dir, tokenizer_filename)
214+
215+
with tempfile.TemporaryDirectory() as tmpdir:
216+
with tarfile.open(nemo_path, "r") as tar:
217+
# Find tokenizer.model in the archive (may be in root or subdirectory)
218+
tokenizer_member = None
219+
for member in tar.getmembers():
220+
if member.name.endswith(tokenizer_filename):
221+
tokenizer_member = member
222+
break
223+
224+
if tokenizer_member is None:
225+
print(f"Warning: {tokenizer_filename} not found in .nemo archive")
226+
return None
227+
228+
# Extract to temp directory
229+
tar.extract(tokenizer_member, tmpdir)
230+
extracted_path = os.path.join(tmpdir, tokenizer_member.name)
231+
232+
# Copy to output directory
233+
shutil.copy2(extracted_path, output_path)
234+
235+
print(f"Extracted tokenizer to: {output_path}")
236+
return output_path
237+
238+
191239
class JointAfterProjection(torch.nn.Module):
192240
def __init__(self, joint):
193241
super().__init__()
@@ -401,6 +449,9 @@ def main():
401449

402450
os.makedirs(args.output_dir, exist_ok=True)
403451

452+
print("Extracting tokenizer...")
453+
extract_tokenizer(args.output_dir)
454+
404455
print("Loading model...")
405456
model = load_model()
406457

0 commit comments

Comments
 (0)