Skip to content

Commit 5f9bf53

Browse files
committed
feat: solve some type hints issues
1 parent 8327d8c commit 5f9bf53

File tree

2 files changed

+42
-31
lines changed

2 files changed

+42
-31
lines changed

faster_whisper/transcribe.py

Lines changed: 37 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from dataclasses import asdict, dataclass
88
from inspect import signature
99
from math import ceil
10-
from typing import BinaryIO, Iterable, List, Optional, Tuple, Union
10+
from typing import Any, BinaryIO, Iterable, List, Optional, Tuple, Union
1111
from warnings import warn
1212

1313
import ctranslate2
@@ -81,11 +81,11 @@ class TranscriptionOptions:
8181
compression_ratio_threshold: Optional[float]
8282
condition_on_previous_text: bool
8383
prompt_reset_on_temperature: float
84-
temperatures: List[float]
84+
temperatures: Union[List[float], Tuple[float, ...]]
8585
initial_prompt: Optional[Union[str, Iterable[int]]]
8686
prefix: Optional[str]
8787
suppress_blank: bool
88-
suppress_tokens: Optional[List[int]]
88+
suppress_tokens: Union[List[int], Tuple[int, ...]]
8989
without_timestamps: bool
9090
max_initial_timestamp: float
9191
word_timestamps: bool
@@ -106,7 +106,7 @@ class TranscriptionInfo:
106106
duration_after_vad: float
107107
all_language_probs: Optional[List[Tuple[str, float]]]
108108
transcription_options: TranscriptionOptions
109-
vad_options: VadOptions
109+
vad_options: Optional[VadOptions]
110110

111111

112112
class BatchedInferencePipeline:
@@ -121,7 +121,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
121121
encoder_output, outputs = self.generate_segment_batched(
122122
features, tokenizer, options
123123
)
124-
125124
segmented_outputs = []
126125
segment_sizes = []
127126
for chunk_metadata, output in zip(chunks_metadata, outputs):
@@ -130,8 +129,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
130129
segment_sizes.append(segment_size)
131130
(
132131
subsegments,
133-
seek,
134-
single_timestamp_ending,
132+
_,
133+
_,
135134
) = self.model._split_segments_by_timestamps(
136135
tokenizer=tokenizer,
137136
tokens=output["tokens"],
@@ -295,7 +294,7 @@ def transcribe(
295294
hallucination_silence_threshold: Optional[float] = None,
296295
batch_size: int = 8,
297296
hotwords: Optional[str] = None,
298-
language_detection_threshold: Optional[float] = 0.5,
297+
language_detection_threshold: float = 0.5,
299298
language_detection_segments: int = 1,
300299
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
301300
"""transcribe audio in chunks in batched fashion and return with language info.
@@ -582,7 +581,7 @@ def __init__(
582581
num_workers: int = 1,
583582
download_root: Optional[str] = None,
584583
local_files_only: bool = False,
585-
files: dict = None,
584+
files: Optional[dict] = None,
586585
**model_kwargs,
587586
):
588587
"""Initializes the Whisper model.
@@ -731,7 +730,7 @@ def transcribe(
731730
clip_timestamps: Union[str, List[float]] = "0",
732731
hallucination_silence_threshold: Optional[float] = None,
733732
hotwords: Optional[str] = None,
734-
language_detection_threshold: Optional[float] = 0.5,
733+
language_detection_threshold: float = 0.5,
735734
language_detection_segments: int = 1,
736735
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
737736
"""Transcribes an input file.
@@ -833,7 +832,7 @@ def transcribe(
833832
elif isinstance(vad_parameters, dict):
834833
vad_parameters = VadOptions(**vad_parameters)
835834
speech_chunks = get_speech_timestamps(audio, vad_parameters)
836-
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
835+
audio_chunks, _ = collect_chunks(audio, speech_chunks)
837836
audio = np.concatenate(audio_chunks, axis=0)
838837
duration_after_vad = audio.shape[0] / sampling_rate
839838

@@ -925,7 +924,7 @@ def transcribe(
925924
condition_on_previous_text=condition_on_previous_text,
926925
prompt_reset_on_temperature=prompt_reset_on_temperature,
927926
temperatures=(
928-
temperature if isinstance(temperature, (list, tuple)) else [temperature]
927+
temperature if isinstance(temperature, (List, Tuple)) else [temperature]
929928
),
930929
initial_prompt=initial_prompt,
931930
prefix=prefix,
@@ -953,7 +952,8 @@ def transcribe(
953952

954953
if speech_chunks:
955954
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
956-
955+
if isinstance(vad_parameters, dict):
956+
vad_parameters = VadOptions(**vad_parameters)
957957
info = TranscriptionInfo(
958958
language=language,
959959
language_probability=language_probability,
@@ -974,7 +974,7 @@ def _split_segments_by_timestamps(
974974
segment_size: int,
975975
segment_duration: float,
976976
seek: int,
977-
) -> List[List[int]]:
977+
) -> Tuple[List[Any], int, bool]:
978978
current_segments = []
979979
single_timestamp_ending = (
980980
len(tokens) >= 2 and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
@@ -1517,8 +1517,8 @@ def add_word_timestamps(
15171517
num_frames: int,
15181518
prepend_punctuations: str,
15191519
append_punctuations: str,
1520-
last_speech_timestamp: float,
1521-
) -> float:
1520+
last_speech_timestamp: Union[float, None],
1521+
) -> Optional[float]:
15221522
if len(segments) == 0:
15231523
return
15241524

@@ -1665,9 +1665,11 @@ def find_alignment(
16651665
text_indices = np.array([pair[0] for pair in alignments])
16661666
time_indices = np.array([pair[1] for pair in alignments])
16671667

1668-
words, word_tokens = tokenizer.split_to_word_tokens(
1669-
text_token + [tokenizer.eot]
1670-
)
1668+
if isinstance(text_token, int):
1669+
tokens = [text_token] + [tokenizer.eot]
1670+
else:
1671+
tokens = text_token + [tokenizer.eot]
1672+
words, word_tokens = tokenizer.split_to_word_tokens(tokens)
16711673
if len(word_tokens) <= 1:
16721674
# return on eot only
16731675
# >>> np.pad([], (1, 0))
@@ -1715,7 +1717,7 @@ def detect_language(
17151717
audio: Optional[np.ndarray] = None,
17161718
features: Optional[np.ndarray] = None,
17171719
vad_filter: bool = False,
1718-
vad_parameters: Union[dict, VadOptions] = None,
1720+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
17191721
language_detection_segments: int = 1,
17201722
language_detection_threshold: float = 0.5,
17211723
) -> Tuple[str, float, List[Tuple[str, float]]]:
@@ -1747,18 +1749,24 @@ def detect_language(
17471749
if audio is not None:
17481750
if vad_filter:
17491751
speech_chunks = get_speech_timestamps(audio, vad_parameters)
1750-
audio_chunks, chunks_metadata = collect_chunks(audio, speech_chunks)
1752+
audio_chunks, _ = collect_chunks(audio, speech_chunks)
17511753
audio = np.concatenate(audio_chunks, axis=0)
1752-
1754+
assert (
1755+
audio is not None
1756+
), "Audio have a problem while concatanating the audio_chunks; return None"
17531757
audio = audio[
17541758
: language_detection_segments * self.feature_extractor.n_samples
17551759
]
17561760
features = self.feature_extractor(audio)
1757-
1761+
assert (
1762+
features is not None
1763+
), "No features extracted from audio file; return None"
17581764
features = features[
17591765
..., : language_detection_segments * self.feature_extractor.nb_max_frames
17601766
]
1761-
1767+
assert (
1768+
features is not None
1769+
), "No features extracted when detectting language in audio segments; return None"
17621770
detected_language_info = {}
17631771
for i in range(0, features.shape[-1], self.feature_extractor.nb_max_frames):
17641772
encoder_output = self.encode(
@@ -1828,13 +1836,13 @@ def get_compression_ratio(text: str) -> float:
18281836

18291837
def get_suppressed_tokens(
18301838
tokenizer: Tokenizer,
1831-
suppress_tokens: Tuple[int],
1832-
) -> Optional[List[int]]:
1833-
if -1 in suppress_tokens:
1839+
suppress_tokens: Optional[List[int]],
1840+
) -> Tuple[int, ...]:
1841+
if suppress_tokens is None or len(suppress_tokens) == 0:
1842+
suppress_tokens = [] # interpret empty string as an empty list
1843+
elif -1 in suppress_tokens:
18341844
suppress_tokens = [t for t in suppress_tokens if t >= 0]
18351845
suppress_tokens.extend(tokenizer.non_speech_tokens)
1836-
elif suppress_tokens is None or len(suppress_tokens) == 0:
1837-
suppress_tokens = [] # interpret empty string as an empty list
18381846
else:
18391847
assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
18401848

faster_whisper/vad.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import os
44

55
from dataclasses import dataclass
6-
from typing import Dict, List, Optional, Tuple
6+
from typing import Dict, List, Optional, Tuple, Union
77

88
import numpy as np
99

@@ -44,7 +44,7 @@ class VadOptions:
4444

4545
def get_speech_timestamps(
4646
audio: np.ndarray,
47-
vad_options: Optional[VadOptions] = None,
47+
vad_options: Optional[Union[dict, VadOptions]] = None,
4848
sampling_rate: int = 16000,
4949
**kwargs,
5050
) -> List[dict]:
@@ -62,6 +62,9 @@ def get_speech_timestamps(
6262
if vad_options is None:
6363
vad_options = VadOptions(**kwargs)
6464

65+
if isinstance(vad_options, dict):
66+
vad_options = VadOptions(**vad_options)
67+
6568
threshold = vad_options.threshold
6669
min_speech_duration_ms = vad_options.min_speech_duration_ms
6770
max_speech_duration_s = vad_options.max_speech_duration_s

0 commit comments

Comments
 (0)