|
6 | 6 | import gc |
7 | 7 | import math |
8 | 8 | import logging |
| 9 | +import tempfile |
| 10 | +import librosa |
9 | 11 | import numpy as np |
10 | 12 | import multiprocessing as mp |
| 13 | +import soundfile as sf |
11 | 14 | from typing import Tuple, List, Optional, Dict, Any, Iterable, Union |
| 15 | +from copy import deepcopy |
12 | 16 | from pysrt import SubRipTime, SubRipItem, SubRipFile |
13 | 17 | from sklearn.metrics import log_loss |
14 | | -from copy import deepcopy |
15 | 18 | from .network import Network |
16 | 19 | from .embedder import FeatureEmbedder |
17 | 20 | from .media_helper import MediaHelper |
18 | 21 | from .subtitle import Subtitle |
19 | 22 | from .hyperparameters import Hyperparameters |
20 | | -from .exception import TerminalException |
21 | | -from .exception import NoFrameRateException |
| 23 | +from .lib.language import Language |
| 24 | +from .utils import Utils |
| 25 | +from .exception import TerminalException, NoFrameRateException |
22 | 26 | from .logger import Logger |
23 | 27 |
|
24 | 28 |
|
@@ -445,7 +449,7 @@ def _predict_in_multithreads( |
445 | 449 | gc.collect() |
446 | 450 |
|
447 | 451 | if stretch: |
448 | | - subs_new = self.__adjust_durations(subs_new, audio_file_path, stretch_in_lang, lock) |
| 452 | + subs_new = self.__compress_and_stretch(subs_new, audio_file_path, stretch_in_lang, lock) |
449 | 453 | self.__LOGGER.info("[{}] Segment {} stretched".format(os.getpid(), segment_index)) |
450 | 454 | return subs_new |
451 | 455 | except Exception as e: |
@@ -715,6 +719,111 @@ def __adjust_durations(self, subs: List[SubRipItem], audio_file_path: str, stret |
715 | 719 | if task.sync_map_file_path_absolute is not None and os.path.exists(task.sync_map_file_path_absolute): |
716 | 720 | os.remove(task.sync_map_file_path_absolute) |
717 | 721 |
|
| 722 | + def __compress_and_stretch(self, subs: List[SubRipItem], audio_file_path: str, stretch_in_lang: str, lock: threading.RLock) -> List[SubRipItem]: |
| 723 | + from dtw import dtw |
| 724 | + try: |
| 725 | + with lock: |
| 726 | + segment_path, _ = self.__media_helper.extract_audio_from_start_to_end( |
| 727 | + audio_file_path, |
| 728 | + str(subs[0].start), |
| 729 | + str(subs[len(subs) - 1].end), |
| 730 | + ) |
| 731 | + |
| 732 | + # Create a text file for DTW alignments |
| 733 | + root, _ = os.path.splitext(segment_path) |
| 734 | + text_file_path = "{}.txt".format(root) |
| 735 | + |
| 736 | + with open(text_file_path, "w", encoding="utf8") as text_file: |
| 737 | + text_file.write("*****".join([sub_new.text for sub_new in subs])) |
| 738 | + |
| 739 | + sample_rate = self.__feature_embedder.frequency |
| 740 | + hop_length = self.__feature_embedder.hop_len |
| 741 | + n_mfcc = self.__feature_embedder.n_mfcc |
| 742 | + |
| 743 | + file_script_duration_mapping = [] |
| 744 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 745 | + with open(text_file_path, "r") as f: |
| 746 | + script_lines = f.read().split("*****") |
| 747 | + wav_data = [] |
| 748 | + for i, line in enumerate(script_lines): |
| 749 | + normalised_line = line.replace('"', "'") |
| 750 | + espeak_output_file = f"espeak_part_{i}.wav" |
| 751 | + espeak_cmd = f"espeak -v {Language.LANGUAGE_TO_VOICE_CODE[stretch_in_lang]} --stdout -- \"{normalised_line}\" | ffmpeg -y -i - -af 'aresample={sample_rate}' {os.path.join(temp_dir, espeak_output_file)}" |
| 752 | + os.system(espeak_cmd) |
| 753 | + y, sr = librosa.load(os.path.join(temp_dir, espeak_output_file), sr=None) |
| 754 | + wav_data.append(y) |
| 755 | + duration = librosa.get_duration(y=y, sr=sr) |
| 756 | + file_script_duration_mapping.append((os.path.join(temp_dir, espeak_output_file), line, duration)) |
| 757 | + data = np.concatenate(wav_data) |
| 758 | + sf.write(os.path.join(temp_dir, "espeak-all.wav"), data, sr) |
| 759 | + |
| 760 | + y_query, sr_query = librosa.load(os.path.join(temp_dir, "espeak-all.wav"), sr=None) |
| 761 | + query_mfcc_features = librosa.feature.mfcc(y=y_query, sr=sr_query, n_mfcc=n_mfcc, hop_length=hop_length).T |
| 762 | + y_reference, sr_reference = librosa.load(segment_path, sr=sample_rate) |
| 763 | + reference_mfcc_features = librosa.feature.mfcc(y=y_reference, sr=sr_reference, n_mfcc=n_mfcc, hop_length=hop_length).T |
| 764 | + |
| 765 | + alignment = dtw(query_mfcc_features, reference_mfcc_features, keep_internals=False) |
| 766 | + assert len(alignment.index1) == len(alignment.index2), "Mismatch in lengths of alignment indices" |
| 767 | + assert sr_query == sr_reference |
| 768 | + frame_duration = hop_length / sr_query |
| 769 | + |
| 770 | + mapped_times = [] |
| 771 | + start_frame_index = 0 |
| 772 | + for index, (wav_file, line_text, duration) in enumerate(file_script_duration_mapping): |
| 773 | + num_frames_in_query = int(np.ceil(duration / frame_duration)) |
| 774 | + |
| 775 | + query_start_frame = start_frame_index |
| 776 | + query_end_frame = start_frame_index + num_frames_in_query - 1 |
| 777 | + reference_frame_indices = [r for q, r in zip(alignment.index1, alignment.index2) if |
| 778 | + query_start_frame <= q <= query_end_frame] |
| 779 | + reference_start_frame = min(reference_frame_indices) |
| 780 | + reference_end_frame = max(reference_frame_indices) |
| 781 | + |
| 782 | + # TODO: Handle cases where mapped frames are not found in the reference audio |
| 783 | + |
| 784 | + new_reference_start_time = reference_start_frame * frame_duration |
| 785 | + new_reference_end_time = (reference_end_frame + 1) * frame_duration |
| 786 | + |
| 787 | + mapped_times.append({ |
| 788 | + "new_reference_start_time": new_reference_start_time, |
| 789 | + "new_reference_end_time": new_reference_end_time |
| 790 | + }) |
| 791 | + |
| 792 | + start_frame_index = query_end_frame + 1 |
| 793 | + |
| 794 | + with open(os.path.join(temp_dir, "synced_subtitles.srt"), "w") as f: |
| 795 | + for index, entry in enumerate(mapped_times): |
| 796 | + start_srt = Utils.format_timestamp(entry["new_reference_start_time"]) |
| 797 | + end_srt = Utils.format_timestamp(entry["new_reference_end_time"]) |
| 798 | + f.write(f"{index + 1}\n") |
| 799 | + f.write(f"{start_srt} --> {end_srt}\n") |
| 800 | + f.write(f"{script_lines[index]}\n") |
| 801 | + f.write(f"\n") |
| 802 | + f.flush() |
| 803 | + |
| 804 | + adjusted_subs = Subtitle._get_srt_subs( |
| 805 | + subrip_file_path=os.path.join(temp_dir, "synced_subtitles.srt"), |
| 806 | + encoding="utf-8" |
| 807 | + ) |
| 808 | + |
| 809 | + for index, sub_new_loaded in enumerate(adjusted_subs): |
| 810 | + sub_new_loaded.index = subs[index].index |
| 811 | + |
| 812 | + adjusted_subs.shift( |
| 813 | + seconds=self.__media_helper.get_duration_in_seconds( |
| 814 | + start=None, end=str(subs[0].start) |
| 815 | + ) |
| 816 | + ) |
| 817 | + return adjusted_subs |
| 818 | + except KeyboardInterrupt: |
| 819 | + raise TerminalException("Subtitle compress and stretch interrupted by the user") |
| 820 | + finally: |
| 821 | + # Housekeep intermediate files |
| 822 | + if text_file_path is not None and os.path.exists( |
| 823 | + text_file_path |
| 824 | + ): |
| 825 | + os.remove(text_file_path) |
| 826 | + |
718 | 827 | def __predict( |
719 | 828 | self, |
720 | 829 | video_file_path: Optional[str], |
|
0 commit comments