77from dataclasses import asdict , dataclass
88from inspect import signature
99from math import ceil
10- from typing import BinaryIO , Iterable , List , Optional , Tuple , Union
10+ from typing import Any , BinaryIO , Iterable , List , Optional , Tuple , Union
1111from warnings import warn
1212
1313import ctranslate2
@@ -81,11 +81,11 @@ class TranscriptionOptions:
8181 compression_ratio_threshold : Optional [float ]
8282 condition_on_previous_text : bool
8383 prompt_reset_on_temperature : float
84- temperatures : List [float ]
84+ temperatures : Union [ List [float ], Tuple [ float , ...] ]
8585 initial_prompt : Optional [Union [str , Iterable [int ]]]
8686 prefix : Optional [str ]
8787 suppress_blank : bool
88- suppress_tokens : Optional [List [int ]]
88+ suppress_tokens : Union [List [int ], Tuple [ int , ... ]]
8989 without_timestamps : bool
9090 max_initial_timestamp : float
9191 word_timestamps : bool
@@ -106,7 +106,7 @@ class TranscriptionInfo:
106106 duration_after_vad : float
107107 all_language_probs : Optional [List [Tuple [str , float ]]]
108108 transcription_options : TranscriptionOptions
109- vad_options : VadOptions
109+ vad_options : Optional [ VadOptions ]
110110
111111
112112class BatchedInferencePipeline :
@@ -121,7 +121,6 @@ def forward(self, features, tokenizer, chunks_metadata, options):
121121 encoder_output , outputs = self .generate_segment_batched (
122122 features , tokenizer , options
123123 )
124-
125124 segmented_outputs = []
126125 segment_sizes = []
127126 for chunk_metadata , output in zip (chunks_metadata , outputs ):
@@ -130,8 +129,8 @@ def forward(self, features, tokenizer, chunks_metadata, options):
130129 segment_sizes .append (segment_size )
131130 (
132131 subsegments ,
133- seek ,
134- single_timestamp_ending ,
132+ _ ,
133+ _ ,
135134 ) = self .model ._split_segments_by_timestamps (
136135 tokenizer = tokenizer ,
137136 tokens = output ["tokens" ],
@@ -295,7 +294,7 @@ def transcribe(
295294 hallucination_silence_threshold : Optional [float ] = None ,
296295 batch_size : int = 8 ,
297296 hotwords : Optional [str ] = None ,
298- language_detection_threshold : Optional [ float ] = 0.5 ,
297+ language_detection_threshold : float = 0.5 ,
299298 language_detection_segments : int = 1 ,
300299 ) -> Tuple [Iterable [Segment ], TranscriptionInfo ]:
301300 """transcribe audio in chunks in batched fashion and return with language info.
@@ -582,7 +581,7 @@ def __init__(
582581 num_workers : int = 1 ,
583582 download_root : Optional [str ] = None ,
584583 local_files_only : bool = False ,
585- files : dict = None ,
584+ files : Optional [ dict ] = None ,
586585 ** model_kwargs ,
587586 ):
588587 """Initializes the Whisper model.
@@ -731,7 +730,7 @@ def transcribe(
731730 clip_timestamps : Union [str , List [float ]] = "0" ,
732731 hallucination_silence_threshold : Optional [float ] = None ,
733732 hotwords : Optional [str ] = None ,
734- language_detection_threshold : Optional [ float ] = 0.5 ,
733+ language_detection_threshold : float = 0.5 ,
735734 language_detection_segments : int = 1 ,
736735 ) -> Tuple [Iterable [Segment ], TranscriptionInfo ]:
737736 """Transcribes an input file.
@@ -833,7 +832,7 @@ def transcribe(
833832 elif isinstance (vad_parameters , dict ):
834833 vad_parameters = VadOptions (** vad_parameters )
835834 speech_chunks = get_speech_timestamps (audio , vad_parameters )
836- audio_chunks , chunks_metadata = collect_chunks (audio , speech_chunks )
835+ audio_chunks , _ = collect_chunks (audio , speech_chunks )
837836 audio = np .concatenate (audio_chunks , axis = 0 )
838837 duration_after_vad = audio .shape [0 ] / sampling_rate
839838
@@ -925,7 +924,7 @@ def transcribe(
925924 condition_on_previous_text = condition_on_previous_text ,
926925 prompt_reset_on_temperature = prompt_reset_on_temperature ,
927926 temperatures = (
928- temperature if isinstance (temperature , (list , tuple )) else [temperature ]
927+ temperature if isinstance (temperature , (List , Tuple )) else [temperature ]
929928 ),
930929 initial_prompt = initial_prompt ,
931930 prefix = prefix ,
@@ -953,7 +952,8 @@ def transcribe(
953952
954953 if speech_chunks :
955954 segments = restore_speech_timestamps (segments , speech_chunks , sampling_rate )
956-
955+ if isinstance (vad_parameters , dict ):
956+ vad_parameters = VadOptions (** vad_parameters )
957957 info = TranscriptionInfo (
958958 language = language ,
959959 language_probability = language_probability ,
@@ -974,7 +974,7 @@ def _split_segments_by_timestamps(
974974 segment_size : int ,
975975 segment_duration : float ,
976976 seek : int ,
977- ) -> List [List [int ] ]:
977+ ) -> Tuple [List [Any ], int , bool ]:
978978 current_segments = []
979979 single_timestamp_ending = (
980980 len (tokens ) >= 2 and tokens [- 2 ] < tokenizer .timestamp_begin <= tokens [- 1 ]
@@ -1517,8 +1517,8 @@ def add_word_timestamps(
15171517 num_frames : int ,
15181518 prepend_punctuations : str ,
15191519 append_punctuations : str ,
1520- last_speech_timestamp : float ,
1521- ) -> float :
1520+ last_speech_timestamp : Union [ float , None ] ,
1521+ ) -> Optional [ float ] :
15221522 if len (segments ) == 0 :
15231523 return
15241524
@@ -1665,9 +1665,11 @@ def find_alignment(
16651665 text_indices = np .array ([pair [0 ] for pair in alignments ])
16661666 time_indices = np .array ([pair [1 ] for pair in alignments ])
16671667
1668- words , word_tokens = tokenizer .split_to_word_tokens (
1669- text_token + [tokenizer .eot ]
1670- )
1668+ if isinstance (text_token , int ):
1669+ tokens = [text_token ] + [tokenizer .eot ]
1670+ else :
1671+ tokens = text_token + [tokenizer .eot ]
1672+ words , word_tokens = tokenizer .split_to_word_tokens (tokens )
16711673 if len (word_tokens ) <= 1 :
16721674 # return on eot only
16731675 # >>> np.pad([], (1, 0))
@@ -1715,7 +1717,7 @@ def detect_language(
17151717 audio : Optional [np .ndarray ] = None ,
17161718 features : Optional [np .ndarray ] = None ,
17171719 vad_filter : bool = False ,
1718- vad_parameters : Union [dict , VadOptions ] = None ,
1720+ vad_parameters : Optional [ Union [dict , VadOptions ] ] = None ,
17191721 language_detection_segments : int = 1 ,
17201722 language_detection_threshold : float = 0.5 ,
17211723 ) -> Tuple [str , float , List [Tuple [str , float ]]]:
@@ -1747,18 +1749,24 @@ def detect_language(
17471749 if audio is not None :
17481750 if vad_filter :
17491751 speech_chunks = get_speech_timestamps (audio , vad_parameters )
1750- audio_chunks , chunks_metadata = collect_chunks (audio , speech_chunks )
1752+ audio_chunks , _ = collect_chunks (audio , speech_chunks )
17511753 audio = np .concatenate (audio_chunks , axis = 0 )
1752-
1754+ assert (
1755+ audio is not None
1756+ ), "Audio have a problem while concatanating the audio_chunks; return None"
17531757 audio = audio [
17541758 : language_detection_segments * self .feature_extractor .n_samples
17551759 ]
17561760 features = self .feature_extractor (audio )
1757-
1761+ assert (
1762+ features is not None
1763+ ), "No features extracted from audio file; return None"
17581764 features = features [
17591765 ..., : language_detection_segments * self .feature_extractor .nb_max_frames
17601766 ]
1761-
1767+ assert (
1768+ features is not None
1769+ ), "No features extracted when detectting language in audio segments; return None"
17621770 detected_language_info = {}
17631771 for i in range (0 , features .shape [- 1 ], self .feature_extractor .nb_max_frames ):
17641772 encoder_output = self .encode (
@@ -1828,13 +1836,13 @@ def get_compression_ratio(text: str) -> float:
18281836
18291837def get_suppressed_tokens (
18301838 tokenizer : Tokenizer ,
1831- suppress_tokens : Tuple [int ],
1832- ) -> Optional [List [int ]]:
1833- if - 1 in suppress_tokens :
1839+ suppress_tokens : Optional [List [int ]],
1840+ ) -> Tuple [int , ...]:
1841+ if suppress_tokens is None or len (suppress_tokens ) == 0 :
1842+ suppress_tokens = [] # interpret empty string as an empty list
1843+ elif - 1 in suppress_tokens :
18341844 suppress_tokens = [t for t in suppress_tokens if t >= 0 ]
18351845 suppress_tokens .extend (tokenizer .non_speech_tokens )
1836- elif suppress_tokens is None or len (suppress_tokens ) == 0 :
1837- suppress_tokens = [] # interpret empty string as an empty list
18381846 else :
18391847 assert isinstance (suppress_tokens , list ), "suppress_tokens must be a list"
18401848
0 commit comments