diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 076f0fc..571558c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -37,7 +37,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
- name: Install dependencies
- run: pip install -r requirements.txt pytest
+ run: pip install -r requirements.txt pytest jiwer
- name: Run test
run: python -m pytest -rs tests
\ No newline at end of file
diff --git a/README.md b/README.md
index df50889..8d85c13 100644
--- a/README.md
+++ b/README.md
@@ -128,3 +128,6 @@ This is Whisper's original VRAM usage table for models.
- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
- [ ] Add fast api script
- [ ] Support real-time transcription for microphone
+
+### Translation 🌐
+Any PRs translating Japanese, Spanish, French, German, Chinese, or any other language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!
diff --git a/app.py b/app.py
index ea094c5..175522b 100644
--- a/app.py
+++ b/app.py
@@ -7,17 +7,14 @@ import yaml
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR, I18N_YAML_PATH)
-from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.files_manager import load_yaml
from modules.whisper.whisper_factory import WhisperFactory
-from modules.whisper.faster_whisper_inference import FasterWhisperInference
-from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.translation.nllb_inference import NLLBInference
from modules.ui.htmls import *
from modules.utils.cli_manager import str2bool
from modules.utils.youtube_manager import get_ytmetas
from modules.translation.deepl_api import DeepLAPI
-from modules.whisper.whisper_parameter import *
+from modules.whisper.data_classes import *
class App:
@@ -44,7 +41,7 @@ class App:
print(f"Use \"{self.args.whisper_type}\" implementation\n"
f"Device \"{self.whisper_inf.device}\" is detected")
- def create_whisper_parameters(self):
+ def create_pipeline_inputs(self):
whisper_params = self.default_params["whisper"]
vad_params = self.default_params["vad"]
diarization_params = self.default_params["diarization"]
@@ -56,7 +53,7 @@ class App:
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
else whisper_params["lang"], label=_("Language"))
- dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
+ dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
with gr.Row():
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
interactive=True)
@@ -66,158 +63,31 @@ class App:
interactive=True)
with gr.Accordion(_("Advanced Parameters"), open=False):
- nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0,
- interactive=True,
- info="Beam size to use for decoding.")
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold",
- value=whisper_params["log_prob_threshold"], interactive=True,
- info="If the average log probability over sampled tokens is below this value, treat as failed.")
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
- interactive=True,
- info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
- dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
- value=self.whisper_inf.current_compute_type, interactive=True,
- allow_custom_value=True,
- info="Select the type of computation to perform.")
- nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
- info="Number of candidates when sampling with non-zero temperature.")
- nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
- info="Beam search patience factor.")
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
- value=whisper_params["condition_on_previous_text"],
- interactive=True,
- info="Condition on previous text during decoding.")
- sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
- value=whisper_params["prompt_reset_on_temperature"],
- minimum=0, maximum=1, step=0.01, interactive=True,
- info="Resets prompt if temperature is above this value."
- " Arg has effect only if 'Condition On Previous Text' is True.")
- tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
- info="Initial prompt to use for decoding.")
- sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
- step=0.01, maximum=1.0, interactive=True,
- info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
- value=whisper_params["compression_ratio_threshold"],
- interactive=True,
- info="If the gzip compression ratio is above this value, treat as failed.")
- nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
- precision=0,
- info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
- with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
- nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
- info="Exponential length penalty constant.")
- nb_repetition_penalty = gr.Number(label="Repetition Penalty",
- value=whisper_params["repetition_penalty"],
- info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
- nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
- value=whisper_params["no_repeat_ngram_size"],
- precision=0,
- info="Prevent repetitions of n-grams with this size (set 0 to disable).")
- tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
- info="Optional text to provide as a prefix for the first window.")
- cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
- info="Suppress blank outputs at the beginning of the sampling.")
- tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
- info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
- nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
- value=whisper_params["max_initial_timestamp"],
- info="The initial timestamp cannot be later than this.")
- cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
- info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
- tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
- value=whisper_params["prepend_punctuations"],
- info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
- tb_append_punctuations = gr.Textbox(label="Append Punctuations",
- value=whisper_params["append_punctuations"],
- info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
- nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
- precision=0,
- info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
- nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
- value=lambda: whisper_params[
- "hallucination_silence_threshold"],
- info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
- tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
- info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
- nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
- value=lambda: whisper_params[
- "language_detection_threshold"],
- info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
- nb_language_detection_segments = gr.Number(label="Language Detection Segments",
- value=lambda: whisper_params["language_detection_segments"],
- precision=0,
- info="Number of segments to consider for the language detection.")
- with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
- nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
+ whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
+ whisper_type=self.args.whisper_type,
+ available_compute_types=self.whisper_inf.available_compute_types,
+ compute_type=self.whisper_inf.current_compute_type)
with gr.Accordion(_("Background Music Remover Filter"), open=False):
- cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"),
- value=uvr_params["is_separate_bgm"],
- interactive=True,
- info=_("Enabling this will remove background music"))
- dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
- choices=self.whisper_inf.music_separator.available_devices)
- dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
- choices=self.whisper_inf.music_separator.available_models)
- nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
- cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
- cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
- value=uvr_params["enable_offload"])
+ uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
+ available_models=self.whisper_inf.music_separator.available_models,
+ available_devices=self.whisper_inf.music_separator.available_devices,
+ device=self.whisper_inf.music_separator.device)
with gr.Accordion(_("Voice Detection Filter"), open=False):
- cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"],
- interactive=True,
- info=_("Enable this to transcribe only detected voice"))
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
- value=vad_params["threshold"],
- info="Lower it to be more sensitive to small sounds.")
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
- value=vad_params["min_speech_duration_ms"],
- info="Final speech chunks shorter than this time are thrown out")
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
- value=vad_params["max_speech_duration_s"],
- info="Maximum duration of speech chunks in \"seconds\".")
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
- value=vad_params["min_silence_duration_ms"],
- info="In the end of each speech chunk wait for this time"
- " before separating it")
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
- info="Final speech chunks are padded by this time each side")
+ vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
with gr.Accordion(_("Diarization"), open=False):
- cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"])
- tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"],
- info=_("This is only needed the first time you download the model"))
- dd_diarization_device = gr.Dropdown(label=_("Device"),
- choices=self.whisper_inf.diarizer.get_available_device(),
- value=self.whisper_inf.diarizer.get_device())
+ diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
+ available_devices=self.whisper_inf.diarizer.available_device,
+ device=self.whisper_inf.diarizer.device)
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
+ pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
+
return (
- WhisperParameters(
- model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
- log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
- compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
- condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
- temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
- vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
- max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
- speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
- is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
- length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
- no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
- suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
- word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
- append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
- hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
- language_detection_threshold=nb_language_detection_threshold,
- language_detection_segments=nb_language_detection_segments,
- prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
- uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
- uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
- ),
+ pipeline_inputs,
dd_file_format,
cb_timestamp
)
@@ -243,7 +113,7 @@ class App:
visible=self.args.colab,
value="")
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -254,7 +124,7 @@ class App:
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_file,
- inputs=params + whisper_params.as_list(),
+ inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
@@ -268,7 +138,7 @@ class App:
tb_title = gr.Label(label=_("Youtube Title"))
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -280,7 +150,7 @@ class App:
params = [tb_youtubelink, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
- inputs=params + whisper_params.as_list(),
+ inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
outputs=[img_thumbnail, tb_title, tb_description])
@@ -290,7 +160,7 @@ class App:
with gr.Row():
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
- whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
+ pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -302,7 +172,7 @@ class App:
params = [mic_input, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_mic,
- inputs=params + whisper_params.as_list(),
+ inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
@@ -417,7 +287,6 @@ class App:
# Launch the app with optional gradio settings
args = self.args
-
self.app.queue(
api_open=args.api_open
).launch(
@@ -447,8 +316,8 @@ class App:
parser = argparse.ArgumentParser()
-parser.add_argument('--whisper_type', type=str, default="faster-whisper",
- choices=["whisper", "faster-whisper", "insanely-fast-whisper"],
+parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
+ choices=[item.value for item in WhisperImpl],
help='A type of the whisper implementation (Github repo name)')
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
diff --git a/configs/default_parameters.yaml b/configs/default_parameters.yaml
index 8339eda..483d17a 100644
--- a/configs/default_parameters.yaml
+++ b/configs/default_parameters.yaml
@@ -1,5 +1,6 @@
whisper:
model_size: "medium.en"
+ file_format: "SRT"
lang: "english"
is_translate: false
beam_size: 5
diff --git a/configs/translation.yaml b/configs/translation.yaml
index 1386867..68e35e3 100644
--- a/configs/translation.yaml
+++ b/configs/translation.yaml
@@ -411,3 +411,49 @@ ru: # Russian
Instrumental: Инструментал
Vocals: Вокал
SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
+
+tr: # Turkish
+ Language: Dil
+ File: Dosya
+ Youtube: Youtube
+ Mic: Mikrofon
+ T2T Translation: T2T Çeviri
+ BGM Separation: Arka Plan Müziği Ayırma
+ GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR
+ Output: Çıktı
+ Downloadable output file: İndirilebilir çıktı dosyası
+ Upload File here: Dosya Yükle
+ Model: Model
+ Automatic Detection: Otomatik Algılama
+ File Format: Dosya Formatı
+ Translate to English?: İngilizceye Çevir?
+ Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle
+ Advanced Parameters: Gelişmiş Parametreler
+ Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi
+ Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır
+ Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir
+ Save separated files to output: Ayrılmış dosyaları çıktıya kaydet
+ Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak
+ Voice Detection Filter: Ses Algılama Filtresi
+ Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et
+ Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir
+ Diarization: Konuşmacı Ayrımı
+ Enable Diarization: Konuşmacı Ayrımını Etkinleştir
+ HuggingFace Token: HuggingFace Anahtarı
+ This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co/pyannote/speaker-diarization-3.1" ve "https://huggingface.co/pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor
+ Device: Cihaz
+ Youtube Link: Youtube Bağlantısı
+ Youtube Thumbnail: Youtube Küçük Resmi
+ Youtube Title: Youtube Başlığı
+ Youtube Description: Youtube Açıklaması
+ Record with Mic: Mikrofonla Kaydet
+ Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle
+ Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI)
+ Source Language: Kaynak Dil
+ Target Language: Hedef Dil
+ Pro User?: Pro Kullanıcı?
+ TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR
+ Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle
+ Instrumental: Enstrümantal
+ Vocals: Vokal
+ SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR
diff --git a/modules/diarize/diarize_pipeline.py b/modules/diarize/diarize_pipeline.py
index b4109e8..4313f5c 100644
--- a/modules/diarize/diarize_pipeline.py
+++ b/modules/diarize/diarize_pipeline.py
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
+from modules.whisper.data_classes import *
from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
@@ -43,6 +44,8 @@ class DiarizationPipeline:
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
+ if transcript_segments and isinstance(transcript_segments[0], Segment):
+ transcript_segments = [seg.model_dump() for seg in transcript_segments]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
@@ -63,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
seg["speaker"] = speaker
# assign speaker to words
- if 'words' in seg:
+ if 'words' in seg and seg['words'] is not None:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
@@ -85,10 +88,10 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
if word_speaker is not None:
word["speaker"] = word_speaker
- return transcript_result
+ return {"segments": transcript_segments}
-class Segment:
+class DiarizationSegment:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end
diff --git a/modules/diarize/diarizer.py b/modules/diarize/diarizer.py
index 2dd3f94..38e150a 100644
--- a/modules/diarize/diarizer.py
+++ b/modules/diarize/diarizer.py
@@ -1,6 +1,6 @@
import os
import torch
-from typing import List, Union, BinaryIO, Optional
+from typing import List, Union, BinaryIO, Optional, Tuple
import numpy as np
import time
import logging
@@ -8,6 +8,7 @@ import logging
from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
from modules.diarize.audio_loader import load_audio
+from modules.whisper.data_classes import *
class Diarizer:
@@ -23,10 +24,10 @@ class Diarizer:
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
- transcribed_result: List[dict],
+ transcribed_result: List[Segment],
use_auth_token: str,
device: Optional[str] = None
- ):
+ ) -> Tuple[List[Segment], float]:
"""
Diarize transcribed result as a post-processing
@@ -34,7 +35,7 @@ class Diarizer:
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
- transcribed_result: List[dict]
+ transcribed_result: List[Segment]
transcribed result through whisper.
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
@@ -44,8 +45,8 @@ class Diarizer:
Returns
----------
- segments_result: List[dict]
- list of dicts that includes start, end timestamps and transcribed text
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
@@ -68,14 +69,20 @@ class Diarizer:
{"segments": transcribed_result}
)
+ segments_result = []
for segment in diarized_result["segments"]:
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
- segment["text"] = speaker + "|" + segment["text"].strip()
+ diarized_text = speaker + "|" + segment["text"].strip()
+ segments_result.append(Segment(
+ start=segment["start"],
+ end=segment["end"],
+ text=diarized_text
+ ))
elapsed_time = time.time() - start_time
- return diarized_result["segments"], elapsed_time
+ return segments_result, elapsed_time
def update_pipe(self,
use_auth_token: str,
diff --git a/modules/translation/deepl_api.py b/modules/translation/deepl_api.py
index 35f245b..0814fb7 100644
--- a/modules/translation/deepl_api.py
+++ b/modules/translation/deepl_api.py
@@ -139,37 +139,27 @@ class DeepLAPI:
)
files_info = {}
- for fileobj in fileobjs:
- file_path = fileobj
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
-
- if file_ext == ".srt":
- parsed_dicts = parse_srt(file_path=file_path)
-
- elif file_ext == ".vtt":
- parsed_dicts = parse_vtt(file_path=file_path)
+ for file_path in fileobjs:
+ file_name, file_ext = os.path.splitext(os.path.basename(file_path))
+ writer = get_writer(file_ext, self.output_dir)
+ segments = writer.to_segments(file_path)
batch_size = self.max_text_batch_size
- for batch_start in range(0, len(parsed_dicts), batch_size):
- batch_end = min(batch_start + batch_size, len(parsed_dicts))
- sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
+ for batch_start in range(0, len(segments), batch_size):
+ progress(batch_start / len(segments), desc="Translating..")
+ sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
target_lang, is_pro)
for i, translated_text in enumerate(translated_texts):
- parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
- progress(batch_end / len(parsed_dicts), desc="Translating..")
+ segments[batch_start + i].text = translated_text["text"]
- if file_ext == ".srt":
- subtitle = get_serialized_srt(parsed_dicts)
- elif file_ext == ".vtt":
- subtitle = get_serialized_vtt(parsed_dicts)
-
- if add_timestamp:
- timestamp = datetime.now().strftime("%m%d%H%M%S")
- file_name += f"-{timestamp}"
-
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
- write_file(subtitle, output_path)
+ subtitle, output_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_ext,
+ result=segments,
+ add_timestamp=add_timestamp
+ )
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
diff --git a/modules/translation/translation_base.py b/modules/translation/translation_base.py
index abc7f44..6087767 100644
--- a/modules/translation/translation_base.py
+++ b/modules/translation/translation_base.py
@@ -6,7 +6,7 @@ from typing import List
from datetime import datetime
import modules.translation.nllb_inference as nllb
-from modules.whisper.whisper_parameter import *
+from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import *
from modules.utils.files_manager import load_yaml, save_yaml
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
files_info = {}
for fileobj in fileobjs:
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
- if file_ext == ".srt":
- parsed_dicts = parse_srt(file_path=fileobj)
- total_progress = len(parsed_dicts)
- for index, dic in enumerate(parsed_dicts):
- progress(index / total_progress, desc="Translating..")
- translated_text = self.translate(dic["sentence"], max_length=max_length)
- dic["sentence"] = translated_text
- subtitle = get_serialized_srt(parsed_dicts)
+ writer = get_writer(file_ext, self.output_dir)
+ segments = writer.to_segments(fileobj)
+ for i, segment in enumerate(segments):
+ progress(i / len(segments), desc="Translating..")
+ translated_text = self.translate(segment.text, max_length=max_length)
+ segment.text = translated_text
- elif file_ext == ".vtt":
- parsed_dicts = parse_vtt(file_path=fileobj)
- total_progress = len(parsed_dicts)
- for index, dic in enumerate(parsed_dicts):
- progress(index / total_progress, desc="Translating..")
- translated_text = self.translate(dic["sentence"], max_length=max_length)
- dic["sentence"] = translated_text
- subtitle = get_serialized_vtt(parsed_dicts)
+ subtitle, file_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_ext,
+ result=segments,
+ add_timestamp=add_timestamp
+ )
- if add_timestamp:
- timestamp = datetime.now().strftime("%m%d%H%M%S")
- file_name += f"-{timestamp}"
-
- output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
- write_file(subtitle, output_path)
-
- files_info[file_name] = {"subtitle": subtitle, "path": output_path}
+ files_info[file_name] = {"subtitle": subtitle, "path": file_path}
total_result = ''
for file_name, info in files_info.items():
@@ -133,7 +123,8 @@ class TranslationBase(ABC):
return [gr_str, output_file_paths]
except Exception as e:
- print(f"Error: {str(e)}")
+ print(f"Error translating file: {e}")
+ raise
finally:
self.release_cuda_memory()
diff --git a/modules/utils/constants.py b/modules/utils/constants.py
index e9309bc..49b45c8 100644
--- a/modules/utils/constants.py
+++ b/modules/utils/constants.py
@@ -1,3 +1,6 @@
from gradio_i18n import Translate, gettext as _
AUTOMATIC_DETECTION = _("Automatic Detection")
+GRADIO_NONE_STR = ""
+GRADIO_NONE_NUMBER_MAX = 9999
+GRADIO_NONE_NUMBER_MIN = 0
diff --git a/modules/utils/files_manager.py b/modules/utils/files_manager.py
index 4ac0c63..29b5242 100644
--- a/modules/utils/files_manager.py
+++ b/modules/utils/files_manager.py
@@ -67,3 +67,9 @@ def is_video(file_path):
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
extension = os.path.splitext(file_path)[1].lower()
return extension in video_extensions
+
+
+def read_file(file_path):
+ with open(file_path, "r", encoding="utf-8") as f:
+ subtitle_content = f.read()
+ return subtitle_content
diff --git a/modules/utils/subtitle_manager.py b/modules/utils/subtitle_manager.py
index 4b48425..1a6ad12 100644
--- a/modules/utils/subtitle_manager.py
+++ b/modules/utils/subtitle_manager.py
@@ -1,121 +1,425 @@
+# Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
+
+import json
+import os
import re
+import sys
+import zlib
+from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
+from datetime import datetime
+
+from modules.whisper.data_classes import Segment, Word
+from .files_manager import read_file
-def timeformat_srt(time):
- hours = time // 3600
- minutes = (time - hours * 3600) // 60
- seconds = time - hours * 3600 - minutes * 60
- milliseconds = (time - int(time)) * 1000
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
+def format_timestamp(
+ seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
+) -> str:
+ assert seconds >= 0, "non-negative timestamp expected"
+ milliseconds = round(seconds * 1000.0)
+
+ hours = milliseconds // 3_600_000
+ milliseconds -= hours * 3_600_000
+
+ minutes = milliseconds // 60_000
+ milliseconds -= minutes * 60_000
+
+ seconds = milliseconds // 1_000
+ milliseconds -= seconds * 1_000
+
+ hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
+ return (
+ f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
+ )
-def timeformat_vtt(time):
- hours = time // 3600
- minutes = (time - hours * 3600) // 60
- seconds = time - hours * 3600 - minutes * 60
- milliseconds = (time - int(time)) * 1000
- return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
+def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
+ times = time_str.split(":")
+
+ if len(times) == 3:
+ hours, minutes, rest = times
+ hours = int(hours)
+ else:
+ hours = 0
+ minutes, rest = times
+
+ seconds, fractional = rest.split(decimal_marker)
+
+ minutes = int(minutes)
+ seconds = int(seconds)
+ fractional_seconds = float("0." + fractional)
+
+ return hours * 3600 + minutes * 60 + seconds + fractional_seconds
-def write_file(subtitle, output_file):
- with open(output_file, 'w', encoding='utf-8') as f:
- f.write(subtitle)
+def get_start(segments: List[dict]) -> Optional[float]:
+ return next(
+ (w["start"] for s in segments for w in s["words"]),
+ segments[0]["start"] if segments else None,
+ )
-def get_srt(segments):
- output = ""
- for i, segment in enumerate(segments):
- output += f"{i + 1}\n"
- output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
- if segment['text'].startswith(' '):
- segment['text'] = segment['text'][1:]
- output += f"{segment['text']}\n\n"
- return output
+def get_end(segments: List[dict]) -> Optional[float]:
+ return next(
+ (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
+ segments[-1]["end"] if segments else None,
+ )
-def get_vtt(segments):
- output = "WebVTT\n\n"
- for i, segment in enumerate(segments):
- output += f"{i + 1}\n"
- output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
- if segment['text'].startswith(' '):
- segment['text'] = segment['text'][1:]
- output += f"{segment['text']}\n\n"
- return output
+class ResultWriter:
+ extension: str
+
+ def __init__(self, output_dir: str):
+ self.output_dir = output_dir
+
+ def __call__(
+ self, result: Union[dict, List[Segment]], output_file_name: str,
+ options: Optional[dict] = None, **kwargs
+ ):
+ if isinstance(result, List) and result and isinstance(result[0], Segment):
+ result = {"segments": [seg.model_dump() for seg in result]}
+
+ output_path = os.path.join(
+ self.output_dir, output_file_name + "." + self.extension
+ )
+
+ with open(output_path, "w", encoding="utf-8") as f:
+ self.write_result(result, file=f, options=options, **kwargs)
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ raise NotImplementedError
-def get_txt(segments):
- output = ""
- for i, segment in enumerate(segments):
- if segment['text'].startswith(' '):
- segment['text'] = segment['text'][1:]
- output += f"{segment['text']}\n"
- return output
+class WriteTXT(ResultWriter):
+ extension: str = "txt"
+
+ def write_result(
+ self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for segment in result["segments"]:
+ print(segment["text"].strip(), file=file, flush=True)
-def parse_srt(file_path):
- """Reads SRT file and returns as dict"""
- with open(file_path, 'r', encoding='utf-8') as file:
- srt_data = file.read()
+class SubtitlesWriter(ResultWriter):
+ always_include_hours: bool
+ decimal_marker: str
- data = []
- blocks = srt_data.split('\n\n')
+ def iterate_result(
+ self,
+ result: dict,
+ options: Optional[dict] = None,
+ *,
+ max_line_width: Optional[int] = None,
+ max_line_count: Optional[int] = None,
+ highlight_words: bool = False,
+ align_lrc_words: bool = False,
+ max_words_per_line: Optional[int] = None,
+ ):
+ options = options or {}
+ max_line_width = max_line_width or options.get("max_line_width")
+ max_line_count = max_line_count or options.get("max_line_count")
+ highlight_words = highlight_words or options.get("highlight_words", False)
+ align_lrc_words = align_lrc_words or options.get("align_lrc_words", False)
+ max_words_per_line = max_words_per_line or options.get("max_words_per_line")
+ preserve_segments = max_line_count is None or max_line_width is None
+ max_line_width = max_line_width or 1000
+ max_words_per_line = max_words_per_line or 1000
- for block in blocks:
- if block.strip() != '':
- lines = block.strip().split('\n')
- index = lines[0]
- timestamp = lines[1]
- sentence = ' '.join(lines[2:])
+ def iterate_subtitles():
+ line_len = 0
+ line_count = 1
+ # the next subtitle to yield (a list of word timings with whitespace)
+ subtitle: List[dict] = []
+ last: float = get_start(result["segments"]) or 0.0
+ for segment in result["segments"]:
+ chunk_index = 0
+ words_count = max_words_per_line
+ while chunk_index < len(segment["words"]):
+ remaining_words = len(segment["words"]) - chunk_index
+ if max_words_per_line > len(segment["words"]) - chunk_index:
+ words_count = remaining_words
+ for i, original_timing in enumerate(
+ segment["words"][chunk_index : chunk_index + words_count]
+ ):
+ timing = original_timing.copy()
+ long_pause = (
+ not preserve_segments and timing["start"] - last > 3.0
+ )
+ has_room = line_len + len(timing["word"]) <= max_line_width
+ seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
+ if (
+ line_len > 0
+ and has_room
+ and not long_pause
+ and not seg_break
+ ):
+ # line continuation
+ line_len += len(timing["word"])
+ else:
+ # new line
+ timing["word"] = timing["word"].strip()
+ if (
+ len(subtitle) > 0
+ and max_line_count is not None
+ and (long_pause or line_count >= max_line_count)
+ or seg_break
+ ):
+ # subtitle break
+ yield subtitle
+ subtitle = []
+ line_count = 1
+ elif line_len > 0:
+ # line break
+ line_count += 1
+ timing["word"] = "\n" + timing["word"]
+ line_len = len(timing["word"].strip())
+ subtitle.append(timing)
+ last = timing["start"]
+ chunk_index += max_words_per_line
+ if len(subtitle) > 0:
+ yield subtitle
- data.append({
- "index": index,
- "timestamp": timestamp,
- "sentence": sentence
- })
- return data
+ if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]:
+ for subtitle in iterate_subtitles():
+ subtitle_start = self.format_timestamp(subtitle[0]["start"])
+ subtitle_end = self.format_timestamp(subtitle[-1]["end"])
+ subtitle_text = "".join([word["word"] for word in subtitle])
+ if highlight_words:
+ last = subtitle_start
+ all_words = [timing["word"] for timing in subtitle]
+ for i, this_word in enumerate(subtitle):
+ start = self.format_timestamp(this_word["start"])
+ end = self.format_timestamp(this_word["end"])
+ if last != start:
+ yield last, start, subtitle_text
+
+ yield start, end, "".join(
+ [
+ re.sub(r"^(\s*)(.*)$", r"\1\2", word)
+ if j == i
+ else word
+ for j, word in enumerate(all_words)
+ ]
+ )
+ last = end
+
+ if align_lrc_words:
+ lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle]
+ l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end'])
+ lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]"
+ lrc_aligned_words = ' '.join(lrc_aligned_words)
+ yield None, None, lrc_aligned_words
+
+ else:
+ yield subtitle_start, subtitle_end, subtitle_text
+ else:
+ for segment in result["segments"]:
+ segment_start = self.format_timestamp(segment["start"])
+ segment_end = self.format_timestamp(segment["end"])
+ segment_text = segment["text"].strip().replace("-->", "->")
+ yield segment_start, segment_end, segment_text
+
+ def format_timestamp(self, seconds: float):
+ return format_timestamp(
+ seconds=seconds,
+ always_include_hours=self.always_include_hours,
+ decimal_marker=self.decimal_marker,
+ )
-def parse_vtt(file_path):
- """Reads WebVTT file and returns as dict"""
- with open(file_path, 'r', encoding='utf-8') as file:
- webvtt_data = file.read()
+class WriteVTT(SubtitlesWriter):
+ extension: str = "vtt"
+ always_include_hours: bool = False
+ decimal_marker: str = "."
- data = []
- blocks = webvtt_data.split('\n\n')
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("WEBVTT\n", file=file)
+ for start, end, text in self.iterate_result(result, options, **kwargs):
+ print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
- for block in blocks:
- if block.strip() != '' and not block.strip().startswith("WebVTT"):
- lines = block.strip().split('\n')
- index = lines[0]
- timestamp = lines[1]
- sentence = ' '.join(lines[2:])
+ def to_segments(self, file_path: str) -> List[Segment]:
+ segments = []
- data.append({
- "index": index,
- "timestamp": timestamp,
- "sentence": sentence
- })
+ blocks = read_file(file_path).split('\n\n')
- return data
+ for block in blocks:
+ if block.strip() != '' and not block.strip().startswith("WEBVTT"):
+ lines = block.strip().split('\n')
+ time_line = lines[0].split(" --> ")
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
+ sentence = ' '.join(lines[1:])
+
+ segments.append(Segment(
+ start=start,
+ end=end,
+ text=sentence
+ ))
+
+ return segments
-def get_serialized_srt(dicts):
- output = ""
- for dic in dicts:
- output += f'{dic["index"]}\n'
- output += f'{dic["timestamp"]}\n'
- output += f'{dic["sentence"]}\n\n'
- return output
+class WriteSRT(SubtitlesWriter):
+ extension: str = "srt"
+ always_include_hours: bool = True
+ decimal_marker: str = ","
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options, **kwargs), start=1
+ ):
+ print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
+
+ def to_segments(self, file_path: str) -> List[Segment]:
+ segments = []
+
+ blocks = read_file(file_path).split('\n\n')
+
+ for block in blocks:
+ if block.strip() != '':
+ lines = block.strip().split('\n')
+ index = lines[0]
+ time_line = lines[1].split(" --> ")
+ start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
+ sentence = ' '.join(lines[2:])
+
+ segments.append(Segment(
+ start=start,
+ end=end,
+ text=sentence
+ ))
+
+ return segments
-def get_serialized_vtt(dicts):
- output = "WebVTT\n\n"
- for dic in dicts:
- output += f'{dic["index"]}\n'
- output += f'{dic["timestamp"]}\n'
- output += f'{dic["sentence"]}\n\n'
- return output
+class WriteLRC(SubtitlesWriter):
+ extension: str = "lrc"
+ always_include_hours: bool = False
+ decimal_marker: str = "."
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for i, (start, end, text) in enumerate(
+ self.iterate_result(result, options, **kwargs), start=1
+ ):
+ if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]:
+ print(f"{text}\n", file=file, flush=True)
+ else:
+ print(f"[{start}]{text}[{end}]\n", file=file, flush=True)
+
+ def to_segments(self, file_path: str) -> List[Segment]:
+ segments = []
+
+ blocks = read_file(file_path).split('\n')
+
+ for block in blocks:
+ if block.strip() != '':
+ lines = block.strip()
+ pattern = r'(\[.*?\])'
+ parts = re.split(pattern, lines)
+ parts = [part.strip() for part in parts if part]
+
+ for i, part in enumerate(parts):
+ sentence_i = i%2
+ if sentence_i == 1:
+ start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1]
+ start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "")
+ start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker)
+
+ segments.append(Segment(
+ start=start,
+ end=end,
+ text=text,
+ ))
+
+ return segments
+
+
+class WriteTSV(ResultWriter):
+ """
+ Write a transcript to a file in TSV (tab-separated values) format containing lines like:
+ \t\t
+
+ Using integer milliseconds as start and end times means there's no chance of interference from
+ an environment setting a language encoding that causes the decimal in a floating point number
+ to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
+ """
+
+ extension: str = "tsv"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ print("start", "end", "text", sep="\t", file=file)
+ for segment in result["segments"]:
+ print(round(1000 * segment["start"]), file=file, end="\t")
+ print(round(1000 * segment["end"]), file=file, end="\t")
+ print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
+
+
+class WriteJSON(ResultWriter):
+ extension: str = "json"
+
+ def write_result(
+ self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ json.dump(result, file)
+
+
+def get_writer(
+ output_format: str, output_dir: str
+) -> Callable[[dict, TextIO, dict], None]:
+ output_format = output_format.strip().lower().replace(".", "")
+
+ writers = {
+ "txt": WriteTXT,
+ "vtt": WriteVTT,
+ "srt": WriteSRT,
+ "tsv": WriteTSV,
+ "json": WriteJSON,
+ "lrc": WriteLRC
+ }
+
+ if output_format == "all":
+ all_writers = [writer(output_dir) for writer in writers.values()]
+
+ def write_all(
+ result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
+ ):
+ for writer in all_writers:
+ writer(result, file, options, **kwargs)
+
+ return write_all
+
+ return writers[output_format](output_dir)
+
+
+def generate_file(
+ output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str,
+ add_timestamp: bool = True, **kwargs
+) -> Tuple[str, str]:
+ output_format = output_format.strip().lower().replace(".", "")
+ output_format = "vtt" if output_format == "webvtt" else output_format
+
+ if add_timestamp:
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
+ output_file_name += f"-{timestamp}"
+
+ file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
+ file_writer = get_writer(output_format=output_format, output_dir=output_dir)
+
+ if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False):
+ kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True
+
+ file_writer(result=result, output_file_name=output_file_name, **kwargs)
+ content = read_file(file_path)
+ return content, file_path
def safe_filename(name):
diff --git a/modules/vad/silero_vad.py b/modules/vad/silero_vad.py
index bb5c919..cb6da93 100644
--- a/modules/vad/silero_vad.py
+++ b/modules/vad/silero_vad.py
@@ -5,7 +5,8 @@ import numpy as np
from typing import BinaryIO, Union, List, Optional, Tuple
import warnings
import faster_whisper
-from faster_whisper.transcribe import SpeechTimestampsMap, Segment
+from modules.whisper.data_classes import *
+from faster_whisper.transcribe import SpeechTimestampsMap
import gradio as gr
@@ -247,18 +248,18 @@ class SileroVAD:
def restore_speech_timestamps(
self,
- segments: List[dict],
+ segments: List[Segment],
speech_chunks: List[dict],
sampling_rate: Optional[int] = None,
- ) -> List[dict]:
+ ) -> List[Segment]:
if sampling_rate is None:
sampling_rate = self.sampling_rate
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
for segment in segments:
- segment["start"] = ts_map.get_original_time(segment["start"])
- segment["end"] = ts_map.get_original_time(segment["end"])
+ segment.start = ts_map.get_original_time(segment.start)
+ segment.end = ts_map.get_original_time(segment.end)
return segments
diff --git a/modules/whisper/whisper_base.py b/modules/whisper/base_transcription_pipeline.py
similarity index 66%
rename from modules/whisper/whisper_base.py
rename to modules/whisper/base_transcription_pipeline.py
index 51c87dd..2791dc6 100644
--- a/modules/whisper/whisper_base.py
+++ b/modules/whisper/base_transcription_pipeline.py
@@ -1,5 +1,4 @@
import os
-import torch
import whisper
import ctranslate2
import gradio as gr
@@ -9,21 +8,20 @@ from typing import BinaryIO, Union, Tuple, List
import numpy as np
from datetime import datetime
from faster_whisper.vad import VadOptions
-from dataclasses import astuple
from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR)
-from modules.utils.constants import AUTOMATIC_DETECTION
-from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
+from modules.utils.constants import *
+from modules.utils.subtitle_manager import *
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
-from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
-from modules.whisper.whisper_parameter import *
+from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
+from modules.whisper.data_classes import *
from modules.diarize.diarizer import Diarizer
from modules.vad.silero_vad import SileroVAD
-class WhisperBase(ABC):
+class BaseTranscriptionPipeline(ABC):
def __init__(self,
model_dir: str = WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -73,13 +71,15 @@ class WhisperBase(ABC):
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
+ file_format: str = "SRT",
add_timestamp: bool = True,
- *whisper_params,
- ) -> Tuple[List[dict], float]:
+ *pipeline_params,
+ ) -> Tuple[List[Segment], float]:
"""
Run transcription with conditional pre-processing and post-processing.
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
The diarization will be performed in post-processing, if enabled.
+ Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
Parameters
----------
@@ -87,40 +87,33 @@ class WhisperBase(ABC):
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
+ file_format: str
+ Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
- *whisper_params: tuple
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
+ *pipeline_params: tuple
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class.
+ This must be provided as a List with * wildcard because of the integration with gradio.
+ See more info at : https://github.com/gradio-app/gradio/issues/2471
Returns
----------
- segments_result: List[dict]
- list of dicts that includes start, end timestamps and transcribed text
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
- params = WhisperParameters.as_value(*whisper_params)
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
+ params = self.validate_gradio_values(params)
+ bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
- self.cache_parameters(
- whisper_params=params,
- add_timestamp=add_timestamp
- )
-
- if params.lang is None:
- pass
- elif params.lang == AUTOMATIC_DETECTION:
- params.lang = None
- else:
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
- params.lang = language_code_dict[params.lang]
-
- if params.is_bgm_separate:
+ if bgm_params.is_separate_bgm:
music, audio, _ = self.music_separator.separate(
audio=audio,
- model_name=params.uvr_model_size,
- device=params.uvr_device,
- segment_size=params.uvr_segment_size,
- save_file=params.uvr_save_file,
+ model_name=bgm_params.model_size,
+ device=bgm_params.device,
+ segment_size=bgm_params.segment_size,
+ save_file=bgm_params.save_file,
progress=progress
)
@@ -132,47 +125,55 @@ class WhisperBase(ABC):
origin_sample_rate = self.music_separator.audio_info.sample_rate
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
- if params.uvr_enable_offload:
+ if bgm_params.enable_offload:
self.music_separator.offload()
- if params.vad_filter:
- # Explicit value set for float('inf') from gr.Number()
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
- params.max_speech_duration_s = float('inf')
-
+ if vad_params.vad_filter:
vad_options = VadOptions(
- threshold=params.threshold,
- min_speech_duration_ms=params.min_speech_duration_ms,
- max_speech_duration_s=params.max_speech_duration_s,
- min_silence_duration_ms=params.min_silence_duration_ms,
- speech_pad_ms=params.speech_pad_ms
+ threshold=vad_params.threshold,
+ min_speech_duration_ms=vad_params.min_speech_duration_ms,
+ max_speech_duration_s=vad_params.max_speech_duration_s,
+ min_silence_duration_ms=vad_params.min_silence_duration_ms,
+ speech_pad_ms=vad_params.speech_pad_ms
)
- audio, speech_chunks = self.vad.run(
+ vad_processed, speech_chunks = self.vad.run(
audio=audio,
vad_parameters=vad_options,
progress=progress
)
+ if vad_processed.size > 0:
+ audio = vad_processed
+ else:
+ vad_params.vad_filter = False
+
result, elapsed_time = self.transcribe(
audio,
progress,
- *astuple(params)
+ *whisper_params.to_list()
)
- if params.vad_filter:
+ if vad_params.vad_filter:
result = self.vad.restore_speech_timestamps(
segments=result,
speech_chunks=speech_chunks,
)
- if params.is_diarize:
+ if diarization_params.is_diarize:
result, elapsed_time_diarization = self.diarizer.run(
audio=audio,
- use_auth_token=params.hf_token,
+ use_auth_token=diarization_params.hf_token,
transcribed_result=result,
+ device=diarization_params.device
)
elapsed_time += elapsed_time_diarization
+
+ self.cache_parameters(
+ params=params,
+ file_format=file_format,
+ add_timestamp=add_timestamp
+ )
return result, elapsed_time
def transcribe_file(self,
@@ -181,8 +182,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
- *whisper_params,
- ) -> list:
+ *pipeline_params,
+ ) -> Tuple[str, List]:
"""
Write subtitle file from Files
@@ -199,8 +200,8 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
- *whisper_params: tuple
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
+ *pipeline_params: tuple
+ Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
Returns
----------
@@ -210,6 +211,11 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
+ writer_options = {
+ "highlight_words": True if params.whisper.word_timestamps else False
+ }
+
if input_folder_path:
files = get_media_files(input_folder_path)
if isinstance(files, str):
@@ -222,19 +228,21 @@ class WhisperBase(ABC):
transcribed_segments, time_for_task = self.run(
file,
progress,
+ file_format,
add_timestamp,
- *whisper_params,
+ *pipeline_params,
)
file_name, file_ext = os.path.splitext(os.path.basename(file))
- subtitle, file_path = self.generate_and_write_file(
- file_name=file_name,
- transcribed_segments=transcribed_segments,
+ subtitle, file_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_format,
+ result=transcribed_segments,
add_timestamp=add_timestamp,
- file_format=file_format,
- output_dir=self.output_dir
+ **writer_options
)
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
+ files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
total_result = ''
total_time = 0
@@ -247,10 +255,11 @@ class WhisperBase(ABC):
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
result_file_path = [info['path'] for info in files_info.values()]
- return [result_str, result_file_path]
+ return result_str, result_file_path
except Exception as e:
print(f"Error transcribing file: {e}")
+ raise
finally:
self.release_cuda_memory()
@@ -259,8 +268,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
- *whisper_params,
- ) -> list:
+ *pipeline_params,
+ ) -> Tuple[str, str]:
"""
Write subtitle file from microphone
@@ -274,7 +283,7 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
- *whisper_params: tuple
+ *pipeline_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
@@ -285,27 +294,36 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
+ writer_options = {
+ "highlight_words": True if params.whisper.word_timestamps else False
+ }
+
progress(0, desc="Loading Audio..")
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
+ file_format,
add_timestamp,
- *whisper_params,
+ *pipeline_params,
)
progress(1, desc="Completed!")
- subtitle, result_file_path = self.generate_and_write_file(
- file_name="Mic",
- transcribed_segments=transcribed_segments,
+ file_name = "Mic"
+ subtitle, file_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_format,
+ result=transcribed_segments,
add_timestamp=add_timestamp,
- file_format=file_format,
- output_dir=self.output_dir
+ **writer_options
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
- return [result_str, result_file_path]
+ return result_str, file_path
except Exception as e:
- print(f"Error transcribing file: {e}")
+ print(f"Error transcribing mic: {e}")
+ raise
finally:
self.release_cuda_memory()
@@ -314,8 +332,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
- *whisper_params,
- ) -> list:
+ *pipeline_params,
+ ) -> Tuple[str, str]:
"""
Write subtitle file from Youtube
@@ -329,7 +347,7 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
- *whisper_params: tuple
+ *pipeline_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
@@ -340,6 +358,11 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
+ params = TranscriptionPipelineParams.from_list(list(pipeline_params))
+ writer_options = {
+ "highlight_words": True if params.whisper.word_timestamps else False
+ }
+
progress(0, desc="Loading Audio from Youtube..")
yt = get_ytdata(youtube_link)
audio = get_ytaudio(yt)
@@ -347,29 +370,33 @@ class WhisperBase(ABC):
transcribed_segments, time_for_task = self.run(
audio,
progress,
+ file_format,
add_timestamp,
- *whisper_params,
+ *pipeline_params,
)
progress(1, desc="Completed!")
file_name = safe_filename(yt.title)
- subtitle, result_file_path = self.generate_and_write_file(
- file_name=file_name,
- transcribed_segments=transcribed_segments,
+ subtitle, file_path = generate_file(
+ output_dir=self.output_dir,
+ output_file_name=file_name,
+ output_format=file_format,
+ result=transcribed_segments,
add_timestamp=add_timestamp,
- file_format=file_format,
- output_dir=self.output_dir
+ **writer_options
)
+
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
if os.path.exists(audio):
os.remove(audio)
- return [result_str, result_file_path]
+ return result_str, file_path
except Exception as e:
- print(f"Error transcribing file: {e}")
+ print(f"Error transcribing youtube: {e}")
+ raise
finally:
self.release_cuda_memory()
@@ -387,58 +414,6 @@ class WhisperBase(ABC):
else:
return list(ctranslate2.get_supported_compute_types("cpu"))
- @staticmethod
- def generate_and_write_file(file_name: str,
- transcribed_segments: list,
- add_timestamp: bool,
- file_format: str,
- output_dir: str
- ) -> str:
- """
- Writes subtitle file
-
- Parameters
- ----------
- file_name: str
- Output file name
- transcribed_segments: list
- Text segments transcribed from audio
- add_timestamp: bool
- Determines whether to add a timestamp to the end of the filename.
- file_format: str
- File format to write. Supported formats: [SRT, WebVTT, txt]
- output_dir: str
- Directory path of the output
-
- Returns
- ----------
- content: str
- Result of the transcription
- output_path: str
- output file path
- """
- if add_timestamp:
- timestamp = datetime.now().strftime("%m%d%H%M%S")
- output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
- else:
- output_path = os.path.join(output_dir, f"{file_name}")
-
- file_format = file_format.strip().lower()
- if file_format == "srt":
- content = get_srt(transcribed_segments)
- output_path += '.srt'
-
- elif file_format == "webvtt":
- content = get_vtt(transcribed_segments)
- output_path += '.vtt'
-
- elif file_format == "txt":
- content = get_txt(transcribed_segments)
- output_path += '.txt'
-
- write_file(content, output_path)
- return content, output_path
-
@staticmethod
def format_time(elapsed_time: float) -> str:
"""
@@ -471,7 +446,7 @@ class WhisperBase(ABC):
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
- if not WhisperBase.is_sparse_api_supported():
+ if not BaseTranscriptionPipeline.is_sparse_api_supported():
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
return "cpu"
return "mps"
@@ -513,17 +488,64 @@ class WhisperBase(ABC):
os.remove(file_path)
@staticmethod
- def cache_parameters(
- whisper_params: WhisperValues,
- add_timestamp: bool
- ):
- """cache parameters to the yaml file"""
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
- cached_whisper_param = whisper_params.to_yaml()
- cached_yaml = {**cached_params, **cached_whisper_param}
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
+ def validate_gradio_values(params: TranscriptionPipelineParams):
+ """
+ Validate gradio specific values that can't be displayed as None in the UI.
+ Related issue : https://github.com/gradio-app/gradio/issues/8723
+ """
+ if params.whisper.lang is None:
+ pass
+ elif params.whisper.lang == AUTOMATIC_DETECTION:
+ params.whisper.lang = None
+ else:
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
+ params.whisper.lang = language_code_dict[params.whisper.lang]
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
+ if params.whisper.initial_prompt == GRADIO_NONE_STR:
+ params.whisper.initial_prompt = None
+ if params.whisper.prefix == GRADIO_NONE_STR:
+ params.whisper.prefix = None
+ if params.whisper.hotwords == GRADIO_NONE_STR:
+ params.whisper.hotwords = None
+ if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
+ params.whisper.max_new_tokens = None
+ if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
+ params.whisper.hallucination_silence_threshold = None
+ if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
+ params.whisper.language_detection_threshold = None
+ if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
+ params.vad.max_speech_duration_s = float('inf')
+ return params
+
+ @staticmethod
+ def cache_parameters(
+ params: TranscriptionPipelineParams,
+ file_format: str = "SRT",
+ add_timestamp: bool = True
+ ):
+ """Cache parameters to the yaml file"""
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
+ param_to_cache = params.to_dict()
+
+ cached_yaml = {**cached_params, **param_to_cache}
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
+ cached_yaml["whisper"]["file_format"] = file_format
+
+ supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
+ if supress_token and isinstance(supress_token, list):
+ cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
+
+ if cached_yaml["whisper"].get("lang", None) is None:
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
+ else:
+ language_dict = whisper.tokenizer.LANGUAGES
+ cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
+
+ if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
+ cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
+
+ if cached_yaml is not None and cached_yaml:
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
@staticmethod
def resample_audio(audio: Union[str, np.ndarray],
diff --git a/modules/whisper/data_classes.py b/modules/whisper/data_classes.py
new file mode 100644
index 0000000..2bd9c31
--- /dev/null
+++ b/modules/whisper/data_classes.py
@@ -0,0 +1,608 @@
+import faster_whisper.transcribe
+import gradio as gr
+import torch
+from typing import Optional, Dict, List, Union, NamedTuple
+from pydantic import BaseModel, Field, field_validator, ConfigDict
+from gradio_i18n import Translate, gettext as _
+from enum import Enum
+from copy import deepcopy
+
+import yaml
+
+from modules.utils.constants import *
+
+
+class WhisperImpl(Enum):
+ WHISPER = "whisper"
+ FASTER_WHISPER = "faster-whisper"
+ INSANELY_FAST_WHISPER = "insanely_fast_whisper"
+
+
+class Segment(BaseModel):
+ id: Optional[int] = Field(default=None, description="Incremental id for the segment")
+ seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio")
+ text: Optional[str] = Field(default=None, description="Transcription text of the segment")
+ start: Optional[float] = Field(default=None, description="Start time of the segment")
+ end: Optional[float] = Field(default=None, description="End time of the segment")
+ tokens: Optional[List[int]] = Field(default=None, description="List of token IDs")
+ temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process")
+ avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens")
+ compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment")
+ no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech")
+ words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment")
+
+ @classmethod
+ def from_faster_whisper(cls,
+ seg: faster_whisper.transcribe.Segment):
+ if seg.words is not None:
+ words = [
+ Word(
+ start=w.start,
+ end=w.end,
+ word=w.word,
+ probability=w.probability
+ ) for w in seg.words
+ ]
+ else:
+ words = None
+
+ return cls(
+ id=seg.id,
+ seek=seg.seek,
+ text=seg.text,
+ start=seg.start,
+ end=seg.end,
+ tokens=seg.tokens,
+ temperature=seg.temperature,
+ avg_logprob=seg.avg_logprob,
+ compression_ratio=seg.compression_ratio,
+ no_speech_prob=seg.no_speech_prob,
+ words=words
+ )
+
+
+class Word(BaseModel):
+ start: Optional[float] = Field(default=None, description="Start time of the word")
+ end: Optional[float] = Field(default=None, description="Start time of the word")
+ word: Optional[str] = Field(default=None, description="Word text")
+ probability: Optional[float] = Field(default=None, description="Probability of the word")
+
+
+class BaseParams(BaseModel):
+ model_config = ConfigDict(protected_namespaces=())
+
+ def to_dict(self) -> Dict:
+ return self.model_dump()
+
+ def to_list(self) -> List:
+ return list(self.model_dump().values())
+
+ @classmethod
+ def from_list(cls, data_list: List) -> 'BaseParams':
+ field_names = list(cls.model_fields.keys())
+ return cls(**dict(zip(field_names, data_list)))
+
+
+class VadParams(BaseParams):
+ """Voice Activity Detection parameters"""
+ vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
+ threshold: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
+ )
+ min_speech_duration_ms: int = Field(
+ default=250,
+ ge=0,
+ description="Final speech chunks shorter than this are discarded"
+ )
+ max_speech_duration_s: float = Field(
+ default=float("inf"),
+ gt=0,
+ description="Maximum duration of speech chunks in seconds"
+ )
+ min_silence_duration_ms: int = Field(
+ default=2000,
+ ge=0,
+ description="Minimum silence duration between speech chunks"
+ )
+ speech_pad_ms: int = Field(
+ default=400,
+ ge=0,
+ description="Padding added to each side of speech chunks"
+ )
+
+ @classmethod
+ def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
+ return [
+ gr.Checkbox(
+ label=_("Enable Silero VAD Filter"),
+ value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
+ interactive=True,
+ info=_("Enable this to transcribe only detected voice")
+ ),
+ gr.Slider(
+ minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
+ value=defaults.get("threshold", cls.__fields__["threshold"].default),
+ info="Lower it to be more sensitive to small sounds."
+ ),
+ gr.Number(
+ label="Minimum Speech Duration (ms)", precision=0,
+ value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
+ info="Final speech chunks shorter than this time are thrown out"
+ ),
+ gr.Number(
+ label="Maximum Speech Duration (s)",
+ value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
+ info="Maximum duration of speech chunks in \"seconds\"."
+ ),
+ gr.Number(
+ label="Minimum Silence Duration (ms)", precision=0,
+ value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
+ info="In the end of each speech chunk wait for this time before separating it"
+ ),
+ gr.Number(
+ label="Speech Padding (ms)", precision=0,
+ value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
+ info="Final speech chunks are padded by this time each side"
+ )
+ ]
+
+
+class DiarizationParams(BaseParams):
+ """Speaker diarization parameters"""
+ is_diarize: bool = Field(default=False, description="Enable speaker diarization")
+ device: str = Field(default="cuda", description="Device to run Diarization model.")
+ hf_token: str = Field(
+ default="",
+ description="Hugging Face token for downloading diarization models"
+ )
+
+ @classmethod
+ def to_gradio_inputs(cls,
+ defaults: Optional[Dict] = None,
+ available_devices: Optional[List] = None,
+ device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
+ return [
+ gr.Checkbox(
+ label=_("Enable Diarization"),
+ value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
+ ),
+ gr.Dropdown(
+ label=_("Device"),
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
+ value=defaults.get("device", device),
+ ),
+ gr.Textbox(
+ label=_("HuggingFace Token"),
+ value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
+ info=_("This is only needed the first time you download the model")
+ ),
+ ]
+
+
+class BGMSeparationParams(BaseParams):
+ """Background music separation parameters"""
+ is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
+ model_size: str = Field(
+ default="UVR-MDX-NET-Inst_HQ_4",
+ description="UVR model size"
+ )
+ device: str = Field(default="cuda", description="Device to run UVR model.")
+ segment_size: int = Field(
+ default=256,
+ gt=0,
+ description="Segment size for UVR model"
+ )
+ save_file: bool = Field(
+ default=False,
+ description="Whether to save separated audio files"
+ )
+ enable_offload: bool = Field(
+ default=True,
+ description="Offload UVR model after transcription"
+ )
+
+ @classmethod
+ def to_gradio_input(cls,
+ defaults: Optional[Dict] = None,
+ available_devices: Optional[List] = None,
+ device: Optional[str] = None,
+ available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
+ return [
+ gr.Checkbox(
+ label=_("Enable Background Music Remover Filter"),
+ value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
+ interactive=True,
+ info=_("Enabling this will remove background music")
+ ),
+ gr.Dropdown(
+ label=_("Model"),
+ choices=["UVR-MDX-NET-Inst_HQ_4",
+ "UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
+ ),
+ gr.Dropdown(
+ label=_("Device"),
+ choices=["cpu", "cuda"] if available_devices is None else available_devices,
+ value=defaults.get("device", device),
+ ),
+ gr.Number(
+ label="Segment Size",
+ value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
+ precision=0,
+ info="Segment size for UVR model"
+ ),
+ gr.Checkbox(
+ label=_("Save separated files to output"),
+ value=defaults.get("save_file", cls.__fields__["save_file"].default),
+ ),
+ gr.Checkbox(
+ label=_("Offload sub model after removing background music"),
+ value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
+ )
+ ]
+
+
+class WhisperParams(BaseParams):
+ """Whisper parameters"""
+ model_size: str = Field(default="large-v2", description="Whisper model size")
+ lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
+ is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
+ beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
+ log_prob_threshold: float = Field(
+ default=-1.0,
+ description="Threshold for average log probability of sampled tokens"
+ )
+ no_speech_threshold: float = Field(
+ default=0.6,
+ ge=0.0,
+ le=1.0,
+ description="Threshold for detecting silence"
+ )
+ compute_type: str = Field(default="float16", description="Computation type for transcription")
+ best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
+ patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
+ condition_on_previous_text: bool = Field(
+ default=True,
+ description="Use previous output as prompt for next window"
+ )
+ prompt_reset_on_temperature: float = Field(
+ default=0.5,
+ ge=0.0,
+ le=1.0,
+ description="Temperature threshold for resetting prompt"
+ )
+ initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
+ temperature: float = Field(
+ default=0.0,
+ ge=0.0,
+ description="Temperature for sampling"
+ )
+ compression_ratio_threshold: float = Field(
+ default=2.4,
+ gt=0,
+ description="Threshold for gzip compression ratio"
+ )
+ length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
+ repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
+ no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
+ prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
+ suppress_blank: bool = Field(
+ default=True,
+ description="Suppress blank outputs at start of sampling"
+ )
+ suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
+ max_initial_timestamp: float = Field(
+ default=1.0,
+ ge=0.0,
+ description="Maximum initial timestamp"
+ )
+ word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
+ prepend_punctuations: Optional[str] = Field(
+ default="\"'“¿([{-",
+ description="Punctuations to merge with next word"
+ )
+ append_punctuations: Optional[str] = Field(
+ default="\"'.。,,!!??::”)]}、",
+ description="Punctuations to merge with previous word"
+ )
+ max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
+ chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
+ hallucination_silence_threshold: Optional[float] = Field(
+ default=None,
+ description="Threshold for skipping silent periods in hallucination detection"
+ )
+ hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
+ language_detection_threshold: Optional[float] = Field(
+ default=None,
+ description="Threshold for language detection probability"
+ )
+ language_detection_segments: int = Field(
+ default=1,
+ gt=0,
+ description="Number of segments for language detection"
+ )
+ batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
+
+ @field_validator('lang')
+ def validate_lang(cls, v):
+ from modules.utils.constants import AUTOMATIC_DETECTION
+ return None if v == AUTOMATIC_DETECTION.unwrap() else v
+
+ @field_validator('suppress_tokens')
+ def validate_supress_tokens(cls, v):
+ import ast
+ try:
+ if isinstance(v, str):
+ suppress_tokens = ast.literal_eval(v)
+ if not isinstance(suppress_tokens, list):
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
+ return suppress_tokens
+ if isinstance(v, list):
+ return v
+ except Exception as e:
+ raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
+
+ @classmethod
+ def to_gradio_inputs(cls,
+ defaults: Optional[Dict] = None,
+ only_advanced: Optional[bool] = True,
+ whisper_type: Optional[str] = None,
+ available_models: Optional[List] = None,
+ available_langs: Optional[List] = None,
+ available_compute_types: Optional[List] = None,
+ compute_type: Optional[str] = None):
+ whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
+
+ inputs = []
+ if not only_advanced:
+ inputs += [
+ gr.Dropdown(
+ label=_("Model"),
+ choices=available_models,
+ value=defaults.get("model_size", cls.__fields__["model_size"].default),
+ ),
+ gr.Dropdown(
+ label=_("Language"),
+ choices=available_langs,
+ value=defaults.get("lang", AUTOMATIC_DETECTION),
+ ),
+ gr.Checkbox(
+ label=_("Translate to English?"),
+ value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
+ ),
+ ]
+
+ inputs += [
+ gr.Number(
+ label="Beam Size",
+ value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
+ precision=0,
+ info="Beam size for decoding"
+ ),
+ gr.Number(
+ label="Log Probability Threshold",
+ value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
+ info="Threshold for average log probability of sampled tokens"
+ ),
+ gr.Number(
+ label="No Speech Threshold",
+ value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
+ info="Threshold for detecting silence"
+ ),
+ gr.Dropdown(
+ label="Compute Type",
+ choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
+ value=defaults.get("compute_type", compute_type),
+ info="Computation type for transcription"
+ ),
+ gr.Number(
+ label="Best Of",
+ value=defaults.get("best_of", cls.__fields__["best_of"].default),
+ precision=0,
+ info="Number of candidates when sampling"
+ ),
+ gr.Number(
+ label="Patience",
+ value=defaults.get("patience", cls.__fields__["patience"].default),
+ info="Beam search patience factor"
+ ),
+ gr.Checkbox(
+ label="Condition On Previous Text",
+ value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
+ info="Use previous output as prompt for next window"
+ ),
+ gr.Slider(
+ label="Prompt Reset On Temperature",
+ value=defaults.get("prompt_reset_on_temperature",
+ cls.__fields__["prompt_reset_on_temperature"].default),
+ minimum=0,
+ maximum=1,
+ step=0.01,
+ info="Temperature threshold for resetting prompt"
+ ),
+ gr.Textbox(
+ label="Initial Prompt",
+ value=defaults.get("initial_prompt", GRADIO_NONE_STR),
+ info="Initial prompt for first window"
+ ),
+ gr.Slider(
+ label="Temperature",
+ value=defaults.get("temperature", cls.__fields__["temperature"].default),
+ minimum=0.0,
+ step=0.01,
+ maximum=1.0,
+ info="Temperature for sampling"
+ ),
+ gr.Number(
+ label="Compression Ratio Threshold",
+ value=defaults.get("compression_ratio_threshold",
+ cls.__fields__["compression_ratio_threshold"].default),
+ info="Threshold for gzip compression ratio"
+ )
+ ]
+
+ faster_whisper_inputs = [
+ gr.Number(
+ label="Length Penalty",
+ value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
+ info="Exponential length penalty",
+ ),
+ gr.Number(
+ label="Repetition Penalty",
+ value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
+ info="Penalty for repeated tokens"
+ ),
+ gr.Number(
+ label="No Repeat N-gram Size",
+ value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
+ precision=0,
+ info="Size of n-grams to prevent repetition"
+ ),
+ gr.Textbox(
+ label="Prefix",
+ value=defaults.get("prefix", GRADIO_NONE_STR),
+ info="Prefix text for first window"
+ ),
+ gr.Checkbox(
+ label="Suppress Blank",
+ value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
+ info="Suppress blank outputs at start of sampling"
+ ),
+ gr.Textbox(
+ label="Suppress Tokens",
+ value=defaults.get("suppress_tokens", "[-1]"),
+ info="Token IDs to suppress"
+ ),
+ gr.Number(
+ label="Max Initial Timestamp",
+ value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
+ info="Maximum initial timestamp"
+ ),
+ gr.Checkbox(
+ label="Word Timestamps",
+ value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
+ info="Extract word-level timestamps"
+ ),
+ gr.Textbox(
+ label="Prepend Punctuations",
+ value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
+ info="Punctuations to merge with next word"
+ ),
+ gr.Textbox(
+ label="Append Punctuations",
+ value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
+ info="Punctuations to merge with previous word"
+ ),
+ gr.Number(
+ label="Max New Tokens",
+ value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
+ precision=0,
+ info="Maximum number of new tokens per chunk"
+ ),
+ gr.Number(
+ label="Chunk Length (s)",
+ value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
+ precision=0,
+ info="Length of audio segments in seconds"
+ ),
+ gr.Number(
+ label="Hallucination Silence Threshold (sec)",
+ value=defaults.get("hallucination_silence_threshold",
+ GRADIO_NONE_NUMBER_MIN),
+ info="Threshold for skipping silent periods in hallucination detection"
+ ),
+ gr.Textbox(
+ label="Hotwords",
+ value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
+ info="Hotwords/hint phrases for the model"
+ ),
+ gr.Number(
+ label="Language Detection Threshold",
+ value=defaults.get("language_detection_threshold",
+ GRADIO_NONE_NUMBER_MIN),
+ info="Threshold for language detection probability"
+ ),
+ gr.Number(
+ label="Language Detection Segments",
+ value=defaults.get("language_detection_segments",
+ cls.__fields__["language_detection_segments"].default),
+ precision=0,
+ info="Number of segments for language detection"
+ )
+ ]
+
+ insanely_fast_whisper_inputs = [
+ gr.Number(
+ label="Batch Size",
+ value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
+ precision=0,
+ info="Batch size for processing"
+ )
+ ]
+
+ if whisper_type != WhisperImpl.FASTER_WHISPER.value:
+ for input_component in faster_whisper_inputs:
+ input_component.visible = False
+
+ if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
+ for input_component in insanely_fast_whisper_inputs:
+ input_component.visible = False
+
+ inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
+
+ return inputs
+
+
+class TranscriptionPipelineParams(BaseModel):
+ """Transcription pipeline parameters"""
+ whisper: WhisperParams = Field(default_factory=WhisperParams)
+ vad: VadParams = Field(default_factory=VadParams)
+ diarization: DiarizationParams = Field(default_factory=DiarizationParams)
+ bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
+
+ def to_dict(self) -> Dict:
+ data = {
+ "whisper": self.whisper.to_dict(),
+ "vad": self.vad.to_dict(),
+ "diarization": self.diarization.to_dict(),
+ "bgm_separation": self.bgm_separation.to_dict()
+ }
+ return data
+
+ def to_list(self) -> List:
+ """
+ Convert data class to the list because I have to pass the parameters as a list in the gradio.
+ Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
+ See more about Gradio pre-processing: https://www.gradio.app/docs/components
+ """
+ whisper_list = self.whisper.to_list()
+ vad_list = self.vad.to_list()
+ diarization_list = self.diarization.to_list()
+ bgm_sep_list = self.bgm_separation.to_list()
+ return whisper_list + vad_list + diarization_list + bgm_sep_list
+
+ @staticmethod
+ def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
+ """Convert list to the data class again to use it in a function."""
+ data_list = deepcopy(pipeline_list)
+
+ whisper_list = data_list[0:len(WhisperParams.__annotations__)]
+ data_list = data_list[len(WhisperParams.__annotations__):]
+
+ vad_list = data_list[0:len(VadParams.__annotations__)]
+ data_list = data_list[len(VadParams.__annotations__):]
+
+ diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
+ data_list = data_list[len(DiarizationParams.__annotations__):]
+
+ bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
+
+ return TranscriptionPipelineParams(
+ whisper=WhisperParams.from_list(whisper_list),
+ vad=VadParams.from_list(vad_list),
+ diarization=DiarizationParams.from_list(diarization_list),
+ bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
+ )
diff --git a/modules/whisper/faster_whisper_inference.py b/modules/whisper/faster_whisper_inference.py
index f12fc01..bc1e8ed 100644
--- a/modules/whisper/faster_whisper_inference.py
+++ b/modules/whisper/faster_whisper_inference.py
@@ -12,11 +12,11 @@ import gradio as gr
from argparse import Namespace
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
-from modules.whisper.whisper_parameter import *
-from modules.whisper.whisper_base import WhisperBase
+from modules.whisper.data_classes import *
+from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
-class FasterWhisperInference(WhisperBase):
+class FasterWhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = FASTER_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
*whisper_params,
- ) -> Tuple[List[dict], float]:
+ ) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
Returns
----------
- segments_result: List[dict]
- list of dicts that includes start, end timestamps and transcribed text
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
- params = WhisperParameters.as_value(*whisper_params)
+ params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
- # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
- if not params.initial_prompt:
- params.initial_prompt = None
- if not params.prefix:
- params.prefix = None
- if not params.hotwords:
- params.hotwords = None
-
- params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
-
segments, info = self.model.transcribe(
audio=audio,
language=params.lang,
@@ -112,11 +102,7 @@ class FasterWhisperInference(WhisperBase):
segments_result = []
for segment in segments:
progress(segment.start / info.duration, desc="Transcribing..")
- segments_result.append({
- "start": segment.start,
- "end": segment.end,
- "text": segment.text
- })
+ segments_result.append(Segment.from_faster_whisper(segment))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
diff --git a/modules/whisper/insanely_fast_whisper_inference.py b/modules/whisper/insanely_fast_whisper_inference.py
index fe6f4fd..2773166 100644
--- a/modules/whisper/insanely_fast_whisper_inference.py
+++ b/modules/whisper/insanely_fast_whisper_inference.py
@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
from argparse import Namespace
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
-from modules.whisper.whisper_parameter import *
-from modules.whisper.whisper_base import WhisperBase
+from modules.whisper.data_classes import *
+from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
-class InsanelyFastWhisperInference(WhisperBase):
+class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -32,15 +32,13 @@ class InsanelyFastWhisperInference(WhisperBase):
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
- openai_models = whisper.available_models()
- distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
- self.available_models = openai_models + distil_models
+ self.available_models = self.get_model_paths()
def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
- ) -> Tuple[List[dict], float]:
+ ) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -55,13 +53,13 @@ class InsanelyFastWhisperInference(WhisperBase):
Returns
----------
- segments_result: List[dict]
- list of dicts that includes start, end timestamps and transcribed text
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
- params = WhisperParameters.as_value(*whisper_params)
+ params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
@@ -95,9 +93,17 @@ class InsanelyFastWhisperInference(WhisperBase):
generate_kwargs=kwargs
)
- segments_result = self.format_result(
- transcribed_result=segments,
- )
+ segments_result = []
+ for item in segments["chunks"]:
+ start, end = item["timestamp"][0], item["timestamp"][1]
+ if end is None:
+ end = start
+ segments_result.append(Segment(
+ text=item["text"],
+ start=start,
+ end=end
+ ))
+
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
@@ -138,31 +144,26 @@ class InsanelyFastWhisperInference(WhisperBase):
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
- @staticmethod
- def format_result(
- transcribed_result: dict
- ) -> List[dict]:
+ def get_model_paths(self):
"""
- Format the transcription result of insanely_fast_whisper as the same with other implementation.
-
- Parameters
- ----------
- transcribed_result: dict
- Transcription result of the insanely_fast_whisper
+ Get available models from models path including fine-tuned model.
Returns
----------
- result: List[dict]
- Formatted result as the same with other implementation
+ Name set of models
"""
- result = transcribed_result["chunks"]
- for item in result:
- start, end = item["timestamp"][0], item["timestamp"][1]
- if end is None:
- end = start
- item["start"] = start
- item["end"] = end
- return result
+ openai_models = whisper.available_models()
+ distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
+ default_models = openai_models + distil_models
+
+ existing_models = os.listdir(self.model_dir)
+ wrong_dirs = [".locks"]
+
+ available_models = default_models + existing_models
+ available_models = [model for model in available_models if model not in wrong_dirs]
+ available_models = sorted(set(available_models), key=available_models.index)
+
+ return available_models
@staticmethod
def download_model(
diff --git a/modules/whisper/whisper_Inference.py b/modules/whisper/whisper_Inference.py
index f87fbe5..ccd4bbb 100644
--- a/modules/whisper/whisper_Inference.py
+++ b/modules/whisper/whisper_Inference.py
@@ -8,11 +8,11 @@ import os
from argparse import Namespace
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
-from modules.whisper.whisper_base import WhisperBase
-from modules.whisper.whisper_parameter import *
+from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
+from modules.whisper.data_classes import *
-class WhisperInference(WhisperBase):
+class WhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
- ) -> Tuple[List[dict], float]:
+ ) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
Returns
----------
- segments_result: List[dict]
- list of dicts that includes start, end timestamps and transcribed text
+ segments_result: List[Segment]
+ list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
- params = WhisperParameters.as_value(*whisper_params)
+ params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
def progress_callback(progress_value):
progress(progress_value, desc="Transcribing..")
- segments_result = self.model.transcribe(audio=audio,
- language=params.lang,
- verbose=False,
- beam_size=params.beam_size,
- logprob_threshold=params.log_prob_threshold,
- no_speech_threshold=params.no_speech_threshold,
- task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
- fp16=True if params.compute_type == "float16" else False,
- best_of=params.best_of,
- patience=params.patience,
- temperature=params.temperature,
- compression_ratio_threshold=params.compression_ratio_threshold,
- progress_callback=progress_callback,)["segments"]
- elapsed_time = time.time() - start_time
+ result = self.model.transcribe(audio=audio,
+ language=params.lang,
+ verbose=False,
+ beam_size=params.beam_size,
+ logprob_threshold=params.log_prob_threshold,
+ no_speech_threshold=params.no_speech_threshold,
+ task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
+ fp16=True if params.compute_type == "float16" else False,
+ best_of=params.best_of,
+ patience=params.patience,
+ temperature=params.temperature,
+ compression_ratio_threshold=params.compression_ratio_threshold,
+ progress_callback=progress_callback,)["segments"]
+ segments_result = []
+ for segment in result:
+ segments_result.append(Segment(
+ start=segment["start"],
+ end=segment["end"],
+ text=segment["text"]
+ ))
+ elapsed_time = time.time() - start_time
return segments_result, elapsed_time
def update_model(self,
diff --git a/modules/whisper/whisper_factory.py b/modules/whisper/whisper_factory.py
index 6bda8c5..b5ae33a 100644
--- a/modules/whisper/whisper_factory.py
+++ b/modules/whisper/whisper_factory.py
@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.whisper_Inference import WhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
-from modules.whisper.whisper_base import WhisperBase
+from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
+from modules.whisper.data_classes import *
class WhisperFactory:
@@ -19,7 +20,7 @@ class WhisperFactory:
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
uvr_model_dir: str = UVR_MODELS_DIR,
output_dir: str = OUTPUT_DIR,
- ) -> "WhisperBase":
+ ) -> "BaseTranscriptionPipeline":
"""
Create a whisper inference class based on the provided whisper_type.
@@ -45,36 +46,29 @@ class WhisperFactory:
Returns
-------
- WhisperBase
+ BaseTranscriptionPipeline
An instance of the appropriate whisper inference class based on the whisper_type.
"""
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
- whisper_type = whisper_type.lower().strip()
+ whisper_type = whisper_type.strip().lower()
- faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
- whisper_typos = ["whisper"]
- insanely_fast_whisper_typos = [
- "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
- "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
- ]
-
- if whisper_type in faster_whisper_typos:
+ if whisper_type == WhisperImpl.FASTER_WHISPER.value:
return FasterWhisperInference(
model_dir=faster_whisper_model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
- elif whisper_type in whisper_typos:
+ elif whisper_type == WhisperImpl.WHISPER.value:
return WhisperInference(
model_dir=whisper_model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
- elif whisper_type in insanely_fast_whisper_typos:
+ elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
return InsanelyFastWhisperInference(
model_dir=insanely_fast_whisper_model_dir,
output_dir=output_dir,
diff --git a/modules/whisper/whisper_parameter.py b/modules/whisper/whisper_parameter.py
deleted file mode 100644
index 19115fc..0000000
--- a/modules/whisper/whisper_parameter.py
+++ /dev/null
@@ -1,371 +0,0 @@
-from dataclasses import dataclass, fields
-import gradio as gr
-from typing import Optional, Dict
-import yaml
-
-from modules.utils.constants import AUTOMATIC_DETECTION
-
-
-@dataclass
-class WhisperParameters:
- model_size: gr.Dropdown
- lang: gr.Dropdown
- is_translate: gr.Checkbox
- beam_size: gr.Number
- log_prob_threshold: gr.Number
- no_speech_threshold: gr.Number
- compute_type: gr.Dropdown
- best_of: gr.Number
- patience: gr.Number
- condition_on_previous_text: gr.Checkbox
- prompt_reset_on_temperature: gr.Slider
- initial_prompt: gr.Textbox
- temperature: gr.Slider
- compression_ratio_threshold: gr.Number
- vad_filter: gr.Checkbox
- threshold: gr.Slider
- min_speech_duration_ms: gr.Number
- max_speech_duration_s: gr.Number
- min_silence_duration_ms: gr.Number
- speech_pad_ms: gr.Number
- batch_size: gr.Number
- is_diarize: gr.Checkbox
- hf_token: gr.Textbox
- diarization_device: gr.Dropdown
- length_penalty: gr.Number
- repetition_penalty: gr.Number
- no_repeat_ngram_size: gr.Number
- prefix: gr.Textbox
- suppress_blank: gr.Checkbox
- suppress_tokens: gr.Textbox
- max_initial_timestamp: gr.Number
- word_timestamps: gr.Checkbox
- prepend_punctuations: gr.Textbox
- append_punctuations: gr.Textbox
- max_new_tokens: gr.Number
- chunk_length: gr.Number
- hallucination_silence_threshold: gr.Number
- hotwords: gr.Textbox
- language_detection_threshold: gr.Number
- language_detection_segments: gr.Number
- is_bgm_separate: gr.Checkbox
- uvr_model_size: gr.Dropdown
- uvr_device: gr.Dropdown
- uvr_segment_size: gr.Number
- uvr_save_file: gr.Checkbox
- uvr_enable_offload: gr.Checkbox
- """
- A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
- This data class is used to mitigate the key-value problem between Gradio components and function parameters.
- Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
- See more about Gradio pre-processing: https://www.gradio.app/docs/components
-
- Attributes
- ----------
- model_size: gr.Dropdown
- Whisper model size.
-
- lang: gr.Dropdown
- Source language of the file to transcribe.
-
- is_translate: gr.Checkbox
- Boolean value that determines whether to translate to English.
- It's Whisper's feature to translate speech from another language directly into English end-to-end.
-
- beam_size: gr.Number
- Int value that is used for decoding option.
-
- log_prob_threshold: gr.Number
- If the average log probability over sampled tokens is below this value, treat as failed.
-
- no_speech_threshold: gr.Number
- If the no_speech probability is higher than this value AND
- the average log probability over sampled tokens is below `log_prob_threshold`,
- consider the segment as silent.
-
- compute_type: gr.Dropdown
- compute type for transcription.
- see more info : https://opennmt.net/CTranslate2/quantization.html
-
- best_of: gr.Number
- Number of candidates when sampling with non-zero temperature.
-
- patience: gr.Number
- Beam search patience factor.
-
- condition_on_previous_text: gr.Checkbox
- if True, the previous output of the model is provided as a prompt for the next window;
- disabling may make the text inconsistent across windows, but the model becomes less prone to
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
-
- initial_prompt: gr.Textbox
- Optional text to provide as a prompt for the first window. This can be used to provide, or
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
- to make it more likely to predict those word correctly.
-
- temperature: gr.Slider
- Temperature for sampling. It can be a tuple of temperatures,
- which will be successively used upon failures according to either
- `compression_ratio_threshold` or `log_prob_threshold`.
-
- compression_ratio_threshold: gr.Number
- If the gzip compression ratio is above this value, treat as failed
-
- vad_filter: gr.Checkbox
- Enable the voice activity detection (VAD) to filter out parts of the audio
- without speech. This step is using the Silero VAD model
- https://github.com/snakers4/silero-vad.
-
- threshold: gr.Slider
- This parameter is related with Silero VAD. Speech threshold.
- Silero VAD outputs speech probabilities for each audio chunk,
- probabilities ABOVE this value are considered as SPEECH. It is better to tune this
- parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
-
- min_speech_duration_ms: gr.Number
- This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
-
- max_speech_duration_s: gr.Number
- This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
- than max_speech_duration_s will be split at the timestamp of the last silence that
- lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
- split aggressively just before max_speech_duration_s.
-
- min_silence_duration_ms: gr.Number
- This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
- before separating it
-
- speech_pad_ms: gr.Number
- This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
-
- batch_size: gr.Number
- This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
-
- is_diarize: gr.Checkbox
- This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
-
- hf_token: gr.Textbox
- This parameter is related with whisperx. Huggingface token is needed to download diarization models.
- Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
-
- diarization_device: gr.Dropdown
- This parameter is related with whisperx. Device to run diarization model
-
- length_penalty: gr.Number
- This parameter is related to faster-whisper. Exponential length penalty constant.
-
- repetition_penalty: gr.Number
- This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
- (set > 1 to penalize).
-
- no_repeat_ngram_size: gr.Number
- This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
-
- prefix: gr.Textbox
- This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
-
- suppress_blank: gr.Checkbox
- This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
-
- suppress_tokens: gr.Textbox
- This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
- of symbols as defined in the model config.json file.
-
- max_initial_timestamp: gr.Number
- This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
-
- word_timestamps: gr.Checkbox
- This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
- and dynamic time warping, and include the timestamps for each word in each segment.
-
- prepend_punctuations: gr.Textbox
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
- with the next word.
-
- append_punctuations: gr.Textbox
- This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
- with the previous word.
-
- max_new_tokens: gr.Number
- This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
- the maximum will be set by the default max_length.
-
- chunk_length: gr.Number
- This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
- If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
-
- hallucination_silence_threshold: gr.Number
- This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
- (in seconds) when a possible hallucination is detected.
-
- hotwords: gr.Textbox
- This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
-
- language_detection_threshold: gr.Number
- This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
-
- language_detection_segments: gr.Number
- This parameter is related to faster-whisper. Number of segments to consider for the language detection.
-
- is_separate_bgm: gr.Checkbox
- This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
-
- uvr_model_size: gr.Dropdown
- This parameter is related to UVR. UVR model size.
-
- uvr_device: gr.Dropdown
- This parameter is related to UVR. Device to run UVR model.
-
- uvr_segment_size: gr.Number
- This parameter is related to UVR. Segment size for UVR model.
-
- uvr_save_file: gr.Checkbox
- This parameter is related to UVR. Boolean value that determines whether to save the file or not.
-
- uvr_enable_offload: gr.Checkbox
- This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
- after each transcription.
- """
-
- def as_list(self) -> list:
- """
- Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
- See more about Gradio pre-processing: : https://www.gradio.app/docs/components
-
- Returns
- ----------
- A list of Gradio components
- """
- return [getattr(self, f.name) for f in fields(self)]
-
- @staticmethod
- def as_value(*args) -> 'WhisperValues':
- """
- To use Whisper parameters in function after Gradio post-processing.
- See more about Gradio post-processing: : https://www.gradio.app/docs/components
-
- Returns
- ----------
- WhisperValues
- Data class that has values of parameters
- """
- return WhisperValues(*args)
-
-
-@dataclass
-class WhisperValues:
- model_size: str = "large-v2"
- lang: Optional[str] = None
- is_translate: bool = False
- beam_size: int = 5
- log_prob_threshold: float = -1.0
- no_speech_threshold: float = 0.6
- compute_type: str = "float16"
- best_of: int = 5
- patience: float = 1.0
- condition_on_previous_text: bool = True
- prompt_reset_on_temperature: float = 0.5
- initial_prompt: Optional[str] = None
- temperature: float = 0.0
- compression_ratio_threshold: float = 2.4
- vad_filter: bool = False
- threshold: float = 0.5
- min_speech_duration_ms: int = 250
- max_speech_duration_s: float = float("inf")
- min_silence_duration_ms: int = 2000
- speech_pad_ms: int = 400
- batch_size: int = 24
- is_diarize: bool = False
- hf_token: str = ""
- diarization_device: str = "cuda"
- length_penalty: float = 1.0
- repetition_penalty: float = 1.0
- no_repeat_ngram_size: int = 0
- prefix: Optional[str] = None
- suppress_blank: bool = True
- suppress_tokens: Optional[str] = "[-1]"
- max_initial_timestamp: float = 0.0
- word_timestamps: bool = False
- prepend_punctuations: Optional[str] = "\"'“¿([{-"
- append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
- max_new_tokens: Optional[int] = None
- chunk_length: Optional[int] = 30
- hallucination_silence_threshold: Optional[float] = None
- hotwords: Optional[str] = None
- language_detection_threshold: Optional[float] = None
- language_detection_segments: int = 1
- is_bgm_separate: bool = False
- uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
- uvr_device: str = "cuda"
- uvr_segment_size: int = 256
- uvr_save_file: bool = False
- uvr_enable_offload: bool = True
- """
- A data class to use Whisper parameters.
- """
-
- def to_yaml(self) -> Dict:
- data = {
- "whisper": {
- "model_size": self.model_size,
- "lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
- "is_translate": self.is_translate,
- "beam_size": self.beam_size,
- "log_prob_threshold": self.log_prob_threshold,
- "no_speech_threshold": self.no_speech_threshold,
- "best_of": self.best_of,
- "patience": self.patience,
- "condition_on_previous_text": self.condition_on_previous_text,
- "prompt_reset_on_temperature": self.prompt_reset_on_temperature,
- "initial_prompt": None if not self.initial_prompt else self.initial_prompt,
- "temperature": self.temperature,
- "compression_ratio_threshold": self.compression_ratio_threshold,
- "batch_size": self.batch_size,
- "length_penalty": self.length_penalty,
- "repetition_penalty": self.repetition_penalty,
- "no_repeat_ngram_size": self.no_repeat_ngram_size,
- "prefix": None if not self.prefix else self.prefix,
- "suppress_blank": self.suppress_blank,
- "suppress_tokens": self.suppress_tokens,
- "max_initial_timestamp": self.max_initial_timestamp,
- "word_timestamps": self.word_timestamps,
- "prepend_punctuations": self.prepend_punctuations,
- "append_punctuations": self.append_punctuations,
- "max_new_tokens": self.max_new_tokens,
- "chunk_length": self.chunk_length,
- "hallucination_silence_threshold": self.hallucination_silence_threshold,
- "hotwords": None if not self.hotwords else self.hotwords,
- "language_detection_threshold": self.language_detection_threshold,
- "language_detection_segments": self.language_detection_segments,
- },
- "vad": {
- "vad_filter": self.vad_filter,
- "threshold": self.threshold,
- "min_speech_duration_ms": self.min_speech_duration_ms,
- "max_speech_duration_s": self.max_speech_duration_s,
- "min_silence_duration_ms": self.min_silence_duration_ms,
- "speech_pad_ms": self.speech_pad_ms,
- },
- "diarization": {
- "is_diarize": self.is_diarize,
- "hf_token": self.hf_token
- },
- "bgm_separation": {
- "is_separate_bgm": self.is_bgm_separate,
- "model_size": self.uvr_model_size,
- "segment_size": self.uvr_segment_size,
- "save_file": self.uvr_save_file,
- "enable_offload": self.uvr_enable_offload
- },
- }
- return data
-
- def as_list(self) -> list:
- """
- Converts the data class attributes into a list
-
- Returns
- ----------
- A list of Whisper parameters
- """
- return [getattr(self, f.name) for f in fields(self)]
diff --git a/notebook/whisper-webui.ipynb b/notebook/whisper-webui.ipynb
index 558daaf..0eb8eab 100644
--- a/notebook/whisper-webui.ipynb
+++ b/notebook/whisper-webui.ipynb
@@ -56,7 +56,7 @@
"!pip install faster-whisper==1.0.3\n",
"!pip install ctranslate2==4.4.0\n",
"!pip install gradio\n",
- "!pip install git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error\n",
+ "!pip install gradio-i18n\n",
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n",
"!pip install git+https://github.com/JuanBindez/pytubefix.git\n",
"!pip install tokenizers==0.19.1\n",
@@ -99,7 +99,7 @@
},
{
"cell_type": "code",
- "execution_count": 2,
+ "execution_count": 3,
"metadata": {
"id": "PQroYRRZzQiN",
"cellView": "form"
diff --git a/requirements.txt b/requirements.txt
index 9aeeb66..4a34878 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -11,7 +11,7 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
faster-whisper==1.0.3
transformers
gradio
-git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error
+gradio-i18n
pytubefix
ruamel.yaml==0.18.6
pyannote.audio==3.3.1
diff --git a/tests/test_bgm_separation.py b/tests/test_bgm_separation.py
index cc4a6f8..95b77a0 100644
--- a/tests/test_bgm_separation.py
+++ b/tests/test_bgm_separation.py
@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
-from modules.whisper.whisper_parameter import WhisperValues
+from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -17,9 +17,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
- ("whisper", False, True, False),
- ("faster-whisper", False, True, False),
- ("insanely_fast_whisper", False, True, False)
+ (WhisperImpl.WHISPER.value, False, True, False),
+ (WhisperImpl.FASTER_WHISPER.value, False, True, False),
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
]
)
def test_bgm_separation_pipeline(
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
- ("whisper", True, True, False),
- ("faster-whisper", True, True, False),
- ("insanely_fast_whisper", True, True, False)
+ (WhisperImpl.WHISPER.value, True, True, False),
+ (WhisperImpl.FASTER_WHISPER.value, True, True, False),
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
]
)
def test_bgm_separation_with_vad_pipeline(
diff --git a/tests/test_config.py b/tests/test_config.py
index 0f60aa5..f82e4f1 100644
--- a/tests/test_config.py
+++ b/tests/test_config.py
@@ -1,10 +1,14 @@
-from modules.utils.paths import *
-
+import functools
+import jiwer
import os
import torch
+from modules.utils.paths import *
+from modules.utils.youtube_manager import *
+
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
+TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
TEST_WHISPER_MODEL = "tiny"
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
@@ -13,5 +17,24 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
+@functools.lru_cache
def is_cuda_available():
return torch.cuda.is_available()
+
+
+@functools.lru_cache
+def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
+ try:
+ yt_temp_path = os.path.join("modules", "yt_tmp.wav")
+ if os.path.exists(yt_temp_path):
+ return False
+ yt = get_ytdata(url)
+ audio = get_ytaudio(yt)
+ return False
+ except Exception as e:
+ print(f"Pytube has detected as a bot: {e}")
+ return True
+
+
+def calculate_wer(answer, prediction):
+ return jiwer.wer(answer, prediction)
diff --git a/tests/test_diarization.py b/tests/test_diarization.py
index 54e7244..f18a263 100644
--- a/tests/test_diarization.py
+++ b/tests/test_diarization.py
@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
-from modules.whisper.whisper_parameter import WhisperValues
+from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -16,9 +16,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
- ("whisper", False, False, True),
- ("faster-whisper", False, False, True),
- ("insanely_fast_whisper", False, False, True)
+ (WhisperImpl.WHISPER.value, False, False, True),
+ (WhisperImpl.FASTER_WHISPER.value, False, False, True),
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
]
)
def test_diarization_pipeline(
diff --git a/tests/test_transcription.py b/tests/test_transcription.py
index 4b5ab98..bc5267c 100644
--- a/tests/test_transcription.py
+++ b/tests/test_transcription.py
@@ -1,5 +1,6 @@
from modules.whisper.whisper_factory import WhisperFactory
-from modules.whisper.whisper_parameter import WhisperValues
+from modules.whisper.data_classes import *
+from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *
@@ -12,9 +13,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
- ("whisper", False, False, False),
- ("faster-whisper", False, False, False),
- ("insanely_fast_whisper", False, False, False)
+ (WhisperImpl.WHISPER.value, False, False, False),
+ (WhisperImpl.FASTER_WHISPER.value, False, False, False),
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
]
)
def test_transcribe(
@@ -28,6 +29,10 @@ def test_transcribe(
if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
+ answer = TEST_ANSWER
+ if diarization:
+ answer = "SPEAKER_00|"+TEST_ANSWER
+
whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type,
)
@@ -37,16 +42,24 @@ def test_transcribe(
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
)
- hparams = WhisperValues(
- model_size=TEST_WHISPER_MODEL,
- vad_filter=vad_filter,
- is_bgm_separate=bgm_separation,
- compute_type=whisper_inferencer.current_compute_type,
- uvr_enable_offload=True,
- is_diarize=diarization,
- ).as_list()
+ hparams = TranscriptionPipelineParams(
+ whisper=WhisperParams(
+ model_size=TEST_WHISPER_MODEL,
+ compute_type=whisper_inferencer.current_compute_type
+ ),
+ vad=VadParams(
+ vad_filter=vad_filter
+ ),
+ bgm_separation=BGMSeparationParams(
+ is_separate_bgm=bgm_separation,
+ enable_offload=True
+ ),
+ diarization=DiarizationParams(
+ is_diarize=diarization
+ ),
+ ).to_list()
- subtitle_str, file_path = whisper_inferencer.transcribe_file(
+ subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path],
None,
"SRT",
@@ -54,29 +67,29 @@ def test_transcribe(
gr.Progress(),
*hparams,
)
+ subtitle = read_file(file_paths[0]).split("\n")
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
- assert isinstance(subtitle_str, str) and subtitle_str
- assert isinstance(file_path[0], str) and file_path
+ if not is_pytube_detected_bot():
+ subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
+ TEST_YOUTUBE_URL,
+ "SRT",
+ False,
+ gr.Progress(),
+ *hparams,
+ )
+ assert isinstance(subtitle_str, str) and subtitle_str
+ assert os.path.exists(file_path)
- whisper_inferencer.transcribe_youtube(
- TEST_YOUTUBE_URL,
- "SRT",
- False,
- gr.Progress(),
- *hparams,
- )
- assert isinstance(subtitle_str, str) and subtitle_str
- assert isinstance(file_path[0], str) and file_path
-
- whisper_inferencer.transcribe_mic(
+ subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path,
"SRT",
False,
gr.Progress(),
*hparams,
)
- assert isinstance(subtitle_str, str) and subtitle_str
- assert isinstance(file_path[0], str) and file_path
+ subtitle = read_file(file_path).split("\n")
+ assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
def download_file(url, save_dir):
diff --git a/tests/test_vad.py b/tests/test_vad.py
index 124a043..cb3dc05 100644
--- a/tests/test_vad.py
+++ b/tests/test_vad.py
@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
-from modules.whisper.whisper_parameter import WhisperValues
+from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -12,9 +12,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
- ("whisper", True, False, False),
- ("faster-whisper", True, False, False),
- ("insanely_fast_whisper", True, False, False)
+ (WhisperImpl.WHISPER.value, True, False, False),
+ (WhisperImpl.FASTER_WHISPER.value, True, False, False),
+ (WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
]
)
def test_vad_pipeline(