diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 076f0fc..571558c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,7 +37,7 @@ jobs: run: sudo apt-get update && sudo apt-get install -y git ffmpeg - name: Install dependencies - run: pip install -r requirements.txt pytest + run: pip install -r requirements.txt pytest jiwer - name: Run test run: python -m pytest -rs tests \ No newline at end of file diff --git a/README.md b/README.md index df50889..8d85c13 100644 --- a/README.md +++ b/README.md @@ -128,3 +128,6 @@ This is Whisper's original VRAM usage table for models. - [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui) - [ ] Add fast api script - [ ] Support real-time transcription for microphone + +### Translation 🌐 +Any PRs translating Japanese, Spanish, French, German, Chinese, or any other language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated! diff --git a/app.py b/app.py index ea094c5..175522b 100644 --- a/app.py +++ b/app.py @@ -7,17 +7,14 @@ import yaml from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR, INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR, I18N_YAML_PATH) -from modules.utils.constants import AUTOMATIC_DETECTION from modules.utils.files_manager import load_yaml from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.faster_whisper_inference import FasterWhisperInference -from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference from modules.translation.nllb_inference import NLLBInference from modules.ui.htmls import * from modules.utils.cli_manager import str2bool from modules.utils.youtube_manager import get_ytmetas from modules.translation.deepl_api import DeepLAPI -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * class App: @@ -44,7 +41,7 @@ class App: print(f"Use \"{self.args.whisper_type}\" implementation\n" f"Device \"{self.whisper_inf.device}\" is detected") - def create_whisper_parameters(self): + def create_pipeline_inputs(self): whisper_params = self.default_params["whisper"] vad_params = self.default_params["vad"] diarization_params = self.default_params["diarization"] @@ -56,7 +53,7 @@ class App: dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION], value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap() else whisper_params["lang"], label=_("Language")) - dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format")) + dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format")) with gr.Row(): cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"), interactive=True) @@ -66,158 +63,31 @@ class App: interactive=True) with gr.Accordion(_("Advanced Parameters"), open=False): - nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0, - interactive=True, - info="Beam size to use for decoding.") - nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", - value=whisper_params["log_prob_threshold"], interactive=True, - info="If the average log probability over sampled tokens is below this value, treat as failed.") - nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"], - interactive=True, - info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.") - dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, - value=self.whisper_inf.current_compute_type, interactive=True, - allow_custom_value=True, - info="Select the type of computation to perform.") - nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True, - info="Number of candidates when sampling with non-zero temperature.") - nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True, - info="Beam search patience factor.") - cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", - value=whisper_params["condition_on_previous_text"], - interactive=True, - info="Condition on previous text during decoding.") - sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", - value=whisper_params["prompt_reset_on_temperature"], - minimum=0, maximum=1, step=0.01, interactive=True, - info="Resets prompt if temperature is above this value." - " Arg has effect only if 'Condition On Previous Text' is True.") - tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True, - info="Initial prompt to use for decoding.") - sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0, - step=0.01, maximum=1.0, interactive=True, - info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.") - nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", - value=whisper_params["compression_ratio_threshold"], - interactive=True, - info="If the gzip compression ratio is above this value, treat as failed.") - nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"], - precision=0, - info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.") - with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)): - nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"], - info="Exponential length penalty constant.") - nb_repetition_penalty = gr.Number(label="Repetition Penalty", - value=whisper_params["repetition_penalty"], - info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).") - nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", - value=whisper_params["no_repeat_ngram_size"], - precision=0, - info="Prevent repetitions of n-grams with this size (set 0 to disable).") - tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"], - info="Optional text to provide as a prefix for the first window.") - cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"], - info="Suppress blank outputs at the beginning of the sampling.") - tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"], - info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.") - nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", - value=whisper_params["max_initial_timestamp"], - info="The initial timestamp cannot be later than this.") - cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"], - info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.") - tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", - value=whisper_params["prepend_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.") - tb_append_punctuations = gr.Textbox(label="Append Punctuations", - value=whisper_params["append_punctuations"], - info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.") - nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"], - precision=0, - info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.") - nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)", - value=lambda: whisper_params[ - "hallucination_silence_threshold"], - info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.") - tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"], - info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.") - nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", - value=lambda: whisper_params[ - "language_detection_threshold"], - info="If the maximum probability of the language tokens is higher than this value, the language is detected.") - nb_language_detection_segments = gr.Number(label="Language Detection Segments", - value=lambda: whisper_params["language_detection_segments"], - precision=0, - info="Number of segments to consider for the language detection.") - with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)): - nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0) + whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True, + whisper_type=self.args.whisper_type, + available_compute_types=self.whisper_inf.available_compute_types, + compute_type=self.whisper_inf.current_compute_type) with gr.Accordion(_("Background Music Remover Filter"), open=False): - cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"), - value=uvr_params["is_separate_bgm"], - interactive=True, - info=_("Enabling this will remove background music")) - dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device, - choices=self.whisper_inf.music_separator.available_devices) - dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"], - choices=self.whisper_inf.music_separator.available_models) - nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0) - cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"]) - cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"), - value=uvr_params["enable_offload"]) + uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params, + available_models=self.whisper_inf.music_separator.available_models, + available_devices=self.whisper_inf.music_separator.available_devices, + device=self.whisper_inf.music_separator.device) with gr.Accordion(_("Voice Detection Filter"), open=False): - cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"], - interactive=True, - info=_("Enable this to transcribe only detected voice")) - sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", - value=vad_params["threshold"], - info="Lower it to be more sensitive to small sounds.") - nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, - value=vad_params["min_speech_duration_ms"], - info="Final speech chunks shorter than this time are thrown out") - nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", - value=vad_params["max_speech_duration_s"], - info="Maximum duration of speech chunks in \"seconds\".") - nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, - value=vad_params["min_silence_duration_ms"], - info="In the end of each speech chunk wait for this time" - " before separating it") - nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"], - info="Final speech chunks are padded by this time each side") + vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params) with gr.Accordion(_("Diarization"), open=False): - cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"]) - tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"], - info=_("This is only needed the first time you download the model")) - dd_diarization_device = gr.Dropdown(label=_("Device"), - choices=self.whisper_inf.diarizer.get_available_device(), - value=self.whisper_inf.diarizer.get_device()) + diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params, + available_devices=self.whisper_inf.diarizer.available_device, + device=self.whisper_inf.diarizer.device) dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate]) + pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs + return ( - WhisperParameters( - model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size, - log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold, - compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience, - condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt, - temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold, - vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms, - max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms, - speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size, - is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device, - length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty, - no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank, - suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp, - word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations, - append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, - hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords, - language_detection_threshold=nb_language_detection_threshold, - language_detection_segments=nb_language_detection_segments, - prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation, - uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size, - uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload - ), + pipeline_inputs, dd_file_format, cb_timestamp ) @@ -243,7 +113,7 @@ class App: visible=self.args.colab, value="") - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -254,7 +124,7 @@ class App: params = [input_file, tb_input_folder, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_file, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -268,7 +138,7 @@ class App: tb_title = gr.Label(label=_("Youtube Title")) tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -280,7 +150,7 @@ class App: params = [tb_youtubelink, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_youtube, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink], outputs=[img_thumbnail, tb_title, tb_description]) @@ -290,7 +160,7 @@ class App: with gr.Row(): mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True) - whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters() + pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs() with gr.Row(): btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary") @@ -302,7 +172,7 @@ class App: params = [mic_input, dd_file_format, cb_timestamp] btn_run.click(fn=self.whisper_inf.transcribe_mic, - inputs=params + whisper_params.as_list(), + inputs=params + pipeline_params, outputs=[tb_indicator, files_subtitles]) btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None) @@ -417,7 +287,6 @@ class App: # Launch the app with optional gradio settings args = self.args - self.app.queue( api_open=args.api_open ).launch( @@ -447,8 +316,8 @@ class App: parser = argparse.ArgumentParser() -parser.add_argument('--whisper_type', type=str, default="faster-whisper", - choices=["whisper", "faster-whisper", "insanely-fast-whisper"], +parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value, + choices=[item.value for item in WhisperImpl], help='A type of the whisper implementation (Github repo name)') parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value') parser.add_argument('--server_name', type=str, default=None, help='Gradio server host') diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml index 8339eda..483d17a 100644 --- a/configs/default_parameters.yaml +++ b/configs/default_parameters.yaml @@ -1,5 +1,6 @@ whisper: model_size: "medium.en" + file_format: "SRT" lang: "english" is_translate: false beam_size: 5 diff --git a/configs/translation.yaml b/configs/translation.yaml index 1386867..68e35e3 100644 --- a/configs/translation.yaml +++ b/configs/translation.yaml @@ -411,3 +411,49 @@ ru: # Russian Instrumental: Инструментал Vocals: Вокал SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ + +tr: # Turkish + Language: Dil + File: Dosya + Youtube: Youtube + Mic: Mikrofon + T2T Translation: T2T Çeviri + BGM Separation: Arka Plan Müziği Ayırma + GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR + Output: Çıktı + Downloadable output file: İndirilebilir çıktı dosyası + Upload File here: Dosya Yükle + Model: Model + Automatic Detection: Otomatik Algılama + File Format: Dosya Formatı + Translate to English?: İngilizceye Çevir? + Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle + Advanced Parameters: Gelişmiş Parametreler + Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi + Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır + Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir + Save separated files to output: Ayrılmış dosyaları çıktıya kaydet + Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak + Voice Detection Filter: Ses Algılama Filtresi + Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et + Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir + Diarization: Konuşmacı Ayrımı + Enable Diarization: Konuşmacı Ayrımını Etkinleştir + HuggingFace Token: HuggingFace Anahtarı + This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co/pyannote/speaker-diarization-3.1" ve "https://huggingface.co/pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor + Device: Cihaz + Youtube Link: Youtube Bağlantısı + Youtube Thumbnail: Youtube Küçük Resmi + Youtube Title: Youtube Başlığı + Youtube Description: Youtube Açıklaması + Record with Mic: Mikrofonla Kaydet + Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle + Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI) + Source Language: Kaynak Dil + Target Language: Hedef Dil + Pro User?: Pro Kullanıcı? + TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR + Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle + Instrumental: Enstrümantal + Vocals: Vokal + SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py index b4109e8..4313f5c 100644 --- a/modules/diarize/diarize_pipeline.py +++ b/modules/diarize/diarize_pipeline.py @@ -7,6 +7,7 @@ from pyannote.audio import Pipeline from typing import Optional, Union import torch +from modules.whisper.data_classes import * from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.audio_loader import load_audio, SAMPLE_RATE @@ -43,6 +44,8 @@ class DiarizationPipeline: def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): transcript_segments = transcript_result["segments"] + if transcript_segments and isinstance(transcript_segments[0], Segment): + transcript_segments = [seg.model_dump() for seg in transcript_segments] for seg in transcript_segments: # assign speaker to segment (if any) diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'], @@ -63,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): seg["speaker"] = speaker # assign speaker to words - if 'words' in seg: + if 'words' in seg and seg['words'] is not None: for word in seg['words']: if 'start' in word: diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum( @@ -85,10 +88,10 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False): if word_speaker is not None: word["speaker"] = word_speaker - return transcript_result + return {"segments": transcript_segments} -class Segment: +class DiarizationSegment: def __init__(self, start, end, speaker=None): self.start = start self.end = end diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py index 2dd3f94..38e150a 100644 --- a/modules/diarize/diarizer.py +++ b/modules/diarize/diarizer.py @@ -1,6 +1,6 @@ import os import torch -from typing import List, Union, BinaryIO, Optional +from typing import List, Union, BinaryIO, Optional, Tuple import numpy as np import time import logging @@ -8,6 +8,7 @@ import logging from modules.utils.paths import DIARIZATION_MODELS_DIR from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers from modules.diarize.audio_loader import load_audio +from modules.whisper.data_classes import * class Diarizer: @@ -23,10 +24,10 @@ class Diarizer: def run(self, audio: Union[str, BinaryIO, np.ndarray], - transcribed_result: List[dict], + transcribed_result: List[Segment], use_auth_token: str, device: Optional[str] = None - ): + ) -> Tuple[List[Segment], float]: """ Diarize transcribed result as a post-processing @@ -34,7 +35,7 @@ class Diarizer: ---------- audio: Union[str, BinaryIO, np.ndarray] Audio input. This can be file path or binary type. - transcribed_result: List[dict] + transcribed_result: List[Segment] transcribed result through whisper. use_auth_token: str Huggingface token with READ permission. This is only needed the first time you download the model. @@ -44,8 +45,8 @@ class Diarizer: Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ @@ -68,14 +69,20 @@ class Diarizer: {"segments": transcribed_result} ) + segments_result = [] for segment in diarized_result["segments"]: speaker = "None" if "speaker" in segment: speaker = segment["speaker"] - segment["text"] = speaker + "|" + segment["text"].strip() + diarized_text = speaker + "|" + segment["text"].strip() + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=diarized_text + )) elapsed_time = time.time() - start_time - return diarized_result["segments"], elapsed_time + return segments_result, elapsed_time def update_pipe(self, use_auth_token: str, diff --git a/modules/translation/deepl_api.py b/modules/translation/deepl_api.py index 35f245b..0814fb7 100644 --- a/modules/translation/deepl_api.py +++ b/modules/translation/deepl_api.py @@ -139,37 +139,27 @@ class DeepLAPI: ) files_info = {} - for fileobj in fileobjs: - file_path = fileobj - file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) - - if file_ext == ".srt": - parsed_dicts = parse_srt(file_path=file_path) - - elif file_ext == ".vtt": - parsed_dicts = parse_vtt(file_path=file_path) + for file_path in fileobjs: + file_name, file_ext = os.path.splitext(os.path.basename(file_path)) + writer = get_writer(file_ext, self.output_dir) + segments = writer.to_segments(file_path) batch_size = self.max_text_batch_size - for batch_start in range(0, len(parsed_dicts), batch_size): - batch_end = min(batch_start + batch_size, len(parsed_dicts)) - sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]] + for batch_start in range(0, len(segments), batch_size): + progress(batch_start / len(segments), desc="Translating..") + sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]] translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang, target_lang, is_pro) for i, translated_text in enumerate(translated_texts): - parsed_dicts[batch_start + i]["sentence"] = translated_text["text"] - progress(batch_end / len(parsed_dicts), desc="Translating..") + segments[batch_start + i].text = translated_text["text"] - if file_ext == ".srt": - subtitle = get_serialized_srt(parsed_dicts) - elif file_ext == ".vtt": - subtitle = get_serialized_vtt(parsed_dicts) - - if add_timestamp: - timestamp = datetime.now().strftime("%m%d%H%M%S") - file_name += f"-{timestamp}" - - output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}") - write_file(subtitle, output_path) + subtitle, output_path = generate_file( + output_dir=self.output_dir, + output_file_name=file_name, + output_format=file_ext, + result=segments, + add_timestamp=add_timestamp + ) files_info[file_name] = {"subtitle": subtitle, "path": output_path} diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py index abc7f44..6087767 100644 --- a/modules/translation/translation_base.py +++ b/modules/translation/translation_base.py @@ -6,7 +6,7 @@ from typing import List from datetime import datetime import modules.translation.nllb_inference as nllb -from modules.whisper.whisper_parameter import * +from modules.whisper.data_classes import * from modules.utils.subtitle_manager import * from modules.utils.files_manager import load_yaml, save_yaml from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR @@ -95,32 +95,22 @@ class TranslationBase(ABC): files_info = {} for fileobj in fileobjs: file_name, file_ext = os.path.splitext(os.path.basename(fileobj)) - if file_ext == ".srt": - parsed_dicts = parse_srt(file_path=fileobj) - total_progress = len(parsed_dicts) - for index, dic in enumerate(parsed_dicts): - progress(index / total_progress, desc="Translating..") - translated_text = self.translate(dic["sentence"], max_length=max_length) - dic["sentence"] = translated_text - subtitle = get_serialized_srt(parsed_dicts) + writer = get_writer(file_ext, self.output_dir) + segments = writer.to_segments(fileobj) + for i, segment in enumerate(segments): + progress(i / len(segments), desc="Translating..") + translated_text = self.translate(segment.text, max_length=max_length) + segment.text = translated_text - elif file_ext == ".vtt": - parsed_dicts = parse_vtt(file_path=fileobj) - total_progress = len(parsed_dicts) - for index, dic in enumerate(parsed_dicts): - progress(index / total_progress, desc="Translating..") - translated_text = self.translate(dic["sentence"], max_length=max_length) - dic["sentence"] = translated_text - subtitle = get_serialized_vtt(parsed_dicts) + subtitle, file_path = generate_file( + output_dir=self.output_dir, + output_file_name=file_name, + output_format=file_ext, + result=segments, + add_timestamp=add_timestamp + ) - if add_timestamp: - timestamp = datetime.now().strftime("%m%d%H%M%S") - file_name += f"-{timestamp}" - - output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}") - write_file(subtitle, output_path) - - files_info[file_name] = {"subtitle": subtitle, "path": output_path} + files_info[file_name] = {"subtitle": subtitle, "path": file_path} total_result = '' for file_name, info in files_info.items(): @@ -133,7 +123,8 @@ class TranslationBase(ABC): return [gr_str, output_file_paths] except Exception as e: - print(f"Error: {str(e)}") + print(f"Error translating file: {e}") + raise finally: self.release_cuda_memory() diff --git a/modules/utils/constants.py b/modules/utils/constants.py index e9309bc..49b45c8 100644 --- a/modules/utils/constants.py +++ b/modules/utils/constants.py @@ -1,3 +1,6 @@ from gradio_i18n import Translate, gettext as _ AUTOMATIC_DETECTION = _("Automatic Detection") +GRADIO_NONE_STR = "" +GRADIO_NONE_NUMBER_MAX = 9999 +GRADIO_NONE_NUMBER_MIN = 0 diff --git a/modules/utils/files_manager.py b/modules/utils/files_manager.py index 4ac0c63..29b5242 100644 --- a/modules/utils/files_manager.py +++ b/modules/utils/files_manager.py @@ -67,3 +67,9 @@ def is_video(file_path): video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp'] extension = os.path.splitext(file_path)[1].lower() return extension in video_extensions + + +def read_file(file_path): + with open(file_path, "r", encoding="utf-8") as f: + subtitle_content = f.read() + return subtitle_content diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py index 4b48425..1a6ad12 100644 --- a/modules/utils/subtitle_manager.py +++ b/modules/utils/subtitle_manager.py @@ -1,121 +1,425 @@ +# Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py + +import json +import os import re +import sys +import zlib +from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple +from datetime import datetime + +from modules.whisper.data_classes import Segment, Word +from .files_manager import read_file -def timeformat_srt(time): - hours = time // 3600 - minutes = (time - hours * 3600) // 60 - seconds = time - hours * 3600 - minutes * 60 - milliseconds = (time - int(time)) * 1000 - return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}" +def format_timestamp( + seconds: float, always_include_hours: bool = True, decimal_marker: str = "," +) -> str: + assert seconds >= 0, "non-negative timestamp expected" + milliseconds = round(seconds * 1000.0) + + hours = milliseconds // 3_600_000 + milliseconds -= hours * 3_600_000 + + minutes = milliseconds // 60_000 + milliseconds -= minutes * 60_000 + + seconds = milliseconds // 1_000 + milliseconds -= seconds * 1_000 + + hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else "" + return ( + f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}" + ) -def timeformat_vtt(time): - hours = time // 3600 - minutes = (time - hours * 3600) // 60 - seconds = time - hours * 3600 - minutes * 60 - milliseconds = (time - int(time)) * 1000 - return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}" +def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float: + times = time_str.split(":") + + if len(times) == 3: + hours, minutes, rest = times + hours = int(hours) + else: + hours = 0 + minutes, rest = times + + seconds, fractional = rest.split(decimal_marker) + + minutes = int(minutes) + seconds = int(seconds) + fractional_seconds = float("0." + fractional) + + return hours * 3600 + minutes * 60 + seconds + fractional_seconds -def write_file(subtitle, output_file): - with open(output_file, 'w', encoding='utf-8') as f: - f.write(subtitle) +def get_start(segments: List[dict]) -> Optional[float]: + return next( + (w["start"] for s in segments for w in s["words"]), + segments[0]["start"] if segments else None, + ) -def get_srt(segments): - output = "" - for i, segment in enumerate(segments): - output += f"{i + 1}\n" - output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n" - if segment['text'].startswith(' '): - segment['text'] = segment['text'][1:] - output += f"{segment['text']}\n\n" - return output +def get_end(segments: List[dict]) -> Optional[float]: + return next( + (w["end"] for s in reversed(segments) for w in reversed(s["words"])), + segments[-1]["end"] if segments else None, + ) -def get_vtt(segments): - output = "WebVTT\n\n" - for i, segment in enumerate(segments): - output += f"{i + 1}\n" - output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n" - if segment['text'].startswith(' '): - segment['text'] = segment['text'][1:] - output += f"{segment['text']}\n\n" - return output +class ResultWriter: + extension: str + + def __init__(self, output_dir: str): + self.output_dir = output_dir + + def __call__( + self, result: Union[dict, List[Segment]], output_file_name: str, + options: Optional[dict] = None, **kwargs + ): + if isinstance(result, List) and result and isinstance(result[0], Segment): + result = {"segments": [seg.model_dump() for seg in result]} + + output_path = os.path.join( + self.output_dir, output_file_name + "." + self.extension + ) + + with open(output_path, "w", encoding="utf-8") as f: + self.write_result(result, file=f, options=options, **kwargs) + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + raise NotImplementedError -def get_txt(segments): - output = "" - for i, segment in enumerate(segments): - if segment['text'].startswith(' '): - segment['text'] = segment['text'][1:] - output += f"{segment['text']}\n" - return output +class WriteTXT(ResultWriter): + extension: str = "txt" + + def write_result( + self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs + ): + for segment in result["segments"]: + print(segment["text"].strip(), file=file, flush=True) -def parse_srt(file_path): - """Reads SRT file and returns as dict""" - with open(file_path, 'r', encoding='utf-8') as file: - srt_data = file.read() +class SubtitlesWriter(ResultWriter): + always_include_hours: bool + decimal_marker: str - data = [] - blocks = srt_data.split('\n\n') + def iterate_result( + self, + result: dict, + options: Optional[dict] = None, + *, + max_line_width: Optional[int] = None, + max_line_count: Optional[int] = None, + highlight_words: bool = False, + align_lrc_words: bool = False, + max_words_per_line: Optional[int] = None, + ): + options = options or {} + max_line_width = max_line_width or options.get("max_line_width") + max_line_count = max_line_count or options.get("max_line_count") + highlight_words = highlight_words or options.get("highlight_words", False) + align_lrc_words = align_lrc_words or options.get("align_lrc_words", False) + max_words_per_line = max_words_per_line or options.get("max_words_per_line") + preserve_segments = max_line_count is None or max_line_width is None + max_line_width = max_line_width or 1000 + max_words_per_line = max_words_per_line or 1000 - for block in blocks: - if block.strip() != '': - lines = block.strip().split('\n') - index = lines[0] - timestamp = lines[1] - sentence = ' '.join(lines[2:]) + def iterate_subtitles(): + line_len = 0 + line_count = 1 + # the next subtitle to yield (a list of word timings with whitespace) + subtitle: List[dict] = [] + last: float = get_start(result["segments"]) or 0.0 + for segment in result["segments"]: + chunk_index = 0 + words_count = max_words_per_line + while chunk_index < len(segment["words"]): + remaining_words = len(segment["words"]) - chunk_index + if max_words_per_line > len(segment["words"]) - chunk_index: + words_count = remaining_words + for i, original_timing in enumerate( + segment["words"][chunk_index : chunk_index + words_count] + ): + timing = original_timing.copy() + long_pause = ( + not preserve_segments and timing["start"] - last > 3.0 + ) + has_room = line_len + len(timing["word"]) <= max_line_width + seg_break = i == 0 and len(subtitle) > 0 and preserve_segments + if ( + line_len > 0 + and has_room + and not long_pause + and not seg_break + ): + # line continuation + line_len += len(timing["word"]) + else: + # new line + timing["word"] = timing["word"].strip() + if ( + len(subtitle) > 0 + and max_line_count is not None + and (long_pause or line_count >= max_line_count) + or seg_break + ): + # subtitle break + yield subtitle + subtitle = [] + line_count = 1 + elif line_len > 0: + # line break + line_count += 1 + timing["word"] = "\n" + timing["word"] + line_len = len(timing["word"].strip()) + subtitle.append(timing) + last = timing["start"] + chunk_index += max_words_per_line + if len(subtitle) > 0: + yield subtitle - data.append({ - "index": index, - "timestamp": timestamp, - "sentence": sentence - }) - return data + if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]: + for subtitle in iterate_subtitles(): + subtitle_start = self.format_timestamp(subtitle[0]["start"]) + subtitle_end = self.format_timestamp(subtitle[-1]["end"]) + subtitle_text = "".join([word["word"] for word in subtitle]) + if highlight_words: + last = subtitle_start + all_words = [timing["word"] for timing in subtitle] + for i, this_word in enumerate(subtitle): + start = self.format_timestamp(this_word["start"]) + end = self.format_timestamp(this_word["end"]) + if last != start: + yield last, start, subtitle_text + + yield start, end, "".join( + [ + re.sub(r"^(\s*)(.*)$", r"\1\2", word) + if j == i + else word + for j, word in enumerate(all_words) + ] + ) + last = end + + if align_lrc_words: + lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle] + l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end']) + lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]" + lrc_aligned_words = ' '.join(lrc_aligned_words) + yield None, None, lrc_aligned_words + + else: + yield subtitle_start, subtitle_end, subtitle_text + else: + for segment in result["segments"]: + segment_start = self.format_timestamp(segment["start"]) + segment_end = self.format_timestamp(segment["end"]) + segment_text = segment["text"].strip().replace("-->", "->") + yield segment_start, segment_end, segment_text + + def format_timestamp(self, seconds: float): + return format_timestamp( + seconds=seconds, + always_include_hours=self.always_include_hours, + decimal_marker=self.decimal_marker, + ) -def parse_vtt(file_path): - """Reads WebVTT file and returns as dict""" - with open(file_path, 'r', encoding='utf-8') as file: - webvtt_data = file.read() +class WriteVTT(SubtitlesWriter): + extension: str = "vtt" + always_include_hours: bool = False + decimal_marker: str = "." - data = [] - blocks = webvtt_data.split('\n\n') + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("WEBVTT\n", file=file) + for start, end, text in self.iterate_result(result, options, **kwargs): + print(f"{start} --> {end}\n{text}\n", file=file, flush=True) - for block in blocks: - if block.strip() != '' and not block.strip().startswith("WebVTT"): - lines = block.strip().split('\n') - index = lines[0] - timestamp = lines[1] - sentence = ' '.join(lines[2:]) + def to_segments(self, file_path: str) -> List[Segment]: + segments = [] - data.append({ - "index": index, - "timestamp": timestamp, - "sentence": sentence - }) + blocks = read_file(file_path).split('\n\n') - return data + for block in blocks: + if block.strip() != '' and not block.strip().startswith("WEBVTT"): + lines = block.strip().split('\n') + time_line = lines[0].split(" --> ") + start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker) + sentence = ' '.join(lines[1:]) + + segments.append(Segment( + start=start, + end=end, + text=sentence + )) + + return segments -def get_serialized_srt(dicts): - output = "" - for dic in dicts: - output += f'{dic["index"]}\n' - output += f'{dic["timestamp"]}\n' - output += f'{dic["sentence"]}\n\n' - return output +class WriteSRT(SubtitlesWriter): + extension: str = "srt" + always_include_hours: bool = True + decimal_marker: str = "," + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options, **kwargs), start=1 + ): + print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True) + + def to_segments(self, file_path: str) -> List[Segment]: + segments = [] + + blocks = read_file(file_path).split('\n\n') + + for block in blocks: + if block.strip() != '': + lines = block.strip().split('\n') + index = lines[0] + time_line = lines[1].split(" --> ") + start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker) + sentence = ' '.join(lines[2:]) + + segments.append(Segment( + start=start, + end=end, + text=sentence + )) + + return segments -def get_serialized_vtt(dicts): - output = "WebVTT\n\n" - for dic in dicts: - output += f'{dic["index"]}\n' - output += f'{dic["timestamp"]}\n' - output += f'{dic["sentence"]}\n\n' - return output +class WriteLRC(SubtitlesWriter): + extension: str = "lrc" + always_include_hours: bool = False + decimal_marker: str = "." + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for i, (start, end, text) in enumerate( + self.iterate_result(result, options, **kwargs), start=1 + ): + if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]: + print(f"{text}\n", file=file, flush=True) + else: + print(f"[{start}]{text}[{end}]\n", file=file, flush=True) + + def to_segments(self, file_path: str) -> List[Segment]: + segments = [] + + blocks = read_file(file_path).split('\n') + + for block in blocks: + if block.strip() != '': + lines = block.strip() + pattern = r'(\[.*?\])' + parts = re.split(pattern, lines) + parts = [part.strip() for part in parts if part] + + for i, part in enumerate(parts): + sentence_i = i%2 + if sentence_i == 1: + start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1] + start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "") + start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker) + + segments.append(Segment( + start=start, + end=end, + text=text, + )) + + return segments + + +class WriteTSV(ResultWriter): + """ + Write a transcript to a file in TSV (tab-separated values) format containing lines like: + \t\t + + Using integer milliseconds as start and end times means there's no chance of interference from + an environment setting a language encoding that causes the decimal in a floating point number + to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++. + """ + + extension: str = "tsv" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + print("start", "end", "text", sep="\t", file=file) + for segment in result["segments"]: + print(round(1000 * segment["start"]), file=file, end="\t") + print(round(1000 * segment["end"]), file=file, end="\t") + print(segment["text"].strip().replace("\t", " "), file=file, flush=True) + + +class WriteJSON(ResultWriter): + extension: str = "json" + + def write_result( + self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + json.dump(result, file) + + +def get_writer( + output_format: str, output_dir: str +) -> Callable[[dict, TextIO, dict], None]: + output_format = output_format.strip().lower().replace(".", "") + + writers = { + "txt": WriteTXT, + "vtt": WriteVTT, + "srt": WriteSRT, + "tsv": WriteTSV, + "json": WriteJSON, + "lrc": WriteLRC + } + + if output_format == "all": + all_writers = [writer(output_dir) for writer in writers.values()] + + def write_all( + result: dict, file: TextIO, options: Optional[dict] = None, **kwargs + ): + for writer in all_writers: + writer(result, file, options, **kwargs) + + return write_all + + return writers[output_format](output_dir) + + +def generate_file( + output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str, + add_timestamp: bool = True, **kwargs +) -> Tuple[str, str]: + output_format = output_format.strip().lower().replace(".", "") + output_format = "vtt" if output_format == "webvtt" else output_format + + if add_timestamp: + timestamp = datetime.now().strftime("%m%d%H%M%S") + output_file_name += f"-{timestamp}" + + file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}") + file_writer = get_writer(output_format=output_format, output_dir=output_dir) + + if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False): + kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True + + file_writer(result=result, output_file_name=output_file_name, **kwargs) + content = read_file(file_path) + return content, file_path def safe_filename(name): diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py index bb5c919..cb6da93 100644 --- a/modules/vad/silero_vad.py +++ b/modules/vad/silero_vad.py @@ -5,7 +5,8 @@ import numpy as np from typing import BinaryIO, Union, List, Optional, Tuple import warnings import faster_whisper -from faster_whisper.transcribe import SpeechTimestampsMap, Segment +from modules.whisper.data_classes import * +from faster_whisper.transcribe import SpeechTimestampsMap import gradio as gr @@ -247,18 +248,18 @@ class SileroVAD: def restore_speech_timestamps( self, - segments: List[dict], + segments: List[Segment], speech_chunks: List[dict], sampling_rate: Optional[int] = None, - ) -> List[dict]: + ) -> List[Segment]: if sampling_rate is None: sampling_rate = self.sampling_rate ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate) for segment in segments: - segment["start"] = ts_map.get_original_time(segment["start"]) - segment["end"] = ts_map.get_original_time(segment["end"]) + segment.start = ts_map.get_original_time(segment.start) + segment.end = ts_map.get_original_time(segment.end) return segments diff --git a/modules/whisper/whisper_base.py b/modules/whisper/base_transcription_pipeline.py similarity index 66% rename from modules/whisper/whisper_base.py rename to modules/whisper/base_transcription_pipeline.py index 51c87dd..2791dc6 100644 --- a/modules/whisper/whisper_base.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -1,5 +1,4 @@ import os -import torch import whisper import ctranslate2 import gradio as gr @@ -9,21 +8,20 @@ from typing import BinaryIO, Union, Tuple, List import numpy as np from datetime import datetime from faster_whisper.vad import VadOptions -from dataclasses import astuple from modules.uvr.music_separator import MusicSeparator from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH, UVR_MODELS_DIR) -from modules.utils.constants import AUTOMATIC_DETECTION -from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename +from modules.utils.constants import * +from modules.utils.subtitle_manager import * from modules.utils.youtube_manager import get_ytdata, get_ytaudio -from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml -from modules.whisper.whisper_parameter import * +from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file +from modules.whisper.data_classes import * from modules.diarize.diarizer import Diarizer from modules.vad.silero_vad import SileroVAD -class WhisperBase(ABC): +class BaseTranscriptionPipeline(ABC): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -73,13 +71,15 @@ class WhisperBase(ABC): def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), + file_format: str = "SRT", add_timestamp: bool = True, - *whisper_params, - ) -> Tuple[List[dict], float]: + *pipeline_params, + ) -> Tuple[List[Segment], float]: """ Run transcription with conditional pre-processing and post-processing. The VAD will be performed to remove noise from the audio input in pre-processing, if enabled. The diarization will be performed in post-processing, if enabled. + Due to the integration with gradio, the parameters have to be specified with a `*` wildcard. Parameters ---------- @@ -87,40 +87,33 @@ class WhisperBase(ABC): Audio input. This can be file path or binary type. progress: gr.Progress Indicator to show progress directly in gradio. + file_format: str + Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"] add_timestamp: bool Whether to add a timestamp at the end of the filename. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *pipeline_params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class. + This must be provided as a List with * wildcard because of the integration with gradio. + See more info at : https://github.com/gradio-app/gradio/issues/2471 Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for running """ - params = WhisperParameters.as_value(*whisper_params) + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + params = self.validate_gradio_values(params) + bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization - self.cache_parameters( - whisper_params=params, - add_timestamp=add_timestamp - ) - - if params.lang is None: - pass - elif params.lang == AUTOMATIC_DETECTION: - params.lang = None - else: - language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} - params.lang = language_code_dict[params.lang] - - if params.is_bgm_separate: + if bgm_params.is_separate_bgm: music, audio, _ = self.music_separator.separate( audio=audio, - model_name=params.uvr_model_size, - device=params.uvr_device, - segment_size=params.uvr_segment_size, - save_file=params.uvr_save_file, + model_name=bgm_params.model_size, + device=bgm_params.device, + segment_size=bgm_params.segment_size, + save_file=bgm_params.save_file, progress=progress ) @@ -132,47 +125,55 @@ class WhisperBase(ABC): origin_sample_rate = self.music_separator.audio_info.sample_rate audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate) - if params.uvr_enable_offload: + if bgm_params.enable_offload: self.music_separator.offload() - if params.vad_filter: - # Explicit value set for float('inf') from gr.Number() - if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999: - params.max_speech_duration_s = float('inf') - + if vad_params.vad_filter: vad_options = VadOptions( - threshold=params.threshold, - min_speech_duration_ms=params.min_speech_duration_ms, - max_speech_duration_s=params.max_speech_duration_s, - min_silence_duration_ms=params.min_silence_duration_ms, - speech_pad_ms=params.speech_pad_ms + threshold=vad_params.threshold, + min_speech_duration_ms=vad_params.min_speech_duration_ms, + max_speech_duration_s=vad_params.max_speech_duration_s, + min_silence_duration_ms=vad_params.min_silence_duration_ms, + speech_pad_ms=vad_params.speech_pad_ms ) - audio, speech_chunks = self.vad.run( + vad_processed, speech_chunks = self.vad.run( audio=audio, vad_parameters=vad_options, progress=progress ) + if vad_processed.size > 0: + audio = vad_processed + else: + vad_params.vad_filter = False + result, elapsed_time = self.transcribe( audio, progress, - *astuple(params) + *whisper_params.to_list() ) - if params.vad_filter: + if vad_params.vad_filter: result = self.vad.restore_speech_timestamps( segments=result, speech_chunks=speech_chunks, ) - if params.is_diarize: + if diarization_params.is_diarize: result, elapsed_time_diarization = self.diarizer.run( audio=audio, - use_auth_token=params.hf_token, + use_auth_token=diarization_params.hf_token, transcribed_result=result, + device=diarization_params.device ) elapsed_time += elapsed_time_diarization + + self.cache_parameters( + params=params, + file_format=file_format, + add_timestamp=add_timestamp + ) return result, elapsed_time def transcribe_file(self, @@ -181,8 +182,8 @@ class WhisperBase(ABC): file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, - ) -> list: + *pipeline_params, + ) -> Tuple[str, List]: """ Write subtitle file from Files @@ -199,8 +200,8 @@ class WhisperBase(ABC): Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple - Parameters related with whisper. This will be dealt with "WhisperParameters" data class + *pipeline_params: tuple + Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class Returns ---------- @@ -210,6 +211,11 @@ class WhisperBase(ABC): Output file path to return to gr.Files() """ try: + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + writer_options = { + "highlight_words": True if params.whisper.word_timestamps else False + } + if input_folder_path: files = get_media_files(input_folder_path) if isinstance(files, str): @@ -222,19 +228,21 @@ class WhisperBase(ABC): transcribed_segments, time_for_task = self.run( file, progress, + file_format, add_timestamp, - *whisper_params, + *pipeline_params, ) file_name, file_ext = os.path.splitext(os.path.basename(file)) - subtitle, file_path = self.generate_and_write_file( - file_name=file_name, - transcribed_segments=transcribed_segments, + subtitle, file_path = generate_file( + output_dir=self.output_dir, + output_file_name=file_name, + output_format=file_format, + result=transcribed_segments, add_timestamp=add_timestamp, - file_format=file_format, - output_dir=self.output_dir + **writer_options ) - files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path} + files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path} total_result = '' total_time = 0 @@ -247,10 +255,11 @@ class WhisperBase(ABC): result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}" result_file_path = [info['path'] for info in files_info.values()] - return [result_str, result_file_path] + return result_str, result_file_path except Exception as e: print(f"Error transcribing file: {e}") + raise finally: self.release_cuda_memory() @@ -259,8 +268,8 @@ class WhisperBase(ABC): file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, - ) -> list: + *pipeline_params, + ) -> Tuple[str, str]: """ Write subtitle file from microphone @@ -274,7 +283,7 @@ class WhisperBase(ABC): Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple + *pipeline_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns @@ -285,27 +294,36 @@ class WhisperBase(ABC): Output file path to return to gr.Files() """ try: + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + writer_options = { + "highlight_words": True if params.whisper.word_timestamps else False + } + progress(0, desc="Loading Audio..") transcribed_segments, time_for_task = self.run( mic_audio, progress, + file_format, add_timestamp, - *whisper_params, + *pipeline_params, ) progress(1, desc="Completed!") - subtitle, result_file_path = self.generate_and_write_file( - file_name="Mic", - transcribed_segments=transcribed_segments, + file_name = "Mic" + subtitle, file_path = generate_file( + output_dir=self.output_dir, + output_file_name=file_name, + output_format=file_format, + result=transcribed_segments, add_timestamp=add_timestamp, - file_format=file_format, - output_dir=self.output_dir + **writer_options ) result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" - return [result_str, result_file_path] + return result_str, file_path except Exception as e: - print(f"Error transcribing file: {e}") + print(f"Error transcribing mic: {e}") + raise finally: self.release_cuda_memory() @@ -314,8 +332,8 @@ class WhisperBase(ABC): file_format: str = "SRT", add_timestamp: bool = True, progress=gr.Progress(), - *whisper_params, - ) -> list: + *pipeline_params, + ) -> Tuple[str, str]: """ Write subtitle file from Youtube @@ -329,7 +347,7 @@ class WhisperBase(ABC): Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename. progress: gr.Progress Indicator to show progress directly in gradio. - *whisper_params: tuple + *pipeline_params: tuple Parameters related with whisper. This will be dealt with "WhisperParameters" data class Returns @@ -340,6 +358,11 @@ class WhisperBase(ABC): Output file path to return to gr.Files() """ try: + params = TranscriptionPipelineParams.from_list(list(pipeline_params)) + writer_options = { + "highlight_words": True if params.whisper.word_timestamps else False + } + progress(0, desc="Loading Audio from Youtube..") yt = get_ytdata(youtube_link) audio = get_ytaudio(yt) @@ -347,29 +370,33 @@ class WhisperBase(ABC): transcribed_segments, time_for_task = self.run( audio, progress, + file_format, add_timestamp, - *whisper_params, + *pipeline_params, ) progress(1, desc="Completed!") file_name = safe_filename(yt.title) - subtitle, result_file_path = self.generate_and_write_file( - file_name=file_name, - transcribed_segments=transcribed_segments, + subtitle, file_path = generate_file( + output_dir=self.output_dir, + output_file_name=file_name, + output_format=file_format, + result=transcribed_segments, add_timestamp=add_timestamp, - file_format=file_format, - output_dir=self.output_dir + **writer_options ) + result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}" if os.path.exists(audio): os.remove(audio) - return [result_str, result_file_path] + return result_str, file_path except Exception as e: - print(f"Error transcribing file: {e}") + print(f"Error transcribing youtube: {e}") + raise finally: self.release_cuda_memory() @@ -387,58 +414,6 @@ class WhisperBase(ABC): else: return list(ctranslate2.get_supported_compute_types("cpu")) - @staticmethod - def generate_and_write_file(file_name: str, - transcribed_segments: list, - add_timestamp: bool, - file_format: str, - output_dir: str - ) -> str: - """ - Writes subtitle file - - Parameters - ---------- - file_name: str - Output file name - transcribed_segments: list - Text segments transcribed from audio - add_timestamp: bool - Determines whether to add a timestamp to the end of the filename. - file_format: str - File format to write. Supported formats: [SRT, WebVTT, txt] - output_dir: str - Directory path of the output - - Returns - ---------- - content: str - Result of the transcription - output_path: str - output file path - """ - if add_timestamp: - timestamp = datetime.now().strftime("%m%d%H%M%S") - output_path = os.path.join(output_dir, f"{file_name}-{timestamp}") - else: - output_path = os.path.join(output_dir, f"{file_name}") - - file_format = file_format.strip().lower() - if file_format == "srt": - content = get_srt(transcribed_segments) - output_path += '.srt' - - elif file_format == "webvtt": - content = get_vtt(transcribed_segments) - output_path += '.vtt' - - elif file_format == "txt": - content = get_txt(transcribed_segments) - output_path += '.txt' - - write_file(content, output_path) - return content, output_path - @staticmethod def format_time(elapsed_time: float) -> str: """ @@ -471,7 +446,7 @@ class WhisperBase(ABC): if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): - if not WhisperBase.is_sparse_api_supported(): + if not BaseTranscriptionPipeline.is_sparse_api_supported(): # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886 return "cpu" return "mps" @@ -513,17 +488,64 @@ class WhisperBase(ABC): os.remove(file_path) @staticmethod - def cache_parameters( - whisper_params: WhisperValues, - add_timestamp: bool - ): - """cache parameters to the yaml file""" - cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) - cached_whisper_param = whisper_params.to_yaml() - cached_yaml = {**cached_params, **cached_whisper_param} - cached_yaml["whisper"]["add_timestamp"] = add_timestamp + def validate_gradio_values(params: TranscriptionPipelineParams): + """ + Validate gradio specific values that can't be displayed as None in the UI. + Related issue : https://github.com/gradio-app/gradio/issues/8723 + """ + if params.whisper.lang is None: + pass + elif params.whisper.lang == AUTOMATIC_DETECTION: + params.whisper.lang = None + else: + language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()} + params.whisper.lang = language_code_dict[params.whisper.lang] - save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) + if params.whisper.initial_prompt == GRADIO_NONE_STR: + params.whisper.initial_prompt = None + if params.whisper.prefix == GRADIO_NONE_STR: + params.whisper.prefix = None + if params.whisper.hotwords == GRADIO_NONE_STR: + params.whisper.hotwords = None + if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN: + params.whisper.max_new_tokens = None + if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN: + params.whisper.hallucination_silence_threshold = None + if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN: + params.whisper.language_detection_threshold = None + if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX: + params.vad.max_speech_duration_s = float('inf') + return params + + @staticmethod + def cache_parameters( + params: TranscriptionPipelineParams, + file_format: str = "SRT", + add_timestamp: bool = True + ): + """Cache parameters to the yaml file""" + cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) + param_to_cache = params.to_dict() + + cached_yaml = {**cached_params, **param_to_cache} + cached_yaml["whisper"]["add_timestamp"] = add_timestamp + cached_yaml["whisper"]["file_format"] = file_format + + supress_token = cached_yaml["whisper"].get("suppress_tokens", None) + if supress_token and isinstance(supress_token, list): + cached_yaml["whisper"]["suppress_tokens"] = str(supress_token) + + if cached_yaml["whisper"].get("lang", None) is None: + cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap() + else: + language_dict = whisper.tokenizer.LANGUAGES + cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]] + + if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'): + cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX + + if cached_yaml is not None and cached_yaml: + save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH) @staticmethod def resample_audio(audio: Union[str, np.ndarray], diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py new file mode 100644 index 0000000..2bd9c31 --- /dev/null +++ b/modules/whisper/data_classes.py @@ -0,0 +1,608 @@ +import faster_whisper.transcribe +import gradio as gr +import torch +from typing import Optional, Dict, List, Union, NamedTuple +from pydantic import BaseModel, Field, field_validator, ConfigDict +from gradio_i18n import Translate, gettext as _ +from enum import Enum +from copy import deepcopy + +import yaml + +from modules.utils.constants import * + + +class WhisperImpl(Enum): + WHISPER = "whisper" + FASTER_WHISPER = "faster-whisper" + INSANELY_FAST_WHISPER = "insanely_fast_whisper" + + +class Segment(BaseModel): + id: Optional[int] = Field(default=None, description="Incremental id for the segment") + seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio") + text: Optional[str] = Field(default=None, description="Transcription text of the segment") + start: Optional[float] = Field(default=None, description="Start time of the segment") + end: Optional[float] = Field(default=None, description="End time of the segment") + tokens: Optional[List[int]] = Field(default=None, description="List of token IDs") + temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process") + avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens") + compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment") + no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech") + words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment") + + @classmethod + def from_faster_whisper(cls, + seg: faster_whisper.transcribe.Segment): + if seg.words is not None: + words = [ + Word( + start=w.start, + end=w.end, + word=w.word, + probability=w.probability + ) for w in seg.words + ] + else: + words = None + + return cls( + id=seg.id, + seek=seg.seek, + text=seg.text, + start=seg.start, + end=seg.end, + tokens=seg.tokens, + temperature=seg.temperature, + avg_logprob=seg.avg_logprob, + compression_ratio=seg.compression_ratio, + no_speech_prob=seg.no_speech_prob, + words=words + ) + + +class Word(BaseModel): + start: Optional[float] = Field(default=None, description="Start time of the word") + end: Optional[float] = Field(default=None, description="Start time of the word") + word: Optional[str] = Field(default=None, description="Word text") + probability: Optional[float] = Field(default=None, description="Probability of the word") + + +class BaseParams(BaseModel): + model_config = ConfigDict(protected_namespaces=()) + + def to_dict(self) -> Dict: + return self.model_dump() + + def to_list(self) -> List: + return list(self.model_dump().values()) + + @classmethod + def from_list(cls, data_list: List) -> 'BaseParams': + field_names = list(cls.model_fields.keys()) + return cls(**dict(zip(field_names, data_list))) + + +class VadParams(BaseParams): + """Voice Activity Detection parameters""" + vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts") + threshold: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Speech threshold for Silero VAD. Probabilities above this value are considered speech" + ) + min_speech_duration_ms: int = Field( + default=250, + ge=0, + description="Final speech chunks shorter than this are discarded" + ) + max_speech_duration_s: float = Field( + default=float("inf"), + gt=0, + description="Maximum duration of speech chunks in seconds" + ) + min_silence_duration_ms: int = Field( + default=2000, + ge=0, + description="Minimum silence duration between speech chunks" + ) + speech_pad_ms: int = Field( + default=400, + ge=0, + description="Padding added to each side of speech chunks" + ) + + @classmethod + def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Silero VAD Filter"), + value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default), + interactive=True, + info=_("Enable this to transcribe only detected voice") + ), + gr.Slider( + minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", + value=defaults.get("threshold", cls.__fields__["threshold"].default), + info="Lower it to be more sensitive to small sounds." + ), + gr.Number( + label="Minimum Speech Duration (ms)", precision=0, + value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default), + info="Final speech chunks shorter than this time are thrown out" + ), + gr.Number( + label="Maximum Speech Duration (s)", + value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX), + info="Maximum duration of speech chunks in \"seconds\"." + ), + gr.Number( + label="Minimum Silence Duration (ms)", precision=0, + value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default), + info="In the end of each speech chunk wait for this time before separating it" + ), + gr.Number( + label="Speech Padding (ms)", precision=0, + value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default), + info="Final speech chunks are padded by this time each side" + ) + ] + + +class DiarizationParams(BaseParams): + """Speaker diarization parameters""" + is_diarize: bool = Field(default=False, description="Enable speaker diarization") + device: str = Field(default="cuda", description="Device to run Diarization model.") + hf_token: str = Field( + default="", + description="Hugging Face token for downloading diarization models" + ) + + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Diarization"), + value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default), + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), + gr.Textbox( + label=_("HuggingFace Token"), + value=defaults.get("hf_token", cls.__fields__["hf_token"].default), + info=_("This is only needed the first time you download the model") + ), + ] + + +class BGMSeparationParams(BaseParams): + """Background music separation parameters""" + is_separate_bgm: bool = Field(default=False, description="Enable background music separation") + model_size: str = Field( + default="UVR-MDX-NET-Inst_HQ_4", + description="UVR model size" + ) + device: str = Field(default="cuda", description="Device to run UVR model.") + segment_size: int = Field( + default=256, + gt=0, + description="Segment size for UVR model" + ) + save_file: bool = Field( + default=False, + description="Whether to save separated audio files" + ) + enable_offload: bool = Field( + default=True, + description="Offload UVR model after transcription" + ) + + @classmethod + def to_gradio_input(cls, + defaults: Optional[Dict] = None, + available_devices: Optional[List] = None, + device: Optional[str] = None, + available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]: + return [ + gr.Checkbox( + label=_("Enable Background Music Remover Filter"), + value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default), + interactive=True, + info=_("Enabling this will remove background music") + ), + gr.Dropdown( + label=_("Model"), + choices=["UVR-MDX-NET-Inst_HQ_4", + "UVR-MDX-NET-Inst_3"] if available_models is None else available_models, + value=defaults.get("model_size", cls.__fields__["model_size"].default), + ), + gr.Dropdown( + label=_("Device"), + choices=["cpu", "cuda"] if available_devices is None else available_devices, + value=defaults.get("device", device), + ), + gr.Number( + label="Segment Size", + value=defaults.get("segment_size", cls.__fields__["segment_size"].default), + precision=0, + info="Segment size for UVR model" + ), + gr.Checkbox( + label=_("Save separated files to output"), + value=defaults.get("save_file", cls.__fields__["save_file"].default), + ), + gr.Checkbox( + label=_("Offload sub model after removing background music"), + value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default), + ) + ] + + +class WhisperParams(BaseParams): + """Whisper parameters""" + model_size: str = Field(default="large-v2", description="Whisper model size") + lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe") + is_translate: bool = Field(default=False, description="Translate speech to English end-to-end") + beam_size: int = Field(default=5, ge=1, description="Beam size for decoding") + log_prob_threshold: float = Field( + default=-1.0, + description="Threshold for average log probability of sampled tokens" + ) + no_speech_threshold: float = Field( + default=0.6, + ge=0.0, + le=1.0, + description="Threshold for detecting silence" + ) + compute_type: str = Field(default="float16", description="Computation type for transcription") + best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling") + patience: float = Field(default=1.0, gt=0, description="Beam search patience factor") + condition_on_previous_text: bool = Field( + default=True, + description="Use previous output as prompt for next window" + ) + prompt_reset_on_temperature: float = Field( + default=0.5, + ge=0.0, + le=1.0, + description="Temperature threshold for resetting prompt" + ) + initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window") + temperature: float = Field( + default=0.0, + ge=0.0, + description="Temperature for sampling" + ) + compression_ratio_threshold: float = Field( + default=2.4, + gt=0, + description="Threshold for gzip compression ratio" + ) + length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty") + repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens") + no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition") + prefix: Optional[str] = Field(default=None, description="Prefix text for first window") + suppress_blank: bool = Field( + default=True, + description="Suppress blank outputs at start of sampling" + ) + suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress") + max_initial_timestamp: float = Field( + default=1.0, + ge=0.0, + description="Maximum initial timestamp" + ) + word_timestamps: bool = Field(default=False, description="Extract word-level timestamps") + prepend_punctuations: Optional[str] = Field( + default="\"'“¿([{-", + description="Punctuations to merge with next word" + ) + append_punctuations: Optional[str] = Field( + default="\"'.。,,!!??::”)]}、", + description="Punctuations to merge with previous word" + ) + max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk") + chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds") + hallucination_silence_threshold: Optional[float] = Field( + default=None, + description="Threshold for skipping silent periods in hallucination detection" + ) + hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model") + language_detection_threshold: Optional[float] = Field( + default=None, + description="Threshold for language detection probability" + ) + language_detection_segments: int = Field( + default=1, + gt=0, + description="Number of segments for language detection" + ) + batch_size: int = Field(default=24, gt=0, description="Batch size for processing") + + @field_validator('lang') + def validate_lang(cls, v): + from modules.utils.constants import AUTOMATIC_DETECTION + return None if v == AUTOMATIC_DETECTION.unwrap() else v + + @field_validator('suppress_tokens') + def validate_supress_tokens(cls, v): + import ast + try: + if isinstance(v, str): + suppress_tokens = ast.literal_eval(v) + if not isinstance(suppress_tokens, list): + raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]") + return suppress_tokens + if isinstance(v, list): + return v + except Exception as e: + raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}") + + @classmethod + def to_gradio_inputs(cls, + defaults: Optional[Dict] = None, + only_advanced: Optional[bool] = True, + whisper_type: Optional[str] = None, + available_models: Optional[List] = None, + available_langs: Optional[List] = None, + available_compute_types: Optional[List] = None, + compute_type: Optional[str] = None): + whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower() + + inputs = [] + if not only_advanced: + inputs += [ + gr.Dropdown( + label=_("Model"), + choices=available_models, + value=defaults.get("model_size", cls.__fields__["model_size"].default), + ), + gr.Dropdown( + label=_("Language"), + choices=available_langs, + value=defaults.get("lang", AUTOMATIC_DETECTION), + ), + gr.Checkbox( + label=_("Translate to English?"), + value=defaults.get("is_translate", cls.__fields__["is_translate"].default), + ), + ] + + inputs += [ + gr.Number( + label="Beam Size", + value=defaults.get("beam_size", cls.__fields__["beam_size"].default), + precision=0, + info="Beam size for decoding" + ), + gr.Number( + label="Log Probability Threshold", + value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default), + info="Threshold for average log probability of sampled tokens" + ), + gr.Number( + label="No Speech Threshold", + value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default), + info="Threshold for detecting silence" + ), + gr.Dropdown( + label="Compute Type", + choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types, + value=defaults.get("compute_type", compute_type), + info="Computation type for transcription" + ), + gr.Number( + label="Best Of", + value=defaults.get("best_of", cls.__fields__["best_of"].default), + precision=0, + info="Number of candidates when sampling" + ), + gr.Number( + label="Patience", + value=defaults.get("patience", cls.__fields__["patience"].default), + info="Beam search patience factor" + ), + gr.Checkbox( + label="Condition On Previous Text", + value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default), + info="Use previous output as prompt for next window" + ), + gr.Slider( + label="Prompt Reset On Temperature", + value=defaults.get("prompt_reset_on_temperature", + cls.__fields__["prompt_reset_on_temperature"].default), + minimum=0, + maximum=1, + step=0.01, + info="Temperature threshold for resetting prompt" + ), + gr.Textbox( + label="Initial Prompt", + value=defaults.get("initial_prompt", GRADIO_NONE_STR), + info="Initial prompt for first window" + ), + gr.Slider( + label="Temperature", + value=defaults.get("temperature", cls.__fields__["temperature"].default), + minimum=0.0, + step=0.01, + maximum=1.0, + info="Temperature for sampling" + ), + gr.Number( + label="Compression Ratio Threshold", + value=defaults.get("compression_ratio_threshold", + cls.__fields__["compression_ratio_threshold"].default), + info="Threshold for gzip compression ratio" + ) + ] + + faster_whisper_inputs = [ + gr.Number( + label="Length Penalty", + value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default), + info="Exponential length penalty", + ), + gr.Number( + label="Repetition Penalty", + value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default), + info="Penalty for repeated tokens" + ), + gr.Number( + label="No Repeat N-gram Size", + value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default), + precision=0, + info="Size of n-grams to prevent repetition" + ), + gr.Textbox( + label="Prefix", + value=defaults.get("prefix", GRADIO_NONE_STR), + info="Prefix text for first window" + ), + gr.Checkbox( + label="Suppress Blank", + value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default), + info="Suppress blank outputs at start of sampling" + ), + gr.Textbox( + label="Suppress Tokens", + value=defaults.get("suppress_tokens", "[-1]"), + info="Token IDs to suppress" + ), + gr.Number( + label="Max Initial Timestamp", + value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default), + info="Maximum initial timestamp" + ), + gr.Checkbox( + label="Word Timestamps", + value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default), + info="Extract word-level timestamps" + ), + gr.Textbox( + label="Prepend Punctuations", + value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default), + info="Punctuations to merge with next word" + ), + gr.Textbox( + label="Append Punctuations", + value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default), + info="Punctuations to merge with previous word" + ), + gr.Number( + label="Max New Tokens", + value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN), + precision=0, + info="Maximum number of new tokens per chunk" + ), + gr.Number( + label="Chunk Length (s)", + value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default), + precision=0, + info="Length of audio segments in seconds" + ), + gr.Number( + label="Hallucination Silence Threshold (sec)", + value=defaults.get("hallucination_silence_threshold", + GRADIO_NONE_NUMBER_MIN), + info="Threshold for skipping silent periods in hallucination detection" + ), + gr.Textbox( + label="Hotwords", + value=defaults.get("hotwords", cls.__fields__["hotwords"].default), + info="Hotwords/hint phrases for the model" + ), + gr.Number( + label="Language Detection Threshold", + value=defaults.get("language_detection_threshold", + GRADIO_NONE_NUMBER_MIN), + info="Threshold for language detection probability" + ), + gr.Number( + label="Language Detection Segments", + value=defaults.get("language_detection_segments", + cls.__fields__["language_detection_segments"].default), + precision=0, + info="Number of segments for language detection" + ) + ] + + insanely_fast_whisper_inputs = [ + gr.Number( + label="Batch Size", + value=defaults.get("batch_size", cls.__fields__["batch_size"].default), + precision=0, + info="Batch size for processing" + ) + ] + + if whisper_type != WhisperImpl.FASTER_WHISPER.value: + for input_component in faster_whisper_inputs: + input_component.visible = False + + if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value: + for input_component in insanely_fast_whisper_inputs: + input_component.visible = False + + inputs += faster_whisper_inputs + insanely_fast_whisper_inputs + + return inputs + + +class TranscriptionPipelineParams(BaseModel): + """Transcription pipeline parameters""" + whisper: WhisperParams = Field(default_factory=WhisperParams) + vad: VadParams = Field(default_factory=VadParams) + diarization: DiarizationParams = Field(default_factory=DiarizationParams) + bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams) + + def to_dict(self) -> Dict: + data = { + "whisper": self.whisper.to_dict(), + "vad": self.vad.to_dict(), + "diarization": self.diarization.to_dict(), + "bgm_separation": self.bgm_separation.to_dict() + } + return data + + def to_list(self) -> List: + """ + Convert data class to the list because I have to pass the parameters as a list in the gradio. + Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 + See more about Gradio pre-processing: https://www.gradio.app/docs/components + """ + whisper_list = self.whisper.to_list() + vad_list = self.vad.to_list() + diarization_list = self.diarization.to_list() + bgm_sep_list = self.bgm_separation.to_list() + return whisper_list + vad_list + diarization_list + bgm_sep_list + + @staticmethod + def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams': + """Convert list to the data class again to use it in a function.""" + data_list = deepcopy(pipeline_list) + + whisper_list = data_list[0:len(WhisperParams.__annotations__)] + data_list = data_list[len(WhisperParams.__annotations__):] + + vad_list = data_list[0:len(VadParams.__annotations__)] + data_list = data_list[len(VadParams.__annotations__):] + + diarization_list = data_list[0:len(DiarizationParams.__annotations__)] + data_list = data_list[len(DiarizationParams.__annotations__):] + + bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)] + + return TranscriptionPipelineParams( + whisper=WhisperParams.from_list(whisper_list), + vad=VadParams.from_list(vad_list), + diarization=DiarizationParams.from_list(diarization_list), + bgm_separation=BGMSeparationParams.from_list(bgm_sep_list) + ) diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py index f12fc01..bc1e8ed 100644 --- a/modules/whisper/faster_whisper_inference.py +++ b/modules/whisper/faster_whisper_inference.py @@ -12,11 +12,11 @@ import gradio as gr from argparse import Namespace from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.data_classes import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class FasterWhisperInference(WhisperBase): +class FasterWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = FASTER_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase): audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase): Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) - # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723 - if not params.initial_prompt: - params.initial_prompt = None - if not params.prefix: - params.prefix = None - if not params.hotwords: - params.hotwords = None - - params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens) - segments, info = self.model.transcribe( audio=audio, language=params.lang, @@ -112,11 +102,7 @@ class FasterWhisperInference(WhisperBase): segments_result = [] for segment in segments: progress(segment.start / info.duration, desc="Transcribing..") - segments_result.append({ - "start": segment.start, - "end": segment.end, - "text": segment.text - }) + segments_result.append(Segment.from_faster_whisper(segment)) elapsed_time = time.time() - start_time return segments_result, elapsed_time diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py index fe6f4fd..2773166 100644 --- a/modules/whisper/insanely_fast_whisper_inference.py +++ b/modules/whisper/insanely_fast_whisper_inference.py @@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn from argparse import Namespace from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR) -from modules.whisper.whisper_parameter import * -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.data_classes import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline -class InsanelyFastWhisperInference(WhisperBase): +class InsanelyFastWhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -32,15 +32,13 @@ class InsanelyFastWhisperInference(WhisperBase): self.model_dir = model_dir os.makedirs(self.model_dir, exist_ok=True) - openai_models = whisper.available_models() - distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] - self.available_models = openai_models + distil_models + self.available_models = self.get_model_paths() def transcribe(self, audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -55,13 +53,13 @@ class InsanelyFastWhisperInference(WhisperBase): Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) @@ -95,9 +93,17 @@ class InsanelyFastWhisperInference(WhisperBase): generate_kwargs=kwargs ) - segments_result = self.format_result( - transcribed_result=segments, - ) + segments_result = [] + for item in segments["chunks"]: + start, end = item["timestamp"][0], item["timestamp"][1] + if end is None: + end = start + segments_result.append(Segment( + text=item["text"], + start=start, + end=end + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time @@ -138,31 +144,26 @@ class InsanelyFastWhisperInference(WhisperBase): model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"}, ) - @staticmethod - def format_result( - transcribed_result: dict - ) -> List[dict]: + def get_model_paths(self): """ - Format the transcription result of insanely_fast_whisper as the same with other implementation. - - Parameters - ---------- - transcribed_result: dict - Transcription result of the insanely_fast_whisper + Get available models from models path including fine-tuned model. Returns ---------- - result: List[dict] - Formatted result as the same with other implementation + Name set of models """ - result = transcribed_result["chunks"] - for item in result: - start, end = item["timestamp"][0], item["timestamp"][1] - if end is None: - end = start - item["start"] = start - item["end"] = end - return result + openai_models = whisper.available_models() + distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"] + default_models = openai_models + distil_models + + existing_models = os.listdir(self.model_dir) + wrong_dirs = [".locks"] + + available_models = default_models + existing_models + available_models = [model for model in available_models if model not in wrong_dirs] + available_models = sorted(set(available_models), key=available_models.index) + + return available_models @staticmethod def download_model( diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py index f87fbe5..ccd4bbb 100644 --- a/modules/whisper/whisper_Inference.py +++ b/modules/whisper/whisper_Inference.py @@ -8,11 +8,11 @@ import os from argparse import Namespace from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR) -from modules.whisper.whisper_base import WhisperBase -from modules.whisper.whisper_parameter import * +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline +from modules.whisper.data_classes import * -class WhisperInference(WhisperBase): +class WhisperInference(BaseTranscriptionPipeline): def __init__(self, model_dir: str = WHISPER_MODELS_DIR, diarization_model_dir: str = DIARIZATION_MODELS_DIR, @@ -30,7 +30,7 @@ class WhisperInference(WhisperBase): audio: Union[str, np.ndarray, torch.Tensor], progress: gr.Progress = gr.Progress(), *whisper_params, - ) -> Tuple[List[dict], float]: + ) -> Tuple[List[Segment], float]: """ transcribe method for faster-whisper. @@ -45,13 +45,13 @@ class WhisperInference(WhisperBase): Returns ---------- - segments_result: List[dict] - list of dicts that includes start, end timestamps and transcribed text + segments_result: List[Segment] + list of Segment that includes start, end timestamps and transcribed text elapsed_time: float elapsed time for transcription """ start_time = time.time() - params = WhisperParameters.as_value(*whisper_params) + params = WhisperParams.from_list(list(whisper_params)) if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type: self.update_model(params.model_size, params.compute_type, progress) @@ -59,21 +59,28 @@ class WhisperInference(WhisperBase): def progress_callback(progress_value): progress(progress_value, desc="Transcribing..") - segments_result = self.model.transcribe(audio=audio, - language=params.lang, - verbose=False, - beam_size=params.beam_size, - logprob_threshold=params.log_prob_threshold, - no_speech_threshold=params.no_speech_threshold, - task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", - fp16=True if params.compute_type == "float16" else False, - best_of=params.best_of, - patience=params.patience, - temperature=params.temperature, - compression_ratio_threshold=params.compression_ratio_threshold, - progress_callback=progress_callback,)["segments"] - elapsed_time = time.time() - start_time + result = self.model.transcribe(audio=audio, + language=params.lang, + verbose=False, + beam_size=params.beam_size, + logprob_threshold=params.log_prob_threshold, + no_speech_threshold=params.no_speech_threshold, + task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe", + fp16=True if params.compute_type == "float16" else False, + best_of=params.best_of, + patience=params.patience, + temperature=params.temperature, + compression_ratio_threshold=params.compression_ratio_threshold, + progress_callback=progress_callback,)["segments"] + segments_result = [] + for segment in result: + segments_result.append(Segment( + start=segment["start"], + end=segment["end"], + text=segment["text"] + )) + elapsed_time = time.time() - start_time return segments_result, elapsed_time def update_model(self, diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py index 6bda8c5..b5ae33a 100644 --- a/modules/whisper/whisper_factory.py +++ b/modules/whisper/whisper_factory.py @@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D from modules.whisper.faster_whisper_inference import FasterWhisperInference from modules.whisper.whisper_Inference import WhisperInference from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference -from modules.whisper.whisper_base import WhisperBase +from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline +from modules.whisper.data_classes import * class WhisperFactory: @@ -19,7 +20,7 @@ class WhisperFactory: diarization_model_dir: str = DIARIZATION_MODELS_DIR, uvr_model_dir: str = UVR_MODELS_DIR, output_dir: str = OUTPUT_DIR, - ) -> "WhisperBase": + ) -> "BaseTranscriptionPipeline": """ Create a whisper inference class based on the provided whisper_type. @@ -45,36 +46,29 @@ class WhisperFactory: Returns ------- - WhisperBase + BaseTranscriptionPipeline An instance of the appropriate whisper inference class based on the whisper_type. """ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144 os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' - whisper_type = whisper_type.lower().strip() + whisper_type = whisper_type.strip().lower() - faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"] - whisper_typos = ["whisper"] - insanely_fast_whisper_typos = [ - "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper", - "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper" - ] - - if whisper_type in faster_whisper_typos: + if whisper_type == WhisperImpl.FASTER_WHISPER.value: return FasterWhisperInference( model_dir=faster_whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in whisper_typos: + elif whisper_type == WhisperImpl.WHISPER.value: return WhisperInference( model_dir=whisper_model_dir, output_dir=output_dir, diarization_model_dir=diarization_model_dir, uvr_model_dir=uvr_model_dir ) - elif whisper_type in insanely_fast_whisper_typos: + elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value: return InsanelyFastWhisperInference( model_dir=insanely_fast_whisper_model_dir, output_dir=output_dir, diff --git a/modules/whisper/whisper_parameter.py b/modules/whisper/whisper_parameter.py deleted file mode 100644 index 19115fc..0000000 --- a/modules/whisper/whisper_parameter.py +++ /dev/null @@ -1,371 +0,0 @@ -from dataclasses import dataclass, fields -import gradio as gr -from typing import Optional, Dict -import yaml - -from modules.utils.constants import AUTOMATIC_DETECTION - - -@dataclass -class WhisperParameters: - model_size: gr.Dropdown - lang: gr.Dropdown - is_translate: gr.Checkbox - beam_size: gr.Number - log_prob_threshold: gr.Number - no_speech_threshold: gr.Number - compute_type: gr.Dropdown - best_of: gr.Number - patience: gr.Number - condition_on_previous_text: gr.Checkbox - prompt_reset_on_temperature: gr.Slider - initial_prompt: gr.Textbox - temperature: gr.Slider - compression_ratio_threshold: gr.Number - vad_filter: gr.Checkbox - threshold: gr.Slider - min_speech_duration_ms: gr.Number - max_speech_duration_s: gr.Number - min_silence_duration_ms: gr.Number - speech_pad_ms: gr.Number - batch_size: gr.Number - is_diarize: gr.Checkbox - hf_token: gr.Textbox - diarization_device: gr.Dropdown - length_penalty: gr.Number - repetition_penalty: gr.Number - no_repeat_ngram_size: gr.Number - prefix: gr.Textbox - suppress_blank: gr.Checkbox - suppress_tokens: gr.Textbox - max_initial_timestamp: gr.Number - word_timestamps: gr.Checkbox - prepend_punctuations: gr.Textbox - append_punctuations: gr.Textbox - max_new_tokens: gr.Number - chunk_length: gr.Number - hallucination_silence_threshold: gr.Number - hotwords: gr.Textbox - language_detection_threshold: gr.Number - language_detection_segments: gr.Number - is_bgm_separate: gr.Checkbox - uvr_model_size: gr.Dropdown - uvr_device: gr.Dropdown - uvr_segment_size: gr.Number - uvr_save_file: gr.Checkbox - uvr_enable_offload: gr.Checkbox - """ - A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing. - This data class is used to mitigate the key-value problem between Gradio components and function parameters. - Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471 - See more about Gradio pre-processing: https://www.gradio.app/docs/components - - Attributes - ---------- - model_size: gr.Dropdown - Whisper model size. - - lang: gr.Dropdown - Source language of the file to transcribe. - - is_translate: gr.Checkbox - Boolean value that determines whether to translate to English. - It's Whisper's feature to translate speech from another language directly into English end-to-end. - - beam_size: gr.Number - Int value that is used for decoding option. - - log_prob_threshold: gr.Number - If the average log probability over sampled tokens is below this value, treat as failed. - - no_speech_threshold: gr.Number - If the no_speech probability is higher than this value AND - the average log probability over sampled tokens is below `log_prob_threshold`, - consider the segment as silent. - - compute_type: gr.Dropdown - compute type for transcription. - see more info : https://opennmt.net/CTranslate2/quantization.html - - best_of: gr.Number - Number of candidates when sampling with non-zero temperature. - - patience: gr.Number - Beam search patience factor. - - condition_on_previous_text: gr.Checkbox - if True, the previous output of the model is provided as a prompt for the next window; - disabling may make the text inconsistent across windows, but the model becomes less prone to - getting stuck in a failure loop, such as repetition looping or timestamps going out of sync. - - initial_prompt: gr.Textbox - Optional text to provide as a prompt for the first window. This can be used to provide, or - "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns - to make it more likely to predict those word correctly. - - temperature: gr.Slider - Temperature for sampling. It can be a tuple of temperatures, - which will be successively used upon failures according to either - `compression_ratio_threshold` or `log_prob_threshold`. - - compression_ratio_threshold: gr.Number - If the gzip compression ratio is above this value, treat as failed - - vad_filter: gr.Checkbox - Enable the voice activity detection (VAD) to filter out parts of the audio - without speech. This step is using the Silero VAD model - https://github.com/snakers4/silero-vad. - - threshold: gr.Slider - This parameter is related with Silero VAD. Speech threshold. - Silero VAD outputs speech probabilities for each audio chunk, - probabilities ABOVE this value are considered as SPEECH. It is better to tune this - parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets. - - min_speech_duration_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out. - - max_speech_duration_s: gr.Number - This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer - than max_speech_duration_s will be split at the timestamp of the last silence that - lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be - split aggressively just before max_speech_duration_s. - - min_silence_duration_ms: gr.Number - This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms - before separating it - - speech_pad_ms: gr.Number - This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side - - batch_size: gr.Number - This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe - - is_diarize: gr.Checkbox - This parameter is related with whisperx. Boolean value that determines whether to diarize or not. - - hf_token: gr.Textbox - This parameter is related with whisperx. Huggingface token is needed to download diarization models. - Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements - - diarization_device: gr.Dropdown - This parameter is related with whisperx. Device to run diarization model - - length_penalty: gr.Number - This parameter is related to faster-whisper. Exponential length penalty constant. - - repetition_penalty: gr.Number - This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens - (set > 1 to penalize). - - no_repeat_ngram_size: gr.Number - This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable). - - prefix: gr.Textbox - This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window. - - suppress_blank: gr.Checkbox - This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling. - - suppress_tokens: gr.Textbox - This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set - of symbols as defined in the model config.json file. - - max_initial_timestamp: gr.Number - This parameter is related to faster-whisper. The initial timestamp cannot be later than this. - - word_timestamps: gr.Checkbox - This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern - and dynamic time warping, and include the timestamps for each word in each segment. - - prepend_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the next word. - - append_punctuations: gr.Textbox - This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols - with the previous word. - - max_new_tokens: gr.Number - This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set, - the maximum will be set by the default max_length. - - chunk_length: gr.Number - This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds. - If it is not None, it will overwrite the default chunk_length of the FeatureExtractor. - - hallucination_silence_threshold: gr.Number - This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold - (in seconds) when a possible hallucination is detected. - - hotwords: gr.Textbox - This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None. - - language_detection_threshold: gr.Number - This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected. - - language_detection_segments: gr.Number - This parameter is related to faster-whisper. Number of segments to consider for the language detection. - - is_separate_bgm: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to separate bgm or not. - - uvr_model_size: gr.Dropdown - This parameter is related to UVR. UVR model size. - - uvr_device: gr.Dropdown - This parameter is related to UVR. Device to run UVR model. - - uvr_segment_size: gr.Number - This parameter is related to UVR. Segment size for UVR model. - - uvr_save_file: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to save the file or not. - - uvr_enable_offload: gr.Checkbox - This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not - after each transcription. - """ - - def as_list(self) -> list: - """ - Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing. - See more about Gradio pre-processing: : https://www.gradio.app/docs/components - - Returns - ---------- - A list of Gradio components - """ - return [getattr(self, f.name) for f in fields(self)] - - @staticmethod - def as_value(*args) -> 'WhisperValues': - """ - To use Whisper parameters in function after Gradio post-processing. - See more about Gradio post-processing: : https://www.gradio.app/docs/components - - Returns - ---------- - WhisperValues - Data class that has values of parameters - """ - return WhisperValues(*args) - - -@dataclass -class WhisperValues: - model_size: str = "large-v2" - lang: Optional[str] = None - is_translate: bool = False - beam_size: int = 5 - log_prob_threshold: float = -1.0 - no_speech_threshold: float = 0.6 - compute_type: str = "float16" - best_of: int = 5 - patience: float = 1.0 - condition_on_previous_text: bool = True - prompt_reset_on_temperature: float = 0.5 - initial_prompt: Optional[str] = None - temperature: float = 0.0 - compression_ratio_threshold: float = 2.4 - vad_filter: bool = False - threshold: float = 0.5 - min_speech_duration_ms: int = 250 - max_speech_duration_s: float = float("inf") - min_silence_duration_ms: int = 2000 - speech_pad_ms: int = 400 - batch_size: int = 24 - is_diarize: bool = False - hf_token: str = "" - diarization_device: str = "cuda" - length_penalty: float = 1.0 - repetition_penalty: float = 1.0 - no_repeat_ngram_size: int = 0 - prefix: Optional[str] = None - suppress_blank: bool = True - suppress_tokens: Optional[str] = "[-1]" - max_initial_timestamp: float = 0.0 - word_timestamps: bool = False - prepend_punctuations: Optional[str] = "\"'“¿([{-" - append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、" - max_new_tokens: Optional[int] = None - chunk_length: Optional[int] = 30 - hallucination_silence_threshold: Optional[float] = None - hotwords: Optional[str] = None - language_detection_threshold: Optional[float] = None - language_detection_segments: int = 1 - is_bgm_separate: bool = False - uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4" - uvr_device: str = "cuda" - uvr_segment_size: int = 256 - uvr_save_file: bool = False - uvr_enable_offload: bool = True - """ - A data class to use Whisper parameters. - """ - - def to_yaml(self) -> Dict: - data = { - "whisper": { - "model_size": self.model_size, - "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang, - "is_translate": self.is_translate, - "beam_size": self.beam_size, - "log_prob_threshold": self.log_prob_threshold, - "no_speech_threshold": self.no_speech_threshold, - "best_of": self.best_of, - "patience": self.patience, - "condition_on_previous_text": self.condition_on_previous_text, - "prompt_reset_on_temperature": self.prompt_reset_on_temperature, - "initial_prompt": None if not self.initial_prompt else self.initial_prompt, - "temperature": self.temperature, - "compression_ratio_threshold": self.compression_ratio_threshold, - "batch_size": self.batch_size, - "length_penalty": self.length_penalty, - "repetition_penalty": self.repetition_penalty, - "no_repeat_ngram_size": self.no_repeat_ngram_size, - "prefix": None if not self.prefix else self.prefix, - "suppress_blank": self.suppress_blank, - "suppress_tokens": self.suppress_tokens, - "max_initial_timestamp": self.max_initial_timestamp, - "word_timestamps": self.word_timestamps, - "prepend_punctuations": self.prepend_punctuations, - "append_punctuations": self.append_punctuations, - "max_new_tokens": self.max_new_tokens, - "chunk_length": self.chunk_length, - "hallucination_silence_threshold": self.hallucination_silence_threshold, - "hotwords": None if not self.hotwords else self.hotwords, - "language_detection_threshold": self.language_detection_threshold, - "language_detection_segments": self.language_detection_segments, - }, - "vad": { - "vad_filter": self.vad_filter, - "threshold": self.threshold, - "min_speech_duration_ms": self.min_speech_duration_ms, - "max_speech_duration_s": self.max_speech_duration_s, - "min_silence_duration_ms": self.min_silence_duration_ms, - "speech_pad_ms": self.speech_pad_ms, - }, - "diarization": { - "is_diarize": self.is_diarize, - "hf_token": self.hf_token - }, - "bgm_separation": { - "is_separate_bgm": self.is_bgm_separate, - "model_size": self.uvr_model_size, - "segment_size": self.uvr_segment_size, - "save_file": self.uvr_save_file, - "enable_offload": self.uvr_enable_offload - }, - } - return data - - def as_list(self) -> list: - """ - Converts the data class attributes into a list - - Returns - ---------- - A list of Whisper parameters - """ - return [getattr(self, f.name) for f in fields(self)] diff --git a/notebook/whisper-webui.ipynb b/notebook/whisper-webui.ipynb index 558daaf..0eb8eab 100644 --- a/notebook/whisper-webui.ipynb +++ b/notebook/whisper-webui.ipynb @@ -56,7 +56,7 @@ "!pip install faster-whisper==1.0.3\n", "!pip install ctranslate2==4.4.0\n", "!pip install gradio\n", - "!pip install git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error\n", + "!pip install gradio-i18n\n", "# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n", "!pip install git+https://github.com/JuanBindez/pytubefix.git\n", "!pip install tokenizers==0.19.1\n", @@ -99,7 +99,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "id": "PQroYRRZzQiN", "cellView": "form" diff --git a/requirements.txt b/requirements.txt index 9aeeb66..4a34878 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,7 +11,7 @@ git+https://github.com/jhj0517/jhj0517-whisper.git faster-whisper==1.0.3 transformers gradio -git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error +gradio-i18n pytubefix ruamel.yaml==0.18.6 pyannote.audio==3.3.1 diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py index cc4a6f8..95b77a0 100644 --- a/tests/test_bgm_separation.py +++ b/tests/test_bgm_separation.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -17,9 +17,9 @@ import os @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, True, False), - ("faster-whisper", False, True, False), - ("insanely_fast_whisper", False, True, False) + (WhisperImpl.WHISPER.value, False, True, False), + (WhisperImpl.FASTER_WHISPER.value, False, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False) ] ) def test_bgm_separation_pipeline( @@ -38,9 +38,9 @@ def test_bgm_separation_pipeline( @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, True, False), - ("faster-whisper", True, True, False), - ("insanely_fast_whisper", True, True, False) + (WhisperImpl.WHISPER.value, True, True, False), + (WhisperImpl.FASTER_WHISPER.value, True, True, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False) ] ) def test_bgm_separation_with_vad_pipeline( diff --git a/tests/test_config.py b/tests/test_config.py index 0f60aa5..f82e4f1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,10 +1,14 @@ -from modules.utils.paths import * - +import functools +import jiwer import os import torch +from modules.utils.paths import * +from modules.utils.youtube_manager import * + TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav" TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav") +TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country" TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer" TEST_WHISPER_MODEL = "tiny" TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4" @@ -13,5 +17,24 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt") TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt") +@functools.lru_cache def is_cuda_available(): return torch.cuda.is_available() + + +@functools.lru_cache +def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL): + try: + yt_temp_path = os.path.join("modules", "yt_tmp.wav") + if os.path.exists(yt_temp_path): + return False + yt = get_ytdata(url) + audio = get_ytaudio(yt) + return False + except Exception as e: + print(f"Pytube has detected as a bot: {e}") + return True + + +def calculate_wer(answer, prediction): + return jiwer.wer(answer, prediction) diff --git a/tests/test_diarization.py b/tests/test_diarization.py index 54e7244..f18a263 100644 --- a/tests/test_diarization.py +++ b/tests/test_diarization.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -16,9 +16,9 @@ import os @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, True), - ("faster-whisper", False, False, True), - ("insanely_fast_whisper", False, False, True) + (WhisperImpl.WHISPER.value, False, False, True), + (WhisperImpl.FASTER_WHISPER.value, False, False, True), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True) ] ) def test_diarization_pipeline( diff --git a/tests/test_transcription.py b/tests/test_transcription.py index 4b5ab98..bc5267c 100644 --- a/tests/test_transcription.py +++ b/tests/test_transcription.py @@ -1,5 +1,6 @@ from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * +from modules.utils.subtitle_manager import read_file from modules.utils.paths import WEBUI_DIR from test_config import * @@ -12,9 +13,9 @@ import os @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", False, False, False), - ("faster-whisper", False, False, False), - ("insanely_fast_whisper", False, False, False) + (WhisperImpl.WHISPER.value, False, False, False), + (WhisperImpl.FASTER_WHISPER.value, False, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False) ] ) def test_transcribe( @@ -28,6 +29,10 @@ def test_transcribe( if not os.path.exists(audio_path): download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) + answer = TEST_ANSWER + if diarization: + answer = "SPEAKER_00|"+TEST_ANSWER + whisper_inferencer = WhisperFactory.create_whisper_inference( whisper_type=whisper_type, ) @@ -37,16 +42,24 @@ def test_transcribe( f"""Diarization Device: {whisper_inferencer.diarizer.device}""" ) - hparams = WhisperValues( - model_size=TEST_WHISPER_MODEL, - vad_filter=vad_filter, - is_bgm_separate=bgm_separation, - compute_type=whisper_inferencer.current_compute_type, - uvr_enable_offload=True, - is_diarize=diarization, - ).as_list() + hparams = TranscriptionPipelineParams( + whisper=WhisperParams( + model_size=TEST_WHISPER_MODEL, + compute_type=whisper_inferencer.current_compute_type + ), + vad=VadParams( + vad_filter=vad_filter + ), + bgm_separation=BGMSeparationParams( + is_separate_bgm=bgm_separation, + enable_offload=True + ), + diarization=DiarizationParams( + is_diarize=diarization + ), + ).to_list() - subtitle_str, file_path = whisper_inferencer.transcribe_file( + subtitle_str, file_paths = whisper_inferencer.transcribe_file( [audio_path], None, "SRT", @@ -54,29 +67,29 @@ def test_transcribe( gr.Progress(), *hparams, ) + subtitle = read_file(file_paths[0]).split("\n") + assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1 - assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path + if not is_pytube_detected_bot(): + subtitle_str, file_path = whisper_inferencer.transcribe_youtube( + TEST_YOUTUBE_URL, + "SRT", + False, + gr.Progress(), + *hparams, + ) + assert isinstance(subtitle_str, str) and subtitle_str + assert os.path.exists(file_path) - whisper_inferencer.transcribe_youtube( - TEST_YOUTUBE_URL, - "SRT", - False, - gr.Progress(), - *hparams, - ) - assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path - - whisper_inferencer.transcribe_mic( + subtitle_str, file_path = whisper_inferencer.transcribe_mic( audio_path, "SRT", False, gr.Progress(), *hparams, ) - assert isinstance(subtitle_str, str) and subtitle_str - assert isinstance(file_path[0], str) and file_path + subtitle = read_file(file_path).split("\n") + assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1 def download_file(url, save_dir): diff --git a/tests/test_vad.py b/tests/test_vad.py index 124a043..cb3dc05 100644 --- a/tests/test_vad.py +++ b/tests/test_vad.py @@ -1,6 +1,6 @@ from modules.utils.paths import * from modules.whisper.whisper_factory import WhisperFactory -from modules.whisper.whisper_parameter import WhisperValues +from modules.whisper.data_classes import * from test_config import * from test_transcription import download_file, test_transcribe @@ -12,9 +12,9 @@ import os @pytest.mark.parametrize( "whisper_type,vad_filter,bgm_separation,diarization", [ - ("whisper", True, False, False), - ("faster-whisper", True, False, False), - ("insanely_fast_whisper", True, False, False) + (WhisperImpl.WHISPER.value, True, False, False), + (WhisperImpl.FASTER_WHISPER.value, True, False, False), + (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False) ] ) def test_vad_pipeline(