|
| 1 | +import os |
| 2 | +import re |
| 3 | +from timeit import default_timer as timer |
| 4 | +import wave |
| 5 | +import argparse |
| 6 | + |
| 7 | +parser = argparse.ArgumentParser(description="Running Whisper TFlite test inference.") |
| 8 | +parser.add_argument("-f", "--folder", default="../test-files/", help="Folder with WAV input files") |
| 9 | +parser.add_argument("-m", "--model", default="models/whisper-tiny-ct2", help="Path to model") |
| 10 | +parser.add_argument("-l", "--lang", default="auto", help="Language used (default: auto)") |
| 11 | +parser.add_argument("-t", "--threads", default=2, help="Threads used (default: 2)") |
| 12 | +parser.add_argument("-b", "--beamsize", default=1, help="Beam size used (default: 1)") |
| 13 | +args = parser.parse_args() |
| 14 | + |
| 15 | +print(f'Importing WhisperModel') |
| 16 | +from faster_whisper import WhisperModel |
| 17 | + |
| 18 | +# run on CPU with INT8: |
| 19 | +model_path = args.model |
| 20 | +print(f'\nLoading model {model_path} ...') |
| 21 | +model = WhisperModel(model_path, device="cpu", compute_type="int8", cpu_threads=int(args.threads)) |
| 22 | +#model = WhisperModel(args.model, device="cuda", compute_type="float16") |
| 23 | +print(f'Threads: {args.threads}') |
| 24 | +print(f'Beam size: {args.beamsize}') |
| 25 | + |
| 26 | +def transcribe(audio_file): |
| 27 | + print(f'\nLoading audio file: {audio_file}') |
| 28 | + wf = wave.open(audio_file, "rb") |
| 29 | + sample_rate_orig = wf.getframerate() |
| 30 | + audio_length = wf.getnframes() * (1 / sample_rate_orig) |
| 31 | + if (wf.getnchannels() != 1 or wf.getsampwidth() != 2 |
| 32 | + or wf.getcomptype() != "NONE" or sample_rate_orig != 16000): |
| 33 | + print("Audio file must be WAV format mono PCM.") |
| 34 | + exit (1) |
| 35 | + wf.close() |
| 36 | + print(f'Samplerate: {sample_rate_orig}, length: {audio_length}s') |
| 37 | + |
| 38 | + file_lang = None |
| 39 | + lang_search = re.findall(r"(?:^|/)(\w\w)_", audio_file) |
| 40 | + if len(lang_search) > 0: |
| 41 | + file_lang = lang_search.pop() |
| 42 | + |
| 43 | + inference_start = timer() |
| 44 | + |
| 45 | + print("\nTranscribing ...") |
| 46 | + segments = None |
| 47 | + info = None |
| 48 | + if "tiny.en" in model_path: |
| 49 | + if file_lang is not None and file_lang != "en": |
| 50 | + print(f"Language found in file name: {file_lang}") |
| 51 | + print("Skipped file to avoid issues with tiny.en model") |
| 52 | + else: |
| 53 | + segments, info = model.transcribe(audio_file, beam_size=int(args.beamsize)) |
| 54 | + print("Model language fixed to 'en'") |
| 55 | + elif args.lang == "auto": |
| 56 | + if file_lang is not None: |
| 57 | + segments, info = model.transcribe(audio_file, beam_size=int(args.beamsize), language=file_lang) |
| 58 | + print(f"Language found in file name: {file_lang}") |
| 59 | + else: |
| 60 | + segments, info = model.transcribe(audio_file, beam_size=int(args.beamsize)) |
| 61 | + print("Detected language '%s' with probability %f" % (info.language, info.language_probability)) |
| 62 | + else: |
| 63 | + segments, info = model.transcribe(audio_file, beam_size=int(args.beamsize), language=args.lang) |
| 64 | + print(f'Pre-defined language: {args.lang}') |
| 65 | + |
| 66 | + if segments is not None: |
| 67 | + print("Result:") |
| 68 | + for segment in segments: |
| 69 | + print("[%ds -> %ds] %s" % (segment.start, segment.end, segment.text)) |
| 70 | + |
| 71 | + print("\nInference took {:.2f}s for {:.2f}s audio file.".format( |
| 72 | + timer() - inference_start, audio_length)) |
| 73 | + |
| 74 | +test_files = os.listdir(args.folder) |
| 75 | +for file in test_files: |
| 76 | + if file.endswith(".wav"): |
| 77 | + transcribe(args.folder + file) |
0 commit comments