Merge branch 'master' into salfter
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
||||
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
|
||||
|
||||
- name: Install dependencies
|
||||
run: pip install -r requirements.txt pytest
|
||||
run: pip install -r requirements.txt pytest jiwer
|
||||
|
||||
- name: Run test
|
||||
run: python -m pytest -rs tests
|
||||
@@ -128,3 +128,6 @@ This is Whisper's original VRAM usage table for models.
|
||||
- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
|
||||
- [ ] Add fast api script
|
||||
- [ ] Support real-time transcription for microphone
|
||||
|
||||
### Translation 🌐
|
||||
Any PRs translating Japanese, Spanish, French, German, Chinese, or any other language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!
|
||||
|
||||
183
app.py
183
app.py
@@ -7,17 +7,14 @@ import yaml
|
||||
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
||||
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
||||
UVR_MODELS_DIR, I18N_YAML_PATH)
|
||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
||||
from modules.utils.files_manager import load_yaml
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
||||
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
||||
from modules.translation.nllb_inference import NLLBInference
|
||||
from modules.ui.htmls import *
|
||||
from modules.utils.cli_manager import str2bool
|
||||
from modules.utils.youtube_manager import get_ytmetas
|
||||
from modules.translation.deepl_api import DeepLAPI
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.whisper.data_classes import *
|
||||
|
||||
|
||||
class App:
|
||||
@@ -44,7 +41,7 @@ class App:
|
||||
print(f"Use \"{self.args.whisper_type}\" implementation\n"
|
||||
f"Device \"{self.whisper_inf.device}\" is detected")
|
||||
|
||||
def create_whisper_parameters(self):
|
||||
def create_pipeline_inputs(self):
|
||||
whisper_params = self.default_params["whisper"]
|
||||
vad_params = self.default_params["vad"]
|
||||
diarization_params = self.default_params["diarization"]
|
||||
@@ -56,7 +53,7 @@ class App:
|
||||
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
|
||||
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
|
||||
else whisper_params["lang"], label=_("Language"))
|
||||
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
|
||||
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
|
||||
with gr.Row():
|
||||
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
|
||||
interactive=True)
|
||||
@@ -66,158 +63,31 @@ class App:
|
||||
interactive=True)
|
||||
|
||||
with gr.Accordion(_("Advanced Parameters"), open=False):
|
||||
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0,
|
||||
interactive=True,
|
||||
info="Beam size to use for decoding.")
|
||||
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold",
|
||||
value=whisper_params["log_prob_threshold"], interactive=True,
|
||||
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
||||
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
|
||||
interactive=True,
|
||||
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
||||
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
||||
value=self.whisper_inf.current_compute_type, interactive=True,
|
||||
allow_custom_value=True,
|
||||
info="Select the type of computation to perform.")
|
||||
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
||||
info="Number of candidates when sampling with non-zero temperature.")
|
||||
nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
|
||||
info="Beam search patience factor.")
|
||||
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
|
||||
value=whisper_params["condition_on_previous_text"],
|
||||
interactive=True,
|
||||
info="Condition on previous text during decoding.")
|
||||
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
|
||||
value=whisper_params["prompt_reset_on_temperature"],
|
||||
minimum=0, maximum=1, step=0.01, interactive=True,
|
||||
info="Resets prompt if temperature is above this value."
|
||||
" Arg has effect only if 'Condition On Previous Text' is True.")
|
||||
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
||||
info="Initial prompt to use for decoding.")
|
||||
sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
|
||||
step=0.01, maximum=1.0, interactive=True,
|
||||
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
||||
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
|
||||
value=whisper_params["compression_ratio_threshold"],
|
||||
interactive=True,
|
||||
info="If the gzip compression ratio is above this value, treat as failed.")
|
||||
nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
|
||||
precision=0,
|
||||
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
||||
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
||||
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
||||
info="Exponential length penalty constant.")
|
||||
nb_repetition_penalty = gr.Number(label="Repetition Penalty",
|
||||
value=whisper_params["repetition_penalty"],
|
||||
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
||||
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
|
||||
value=whisper_params["no_repeat_ngram_size"],
|
||||
precision=0,
|
||||
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
||||
tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
|
||||
info="Optional text to provide as a prefix for the first window.")
|
||||
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
|
||||
info="Suppress blank outputs at the beginning of the sampling.")
|
||||
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
|
||||
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
||||
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
|
||||
value=whisper_params["max_initial_timestamp"],
|
||||
info="The initial timestamp cannot be later than this.")
|
||||
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
|
||||
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
||||
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
|
||||
value=whisper_params["prepend_punctuations"],
|
||||
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
||||
tb_append_punctuations = gr.Textbox(label="Append Punctuations",
|
||||
value=whisper_params["append_punctuations"],
|
||||
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
||||
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
||||
precision=0,
|
||||
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
||||
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
||||
value=lambda: whisper_params[
|
||||
"hallucination_silence_threshold"],
|
||||
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
||||
tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
|
||||
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
||||
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
|
||||
value=lambda: whisper_params[
|
||||
"language_detection_threshold"],
|
||||
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
||||
nb_language_detection_segments = gr.Number(label="Language Detection Segments",
|
||||
value=lambda: whisper_params["language_detection_segments"],
|
||||
precision=0,
|
||||
info="Number of segments to consider for the language detection.")
|
||||
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
||||
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
||||
whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
|
||||
whisper_type=self.args.whisper_type,
|
||||
available_compute_types=self.whisper_inf.available_compute_types,
|
||||
compute_type=self.whisper_inf.current_compute_type)
|
||||
|
||||
with gr.Accordion(_("Background Music Remover Filter"), open=False):
|
||||
cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"),
|
||||
value=uvr_params["is_separate_bgm"],
|
||||
interactive=True,
|
||||
info=_("Enabling this will remove background music"))
|
||||
dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
|
||||
choices=self.whisper_inf.music_separator.available_devices)
|
||||
dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
|
||||
choices=self.whisper_inf.music_separator.available_models)
|
||||
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
|
||||
cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
|
||||
cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
|
||||
value=uvr_params["enable_offload"])
|
||||
uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
|
||||
available_models=self.whisper_inf.music_separator.available_models,
|
||||
available_devices=self.whisper_inf.music_separator.available_devices,
|
||||
device=self.whisper_inf.music_separator.device)
|
||||
|
||||
with gr.Accordion(_("Voice Detection Filter"), open=False):
|
||||
cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"],
|
||||
interactive=True,
|
||||
info=_("Enable this to transcribe only detected voice"))
|
||||
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
||||
value=vad_params["threshold"],
|
||||
info="Lower it to be more sensitive to small sounds.")
|
||||
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
|
||||
value=vad_params["min_speech_duration_ms"],
|
||||
info="Final speech chunks shorter than this time are thrown out")
|
||||
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
|
||||
value=vad_params["max_speech_duration_s"],
|
||||
info="Maximum duration of speech chunks in \"seconds\".")
|
||||
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
|
||||
value=vad_params["min_silence_duration_ms"],
|
||||
info="In the end of each speech chunk wait for this time"
|
||||
" before separating it")
|
||||
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
||||
info="Final speech chunks are padded by this time each side")
|
||||
vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
|
||||
|
||||
with gr.Accordion(_("Diarization"), open=False):
|
||||
cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"])
|
||||
tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"],
|
||||
info=_("This is only needed the first time you download the model"))
|
||||
dd_diarization_device = gr.Dropdown(label=_("Device"),
|
||||
choices=self.whisper_inf.diarizer.get_available_device(),
|
||||
value=self.whisper_inf.diarizer.get_device())
|
||||
diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
|
||||
available_devices=self.whisper_inf.diarizer.available_device,
|
||||
device=self.whisper_inf.diarizer.device)
|
||||
|
||||
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
||||
|
||||
pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
|
||||
|
||||
return (
|
||||
WhisperParameters(
|
||||
model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
|
||||
log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
|
||||
compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
|
||||
condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
|
||||
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
||||
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
||||
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
||||
speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
|
||||
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
||||
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
||||
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
||||
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
||||
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
||||
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
|
||||
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
||||
language_detection_threshold=nb_language_detection_threshold,
|
||||
language_detection_segments=nb_language_detection_segments,
|
||||
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
|
||||
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
|
||||
uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
|
||||
),
|
||||
pipeline_inputs,
|
||||
dd_file_format,
|
||||
cb_timestamp
|
||||
)
|
||||
@@ -243,7 +113,7 @@ class App:
|
||||
visible=self.args.colab,
|
||||
value="")
|
||||
|
||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
||||
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||
|
||||
with gr.Row():
|
||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||
@@ -254,7 +124,7 @@ class App:
|
||||
|
||||
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
||||
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
||||
inputs=params + whisper_params.as_list(),
|
||||
inputs=params + pipeline_params,
|
||||
outputs=[tb_indicator, files_subtitles])
|
||||
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
||||
|
||||
@@ -268,7 +138,7 @@ class App:
|
||||
tb_title = gr.Label(label=_("Youtube Title"))
|
||||
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
|
||||
|
||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
||||
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||
|
||||
with gr.Row():
|
||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||
@@ -280,7 +150,7 @@ class App:
|
||||
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
||||
|
||||
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
||||
inputs=params + whisper_params.as_list(),
|
||||
inputs=params + pipeline_params,
|
||||
outputs=[tb_indicator, files_subtitles])
|
||||
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
||||
outputs=[img_thumbnail, tb_title, tb_description])
|
||||
@@ -290,7 +160,7 @@ class App:
|
||||
with gr.Row():
|
||||
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
|
||||
|
||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
||||
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||
|
||||
with gr.Row():
|
||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||
@@ -302,7 +172,7 @@ class App:
|
||||
params = [mic_input, dd_file_format, cb_timestamp]
|
||||
|
||||
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
||||
inputs=params + whisper_params.as_list(),
|
||||
inputs=params + pipeline_params,
|
||||
outputs=[tb_indicator, files_subtitles])
|
||||
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
||||
|
||||
@@ -417,7 +287,6 @@ class App:
|
||||
|
||||
# Launch the app with optional gradio settings
|
||||
args = self.args
|
||||
|
||||
self.app.queue(
|
||||
api_open=args.api_open
|
||||
).launch(
|
||||
@@ -447,8 +316,8 @@ class App:
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--whisper_type', type=str, default="faster-whisper",
|
||||
choices=["whisper", "faster-whisper", "insanely-fast-whisper"],
|
||||
parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
|
||||
choices=[item.value for item in WhisperImpl],
|
||||
help='A type of the whisper implementation (Github repo name)')
|
||||
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
|
||||
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
whisper:
|
||||
model_size: "medium.en"
|
||||
file_format: "SRT"
|
||||
lang: "english"
|
||||
is_translate: false
|
||||
beam_size: 5
|
||||
|
||||
@@ -411,3 +411,49 @@ ru: # Russian
|
||||
Instrumental: Инструментал
|
||||
Vocals: Вокал
|
||||
SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
|
||||
|
||||
tr: # Turkish
|
||||
Language: Dil
|
||||
File: Dosya
|
||||
Youtube: Youtube
|
||||
Mic: Mikrofon
|
||||
T2T Translation: T2T Çeviri
|
||||
BGM Separation: Arka Plan Müziği Ayırma
|
||||
GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR
|
||||
Output: Çıktı
|
||||
Downloadable output file: İndirilebilir çıktı dosyası
|
||||
Upload File here: Dosya Yükle
|
||||
Model: Model
|
||||
Automatic Detection: Otomatik Algılama
|
||||
File Format: Dosya Formatı
|
||||
Translate to English?: İngilizceye Çevir?
|
||||
Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle
|
||||
Advanced Parameters: Gelişmiş Parametreler
|
||||
Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi
|
||||
Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır
|
||||
Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir
|
||||
Save separated files to output: Ayrılmış dosyaları çıktıya kaydet
|
||||
Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak
|
||||
Voice Detection Filter: Ses Algılama Filtresi
|
||||
Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et
|
||||
Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir
|
||||
Diarization: Konuşmacı Ayrımı
|
||||
Enable Diarization: Konuşmacı Ayrımını Etkinleştir
|
||||
HuggingFace Token: HuggingFace Anahtarı
|
||||
This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co/pyannote/speaker-diarization-3.1" ve "https://huggingface.co/pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor
|
||||
Device: Cihaz
|
||||
Youtube Link: Youtube Bağlantısı
|
||||
Youtube Thumbnail: Youtube Küçük Resmi
|
||||
Youtube Title: Youtube Başlığı
|
||||
Youtube Description: Youtube Açıklaması
|
||||
Record with Mic: Mikrofonla Kaydet
|
||||
Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle
|
||||
Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI)
|
||||
Source Language: Kaynak Dil
|
||||
Target Language: Hedef Dil
|
||||
Pro User?: Pro Kullanıcı?
|
||||
TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR
|
||||
Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle
|
||||
Instrumental: Enstrümantal
|
||||
Vocals: Vokal
|
||||
SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR
|
||||
|
||||
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
|
||||
from typing import Optional, Union
|
||||
import torch
|
||||
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
||||
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
||||
|
||||
@@ -43,6 +44,8 @@ class DiarizationPipeline:
|
||||
|
||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
transcript_segments = transcript_result["segments"]
|
||||
if transcript_segments and isinstance(transcript_segments[0], Segment):
|
||||
transcript_segments = [seg.model_dump() for seg in transcript_segments]
|
||||
for seg in transcript_segments:
|
||||
# assign speaker to segment (if any)
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
||||
@@ -63,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
seg["speaker"] = speaker
|
||||
|
||||
# assign speaker to words
|
||||
if 'words' in seg:
|
||||
if 'words' in seg and seg['words'] is not None:
|
||||
for word in seg['words']:
|
||||
if 'start' in word:
|
||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
||||
@@ -85,10 +88,10 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||
if word_speaker is not None:
|
||||
word["speaker"] = word_speaker
|
||||
|
||||
return transcript_result
|
||||
return {"segments": transcript_segments}
|
||||
|
||||
|
||||
class Segment:
|
||||
class DiarizationSegment:
|
||||
def __init__(self, start, end, speaker=None):
|
||||
self.start = start
|
||||
self.end = end
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
import torch
|
||||
from typing import List, Union, BinaryIO, Optional
|
||||
from typing import List, Union, BinaryIO, Optional, Tuple
|
||||
import numpy as np
|
||||
import time
|
||||
import logging
|
||||
@@ -8,6 +8,7 @@ import logging
|
||||
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
||||
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
||||
from modules.diarize.audio_loader import load_audio
|
||||
from modules.whisper.data_classes import *
|
||||
|
||||
|
||||
class Diarizer:
|
||||
@@ -23,10 +24,10 @@ class Diarizer:
|
||||
|
||||
def run(self,
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
transcribed_result: List[dict],
|
||||
transcribed_result: List[Segment],
|
||||
use_auth_token: str,
|
||||
device: Optional[str] = None
|
||||
):
|
||||
) -> Tuple[List[Segment], float]:
|
||||
"""
|
||||
Diarize transcribed result as a post-processing
|
||||
|
||||
@@ -34,7 +35,7 @@ class Diarizer:
|
||||
----------
|
||||
audio: Union[str, BinaryIO, np.ndarray]
|
||||
Audio input. This can be file path or binary type.
|
||||
transcribed_result: List[dict]
|
||||
transcribed_result: List[Segment]
|
||||
transcribed result through whisper.
|
||||
use_auth_token: str
|
||||
Huggingface token with READ permission. This is only needed the first time you download the model.
|
||||
@@ -44,8 +45,8 @@ class Diarizer:
|
||||
|
||||
Returns
|
||||
----------
|
||||
segments_result: List[dict]
|
||||
list of dicts that includes start, end timestamps and transcribed text
|
||||
segments_result: List[Segment]
|
||||
list of Segment that includes start, end timestamps and transcribed text
|
||||
elapsed_time: float
|
||||
elapsed time for running
|
||||
"""
|
||||
@@ -68,14 +69,20 @@ class Diarizer:
|
||||
{"segments": transcribed_result}
|
||||
)
|
||||
|
||||
segments_result = []
|
||||
for segment in diarized_result["segments"]:
|
||||
speaker = "None"
|
||||
if "speaker" in segment:
|
||||
speaker = segment["speaker"]
|
||||
segment["text"] = speaker + "|" + segment["text"].strip()
|
||||
diarized_text = speaker + "|" + segment["text"].strip()
|
||||
segments_result.append(Segment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=diarized_text
|
||||
))
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
return diarized_result["segments"], elapsed_time
|
||||
return segments_result, elapsed_time
|
||||
|
||||
def update_pipe(self,
|
||||
use_auth_token: str,
|
||||
|
||||
@@ -139,37 +139,27 @@ class DeepLAPI:
|
||||
)
|
||||
|
||||
files_info = {}
|
||||
for fileobj in fileobjs:
|
||||
file_path = fileobj
|
||||
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
||||
|
||||
if file_ext == ".srt":
|
||||
parsed_dicts = parse_srt(file_path=file_path)
|
||||
|
||||
elif file_ext == ".vtt":
|
||||
parsed_dicts = parse_vtt(file_path=file_path)
|
||||
for file_path in fileobjs:
|
||||
file_name, file_ext = os.path.splitext(os.path.basename(file_path))
|
||||
writer = get_writer(file_ext, self.output_dir)
|
||||
segments = writer.to_segments(file_path)
|
||||
|
||||
batch_size = self.max_text_batch_size
|
||||
for batch_start in range(0, len(parsed_dicts), batch_size):
|
||||
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
||||
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
||||
for batch_start in range(0, len(segments), batch_size):
|
||||
progress(batch_start / len(segments), desc="Translating..")
|
||||
sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
|
||||
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
||||
target_lang, is_pro)
|
||||
for i, translated_text in enumerate(translated_texts):
|
||||
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
||||
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
||||
segments[batch_start + i].text = translated_text["text"]
|
||||
|
||||
if file_ext == ".srt":
|
||||
subtitle = get_serialized_srt(parsed_dicts)
|
||||
elif file_ext == ".vtt":
|
||||
subtitle = get_serialized_vtt(parsed_dicts)
|
||||
|
||||
if add_timestamp:
|
||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
||||
file_name += f"-{timestamp}"
|
||||
|
||||
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
||||
write_file(subtitle, output_path)
|
||||
subtitle, output_path = generate_file(
|
||||
output_dir=self.output_dir,
|
||||
output_file_name=file_name,
|
||||
output_format=file_ext,
|
||||
result=segments,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
|
||||
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from typing import List
|
||||
from datetime import datetime
|
||||
|
||||
import modules.translation.nllb_inference as nllb
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.utils.subtitle_manager import *
|
||||
from modules.utils.files_manager import load_yaml, save_yaml
|
||||
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
||||
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
|
||||
files_info = {}
|
||||
for fileobj in fileobjs:
|
||||
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
||||
if file_ext == ".srt":
|
||||
parsed_dicts = parse_srt(file_path=fileobj)
|
||||
total_progress = len(parsed_dicts)
|
||||
for index, dic in enumerate(parsed_dicts):
|
||||
progress(index / total_progress, desc="Translating..")
|
||||
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
||||
dic["sentence"] = translated_text
|
||||
subtitle = get_serialized_srt(parsed_dicts)
|
||||
writer = get_writer(file_ext, self.output_dir)
|
||||
segments = writer.to_segments(fileobj)
|
||||
for i, segment in enumerate(segments):
|
||||
progress(i / len(segments), desc="Translating..")
|
||||
translated_text = self.translate(segment.text, max_length=max_length)
|
||||
segment.text = translated_text
|
||||
|
||||
elif file_ext == ".vtt":
|
||||
parsed_dicts = parse_vtt(file_path=fileobj)
|
||||
total_progress = len(parsed_dicts)
|
||||
for index, dic in enumerate(parsed_dicts):
|
||||
progress(index / total_progress, desc="Translating..")
|
||||
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
||||
dic["sentence"] = translated_text
|
||||
subtitle = get_serialized_vtt(parsed_dicts)
|
||||
subtitle, file_path = generate_file(
|
||||
output_dir=self.output_dir,
|
||||
output_file_name=file_name,
|
||||
output_format=file_ext,
|
||||
result=segments,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
|
||||
if add_timestamp:
|
||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
||||
file_name += f"-{timestamp}"
|
||||
|
||||
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
||||
write_file(subtitle, output_path)
|
||||
|
||||
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
||||
files_info[file_name] = {"subtitle": subtitle, "path": file_path}
|
||||
|
||||
total_result = ''
|
||||
for file_name, info in files_info.items():
|
||||
@@ -133,7 +123,8 @@ class TranslationBase(ABC):
|
||||
return [gr_str, output_file_paths]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
print(f"Error translating file: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.release_cuda_memory()
|
||||
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
from gradio_i18n import Translate, gettext as _
|
||||
|
||||
AUTOMATIC_DETECTION = _("Automatic Detection")
|
||||
GRADIO_NONE_STR = ""
|
||||
GRADIO_NONE_NUMBER_MAX = 9999
|
||||
GRADIO_NONE_NUMBER_MIN = 0
|
||||
|
||||
@@ -67,3 +67,9 @@ def is_video(file_path):
|
||||
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
|
||||
extension = os.path.splitext(file_path)[1].lower()
|
||||
return extension in video_extensions
|
||||
|
||||
|
||||
def read_file(file_path):
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
subtitle_content = f.read()
|
||||
return subtitle_content
|
||||
|
||||
@@ -1,121 +1,425 @@
|
||||
# Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
|
||||
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import zlib
|
||||
from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
|
||||
from datetime import datetime
|
||||
|
||||
from modules.whisper.data_classes import Segment, Word
|
||||
from .files_manager import read_file
|
||||
|
||||
|
||||
def timeformat_srt(time):
|
||||
hours = time // 3600
|
||||
minutes = (time - hours * 3600) // 60
|
||||
seconds = time - hours * 3600 - minutes * 60
|
||||
milliseconds = (time - int(time)) * 1000
|
||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
||||
def format_timestamp(
|
||||
seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
|
||||
) -> str:
|
||||
assert seconds >= 0, "non-negative timestamp expected"
|
||||
milliseconds = round(seconds * 1000.0)
|
||||
|
||||
hours = milliseconds // 3_600_000
|
||||
milliseconds -= hours * 3_600_000
|
||||
|
||||
minutes = milliseconds // 60_000
|
||||
milliseconds -= minutes * 60_000
|
||||
|
||||
seconds = milliseconds // 1_000
|
||||
milliseconds -= seconds * 1_000
|
||||
|
||||
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||
return (
|
||||
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||
)
|
||||
|
||||
|
||||
def timeformat_vtt(time):
|
||||
hours = time // 3600
|
||||
minutes = (time - hours * 3600) // 60
|
||||
seconds = time - hours * 3600 - minutes * 60
|
||||
milliseconds = (time - int(time)) * 1000
|
||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
||||
def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
|
||||
times = time_str.split(":")
|
||||
|
||||
if len(times) == 3:
|
||||
hours, minutes, rest = times
|
||||
hours = int(hours)
|
||||
else:
|
||||
hours = 0
|
||||
minutes, rest = times
|
||||
|
||||
seconds, fractional = rest.split(decimal_marker)
|
||||
|
||||
minutes = int(minutes)
|
||||
seconds = int(seconds)
|
||||
fractional_seconds = float("0." + fractional)
|
||||
|
||||
return hours * 3600 + minutes * 60 + seconds + fractional_seconds
|
||||
|
||||
|
||||
def write_file(subtitle, output_file):
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
f.write(subtitle)
|
||||
def get_start(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["start"] for s in segments for w in s["words"]),
|
||||
segments[0]["start"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
def get_srt(segments):
|
||||
output = ""
|
||||
for i, segment in enumerate(segments):
|
||||
output += f"{i + 1}\n"
|
||||
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
|
||||
if segment['text'].startswith(' '):
|
||||
segment['text'] = segment['text'][1:]
|
||||
output += f"{segment['text']}\n\n"
|
||||
return output
|
||||
def get_end(segments: List[dict]) -> Optional[float]:
|
||||
return next(
|
||||
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||
segments[-1]["end"] if segments else None,
|
||||
)
|
||||
|
||||
|
||||
def get_vtt(segments):
|
||||
output = "WebVTT\n\n"
|
||||
for i, segment in enumerate(segments):
|
||||
output += f"{i + 1}\n"
|
||||
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
||||
if segment['text'].startswith(' '):
|
||||
segment['text'] = segment['text'][1:]
|
||||
output += f"{segment['text']}\n\n"
|
||||
return output
|
||||
class ResultWriter:
|
||||
extension: str
|
||||
|
||||
def __init__(self, output_dir: str):
|
||||
self.output_dir = output_dir
|
||||
|
||||
def __call__(
|
||||
self, result: Union[dict, List[Segment]], output_file_name: str,
|
||||
options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
if isinstance(result, List) and result and isinstance(result[0], Segment):
|
||||
result = {"segments": [seg.model_dump() for seg in result]}
|
||||
|
||||
output_path = os.path.join(
|
||||
self.output_dir, output_file_name + "." + self.extension
|
||||
)
|
||||
|
||||
with open(output_path, "w", encoding="utf-8") as f:
|
||||
self.write_result(result, file=f, options=options, **kwargs)
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def get_txt(segments):
|
||||
output = ""
|
||||
for i, segment in enumerate(segments):
|
||||
if segment['text'].startswith(' '):
|
||||
segment['text'] = segment['text'][1:]
|
||||
output += f"{segment['text']}\n"
|
||||
return output
|
||||
class WriteTXT(ResultWriter):
|
||||
extension: str = "txt"
|
||||
|
||||
def write_result(
|
||||
self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for segment in result["segments"]:
|
||||
print(segment["text"].strip(), file=file, flush=True)
|
||||
|
||||
|
||||
def parse_srt(file_path):
|
||||
"""Reads SRT file and returns as dict"""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
srt_data = file.read()
|
||||
class SubtitlesWriter(ResultWriter):
|
||||
always_include_hours: bool
|
||||
decimal_marker: str
|
||||
|
||||
data = []
|
||||
blocks = srt_data.split('\n\n')
|
||||
def iterate_result(
|
||||
self,
|
||||
result: dict,
|
||||
options: Optional[dict] = None,
|
||||
*,
|
||||
max_line_width: Optional[int] = None,
|
||||
max_line_count: Optional[int] = None,
|
||||
highlight_words: bool = False,
|
||||
align_lrc_words: bool = False,
|
||||
max_words_per_line: Optional[int] = None,
|
||||
):
|
||||
options = options or {}
|
||||
max_line_width = max_line_width or options.get("max_line_width")
|
||||
max_line_count = max_line_count or options.get("max_line_count")
|
||||
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||
align_lrc_words = align_lrc_words or options.get("align_lrc_words", False)
|
||||
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||
preserve_segments = max_line_count is None or max_line_width is None
|
||||
max_line_width = max_line_width or 1000
|
||||
max_words_per_line = max_words_per_line or 1000
|
||||
|
||||
for block in blocks:
|
||||
if block.strip() != '':
|
||||
lines = block.strip().split('\n')
|
||||
index = lines[0]
|
||||
timestamp = lines[1]
|
||||
sentence = ' '.join(lines[2:])
|
||||
def iterate_subtitles():
|
||||
line_len = 0
|
||||
line_count = 1
|
||||
# the next subtitle to yield (a list of word timings with whitespace)
|
||||
subtitle: List[dict] = []
|
||||
last: float = get_start(result["segments"]) or 0.0
|
||||
for segment in result["segments"]:
|
||||
chunk_index = 0
|
||||
words_count = max_words_per_line
|
||||
while chunk_index < len(segment["words"]):
|
||||
remaining_words = len(segment["words"]) - chunk_index
|
||||
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||
words_count = remaining_words
|
||||
for i, original_timing in enumerate(
|
||||
segment["words"][chunk_index : chunk_index + words_count]
|
||||
):
|
||||
timing = original_timing.copy()
|
||||
long_pause = (
|
||||
not preserve_segments and timing["start"] - last > 3.0
|
||||
)
|
||||
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||
if (
|
||||
line_len > 0
|
||||
and has_room
|
||||
and not long_pause
|
||||
and not seg_break
|
||||
):
|
||||
# line continuation
|
||||
line_len += len(timing["word"])
|
||||
else:
|
||||
# new line
|
||||
timing["word"] = timing["word"].strip()
|
||||
if (
|
||||
len(subtitle) > 0
|
||||
and max_line_count is not None
|
||||
and (long_pause or line_count >= max_line_count)
|
||||
or seg_break
|
||||
):
|
||||
# subtitle break
|
||||
yield subtitle
|
||||
subtitle = []
|
||||
line_count = 1
|
||||
elif line_len > 0:
|
||||
# line break
|
||||
line_count += 1
|
||||
timing["word"] = "\n" + timing["word"]
|
||||
line_len = len(timing["word"].strip())
|
||||
subtitle.append(timing)
|
||||
last = timing["start"]
|
||||
chunk_index += max_words_per_line
|
||||
if len(subtitle) > 0:
|
||||
yield subtitle
|
||||
|
||||
data.append({
|
||||
"index": index,
|
||||
"timestamp": timestamp,
|
||||
"sentence": sentence
|
||||
})
|
||||
return data
|
||||
if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]:
|
||||
for subtitle in iterate_subtitles():
|
||||
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||
if highlight_words:
|
||||
last = subtitle_start
|
||||
all_words = [timing["word"] for timing in subtitle]
|
||||
for i, this_word in enumerate(subtitle):
|
||||
start = self.format_timestamp(this_word["start"])
|
||||
end = self.format_timestamp(this_word["end"])
|
||||
if last != start:
|
||||
yield last, start, subtitle_text
|
||||
|
||||
yield start, end, "".join(
|
||||
[
|
||||
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||
if j == i
|
||||
else word
|
||||
for j, word in enumerate(all_words)
|
||||
]
|
||||
)
|
||||
last = end
|
||||
|
||||
if align_lrc_words:
|
||||
lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle]
|
||||
l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end'])
|
||||
lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]"
|
||||
lrc_aligned_words = ' '.join(lrc_aligned_words)
|
||||
yield None, None, lrc_aligned_words
|
||||
|
||||
else:
|
||||
yield subtitle_start, subtitle_end, subtitle_text
|
||||
else:
|
||||
for segment in result["segments"]:
|
||||
segment_start = self.format_timestamp(segment["start"])
|
||||
segment_end = self.format_timestamp(segment["end"])
|
||||
segment_text = segment["text"].strip().replace("-->", "->")
|
||||
yield segment_start, segment_end, segment_text
|
||||
|
||||
def format_timestamp(self, seconds: float):
|
||||
return format_timestamp(
|
||||
seconds=seconds,
|
||||
always_include_hours=self.always_include_hours,
|
||||
decimal_marker=self.decimal_marker,
|
||||
)
|
||||
|
||||
|
||||
def parse_vtt(file_path):
|
||||
"""Reads WebVTT file and returns as dict"""
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
webvtt_data = file.read()
|
||||
class WriteVTT(SubtitlesWriter):
|
||||
extension: str = "vtt"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
data = []
|
||||
blocks = webvtt_data.split('\n\n')
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("WEBVTT\n", file=file)
|
||||
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
for block in blocks:
|
||||
if block.strip() != '' and not block.strip().startswith("WebVTT"):
|
||||
lines = block.strip().split('\n')
|
||||
index = lines[0]
|
||||
timestamp = lines[1]
|
||||
sentence = ' '.join(lines[2:])
|
||||
def to_segments(self, file_path: str) -> List[Segment]:
|
||||
segments = []
|
||||
|
||||
data.append({
|
||||
"index": index,
|
||||
"timestamp": timestamp,
|
||||
"sentence": sentence
|
||||
})
|
||||
blocks = read_file(file_path).split('\n\n')
|
||||
|
||||
return data
|
||||
for block in blocks:
|
||||
if block.strip() != '' and not block.strip().startswith("WEBVTT"):
|
||||
lines = block.strip().split('\n')
|
||||
time_line = lines[0].split(" --> ")
|
||||
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
|
||||
sentence = ' '.join(lines[1:])
|
||||
|
||||
segments.append(Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=sentence
|
||||
))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def get_serialized_srt(dicts):
|
||||
output = ""
|
||||
for dic in dicts:
|
||||
output += f'{dic["index"]}\n'
|
||||
output += f'{dic["timestamp"]}\n'
|
||||
output += f'{dic["sentence"]}\n\n'
|
||||
return output
|
||||
class WriteSRT(SubtitlesWriter):
|
||||
extension: str = "srt"
|
||||
always_include_hours: bool = True
|
||||
decimal_marker: str = ","
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||
|
||||
def to_segments(self, file_path: str) -> List[Segment]:
|
||||
segments = []
|
||||
|
||||
blocks = read_file(file_path).split('\n\n')
|
||||
|
||||
for block in blocks:
|
||||
if block.strip() != '':
|
||||
lines = block.strip().split('\n')
|
||||
index = lines[0]
|
||||
time_line = lines[1].split(" --> ")
|
||||
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
|
||||
sentence = ' '.join(lines[2:])
|
||||
|
||||
segments.append(Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=sentence
|
||||
))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
def get_serialized_vtt(dicts):
|
||||
output = "WebVTT\n\n"
|
||||
for dic in dicts:
|
||||
output += f'{dic["index"]}\n'
|
||||
output += f'{dic["timestamp"]}\n'
|
||||
output += f'{dic["sentence"]}\n\n'
|
||||
return output
|
||||
class WriteLRC(SubtitlesWriter):
|
||||
extension: str = "lrc"
|
||||
always_include_hours: bool = False
|
||||
decimal_marker: str = "."
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for i, (start, end, text) in enumerate(
|
||||
self.iterate_result(result, options, **kwargs), start=1
|
||||
):
|
||||
if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]:
|
||||
print(f"{text}\n", file=file, flush=True)
|
||||
else:
|
||||
print(f"[{start}]{text}[{end}]\n", file=file, flush=True)
|
||||
|
||||
def to_segments(self, file_path: str) -> List[Segment]:
|
||||
segments = []
|
||||
|
||||
blocks = read_file(file_path).split('\n')
|
||||
|
||||
for block in blocks:
|
||||
if block.strip() != '':
|
||||
lines = block.strip()
|
||||
pattern = r'(\[.*?\])'
|
||||
parts = re.split(pattern, lines)
|
||||
parts = [part.strip() for part in parts if part]
|
||||
|
||||
for i, part in enumerate(parts):
|
||||
sentence_i = i%2
|
||||
if sentence_i == 1:
|
||||
start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1]
|
||||
start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "")
|
||||
start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker)
|
||||
|
||||
segments.append(Segment(
|
||||
start=start,
|
||||
end=end,
|
||||
text=text,
|
||||
))
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
class WriteTSV(ResultWriter):
|
||||
"""
|
||||
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||
|
||||
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||
an environment setting a language encoding that causes the decimal in a floating point number
|
||||
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||
"""
|
||||
|
||||
extension: str = "tsv"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
print("start", "end", "text", sep="\t", file=file)
|
||||
for segment in result["segments"]:
|
||||
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||
|
||||
|
||||
class WriteJSON(ResultWriter):
|
||||
extension: str = "json"
|
||||
|
||||
def write_result(
|
||||
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
json.dump(result, file)
|
||||
|
||||
|
||||
def get_writer(
|
||||
output_format: str, output_dir: str
|
||||
) -> Callable[[dict, TextIO, dict], None]:
|
||||
output_format = output_format.strip().lower().replace(".", "")
|
||||
|
||||
writers = {
|
||||
"txt": WriteTXT,
|
||||
"vtt": WriteVTT,
|
||||
"srt": WriteSRT,
|
||||
"tsv": WriteTSV,
|
||||
"json": WriteJSON,
|
||||
"lrc": WriteLRC
|
||||
}
|
||||
|
||||
if output_format == "all":
|
||||
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||
|
||||
def write_all(
|
||||
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||
):
|
||||
for writer in all_writers:
|
||||
writer(result, file, options, **kwargs)
|
||||
|
||||
return write_all
|
||||
|
||||
return writers[output_format](output_dir)
|
||||
|
||||
|
||||
def generate_file(
|
||||
output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str,
|
||||
add_timestamp: bool = True, **kwargs
|
||||
) -> Tuple[str, str]:
|
||||
output_format = output_format.strip().lower().replace(".", "")
|
||||
output_format = "vtt" if output_format == "webvtt" else output_format
|
||||
|
||||
if add_timestamp:
|
||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
||||
output_file_name += f"-{timestamp}"
|
||||
|
||||
file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
|
||||
file_writer = get_writer(output_format=output_format, output_dir=output_dir)
|
||||
|
||||
if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False):
|
||||
kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True
|
||||
|
||||
file_writer(result=result, output_file_name=output_file_name, **kwargs)
|
||||
content = read_file(file_path)
|
||||
return content, file_path
|
||||
|
||||
|
||||
def safe_filename(name):
|
||||
|
||||
@@ -5,7 +5,8 @@ import numpy as np
|
||||
from typing import BinaryIO, Union, List, Optional, Tuple
|
||||
import warnings
|
||||
import faster_whisper
|
||||
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
|
||||
from modules.whisper.data_classes import *
|
||||
from faster_whisper.transcribe import SpeechTimestampsMap
|
||||
import gradio as gr
|
||||
|
||||
|
||||
@@ -247,18 +248,18 @@ class SileroVAD:
|
||||
|
||||
def restore_speech_timestamps(
|
||||
self,
|
||||
segments: List[dict],
|
||||
segments: List[Segment],
|
||||
speech_chunks: List[dict],
|
||||
sampling_rate: Optional[int] = None,
|
||||
) -> List[dict]:
|
||||
) -> List[Segment]:
|
||||
if sampling_rate is None:
|
||||
sampling_rate = self.sampling_rate
|
||||
|
||||
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
||||
|
||||
for segment in segments:
|
||||
segment["start"] = ts_map.get_original_time(segment["start"])
|
||||
segment["end"] = ts_map.get_original_time(segment["end"])
|
||||
segment.start = ts_map.get_original_time(segment.start)
|
||||
segment.end = ts_map.get_original_time(segment.end)
|
||||
|
||||
return segments
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import os
|
||||
import torch
|
||||
import whisper
|
||||
import ctranslate2
|
||||
import gradio as gr
|
||||
@@ -9,21 +8,20 @@ from typing import BinaryIO, Union, Tuple, List
|
||||
import numpy as np
|
||||
from datetime import datetime
|
||||
from faster_whisper.vad import VadOptions
|
||||
from dataclasses import astuple
|
||||
|
||||
from modules.uvr.music_separator import MusicSeparator
|
||||
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
||||
UVR_MODELS_DIR)
|
||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
||||
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
||||
from modules.utils.constants import *
|
||||
from modules.utils.subtitle_manager import *
|
||||
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
||||
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.diarize.diarizer import Diarizer
|
||||
from modules.vad.silero_vad import SileroVAD
|
||||
|
||||
|
||||
class WhisperBase(ABC):
|
||||
class BaseTranscriptionPipeline(ABC):
|
||||
def __init__(self,
|
||||
model_dir: str = WHISPER_MODELS_DIR,
|
||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||
@@ -73,13 +71,15 @@ class WhisperBase(ABC):
|
||||
def run(self,
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
*whisper_params,
|
||||
) -> Tuple[List[dict], float]:
|
||||
*pipeline_params,
|
||||
) -> Tuple[List[Segment], float]:
|
||||
"""
|
||||
Run transcription with conditional pre-processing and post-processing.
|
||||
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
||||
The diarization will be performed in post-processing, if enabled.
|
||||
Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -87,40 +87,33 @@ class WhisperBase(ABC):
|
||||
Audio input. This can be file path or binary type.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
file_format: str
|
||||
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
||||
add_timestamp: bool
|
||||
Whether to add a timestamp at the end of the filename.
|
||||
*whisper_params: tuple
|
||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||
*pipeline_params: tuple
|
||||
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class.
|
||||
This must be provided as a List with * wildcard because of the integration with gradio.
|
||||
See more info at : https://github.com/gradio-app/gradio/issues/2471
|
||||
|
||||
Returns
|
||||
----------
|
||||
segments_result: List[dict]
|
||||
list of dicts that includes start, end timestamps and transcribed text
|
||||
segments_result: List[Segment]
|
||||
list of Segment that includes start, end timestamps and transcribed text
|
||||
elapsed_time: float
|
||||
elapsed time for running
|
||||
"""
|
||||
params = WhisperParameters.as_value(*whisper_params)
|
||||
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||
params = self.validate_gradio_values(params)
|
||||
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
|
||||
|
||||
self.cache_parameters(
|
||||
whisper_params=params,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
|
||||
if params.lang is None:
|
||||
pass
|
||||
elif params.lang == AUTOMATIC_DETECTION:
|
||||
params.lang = None
|
||||
else:
|
||||
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
||||
params.lang = language_code_dict[params.lang]
|
||||
|
||||
if params.is_bgm_separate:
|
||||
if bgm_params.is_separate_bgm:
|
||||
music, audio, _ = self.music_separator.separate(
|
||||
audio=audio,
|
||||
model_name=params.uvr_model_size,
|
||||
device=params.uvr_device,
|
||||
segment_size=params.uvr_segment_size,
|
||||
save_file=params.uvr_save_file,
|
||||
model_name=bgm_params.model_size,
|
||||
device=bgm_params.device,
|
||||
segment_size=bgm_params.segment_size,
|
||||
save_file=bgm_params.save_file,
|
||||
progress=progress
|
||||
)
|
||||
|
||||
@@ -132,47 +125,55 @@ class WhisperBase(ABC):
|
||||
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
||||
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
||||
|
||||
if params.uvr_enable_offload:
|
||||
if bgm_params.enable_offload:
|
||||
self.music_separator.offload()
|
||||
|
||||
if params.vad_filter:
|
||||
# Explicit value set for float('inf') from gr.Number()
|
||||
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
|
||||
params.max_speech_duration_s = float('inf')
|
||||
|
||||
if vad_params.vad_filter:
|
||||
vad_options = VadOptions(
|
||||
threshold=params.threshold,
|
||||
min_speech_duration_ms=params.min_speech_duration_ms,
|
||||
max_speech_duration_s=params.max_speech_duration_s,
|
||||
min_silence_duration_ms=params.min_silence_duration_ms,
|
||||
speech_pad_ms=params.speech_pad_ms
|
||||
threshold=vad_params.threshold,
|
||||
min_speech_duration_ms=vad_params.min_speech_duration_ms,
|
||||
max_speech_duration_s=vad_params.max_speech_duration_s,
|
||||
min_silence_duration_ms=vad_params.min_silence_duration_ms,
|
||||
speech_pad_ms=vad_params.speech_pad_ms
|
||||
)
|
||||
|
||||
audio, speech_chunks = self.vad.run(
|
||||
vad_processed, speech_chunks = self.vad.run(
|
||||
audio=audio,
|
||||
vad_parameters=vad_options,
|
||||
progress=progress
|
||||
)
|
||||
|
||||
if vad_processed.size > 0:
|
||||
audio = vad_processed
|
||||
else:
|
||||
vad_params.vad_filter = False
|
||||
|
||||
result, elapsed_time = self.transcribe(
|
||||
audio,
|
||||
progress,
|
||||
*astuple(params)
|
||||
*whisper_params.to_list()
|
||||
)
|
||||
|
||||
if params.vad_filter:
|
||||
if vad_params.vad_filter:
|
||||
result = self.vad.restore_speech_timestamps(
|
||||
segments=result,
|
||||
speech_chunks=speech_chunks,
|
||||
)
|
||||
|
||||
if params.is_diarize:
|
||||
if diarization_params.is_diarize:
|
||||
result, elapsed_time_diarization = self.diarizer.run(
|
||||
audio=audio,
|
||||
use_auth_token=params.hf_token,
|
||||
use_auth_token=diarization_params.hf_token,
|
||||
transcribed_result=result,
|
||||
device=diarization_params.device
|
||||
)
|
||||
elapsed_time += elapsed_time_diarization
|
||||
|
||||
self.cache_parameters(
|
||||
params=params,
|
||||
file_format=file_format,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
return result, elapsed_time
|
||||
|
||||
def transcribe_file(self,
|
||||
@@ -181,8 +182,8 @@ class WhisperBase(ABC):
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
progress=gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> list:
|
||||
*pipeline_params,
|
||||
) -> Tuple[str, List]:
|
||||
"""
|
||||
Write subtitle file from Files
|
||||
|
||||
@@ -199,8 +200,8 @@ class WhisperBase(ABC):
|
||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
*whisper_params: tuple
|
||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||
*pipeline_params: tuple
|
||||
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
|
||||
|
||||
Returns
|
||||
----------
|
||||
@@ -210,6 +211,11 @@ class WhisperBase(ABC):
|
||||
Output file path to return to gr.Files()
|
||||
"""
|
||||
try:
|
||||
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||
writer_options = {
|
||||
"highlight_words": True if params.whisper.word_timestamps else False
|
||||
}
|
||||
|
||||
if input_folder_path:
|
||||
files = get_media_files(input_folder_path)
|
||||
if isinstance(files, str):
|
||||
@@ -222,19 +228,21 @@ class WhisperBase(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
file,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*whisper_params,
|
||||
*pipeline_params,
|
||||
)
|
||||
|
||||
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
||||
subtitle, file_path = self.generate_and_write_file(
|
||||
file_name=file_name,
|
||||
transcribed_segments=transcribed_segments,
|
||||
subtitle, file_path = generate_file(
|
||||
output_dir=self.output_dir,
|
||||
output_file_name=file_name,
|
||||
output_format=file_format,
|
||||
result=transcribed_segments,
|
||||
add_timestamp=add_timestamp,
|
||||
file_format=file_format,
|
||||
output_dir=self.output_dir
|
||||
**writer_options
|
||||
)
|
||||
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
||||
files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
|
||||
|
||||
total_result = ''
|
||||
total_time = 0
|
||||
@@ -247,10 +255,11 @@ class WhisperBase(ABC):
|
||||
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
||||
result_file_path = [info['path'] for info in files_info.values()]
|
||||
|
||||
return [result_str, result_file_path]
|
||||
return result_str, result_file_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error transcribing file: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.release_cuda_memory()
|
||||
|
||||
@@ -259,8 +268,8 @@ class WhisperBase(ABC):
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
progress=gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> list:
|
||||
*pipeline_params,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Write subtitle file from microphone
|
||||
|
||||
@@ -274,7 +283,7 @@ class WhisperBase(ABC):
|
||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
*whisper_params: tuple
|
||||
*pipeline_params: tuple
|
||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||
|
||||
Returns
|
||||
@@ -285,27 +294,36 @@ class WhisperBase(ABC):
|
||||
Output file path to return to gr.Files()
|
||||
"""
|
||||
try:
|
||||
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||
writer_options = {
|
||||
"highlight_words": True if params.whisper.word_timestamps else False
|
||||
}
|
||||
|
||||
progress(0, desc="Loading Audio..")
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
mic_audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*whisper_params,
|
||||
*pipeline_params,
|
||||
)
|
||||
progress(1, desc="Completed!")
|
||||
|
||||
subtitle, result_file_path = self.generate_and_write_file(
|
||||
file_name="Mic",
|
||||
transcribed_segments=transcribed_segments,
|
||||
file_name = "Mic"
|
||||
subtitle, file_path = generate_file(
|
||||
output_dir=self.output_dir,
|
||||
output_file_name=file_name,
|
||||
output_format=file_format,
|
||||
result=transcribed_segments,
|
||||
add_timestamp=add_timestamp,
|
||||
file_format=file_format,
|
||||
output_dir=self.output_dir
|
||||
**writer_options
|
||||
)
|
||||
|
||||
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
||||
return [result_str, result_file_path]
|
||||
return result_str, file_path
|
||||
except Exception as e:
|
||||
print(f"Error transcribing file: {e}")
|
||||
print(f"Error transcribing mic: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.release_cuda_memory()
|
||||
|
||||
@@ -314,8 +332,8 @@ class WhisperBase(ABC):
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
progress=gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> list:
|
||||
*pipeline_params,
|
||||
) -> Tuple[str, str]:
|
||||
"""
|
||||
Write subtitle file from Youtube
|
||||
|
||||
@@ -329,7 +347,7 @@ class WhisperBase(ABC):
|
||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
*whisper_params: tuple
|
||||
*pipeline_params: tuple
|
||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||
|
||||
Returns
|
||||
@@ -340,6 +358,11 @@ class WhisperBase(ABC):
|
||||
Output file path to return to gr.Files()
|
||||
"""
|
||||
try:
|
||||
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||
writer_options = {
|
||||
"highlight_words": True if params.whisper.word_timestamps else False
|
||||
}
|
||||
|
||||
progress(0, desc="Loading Audio from Youtube..")
|
||||
yt = get_ytdata(youtube_link)
|
||||
audio = get_ytaudio(yt)
|
||||
@@ -347,29 +370,33 @@ class WhisperBase(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*whisper_params,
|
||||
*pipeline_params,
|
||||
)
|
||||
|
||||
progress(1, desc="Completed!")
|
||||
|
||||
file_name = safe_filename(yt.title)
|
||||
subtitle, result_file_path = self.generate_and_write_file(
|
||||
file_name=file_name,
|
||||
transcribed_segments=transcribed_segments,
|
||||
subtitle, file_path = generate_file(
|
||||
output_dir=self.output_dir,
|
||||
output_file_name=file_name,
|
||||
output_format=file_format,
|
||||
result=transcribed_segments,
|
||||
add_timestamp=add_timestamp,
|
||||
file_format=file_format,
|
||||
output_dir=self.output_dir
|
||||
**writer_options
|
||||
)
|
||||
|
||||
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
||||
|
||||
if os.path.exists(audio):
|
||||
os.remove(audio)
|
||||
|
||||
return [result_str, result_file_path]
|
||||
return result_str, file_path
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error transcribing file: {e}")
|
||||
print(f"Error transcribing youtube: {e}")
|
||||
raise
|
||||
finally:
|
||||
self.release_cuda_memory()
|
||||
|
||||
@@ -387,58 +414,6 @@ class WhisperBase(ABC):
|
||||
else:
|
||||
return list(ctranslate2.get_supported_compute_types("cpu"))
|
||||
|
||||
@staticmethod
|
||||
def generate_and_write_file(file_name: str,
|
||||
transcribed_segments: list,
|
||||
add_timestamp: bool,
|
||||
file_format: str,
|
||||
output_dir: str
|
||||
) -> str:
|
||||
"""
|
||||
Writes subtitle file
|
||||
|
||||
Parameters
|
||||
----------
|
||||
file_name: str
|
||||
Output file name
|
||||
transcribed_segments: list
|
||||
Text segments transcribed from audio
|
||||
add_timestamp: bool
|
||||
Determines whether to add a timestamp to the end of the filename.
|
||||
file_format: str
|
||||
File format to write. Supported formats: [SRT, WebVTT, txt]
|
||||
output_dir: str
|
||||
Directory path of the output
|
||||
|
||||
Returns
|
||||
----------
|
||||
content: str
|
||||
Result of the transcription
|
||||
output_path: str
|
||||
output file path
|
||||
"""
|
||||
if add_timestamp:
|
||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
||||
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
||||
else:
|
||||
output_path = os.path.join(output_dir, f"{file_name}")
|
||||
|
||||
file_format = file_format.strip().lower()
|
||||
if file_format == "srt":
|
||||
content = get_srt(transcribed_segments)
|
||||
output_path += '.srt'
|
||||
|
||||
elif file_format == "webvtt":
|
||||
content = get_vtt(transcribed_segments)
|
||||
output_path += '.vtt'
|
||||
|
||||
elif file_format == "txt":
|
||||
content = get_txt(transcribed_segments)
|
||||
output_path += '.txt'
|
||||
|
||||
write_file(content, output_path)
|
||||
return content, output_path
|
||||
|
||||
@staticmethod
|
||||
def format_time(elapsed_time: float) -> str:
|
||||
"""
|
||||
@@ -471,7 +446,7 @@ class WhisperBase(ABC):
|
||||
if torch.cuda.is_available():
|
||||
return "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
if not WhisperBase.is_sparse_api_supported():
|
||||
if not BaseTranscriptionPipeline.is_sparse_api_supported():
|
||||
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
||||
return "cpu"
|
||||
return "mps"
|
||||
@@ -513,17 +488,64 @@ class WhisperBase(ABC):
|
||||
os.remove(file_path)
|
||||
|
||||
@staticmethod
|
||||
def cache_parameters(
|
||||
whisper_params: WhisperValues,
|
||||
add_timestamp: bool
|
||||
):
|
||||
"""cache parameters to the yaml file"""
|
||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
cached_whisper_param = whisper_params.to_yaml()
|
||||
cached_yaml = {**cached_params, **cached_whisper_param}
|
||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||
def validate_gradio_values(params: TranscriptionPipelineParams):
|
||||
"""
|
||||
Validate gradio specific values that can't be displayed as None in the UI.
|
||||
Related issue : https://github.com/gradio-app/gradio/issues/8723
|
||||
"""
|
||||
if params.whisper.lang is None:
|
||||
pass
|
||||
elif params.whisper.lang == AUTOMATIC_DETECTION:
|
||||
params.whisper.lang = None
|
||||
else:
|
||||
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
||||
params.whisper.lang = language_code_dict[params.whisper.lang]
|
||||
|
||||
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
if params.whisper.initial_prompt == GRADIO_NONE_STR:
|
||||
params.whisper.initial_prompt = None
|
||||
if params.whisper.prefix == GRADIO_NONE_STR:
|
||||
params.whisper.prefix = None
|
||||
if params.whisper.hotwords == GRADIO_NONE_STR:
|
||||
params.whisper.hotwords = None
|
||||
if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
|
||||
params.whisper.max_new_tokens = None
|
||||
if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
|
||||
params.whisper.hallucination_silence_threshold = None
|
||||
if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
|
||||
params.whisper.language_detection_threshold = None
|
||||
if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
|
||||
params.vad.max_speech_duration_s = float('inf')
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def cache_parameters(
|
||||
params: TranscriptionPipelineParams,
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True
|
||||
):
|
||||
"""Cache parameters to the yaml file"""
|
||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
param_to_cache = params.to_dict()
|
||||
|
||||
cached_yaml = {**cached_params, **param_to_cache}
|
||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||
cached_yaml["whisper"]["file_format"] = file_format
|
||||
|
||||
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
||||
if supress_token and isinstance(supress_token, list):
|
||||
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
|
||||
|
||||
if cached_yaml["whisper"].get("lang", None) is None:
|
||||
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
|
||||
else:
|
||||
language_dict = whisper.tokenizer.LANGUAGES
|
||||
cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
|
||||
|
||||
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
|
||||
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
|
||||
|
||||
if cached_yaml is not None and cached_yaml:
|
||||
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
|
||||
@staticmethod
|
||||
def resample_audio(audio: Union[str, np.ndarray],
|
||||
608
modules/whisper/data_classes.py
Normal file
608
modules/whisper/data_classes.py
Normal file
@@ -0,0 +1,608 @@
|
||||
import faster_whisper.transcribe
|
||||
import gradio as gr
|
||||
import torch
|
||||
from typing import Optional, Dict, List, Union, NamedTuple
|
||||
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||
from gradio_i18n import Translate, gettext as _
|
||||
from enum import Enum
|
||||
from copy import deepcopy
|
||||
|
||||
import yaml
|
||||
|
||||
from modules.utils.constants import *
|
||||
|
||||
|
||||
class WhisperImpl(Enum):
|
||||
WHISPER = "whisper"
|
||||
FASTER_WHISPER = "faster-whisper"
|
||||
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
id: Optional[int] = Field(default=None, description="Incremental id for the segment")
|
||||
seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio")
|
||||
text: Optional[str] = Field(default=None, description="Transcription text of the segment")
|
||||
start: Optional[float] = Field(default=None, description="Start time of the segment")
|
||||
end: Optional[float] = Field(default=None, description="End time of the segment")
|
||||
tokens: Optional[List[int]] = Field(default=None, description="List of token IDs")
|
||||
temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process")
|
||||
avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens")
|
||||
compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment")
|
||||
no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech")
|
||||
words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment")
|
||||
|
||||
@classmethod
|
||||
def from_faster_whisper(cls,
|
||||
seg: faster_whisper.transcribe.Segment):
|
||||
if seg.words is not None:
|
||||
words = [
|
||||
Word(
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
word=w.word,
|
||||
probability=w.probability
|
||||
) for w in seg.words
|
||||
]
|
||||
else:
|
||||
words = None
|
||||
|
||||
return cls(
|
||||
id=seg.id,
|
||||
seek=seg.seek,
|
||||
text=seg.text,
|
||||
start=seg.start,
|
||||
end=seg.end,
|
||||
tokens=seg.tokens,
|
||||
temperature=seg.temperature,
|
||||
avg_logprob=seg.avg_logprob,
|
||||
compression_ratio=seg.compression_ratio,
|
||||
no_speech_prob=seg.no_speech_prob,
|
||||
words=words
|
||||
)
|
||||
|
||||
|
||||
class Word(BaseModel):
|
||||
start: Optional[float] = Field(default=None, description="Start time of the word")
|
||||
end: Optional[float] = Field(default=None, description="Start time of the word")
|
||||
word: Optional[str] = Field(default=None, description="Word text")
|
||||
probability: Optional[float] = Field(default=None, description="Probability of the word")
|
||||
|
||||
|
||||
class BaseParams(BaseModel):
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
return self.model_dump()
|
||||
|
||||
def to_list(self) -> List:
|
||||
return list(self.model_dump().values())
|
||||
|
||||
@classmethod
|
||||
def from_list(cls, data_list: List) -> 'BaseParams':
|
||||
field_names = list(cls.model_fields.keys())
|
||||
return cls(**dict(zip(field_names, data_list)))
|
||||
|
||||
|
||||
class VadParams(BaseParams):
|
||||
"""Voice Activity Detection parameters"""
|
||||
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
|
||||
threshold: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
|
||||
)
|
||||
min_speech_duration_ms: int = Field(
|
||||
default=250,
|
||||
ge=0,
|
||||
description="Final speech chunks shorter than this are discarded"
|
||||
)
|
||||
max_speech_duration_s: float = Field(
|
||||
default=float("inf"),
|
||||
gt=0,
|
||||
description="Maximum duration of speech chunks in seconds"
|
||||
)
|
||||
min_silence_duration_ms: int = Field(
|
||||
default=2000,
|
||||
ge=0,
|
||||
description="Minimum silence duration between speech chunks"
|
||||
)
|
||||
speech_pad_ms: int = Field(
|
||||
default=400,
|
||||
ge=0,
|
||||
description="Padding added to each side of speech chunks"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
|
||||
return [
|
||||
gr.Checkbox(
|
||||
label=_("Enable Silero VAD Filter"),
|
||||
value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
|
||||
interactive=True,
|
||||
info=_("Enable this to transcribe only detected voice")
|
||||
),
|
||||
gr.Slider(
|
||||
minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
||||
value=defaults.get("threshold", cls.__fields__["threshold"].default),
|
||||
info="Lower it to be more sensitive to small sounds."
|
||||
),
|
||||
gr.Number(
|
||||
label="Minimum Speech Duration (ms)", precision=0,
|
||||
value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
|
||||
info="Final speech chunks shorter than this time are thrown out"
|
||||
),
|
||||
gr.Number(
|
||||
label="Maximum Speech Duration (s)",
|
||||
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
|
||||
info="Maximum duration of speech chunks in \"seconds\"."
|
||||
),
|
||||
gr.Number(
|
||||
label="Minimum Silence Duration (ms)", precision=0,
|
||||
value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
|
||||
info="In the end of each speech chunk wait for this time before separating it"
|
||||
),
|
||||
gr.Number(
|
||||
label="Speech Padding (ms)", precision=0,
|
||||
value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
|
||||
info="Final speech chunks are padded by this time each side"
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class DiarizationParams(BaseParams):
|
||||
"""Speaker diarization parameters"""
|
||||
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
|
||||
device: str = Field(default="cuda", description="Device to run Diarization model.")
|
||||
hf_token: str = Field(
|
||||
default="",
|
||||
description="Hugging Face token for downloading diarization models"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_gradio_inputs(cls,
|
||||
defaults: Optional[Dict] = None,
|
||||
available_devices: Optional[List] = None,
|
||||
device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
|
||||
return [
|
||||
gr.Checkbox(
|
||||
label=_("Enable Diarization"),
|
||||
value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
|
||||
),
|
||||
gr.Dropdown(
|
||||
label=_("Device"),
|
||||
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
||||
value=defaults.get("device", device),
|
||||
),
|
||||
gr.Textbox(
|
||||
label=_("HuggingFace Token"),
|
||||
value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
|
||||
info=_("This is only needed the first time you download the model")
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class BGMSeparationParams(BaseParams):
|
||||
"""Background music separation parameters"""
|
||||
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
|
||||
model_size: str = Field(
|
||||
default="UVR-MDX-NET-Inst_HQ_4",
|
||||
description="UVR model size"
|
||||
)
|
||||
device: str = Field(default="cuda", description="Device to run UVR model.")
|
||||
segment_size: int = Field(
|
||||
default=256,
|
||||
gt=0,
|
||||
description="Segment size for UVR model"
|
||||
)
|
||||
save_file: bool = Field(
|
||||
default=False,
|
||||
description="Whether to save separated audio files"
|
||||
)
|
||||
enable_offload: bool = Field(
|
||||
default=True,
|
||||
description="Offload UVR model after transcription"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def to_gradio_input(cls,
|
||||
defaults: Optional[Dict] = None,
|
||||
available_devices: Optional[List] = None,
|
||||
device: Optional[str] = None,
|
||||
available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
|
||||
return [
|
||||
gr.Checkbox(
|
||||
label=_("Enable Background Music Remover Filter"),
|
||||
value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
|
||||
interactive=True,
|
||||
info=_("Enabling this will remove background music")
|
||||
),
|
||||
gr.Dropdown(
|
||||
label=_("Model"),
|
||||
choices=["UVR-MDX-NET-Inst_HQ_4",
|
||||
"UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
|
||||
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
||||
),
|
||||
gr.Dropdown(
|
||||
label=_("Device"),
|
||||
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
||||
value=defaults.get("device", device),
|
||||
),
|
||||
gr.Number(
|
||||
label="Segment Size",
|
||||
value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
|
||||
precision=0,
|
||||
info="Segment size for UVR model"
|
||||
),
|
||||
gr.Checkbox(
|
||||
label=_("Save separated files to output"),
|
||||
value=defaults.get("save_file", cls.__fields__["save_file"].default),
|
||||
),
|
||||
gr.Checkbox(
|
||||
label=_("Offload sub model after removing background music"),
|
||||
value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class WhisperParams(BaseParams):
|
||||
"""Whisper parameters"""
|
||||
model_size: str = Field(default="large-v2", description="Whisper model size")
|
||||
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
|
||||
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
|
||||
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
|
||||
log_prob_threshold: float = Field(
|
||||
default=-1.0,
|
||||
description="Threshold for average log probability of sampled tokens"
|
||||
)
|
||||
no_speech_threshold: float = Field(
|
||||
default=0.6,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Threshold for detecting silence"
|
||||
)
|
||||
compute_type: str = Field(default="float16", description="Computation type for transcription")
|
||||
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
|
||||
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
|
||||
condition_on_previous_text: bool = Field(
|
||||
default=True,
|
||||
description="Use previous output as prompt for next window"
|
||||
)
|
||||
prompt_reset_on_temperature: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
description="Temperature threshold for resetting prompt"
|
||||
)
|
||||
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
|
||||
temperature: float = Field(
|
||||
default=0.0,
|
||||
ge=0.0,
|
||||
description="Temperature for sampling"
|
||||
)
|
||||
compression_ratio_threshold: float = Field(
|
||||
default=2.4,
|
||||
gt=0,
|
||||
description="Threshold for gzip compression ratio"
|
||||
)
|
||||
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
|
||||
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
|
||||
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
|
||||
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
|
||||
suppress_blank: bool = Field(
|
||||
default=True,
|
||||
description="Suppress blank outputs at start of sampling"
|
||||
)
|
||||
suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
|
||||
max_initial_timestamp: float = Field(
|
||||
default=1.0,
|
||||
ge=0.0,
|
||||
description="Maximum initial timestamp"
|
||||
)
|
||||
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
|
||||
prepend_punctuations: Optional[str] = Field(
|
||||
default="\"'“¿([{-",
|
||||
description="Punctuations to merge with next word"
|
||||
)
|
||||
append_punctuations: Optional[str] = Field(
|
||||
default="\"'.。,,!!??::”)]}、",
|
||||
description="Punctuations to merge with previous word"
|
||||
)
|
||||
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
|
||||
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
|
||||
hallucination_silence_threshold: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Threshold for skipping silent periods in hallucination detection"
|
||||
)
|
||||
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
|
||||
language_detection_threshold: Optional[float] = Field(
|
||||
default=None,
|
||||
description="Threshold for language detection probability"
|
||||
)
|
||||
language_detection_segments: int = Field(
|
||||
default=1,
|
||||
gt=0,
|
||||
description="Number of segments for language detection"
|
||||
)
|
||||
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
|
||||
|
||||
@field_validator('lang')
|
||||
def validate_lang(cls, v):
|
||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
||||
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
||||
|
||||
@field_validator('suppress_tokens')
|
||||
def validate_supress_tokens(cls, v):
|
||||
import ast
|
||||
try:
|
||||
if isinstance(v, str):
|
||||
suppress_tokens = ast.literal_eval(v)
|
||||
if not isinstance(suppress_tokens, list):
|
||||
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
||||
return suppress_tokens
|
||||
if isinstance(v, list):
|
||||
return v
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
|
||||
|
||||
@classmethod
|
||||
def to_gradio_inputs(cls,
|
||||
defaults: Optional[Dict] = None,
|
||||
only_advanced: Optional[bool] = True,
|
||||
whisper_type: Optional[str] = None,
|
||||
available_models: Optional[List] = None,
|
||||
available_langs: Optional[List] = None,
|
||||
available_compute_types: Optional[List] = None,
|
||||
compute_type: Optional[str] = None):
|
||||
whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
|
||||
|
||||
inputs = []
|
||||
if not only_advanced:
|
||||
inputs += [
|
||||
gr.Dropdown(
|
||||
label=_("Model"),
|
||||
choices=available_models,
|
||||
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
||||
),
|
||||
gr.Dropdown(
|
||||
label=_("Language"),
|
||||
choices=available_langs,
|
||||
value=defaults.get("lang", AUTOMATIC_DETECTION),
|
||||
),
|
||||
gr.Checkbox(
|
||||
label=_("Translate to English?"),
|
||||
value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
|
||||
),
|
||||
]
|
||||
|
||||
inputs += [
|
||||
gr.Number(
|
||||
label="Beam Size",
|
||||
value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
|
||||
precision=0,
|
||||
info="Beam size for decoding"
|
||||
),
|
||||
gr.Number(
|
||||
label="Log Probability Threshold",
|
||||
value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
|
||||
info="Threshold for average log probability of sampled tokens"
|
||||
),
|
||||
gr.Number(
|
||||
label="No Speech Threshold",
|
||||
value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
|
||||
info="Threshold for detecting silence"
|
||||
),
|
||||
gr.Dropdown(
|
||||
label="Compute Type",
|
||||
choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
|
||||
value=defaults.get("compute_type", compute_type),
|
||||
info="Computation type for transcription"
|
||||
),
|
||||
gr.Number(
|
||||
label="Best Of",
|
||||
value=defaults.get("best_of", cls.__fields__["best_of"].default),
|
||||
precision=0,
|
||||
info="Number of candidates when sampling"
|
||||
),
|
||||
gr.Number(
|
||||
label="Patience",
|
||||
value=defaults.get("patience", cls.__fields__["patience"].default),
|
||||
info="Beam search patience factor"
|
||||
),
|
||||
gr.Checkbox(
|
||||
label="Condition On Previous Text",
|
||||
value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
|
||||
info="Use previous output as prompt for next window"
|
||||
),
|
||||
gr.Slider(
|
||||
label="Prompt Reset On Temperature",
|
||||
value=defaults.get("prompt_reset_on_temperature",
|
||||
cls.__fields__["prompt_reset_on_temperature"].default),
|
||||
minimum=0,
|
||||
maximum=1,
|
||||
step=0.01,
|
||||
info="Temperature threshold for resetting prompt"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Initial Prompt",
|
||||
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
|
||||
info="Initial prompt for first window"
|
||||
),
|
||||
gr.Slider(
|
||||
label="Temperature",
|
||||
value=defaults.get("temperature", cls.__fields__["temperature"].default),
|
||||
minimum=0.0,
|
||||
step=0.01,
|
||||
maximum=1.0,
|
||||
info="Temperature for sampling"
|
||||
),
|
||||
gr.Number(
|
||||
label="Compression Ratio Threshold",
|
||||
value=defaults.get("compression_ratio_threshold",
|
||||
cls.__fields__["compression_ratio_threshold"].default),
|
||||
info="Threshold for gzip compression ratio"
|
||||
)
|
||||
]
|
||||
|
||||
faster_whisper_inputs = [
|
||||
gr.Number(
|
||||
label="Length Penalty",
|
||||
value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
|
||||
info="Exponential length penalty",
|
||||
),
|
||||
gr.Number(
|
||||
label="Repetition Penalty",
|
||||
value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
|
||||
info="Penalty for repeated tokens"
|
||||
),
|
||||
gr.Number(
|
||||
label="No Repeat N-gram Size",
|
||||
value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
|
||||
precision=0,
|
||||
info="Size of n-grams to prevent repetition"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Prefix",
|
||||
value=defaults.get("prefix", GRADIO_NONE_STR),
|
||||
info="Prefix text for first window"
|
||||
),
|
||||
gr.Checkbox(
|
||||
label="Suppress Blank",
|
||||
value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
|
||||
info="Suppress blank outputs at start of sampling"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Suppress Tokens",
|
||||
value=defaults.get("suppress_tokens", "[-1]"),
|
||||
info="Token IDs to suppress"
|
||||
),
|
||||
gr.Number(
|
||||
label="Max Initial Timestamp",
|
||||
value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
|
||||
info="Maximum initial timestamp"
|
||||
),
|
||||
gr.Checkbox(
|
||||
label="Word Timestamps",
|
||||
value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
|
||||
info="Extract word-level timestamps"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Prepend Punctuations",
|
||||
value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
|
||||
info="Punctuations to merge with next word"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Append Punctuations",
|
||||
value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
|
||||
info="Punctuations to merge with previous word"
|
||||
),
|
||||
gr.Number(
|
||||
label="Max New Tokens",
|
||||
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
|
||||
precision=0,
|
||||
info="Maximum number of new tokens per chunk"
|
||||
),
|
||||
gr.Number(
|
||||
label="Chunk Length (s)",
|
||||
value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
|
||||
precision=0,
|
||||
info="Length of audio segments in seconds"
|
||||
),
|
||||
gr.Number(
|
||||
label="Hallucination Silence Threshold (sec)",
|
||||
value=defaults.get("hallucination_silence_threshold",
|
||||
GRADIO_NONE_NUMBER_MIN),
|
||||
info="Threshold for skipping silent periods in hallucination detection"
|
||||
),
|
||||
gr.Textbox(
|
||||
label="Hotwords",
|
||||
value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
|
||||
info="Hotwords/hint phrases for the model"
|
||||
),
|
||||
gr.Number(
|
||||
label="Language Detection Threshold",
|
||||
value=defaults.get("language_detection_threshold",
|
||||
GRADIO_NONE_NUMBER_MIN),
|
||||
info="Threshold for language detection probability"
|
||||
),
|
||||
gr.Number(
|
||||
label="Language Detection Segments",
|
||||
value=defaults.get("language_detection_segments",
|
||||
cls.__fields__["language_detection_segments"].default),
|
||||
precision=0,
|
||||
info="Number of segments for language detection"
|
||||
)
|
||||
]
|
||||
|
||||
insanely_fast_whisper_inputs = [
|
||||
gr.Number(
|
||||
label="Batch Size",
|
||||
value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
|
||||
precision=0,
|
||||
info="Batch size for processing"
|
||||
)
|
||||
]
|
||||
|
||||
if whisper_type != WhisperImpl.FASTER_WHISPER.value:
|
||||
for input_component in faster_whisper_inputs:
|
||||
input_component.visible = False
|
||||
|
||||
if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
|
||||
for input_component in insanely_fast_whisper_inputs:
|
||||
input_component.visible = False
|
||||
|
||||
inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class TranscriptionPipelineParams(BaseModel):
|
||||
"""Transcription pipeline parameters"""
|
||||
whisper: WhisperParams = Field(default_factory=WhisperParams)
|
||||
vad: VadParams = Field(default_factory=VadParams)
|
||||
diarization: DiarizationParams = Field(default_factory=DiarizationParams)
|
||||
bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
|
||||
|
||||
def to_dict(self) -> Dict:
|
||||
data = {
|
||||
"whisper": self.whisper.to_dict(),
|
||||
"vad": self.vad.to_dict(),
|
||||
"diarization": self.diarization.to_dict(),
|
||||
"bgm_separation": self.bgm_separation.to_dict()
|
||||
}
|
||||
return data
|
||||
|
||||
def to_list(self) -> List:
|
||||
"""
|
||||
Convert data class to the list because I have to pass the parameters as a list in the gradio.
|
||||
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
||||
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
||||
"""
|
||||
whisper_list = self.whisper.to_list()
|
||||
vad_list = self.vad.to_list()
|
||||
diarization_list = self.diarization.to_list()
|
||||
bgm_sep_list = self.bgm_separation.to_list()
|
||||
return whisper_list + vad_list + diarization_list + bgm_sep_list
|
||||
|
||||
@staticmethod
|
||||
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
|
||||
"""Convert list to the data class again to use it in a function."""
|
||||
data_list = deepcopy(pipeline_list)
|
||||
|
||||
whisper_list = data_list[0:len(WhisperParams.__annotations__)]
|
||||
data_list = data_list[len(WhisperParams.__annotations__):]
|
||||
|
||||
vad_list = data_list[0:len(VadParams.__annotations__)]
|
||||
data_list = data_list[len(VadParams.__annotations__):]
|
||||
|
||||
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
|
||||
data_list = data_list[len(DiarizationParams.__annotations__):]
|
||||
|
||||
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
|
||||
|
||||
return TranscriptionPipelineParams(
|
||||
whisper=WhisperParams.from_list(whisper_list),
|
||||
vad=VadParams.from_list(vad_list),
|
||||
diarization=DiarizationParams.from_list(diarization_list),
|
||||
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
|
||||
)
|
||||
@@ -12,11 +12,11 @@ import gradio as gr
|
||||
from argparse import Namespace
|
||||
|
||||
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.whisper.whisper_base import WhisperBase
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||
|
||||
|
||||
class FasterWhisperInference(WhisperBase):
|
||||
class FasterWhisperInference(BaseTranscriptionPipeline):
|
||||
def __init__(self,
|
||||
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> Tuple[List[dict], float]:
|
||||
) -> Tuple[List[Segment], float]:
|
||||
"""
|
||||
transcribe method for faster-whisper.
|
||||
|
||||
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
|
||||
|
||||
Returns
|
||||
----------
|
||||
segments_result: List[dict]
|
||||
list of dicts that includes start, end timestamps and transcribed text
|
||||
segments_result: List[Segment]
|
||||
list of Segment that includes start, end timestamps and transcribed text
|
||||
elapsed_time: float
|
||||
elapsed time for transcription
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
params = WhisperParameters.as_value(*whisper_params)
|
||||
params = WhisperParams.from_list(list(whisper_params))
|
||||
|
||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||
self.update_model(params.model_size, params.compute_type, progress)
|
||||
|
||||
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
||||
if not params.initial_prompt:
|
||||
params.initial_prompt = None
|
||||
if not params.prefix:
|
||||
params.prefix = None
|
||||
if not params.hotwords:
|
||||
params.hotwords = None
|
||||
|
||||
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
||||
|
||||
segments, info = self.model.transcribe(
|
||||
audio=audio,
|
||||
language=params.lang,
|
||||
@@ -112,11 +102,7 @@ class FasterWhisperInference(WhisperBase):
|
||||
segments_result = []
|
||||
for segment in segments:
|
||||
progress(segment.start / info.duration, desc="Transcribing..")
|
||||
segments_result.append({
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"text": segment.text
|
||||
})
|
||||
segments_result.append(Segment.from_faster_whisper(segment))
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
return segments_result, elapsed_time
|
||||
|
||||
@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
||||
from argparse import Namespace
|
||||
|
||||
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.whisper.whisper_base import WhisperBase
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||
|
||||
|
||||
class InsanelyFastWhisperInference(WhisperBase):
|
||||
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
|
||||
def __init__(self,
|
||||
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||
@@ -32,15 +32,13 @@ class InsanelyFastWhisperInference(WhisperBase):
|
||||
self.model_dir = model_dir
|
||||
os.makedirs(self.model_dir, exist_ok=True)
|
||||
|
||||
openai_models = whisper.available_models()
|
||||
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
||||
self.available_models = openai_models + distil_models
|
||||
self.available_models = self.get_model_paths()
|
||||
|
||||
def transcribe(self,
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> Tuple[List[dict], float]:
|
||||
) -> Tuple[List[Segment], float]:
|
||||
"""
|
||||
transcribe method for faster-whisper.
|
||||
|
||||
@@ -55,13 +53,13 @@ class InsanelyFastWhisperInference(WhisperBase):
|
||||
|
||||
Returns
|
||||
----------
|
||||
segments_result: List[dict]
|
||||
list of dicts that includes start, end timestamps and transcribed text
|
||||
segments_result: List[Segment]
|
||||
list of Segment that includes start, end timestamps and transcribed text
|
||||
elapsed_time: float
|
||||
elapsed time for transcription
|
||||
"""
|
||||
start_time = time.time()
|
||||
params = WhisperParameters.as_value(*whisper_params)
|
||||
params = WhisperParams.from_list(list(whisper_params))
|
||||
|
||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||
self.update_model(params.model_size, params.compute_type, progress)
|
||||
@@ -95,9 +93,17 @@ class InsanelyFastWhisperInference(WhisperBase):
|
||||
generate_kwargs=kwargs
|
||||
)
|
||||
|
||||
segments_result = self.format_result(
|
||||
transcribed_result=segments,
|
||||
)
|
||||
segments_result = []
|
||||
for item in segments["chunks"]:
|
||||
start, end = item["timestamp"][0], item["timestamp"][1]
|
||||
if end is None:
|
||||
end = start
|
||||
segments_result.append(Segment(
|
||||
text=item["text"],
|
||||
start=start,
|
||||
end=end
|
||||
))
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
return segments_result, elapsed_time
|
||||
|
||||
@@ -138,31 +144,26 @@ class InsanelyFastWhisperInference(WhisperBase):
|
||||
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_result(
|
||||
transcribed_result: dict
|
||||
) -> List[dict]:
|
||||
def get_model_paths(self):
|
||||
"""
|
||||
Format the transcription result of insanely_fast_whisper as the same with other implementation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
transcribed_result: dict
|
||||
Transcription result of the insanely_fast_whisper
|
||||
Get available models from models path including fine-tuned model.
|
||||
|
||||
Returns
|
||||
----------
|
||||
result: List[dict]
|
||||
Formatted result as the same with other implementation
|
||||
Name set of models
|
||||
"""
|
||||
result = transcribed_result["chunks"]
|
||||
for item in result:
|
||||
start, end = item["timestamp"][0], item["timestamp"][1]
|
||||
if end is None:
|
||||
end = start
|
||||
item["start"] = start
|
||||
item["end"] = end
|
||||
return result
|
||||
openai_models = whisper.available_models()
|
||||
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
||||
default_models = openai_models + distil_models
|
||||
|
||||
existing_models = os.listdir(self.model_dir)
|
||||
wrong_dirs = [".locks"]
|
||||
|
||||
available_models = default_models + existing_models
|
||||
available_models = [model for model in available_models if model not in wrong_dirs]
|
||||
available_models = sorted(set(available_models), key=available_models.index)
|
||||
|
||||
return available_models
|
||||
|
||||
@staticmethod
|
||||
def download_model(
|
||||
|
||||
@@ -8,11 +8,11 @@ import os
|
||||
from argparse import Namespace
|
||||
|
||||
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
||||
from modules.whisper.whisper_base import WhisperBase
|
||||
from modules.whisper.whisper_parameter import *
|
||||
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||
from modules.whisper.data_classes import *
|
||||
|
||||
|
||||
class WhisperInference(WhisperBase):
|
||||
class WhisperInference(BaseTranscriptionPipeline):
|
||||
def __init__(self,
|
||||
model_dir: str = WHISPER_MODELS_DIR,
|
||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
*whisper_params,
|
||||
) -> Tuple[List[dict], float]:
|
||||
) -> Tuple[List[Segment], float]:
|
||||
"""
|
||||
transcribe method for faster-whisper.
|
||||
|
||||
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
|
||||
|
||||
Returns
|
||||
----------
|
||||
segments_result: List[dict]
|
||||
list of dicts that includes start, end timestamps and transcribed text
|
||||
segments_result: List[Segment]
|
||||
list of Segment that includes start, end timestamps and transcribed text
|
||||
elapsed_time: float
|
||||
elapsed time for transcription
|
||||
"""
|
||||
start_time = time.time()
|
||||
params = WhisperParameters.as_value(*whisper_params)
|
||||
params = WhisperParams.from_list(list(whisper_params))
|
||||
|
||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||
self.update_model(params.model_size, params.compute_type, progress)
|
||||
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
|
||||
def progress_callback(progress_value):
|
||||
progress(progress_value, desc="Transcribing..")
|
||||
|
||||
segments_result = self.model.transcribe(audio=audio,
|
||||
language=params.lang,
|
||||
verbose=False,
|
||||
beam_size=params.beam_size,
|
||||
logprob_threshold=params.log_prob_threshold,
|
||||
no_speech_threshold=params.no_speech_threshold,
|
||||
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
||||
fp16=True if params.compute_type == "float16" else False,
|
||||
best_of=params.best_of,
|
||||
patience=params.patience,
|
||||
temperature=params.temperature,
|
||||
compression_ratio_threshold=params.compression_ratio_threshold,
|
||||
progress_callback=progress_callback,)["segments"]
|
||||
elapsed_time = time.time() - start_time
|
||||
result = self.model.transcribe(audio=audio,
|
||||
language=params.lang,
|
||||
verbose=False,
|
||||
beam_size=params.beam_size,
|
||||
logprob_threshold=params.log_prob_threshold,
|
||||
no_speech_threshold=params.no_speech_threshold,
|
||||
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
||||
fp16=True if params.compute_type == "float16" else False,
|
||||
best_of=params.best_of,
|
||||
patience=params.patience,
|
||||
temperature=params.temperature,
|
||||
compression_ratio_threshold=params.compression_ratio_threshold,
|
||||
progress_callback=progress_callback,)["segments"]
|
||||
segments_result = []
|
||||
for segment in result:
|
||||
segments_result.append(Segment(
|
||||
start=segment["start"],
|
||||
end=segment["end"],
|
||||
text=segment["text"]
|
||||
))
|
||||
|
||||
elapsed_time = time.time() - start_time
|
||||
return segments_result, elapsed_time
|
||||
|
||||
def update_model(self,
|
||||
|
||||
@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
|
||||
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
||||
from modules.whisper.whisper_Inference import WhisperInference
|
||||
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
||||
from modules.whisper.whisper_base import WhisperBase
|
||||
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||
from modules.whisper.data_classes import *
|
||||
|
||||
|
||||
class WhisperFactory:
|
||||
@@ -19,7 +20,7 @@ class WhisperFactory:
|
||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||
uvr_model_dir: str = UVR_MODELS_DIR,
|
||||
output_dir: str = OUTPUT_DIR,
|
||||
) -> "WhisperBase":
|
||||
) -> "BaseTranscriptionPipeline":
|
||||
"""
|
||||
Create a whisper inference class based on the provided whisper_type.
|
||||
|
||||
@@ -45,36 +46,29 @@ class WhisperFactory:
|
||||
|
||||
Returns
|
||||
-------
|
||||
WhisperBase
|
||||
BaseTranscriptionPipeline
|
||||
An instance of the appropriate whisper inference class based on the whisper_type.
|
||||
"""
|
||||
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||
|
||||
whisper_type = whisper_type.lower().strip()
|
||||
whisper_type = whisper_type.strip().lower()
|
||||
|
||||
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
|
||||
whisper_typos = ["whisper"]
|
||||
insanely_fast_whisper_typos = [
|
||||
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
||||
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
||||
]
|
||||
|
||||
if whisper_type in faster_whisper_typos:
|
||||
if whisper_type == WhisperImpl.FASTER_WHISPER.value:
|
||||
return FasterWhisperInference(
|
||||
model_dir=faster_whisper_model_dir,
|
||||
output_dir=output_dir,
|
||||
diarization_model_dir=diarization_model_dir,
|
||||
uvr_model_dir=uvr_model_dir
|
||||
)
|
||||
elif whisper_type in whisper_typos:
|
||||
elif whisper_type == WhisperImpl.WHISPER.value:
|
||||
return WhisperInference(
|
||||
model_dir=whisper_model_dir,
|
||||
output_dir=output_dir,
|
||||
diarization_model_dir=diarization_model_dir,
|
||||
uvr_model_dir=uvr_model_dir
|
||||
)
|
||||
elif whisper_type in insanely_fast_whisper_typos:
|
||||
elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
|
||||
return InsanelyFastWhisperInference(
|
||||
model_dir=insanely_fast_whisper_model_dir,
|
||||
output_dir=output_dir,
|
||||
|
||||
@@ -1,371 +0,0 @@
|
||||
from dataclasses import dataclass, fields
|
||||
import gradio as gr
|
||||
from typing import Optional, Dict
|
||||
import yaml
|
||||
|
||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperParameters:
|
||||
model_size: gr.Dropdown
|
||||
lang: gr.Dropdown
|
||||
is_translate: gr.Checkbox
|
||||
beam_size: gr.Number
|
||||
log_prob_threshold: gr.Number
|
||||
no_speech_threshold: gr.Number
|
||||
compute_type: gr.Dropdown
|
||||
best_of: gr.Number
|
||||
patience: gr.Number
|
||||
condition_on_previous_text: gr.Checkbox
|
||||
prompt_reset_on_temperature: gr.Slider
|
||||
initial_prompt: gr.Textbox
|
||||
temperature: gr.Slider
|
||||
compression_ratio_threshold: gr.Number
|
||||
vad_filter: gr.Checkbox
|
||||
threshold: gr.Slider
|
||||
min_speech_duration_ms: gr.Number
|
||||
max_speech_duration_s: gr.Number
|
||||
min_silence_duration_ms: gr.Number
|
||||
speech_pad_ms: gr.Number
|
||||
batch_size: gr.Number
|
||||
is_diarize: gr.Checkbox
|
||||
hf_token: gr.Textbox
|
||||
diarization_device: gr.Dropdown
|
||||
length_penalty: gr.Number
|
||||
repetition_penalty: gr.Number
|
||||
no_repeat_ngram_size: gr.Number
|
||||
prefix: gr.Textbox
|
||||
suppress_blank: gr.Checkbox
|
||||
suppress_tokens: gr.Textbox
|
||||
max_initial_timestamp: gr.Number
|
||||
word_timestamps: gr.Checkbox
|
||||
prepend_punctuations: gr.Textbox
|
||||
append_punctuations: gr.Textbox
|
||||
max_new_tokens: gr.Number
|
||||
chunk_length: gr.Number
|
||||
hallucination_silence_threshold: gr.Number
|
||||
hotwords: gr.Textbox
|
||||
language_detection_threshold: gr.Number
|
||||
language_detection_segments: gr.Number
|
||||
is_bgm_separate: gr.Checkbox
|
||||
uvr_model_size: gr.Dropdown
|
||||
uvr_device: gr.Dropdown
|
||||
uvr_segment_size: gr.Number
|
||||
uvr_save_file: gr.Checkbox
|
||||
uvr_enable_offload: gr.Checkbox
|
||||
"""
|
||||
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
||||
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
||||
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
||||
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
||||
|
||||
Attributes
|
||||
----------
|
||||
model_size: gr.Dropdown
|
||||
Whisper model size.
|
||||
|
||||
lang: gr.Dropdown
|
||||
Source language of the file to transcribe.
|
||||
|
||||
is_translate: gr.Checkbox
|
||||
Boolean value that determines whether to translate to English.
|
||||
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
||||
|
||||
beam_size: gr.Number
|
||||
Int value that is used for decoding option.
|
||||
|
||||
log_prob_threshold: gr.Number
|
||||
If the average log probability over sampled tokens is below this value, treat as failed.
|
||||
|
||||
no_speech_threshold: gr.Number
|
||||
If the no_speech probability is higher than this value AND
|
||||
the average log probability over sampled tokens is below `log_prob_threshold`,
|
||||
consider the segment as silent.
|
||||
|
||||
compute_type: gr.Dropdown
|
||||
compute type for transcription.
|
||||
see more info : https://opennmt.net/CTranslate2/quantization.html
|
||||
|
||||
best_of: gr.Number
|
||||
Number of candidates when sampling with non-zero temperature.
|
||||
|
||||
patience: gr.Number
|
||||
Beam search patience factor.
|
||||
|
||||
condition_on_previous_text: gr.Checkbox
|
||||
if True, the previous output of the model is provided as a prompt for the next window;
|
||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
||||
|
||||
initial_prompt: gr.Textbox
|
||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
||||
to make it more likely to predict those word correctly.
|
||||
|
||||
temperature: gr.Slider
|
||||
Temperature for sampling. It can be a tuple of temperatures,
|
||||
which will be successively used upon failures according to either
|
||||
`compression_ratio_threshold` or `log_prob_threshold`.
|
||||
|
||||
compression_ratio_threshold: gr.Number
|
||||
If the gzip compression ratio is above this value, treat as failed
|
||||
|
||||
vad_filter: gr.Checkbox
|
||||
Enable the voice activity detection (VAD) to filter out parts of the audio
|
||||
without speech. This step is using the Silero VAD model
|
||||
https://github.com/snakers4/silero-vad.
|
||||
|
||||
threshold: gr.Slider
|
||||
This parameter is related with Silero VAD. Speech threshold.
|
||||
Silero VAD outputs speech probabilities for each audio chunk,
|
||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
||||
|
||||
min_speech_duration_ms: gr.Number
|
||||
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
||||
|
||||
max_speech_duration_s: gr.Number
|
||||
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
||||
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
||||
split aggressively just before max_speech_duration_s.
|
||||
|
||||
min_silence_duration_ms: gr.Number
|
||||
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
||||
before separating it
|
||||
|
||||
speech_pad_ms: gr.Number
|
||||
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
||||
|
||||
batch_size: gr.Number
|
||||
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
||||
|
||||
is_diarize: gr.Checkbox
|
||||
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
||||
|
||||
hf_token: gr.Textbox
|
||||
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
||||
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
||||
|
||||
diarization_device: gr.Dropdown
|
||||
This parameter is related with whisperx. Device to run diarization model
|
||||
|
||||
length_penalty: gr.Number
|
||||
This parameter is related to faster-whisper. Exponential length penalty constant.
|
||||
|
||||
repetition_penalty: gr.Number
|
||||
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
||||
(set > 1 to penalize).
|
||||
|
||||
no_repeat_ngram_size: gr.Number
|
||||
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
||||
|
||||
prefix: gr.Textbox
|
||||
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
||||
|
||||
suppress_blank: gr.Checkbox
|
||||
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
||||
|
||||
suppress_tokens: gr.Textbox
|
||||
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
||||
of symbols as defined in the model config.json file.
|
||||
|
||||
max_initial_timestamp: gr.Number
|
||||
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
||||
|
||||
word_timestamps: gr.Checkbox
|
||||
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
||||
and dynamic time warping, and include the timestamps for each word in each segment.
|
||||
|
||||
prepend_punctuations: gr.Textbox
|
||||
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
||||
with the next word.
|
||||
|
||||
append_punctuations: gr.Textbox
|
||||
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
||||
with the previous word.
|
||||
|
||||
max_new_tokens: gr.Number
|
||||
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
||||
the maximum will be set by the default max_length.
|
||||
|
||||
chunk_length: gr.Number
|
||||
This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
|
||||
If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
|
||||
|
||||
hallucination_silence_threshold: gr.Number
|
||||
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
||||
(in seconds) when a possible hallucination is detected.
|
||||
|
||||
hotwords: gr.Textbox
|
||||
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
||||
|
||||
language_detection_threshold: gr.Number
|
||||
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
||||
|
||||
language_detection_segments: gr.Number
|
||||
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
||||
|
||||
is_separate_bgm: gr.Checkbox
|
||||
This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
|
||||
|
||||
uvr_model_size: gr.Dropdown
|
||||
This parameter is related to UVR. UVR model size.
|
||||
|
||||
uvr_device: gr.Dropdown
|
||||
This parameter is related to UVR. Device to run UVR model.
|
||||
|
||||
uvr_segment_size: gr.Number
|
||||
This parameter is related to UVR. Segment size for UVR model.
|
||||
|
||||
uvr_save_file: gr.Checkbox
|
||||
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
|
||||
|
||||
uvr_enable_offload: gr.Checkbox
|
||||
This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
|
||||
after each transcription.
|
||||
"""
|
||||
|
||||
def as_list(self) -> list:
|
||||
"""
|
||||
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
||||
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
||||
|
||||
Returns
|
||||
----------
|
||||
A list of Gradio components
|
||||
"""
|
||||
return [getattr(self, f.name) for f in fields(self)]
|
||||
|
||||
@staticmethod
|
||||
def as_value(*args) -> 'WhisperValues':
|
||||
"""
|
||||
To use Whisper parameters in function after Gradio post-processing.
|
||||
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
||||
|
||||
Returns
|
||||
----------
|
||||
WhisperValues
|
||||
Data class that has values of parameters
|
||||
"""
|
||||
return WhisperValues(*args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperValues:
|
||||
model_size: str = "large-v2"
|
||||
lang: Optional[str] = None
|
||||
is_translate: bool = False
|
||||
beam_size: int = 5
|
||||
log_prob_threshold: float = -1.0
|
||||
no_speech_threshold: float = 0.6
|
||||
compute_type: str = "float16"
|
||||
best_of: int = 5
|
||||
patience: float = 1.0
|
||||
condition_on_previous_text: bool = True
|
||||
prompt_reset_on_temperature: float = 0.5
|
||||
initial_prompt: Optional[str] = None
|
||||
temperature: float = 0.0
|
||||
compression_ratio_threshold: float = 2.4
|
||||
vad_filter: bool = False
|
||||
threshold: float = 0.5
|
||||
min_speech_duration_ms: int = 250
|
||||
max_speech_duration_s: float = float("inf")
|
||||
min_silence_duration_ms: int = 2000
|
||||
speech_pad_ms: int = 400
|
||||
batch_size: int = 24
|
||||
is_diarize: bool = False
|
||||
hf_token: str = ""
|
||||
diarization_device: str = "cuda"
|
||||
length_penalty: float = 1.0
|
||||
repetition_penalty: float = 1.0
|
||||
no_repeat_ngram_size: int = 0
|
||||
prefix: Optional[str] = None
|
||||
suppress_blank: bool = True
|
||||
suppress_tokens: Optional[str] = "[-1]"
|
||||
max_initial_timestamp: float = 0.0
|
||||
word_timestamps: bool = False
|
||||
prepend_punctuations: Optional[str] = "\"'“¿([{-"
|
||||
append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
|
||||
max_new_tokens: Optional[int] = None
|
||||
chunk_length: Optional[int] = 30
|
||||
hallucination_silence_threshold: Optional[float] = None
|
||||
hotwords: Optional[str] = None
|
||||
language_detection_threshold: Optional[float] = None
|
||||
language_detection_segments: int = 1
|
||||
is_bgm_separate: bool = False
|
||||
uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
|
||||
uvr_device: str = "cuda"
|
||||
uvr_segment_size: int = 256
|
||||
uvr_save_file: bool = False
|
||||
uvr_enable_offload: bool = True
|
||||
"""
|
||||
A data class to use Whisper parameters.
|
||||
"""
|
||||
|
||||
def to_yaml(self) -> Dict:
|
||||
data = {
|
||||
"whisper": {
|
||||
"model_size": self.model_size,
|
||||
"lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
|
||||
"is_translate": self.is_translate,
|
||||
"beam_size": self.beam_size,
|
||||
"log_prob_threshold": self.log_prob_threshold,
|
||||
"no_speech_threshold": self.no_speech_threshold,
|
||||
"best_of": self.best_of,
|
||||
"patience": self.patience,
|
||||
"condition_on_previous_text": self.condition_on_previous_text,
|
||||
"prompt_reset_on_temperature": self.prompt_reset_on_temperature,
|
||||
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
||||
"temperature": self.temperature,
|
||||
"compression_ratio_threshold": self.compression_ratio_threshold,
|
||||
"batch_size": self.batch_size,
|
||||
"length_penalty": self.length_penalty,
|
||||
"repetition_penalty": self.repetition_penalty,
|
||||
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
||||
"prefix": None if not self.prefix else self.prefix,
|
||||
"suppress_blank": self.suppress_blank,
|
||||
"suppress_tokens": self.suppress_tokens,
|
||||
"max_initial_timestamp": self.max_initial_timestamp,
|
||||
"word_timestamps": self.word_timestamps,
|
||||
"prepend_punctuations": self.prepend_punctuations,
|
||||
"append_punctuations": self.append_punctuations,
|
||||
"max_new_tokens": self.max_new_tokens,
|
||||
"chunk_length": self.chunk_length,
|
||||
"hallucination_silence_threshold": self.hallucination_silence_threshold,
|
||||
"hotwords": None if not self.hotwords else self.hotwords,
|
||||
"language_detection_threshold": self.language_detection_threshold,
|
||||
"language_detection_segments": self.language_detection_segments,
|
||||
},
|
||||
"vad": {
|
||||
"vad_filter": self.vad_filter,
|
||||
"threshold": self.threshold,
|
||||
"min_speech_duration_ms": self.min_speech_duration_ms,
|
||||
"max_speech_duration_s": self.max_speech_duration_s,
|
||||
"min_silence_duration_ms": self.min_silence_duration_ms,
|
||||
"speech_pad_ms": self.speech_pad_ms,
|
||||
},
|
||||
"diarization": {
|
||||
"is_diarize": self.is_diarize,
|
||||
"hf_token": self.hf_token
|
||||
},
|
||||
"bgm_separation": {
|
||||
"is_separate_bgm": self.is_bgm_separate,
|
||||
"model_size": self.uvr_model_size,
|
||||
"segment_size": self.uvr_segment_size,
|
||||
"save_file": self.uvr_save_file,
|
||||
"enable_offload": self.uvr_enable_offload
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
def as_list(self) -> list:
|
||||
"""
|
||||
Converts the data class attributes into a list
|
||||
|
||||
Returns
|
||||
----------
|
||||
A list of Whisper parameters
|
||||
"""
|
||||
return [getattr(self, f.name) for f in fields(self)]
|
||||
@@ -56,7 +56,7 @@
|
||||
"!pip install faster-whisper==1.0.3\n",
|
||||
"!pip install ctranslate2==4.4.0\n",
|
||||
"!pip install gradio\n",
|
||||
"!pip install git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error\n",
|
||||
"!pip install gradio-i18n\n",
|
||||
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n",
|
||||
"!pip install git+https://github.com/JuanBindez/pytubefix.git\n",
|
||||
"!pip install tokenizers==0.19.1\n",
|
||||
@@ -99,7 +99,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 3,
|
||||
"metadata": {
|
||||
"id": "PQroYRRZzQiN",
|
||||
"cellView": "form"
|
||||
|
||||
@@ -11,7 +11,7 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
|
||||
faster-whisper==1.0.3
|
||||
transformers
|
||||
gradio
|
||||
git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error
|
||||
gradio-i18n
|
||||
pytubefix
|
||||
ruamel.yaml==0.18.6
|
||||
pyannote.audio==3.3.1
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from modules.utils.paths import *
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.whisper_parameter import WhisperValues
|
||||
from modules.whisper.data_classes import *
|
||||
from test_config import *
|
||||
from test_transcription import download_file, test_transcribe
|
||||
|
||||
@@ -17,9 +17,9 @@ import os
|
||||
@pytest.mark.parametrize(
|
||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||
[
|
||||
("whisper", False, True, False),
|
||||
("faster-whisper", False, True, False),
|
||||
("insanely_fast_whisper", False, True, False)
|
||||
(WhisperImpl.WHISPER.value, False, True, False),
|
||||
(WhisperImpl.FASTER_WHISPER.value, False, True, False),
|
||||
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
|
||||
]
|
||||
)
|
||||
def test_bgm_separation_pipeline(
|
||||
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
|
||||
@pytest.mark.parametrize(
|
||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||
[
|
||||
("whisper", True, True, False),
|
||||
("faster-whisper", True, True, False),
|
||||
("insanely_fast_whisper", True, True, False)
|
||||
(WhisperImpl.WHISPER.value, True, True, False),
|
||||
(WhisperImpl.FASTER_WHISPER.value, True, True, False),
|
||||
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
|
||||
]
|
||||
)
|
||||
def test_bgm_separation_with_vad_pipeline(
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
from modules.utils.paths import *
|
||||
|
||||
import functools
|
||||
import jiwer
|
||||
import os
|
||||
import torch
|
||||
|
||||
from modules.utils.paths import *
|
||||
from modules.utils.youtube_manager import *
|
||||
|
||||
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
||||
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
||||
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
|
||||
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
||||
TEST_WHISPER_MODEL = "tiny"
|
||||
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
||||
@@ -13,5 +17,24 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
||||
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_cuda_available():
|
||||
return torch.cuda.is_available()
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
|
||||
try:
|
||||
yt_temp_path = os.path.join("modules", "yt_tmp.wav")
|
||||
if os.path.exists(yt_temp_path):
|
||||
return False
|
||||
yt = get_ytdata(url)
|
||||
audio = get_ytaudio(yt)
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"Pytube has detected as a bot: {e}")
|
||||
return True
|
||||
|
||||
|
||||
def calculate_wer(answer, prediction):
|
||||
return jiwer.wer(answer, prediction)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from modules.utils.paths import *
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.whisper_parameter import WhisperValues
|
||||
from modules.whisper.data_classes import *
|
||||
from test_config import *
|
||||
from test_transcription import download_file, test_transcribe
|
||||
|
||||
@@ -16,9 +16,9 @@ import os
|
||||
@pytest.mark.parametrize(
|
||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||
[
|
||||
("whisper", False, False, True),
|
||||
("faster-whisper", False, False, True),
|
||||
("insanely_fast_whisper", False, False, True)
|
||||
(WhisperImpl.WHISPER.value, False, False, True),
|
||||
(WhisperImpl.FASTER_WHISPER.value, False, False, True),
|
||||
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
|
||||
]
|
||||
)
|
||||
def test_diarization_pipeline(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.whisper_parameter import WhisperValues
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.utils.subtitle_manager import read_file
|
||||
from modules.utils.paths import WEBUI_DIR
|
||||
from test_config import *
|
||||
|
||||
@@ -12,9 +13,9 @@ import os
|
||||
@pytest.mark.parametrize(
|
||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||
[
|
||||
("whisper", False, False, False),
|
||||
("faster-whisper", False, False, False),
|
||||
("insanely_fast_whisper", False, False, False)
|
||||
(WhisperImpl.WHISPER.value, False, False, False),
|
||||
(WhisperImpl.FASTER_WHISPER.value, False, False, False),
|
||||
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
|
||||
]
|
||||
)
|
||||
def test_transcribe(
|
||||
@@ -28,6 +29,10 @@ def test_transcribe(
|
||||
if not os.path.exists(audio_path):
|
||||
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
||||
|
||||
answer = TEST_ANSWER
|
||||
if diarization:
|
||||
answer = "SPEAKER_00|"+TEST_ANSWER
|
||||
|
||||
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
||||
whisper_type=whisper_type,
|
||||
)
|
||||
@@ -37,16 +42,24 @@ def test_transcribe(
|
||||
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
||||
)
|
||||
|
||||
hparams = WhisperValues(
|
||||
model_size=TEST_WHISPER_MODEL,
|
||||
vad_filter=vad_filter,
|
||||
is_bgm_separate=bgm_separation,
|
||||
compute_type=whisper_inferencer.current_compute_type,
|
||||
uvr_enable_offload=True,
|
||||
is_diarize=diarization,
|
||||
).as_list()
|
||||
hparams = TranscriptionPipelineParams(
|
||||
whisper=WhisperParams(
|
||||
model_size=TEST_WHISPER_MODEL,
|
||||
compute_type=whisper_inferencer.current_compute_type
|
||||
),
|
||||
vad=VadParams(
|
||||
vad_filter=vad_filter
|
||||
),
|
||||
bgm_separation=BGMSeparationParams(
|
||||
is_separate_bgm=bgm_separation,
|
||||
enable_offload=True
|
||||
),
|
||||
diarization=DiarizationParams(
|
||||
is_diarize=diarization
|
||||
),
|
||||
).to_list()
|
||||
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
||||
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
|
||||
[audio_path],
|
||||
None,
|
||||
"SRT",
|
||||
@@ -54,29 +67,29 @@ def test_transcribe(
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
subtitle = read_file(file_paths[0]).split("\n")
|
||||
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
if not is_pytube_detected_bot():
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
|
||||
TEST_YOUTUBE_URL,
|
||||
"SRT",
|
||||
False,
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert os.path.exists(file_path)
|
||||
|
||||
whisper_inferencer.transcribe_youtube(
|
||||
TEST_YOUTUBE_URL,
|
||||
"SRT",
|
||||
False,
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
|
||||
whisper_inferencer.transcribe_mic(
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
|
||||
audio_path,
|
||||
"SRT",
|
||||
False,
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
subtitle = read_file(file_path).split("\n")
|
||||
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||
|
||||
|
||||
def download_file(url, save_dir):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from modules.utils.paths import *
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.whisper_parameter import WhisperValues
|
||||
from modules.whisper.data_classes import *
|
||||
from test_config import *
|
||||
from test_transcription import download_file, test_transcribe
|
||||
|
||||
@@ -12,9 +12,9 @@ import os
|
||||
@pytest.mark.parametrize(
|
||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||
[
|
||||
("whisper", True, False, False),
|
||||
("faster-whisper", True, False, False),
|
||||
("insanely_fast_whisper", True, False, False)
|
||||
(WhisperImpl.WHISPER.value, True, False, False),
|
||||
(WhisperImpl.FASTER_WHISPER.value, True, False, False),
|
||||
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
|
||||
]
|
||||
)
|
||||
def test_vad_pipeline(
|
||||
|
||||
Reference in New Issue
Block a user