Merge branch 'master' into salfter

This commit is contained in:
2024-11-04 13:15:50 -08:00
27 changed files with 1476 additions and 969 deletions

View File

@@ -37,7 +37,7 @@ jobs:
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
- name: Install dependencies
run: pip install -r requirements.txt pytest
run: pip install -r requirements.txt pytest jiwer
- name: Run test
run: python -m pytest -rs tests

View File

@@ -128,3 +128,6 @@ This is Whisper's original VRAM usage table for models.
- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
- [ ] Add fast api script
- [ ] Support real-time transcription for microphone
### Translation 🌐
Any PRs translating Japanese, Spanish, French, German, Chinese, or any other language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!

183
app.py
View File

@@ -7,17 +7,14 @@ import yaml
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR, I18N_YAML_PATH)
from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.files_manager import load_yaml
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.translation.nllb_inference import NLLBInference
from modules.ui.htmls import *
from modules.utils.cli_manager import str2bool
from modules.utils.youtube_manager import get_ytmetas
from modules.translation.deepl_api import DeepLAPI
from modules.whisper.whisper_parameter import *
from modules.whisper.data_classes import *
class App:
@@ -44,7 +41,7 @@ class App:
print(f"Use \"{self.args.whisper_type}\" implementation\n"
f"Device \"{self.whisper_inf.device}\" is detected")
def create_whisper_parameters(self):
def create_pipeline_inputs(self):
whisper_params = self.default_params["whisper"]
vad_params = self.default_params["vad"]
diarization_params = self.default_params["diarization"]
@@ -56,7 +53,7 @@ class App:
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
else whisper_params["lang"], label=_("Language"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
with gr.Row():
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
interactive=True)
@@ -66,158 +63,31 @@ class App:
interactive=True)
with gr.Accordion(_("Advanced Parameters"), open=False):
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0,
interactive=True,
info="Beam size to use for decoding.")
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold",
value=whisper_params["log_prob_threshold"], interactive=True,
info="If the average log probability over sampled tokens is below this value, treat as failed.")
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
interactive=True,
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
value=self.whisper_inf.current_compute_type, interactive=True,
allow_custom_value=True,
info="Select the type of computation to perform.")
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
info="Number of candidates when sampling with non-zero temperature.")
nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
info="Beam search patience factor.")
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
value=whisper_params["condition_on_previous_text"],
interactive=True,
info="Condition on previous text during decoding.")
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
value=whisper_params["prompt_reset_on_temperature"],
minimum=0, maximum=1, step=0.01, interactive=True,
info="Resets prompt if temperature is above this value."
" Arg has effect only if 'Condition On Previous Text' is True.")
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
info="Initial prompt to use for decoding.")
sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
step=0.01, maximum=1.0, interactive=True,
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
value=whisper_params["compression_ratio_threshold"],
interactive=True,
info="If the gzip compression ratio is above this value, treat as failed.")
nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
precision=0,
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
info="Exponential length penalty constant.")
nb_repetition_penalty = gr.Number(label="Repetition Penalty",
value=whisper_params["repetition_penalty"],
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
value=whisper_params["no_repeat_ngram_size"],
precision=0,
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
info="Optional text to provide as a prefix for the first window.")
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
info="Suppress blank outputs at the beginning of the sampling.")
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
value=whisper_params["max_initial_timestamp"],
info="The initial timestamp cannot be later than this.")
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
value=whisper_params["prepend_punctuations"],
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
tb_append_punctuations = gr.Textbox(label="Append Punctuations",
value=whisper_params["append_punctuations"],
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
precision=0,
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
value=lambda: whisper_params[
"hallucination_silence_threshold"],
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
value=lambda: whisper_params[
"language_detection_threshold"],
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
nb_language_detection_segments = gr.Number(label="Language Detection Segments",
value=lambda: whisper_params["language_detection_segments"],
precision=0,
info="Number of segments to consider for the language detection.")
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
whisper_type=self.args.whisper_type,
available_compute_types=self.whisper_inf.available_compute_types,
compute_type=self.whisper_inf.current_compute_type)
with gr.Accordion(_("Background Music Remover Filter"), open=False):
cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"),
value=uvr_params["is_separate_bgm"],
interactive=True,
info=_("Enabling this will remove background music"))
dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
choices=self.whisper_inf.music_separator.available_devices)
dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
choices=self.whisper_inf.music_separator.available_models)
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
value=uvr_params["enable_offload"])
uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
available_models=self.whisper_inf.music_separator.available_models,
available_devices=self.whisper_inf.music_separator.available_devices,
device=self.whisper_inf.music_separator.device)
with gr.Accordion(_("Voice Detection Filter"), open=False):
cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"],
interactive=True,
info=_("Enable this to transcribe only detected voice"))
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
value=vad_params["threshold"],
info="Lower it to be more sensitive to small sounds.")
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
value=vad_params["min_speech_duration_ms"],
info="Final speech chunks shorter than this time are thrown out")
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
value=vad_params["max_speech_duration_s"],
info="Maximum duration of speech chunks in \"seconds\".")
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
value=vad_params["min_silence_duration_ms"],
info="In the end of each speech chunk wait for this time"
" before separating it")
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
info="Final speech chunks are padded by this time each side")
vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
with gr.Accordion(_("Diarization"), open=False):
cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"])
tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"],
info=_("This is only needed the first time you download the model"))
dd_diarization_device = gr.Dropdown(label=_("Device"),
choices=self.whisper_inf.diarizer.get_available_device(),
value=self.whisper_inf.diarizer.get_device())
diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
available_devices=self.whisper_inf.diarizer.available_device,
device=self.whisper_inf.diarizer.device)
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
return (
WhisperParameters(
model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
language_detection_threshold=nb_language_detection_threshold,
language_detection_segments=nb_language_detection_segments,
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
),
pipeline_inputs,
dd_file_format,
cb_timestamp
)
@@ -243,7 +113,7 @@ class App:
visible=self.args.colab,
value="")
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -254,7 +124,7 @@ class App:
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_file,
inputs=params + whisper_params.as_list(),
inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
@@ -268,7 +138,7 @@ class App:
tb_title = gr.Label(label=_("Youtube Title"))
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -280,7 +150,7 @@ class App:
params = [tb_youtubelink, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
inputs=params + whisper_params.as_list(),
inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
outputs=[img_thumbnail, tb_title, tb_description])
@@ -290,7 +160,7 @@ class App:
with gr.Row():
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
with gr.Row():
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
@@ -302,7 +172,7 @@ class App:
params = [mic_input, dd_file_format, cb_timestamp]
btn_run.click(fn=self.whisper_inf.transcribe_mic,
inputs=params + whisper_params.as_list(),
inputs=params + pipeline_params,
outputs=[tb_indicator, files_subtitles])
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
@@ -417,7 +287,6 @@ class App:
# Launch the app with optional gradio settings
args = self.args
self.app.queue(
api_open=args.api_open
).launch(
@@ -447,8 +316,8 @@ class App:
parser = argparse.ArgumentParser()
parser.add_argument('--whisper_type', type=str, default="faster-whisper",
choices=["whisper", "faster-whisper", "insanely-fast-whisper"],
parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
choices=[item.value for item in WhisperImpl],
help='A type of the whisper implementation (Github repo name)')
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')

View File

@@ -1,5 +1,6 @@
whisper:
model_size: "medium.en"
file_format: "SRT"
lang: "english"
is_translate: false
beam_size: 5

View File

@@ -411,3 +411,49 @@ ru: # Russian
Instrumental: Инструментал
Vocals: Вокал
SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
tr: # Turkish
Language: Dil
File: Dosya
Youtube: Youtube
Mic: Mikrofon
T2T Translation: T2T Çeviri
BGM Separation: Arka Plan Müziği Ayırma
GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR
Output: Çıktı
Downloadable output file: İndirilebilir çıktı dosyası
Upload File here: Dosya Yükle
Model: Model
Automatic Detection: Otomatik Algılama
File Format: Dosya Formatı
Translate to English?: İngilizceye Çevir?
Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle
Advanced Parameters: Gelişmiş Parametreler
Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi
Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır
Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir
Save separated files to output: Ayrılmış dosyaları çıktıya kaydet
Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak
Voice Detection Filter: Ses Algılama Filtresi
Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et
Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir
Diarization: Konuşmacı Ayrımı
Enable Diarization: Konuşmacı Ayrımını Etkinleştir
HuggingFace Token: HuggingFace Anahtarı
This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co/pyannote/speaker-diarization-3.1" ve "https://huggingface.co/pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor
Device: Cihaz
Youtube Link: Youtube Bağlantısı
Youtube Thumbnail: Youtube Küçük Resmi
Youtube Title: Youtube Başlığı
Youtube Description: Youtube Açıklaması
Record with Mic: Mikrofonla Kaydet
Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle
Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI)
Source Language: Kaynak Dil
Target Language: Hedef Dil
Pro User?: Pro Kullanıcı?
TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR
Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle
Instrumental: Enstrümantal
Vocals: Vokal
SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR

View File

@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
from typing import Optional, Union
import torch
from modules.whisper.data_classes import *
from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
@@ -43,6 +44,8 @@ class DiarizationPipeline:
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
transcript_segments = transcript_result["segments"]
if transcript_segments and isinstance(transcript_segments[0], Segment):
transcript_segments = [seg.model_dump() for seg in transcript_segments]
for seg in transcript_segments:
# assign speaker to segment (if any)
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
@@ -63,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
seg["speaker"] = speaker
# assign speaker to words
if 'words' in seg:
if 'words' in seg and seg['words'] is not None:
for word in seg['words']:
if 'start' in word:
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
@@ -85,10 +88,10 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
if word_speaker is not None:
word["speaker"] = word_speaker
return transcript_result
return {"segments": transcript_segments}
class Segment:
class DiarizationSegment:
def __init__(self, start, end, speaker=None):
self.start = start
self.end = end

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import List, Union, BinaryIO, Optional
from typing import List, Union, BinaryIO, Optional, Tuple
import numpy as np
import time
import logging
@@ -8,6 +8,7 @@ import logging
from modules.utils.paths import DIARIZATION_MODELS_DIR
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
from modules.diarize.audio_loader import load_audio
from modules.whisper.data_classes import *
class Diarizer:
@@ -23,10 +24,10 @@ class Diarizer:
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[dict],
transcribed_result: List[Segment],
use_auth_token: str,
device: Optional[str] = None
):
) -> Tuple[List[Segment], float]:
"""
Diarize transcribed result as a post-processing
@@ -34,7 +35,7 @@ class Diarizer:
----------
audio: Union[str, BinaryIO, np.ndarray]
Audio input. This can be file path or binary type.
transcribed_result: List[dict]
transcribed_result: List[Segment]
transcribed result through whisper.
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
@@ -44,8 +45,8 @@ class Diarizer:
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
@@ -68,14 +69,20 @@ class Diarizer:
{"segments": transcribed_result}
)
segments_result = []
for segment in diarized_result["segments"]:
speaker = "None"
if "speaker" in segment:
speaker = segment["speaker"]
segment["text"] = speaker + "|" + segment["text"].strip()
diarized_text = speaker + "|" + segment["text"].strip()
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=diarized_text
))
elapsed_time = time.time() - start_time
return diarized_result["segments"], elapsed_time
return segments_result, elapsed_time
def update_pipe(self,
use_auth_token: str,

View File

@@ -139,37 +139,27 @@ class DeepLAPI:
)
files_info = {}
for fileobj in fileobjs:
file_path = fileobj
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
if file_ext == ".srt":
parsed_dicts = parse_srt(file_path=file_path)
elif file_ext == ".vtt":
parsed_dicts = parse_vtt(file_path=file_path)
for file_path in fileobjs:
file_name, file_ext = os.path.splitext(os.path.basename(file_path))
writer = get_writer(file_ext, self.output_dir)
segments = writer.to_segments(file_path)
batch_size = self.max_text_batch_size
for batch_start in range(0, len(parsed_dicts), batch_size):
batch_end = min(batch_start + batch_size, len(parsed_dicts))
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
for batch_start in range(0, len(segments), batch_size):
progress(batch_start / len(segments), desc="Translating..")
sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
target_lang, is_pro)
for i, translated_text in enumerate(translated_texts):
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
progress(batch_end / len(parsed_dicts), desc="Translating..")
segments[batch_start + i].text = translated_text["text"]
if file_ext == ".srt":
subtitle = get_serialized_srt(parsed_dicts)
elif file_ext == ".vtt":
subtitle = get_serialized_vtt(parsed_dicts)
if add_timestamp:
timestamp = datetime.now().strftime("%m%d%H%M%S")
file_name += f"-{timestamp}"
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
write_file(subtitle, output_path)
subtitle, output_path = generate_file(
output_dir=self.output_dir,
output_file_name=file_name,
output_format=file_ext,
result=segments,
add_timestamp=add_timestamp
)
files_info[file_name] = {"subtitle": subtitle, "path": output_path}

View File

@@ -6,7 +6,7 @@ from typing import List
from datetime import datetime
import modules.translation.nllb_inference as nllb
from modules.whisper.whisper_parameter import *
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import *
from modules.utils.files_manager import load_yaml, save_yaml
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
files_info = {}
for fileobj in fileobjs:
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
if file_ext == ".srt":
parsed_dicts = parse_srt(file_path=fileobj)
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating..")
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_srt(parsed_dicts)
writer = get_writer(file_ext, self.output_dir)
segments = writer.to_segments(fileobj)
for i, segment in enumerate(segments):
progress(i / len(segments), desc="Translating..")
translated_text = self.translate(segment.text, max_length=max_length)
segment.text = translated_text
elif file_ext == ".vtt":
parsed_dicts = parse_vtt(file_path=fileobj)
total_progress = len(parsed_dicts)
for index, dic in enumerate(parsed_dicts):
progress(index / total_progress, desc="Translating..")
translated_text = self.translate(dic["sentence"], max_length=max_length)
dic["sentence"] = translated_text
subtitle = get_serialized_vtt(parsed_dicts)
subtitle, file_path = generate_file(
output_dir=self.output_dir,
output_file_name=file_name,
output_format=file_ext,
result=segments,
add_timestamp=add_timestamp
)
if add_timestamp:
timestamp = datetime.now().strftime("%m%d%H%M%S")
file_name += f"-{timestamp}"
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
write_file(subtitle, output_path)
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
files_info[file_name] = {"subtitle": subtitle, "path": file_path}
total_result = ''
for file_name, info in files_info.items():
@@ -133,7 +123,8 @@ class TranslationBase(ABC):
return [gr_str, output_file_paths]
except Exception as e:
print(f"Error: {str(e)}")
print(f"Error translating file: {e}")
raise
finally:
self.release_cuda_memory()

View File

@@ -1,3 +1,6 @@
from gradio_i18n import Translate, gettext as _
AUTOMATIC_DETECTION = _("Automatic Detection")
GRADIO_NONE_STR = ""
GRADIO_NONE_NUMBER_MAX = 9999
GRADIO_NONE_NUMBER_MIN = 0

View File

@@ -67,3 +67,9 @@ def is_video(file_path):
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
extension = os.path.splitext(file_path)[1].lower()
return extension in video_extensions
def read_file(file_path):
with open(file_path, "r", encoding="utf-8") as f:
subtitle_content = f.read()
return subtitle_content

View File

@@ -1,121 +1,425 @@
# Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
import json
import os
import re
import sys
import zlib
from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
from datetime import datetime
from modules.whisper.data_classes import Segment, Word
from .files_manager import read_file
def timeformat_srt(time):
hours = time // 3600
minutes = (time - hours * 3600) // 60
seconds = time - hours * 3600 - minutes * 60
milliseconds = (time - int(time)) * 1000
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
def format_timestamp(
seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
) -> str:
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
hours = milliseconds // 3_600_000
milliseconds -= hours * 3_600_000
minutes = milliseconds // 60_000
milliseconds -= minutes * 60_000
seconds = milliseconds // 1_000
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
return (
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
)
def timeformat_vtt(time):
hours = time // 3600
minutes = (time - hours * 3600) // 60
seconds = time - hours * 3600 - minutes * 60
milliseconds = (time - int(time)) * 1000
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
times = time_str.split(":")
if len(times) == 3:
hours, minutes, rest = times
hours = int(hours)
else:
hours = 0
minutes, rest = times
seconds, fractional = rest.split(decimal_marker)
minutes = int(minutes)
seconds = int(seconds)
fractional_seconds = float("0." + fractional)
return hours * 3600 + minutes * 60 + seconds + fractional_seconds
def write_file(subtitle, output_file):
with open(output_file, 'w', encoding='utf-8') as f:
f.write(subtitle)
def get_start(segments: List[dict]) -> Optional[float]:
return next(
(w["start"] for s in segments for w in s["words"]),
segments[0]["start"] if segments else None,
)
def get_srt(segments):
output = ""
for i, segment in enumerate(segments):
output += f"{i + 1}\n"
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
if segment['text'].startswith(' '):
segment['text'] = segment['text'][1:]
output += f"{segment['text']}\n\n"
return output
def get_end(segments: List[dict]) -> Optional[float]:
return next(
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
segments[-1]["end"] if segments else None,
)
def get_vtt(segments):
output = "WebVTT\n\n"
for i, segment in enumerate(segments):
output += f"{i + 1}\n"
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
if segment['text'].startswith(' '):
segment['text'] = segment['text'][1:]
output += f"{segment['text']}\n\n"
return output
class ResultWriter:
extension: str
def __init__(self, output_dir: str):
self.output_dir = output_dir
def __call__(
self, result: Union[dict, List[Segment]], output_file_name: str,
options: Optional[dict] = None, **kwargs
):
if isinstance(result, List) and result and isinstance(result[0], Segment):
result = {"segments": [seg.model_dump() for seg in result]}
output_path = os.path.join(
self.output_dir, output_file_name + "." + self.extension
)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
raise NotImplementedError
def get_txt(segments):
output = ""
for i, segment in enumerate(segments):
if segment['text'].startswith(' '):
segment['text'] = segment['text'][1:]
output += f"{segment['text']}\n"
return output
class WriteTXT(ResultWriter):
extension: str = "txt"
def write_result(
self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
def parse_srt(file_path):
"""Reads SRT file and returns as dict"""
with open(file_path, 'r', encoding='utf-8') as file:
srt_data = file.read()
class SubtitlesWriter(ResultWriter):
always_include_hours: bool
decimal_marker: str
data = []
blocks = srt_data.split('\n\n')
def iterate_result(
self,
result: dict,
options: Optional[dict] = None,
*,
max_line_width: Optional[int] = None,
max_line_count: Optional[int] = None,
highlight_words: bool = False,
align_lrc_words: bool = False,
max_words_per_line: Optional[int] = None,
):
options = options or {}
max_line_width = max_line_width or options.get("max_line_width")
max_line_count = max_line_count or options.get("max_line_count")
highlight_words = highlight_words or options.get("highlight_words", False)
align_lrc_words = align_lrc_words or options.get("align_lrc_words", False)
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
preserve_segments = max_line_count is None or max_line_width is None
max_line_width = max_line_width or 1000
max_words_per_line = max_words_per_line or 1000
for block in blocks:
if block.strip() != '':
lines = block.strip().split('\n')
index = lines[0]
timestamp = lines[1]
sentence = ' '.join(lines[2:])
def iterate_subtitles():
line_len = 0
line_count = 1
# the next subtitle to yield (a list of word timings with whitespace)
subtitle: List[dict] = []
last: float = get_start(result["segments"]) or 0.0
for segment in result["segments"]:
chunk_index = 0
words_count = max_words_per_line
while chunk_index < len(segment["words"]):
remaining_words = len(segment["words"]) - chunk_index
if max_words_per_line > len(segment["words"]) - chunk_index:
words_count = remaining_words
for i, original_timing in enumerate(
segment["words"][chunk_index : chunk_index + words_count]
):
timing = original_timing.copy()
long_pause = (
not preserve_segments and timing["start"] - last > 3.0
)
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
if (
line_len > 0
and has_room
and not long_pause
and not seg_break
):
# line continuation
line_len += len(timing["word"])
else:
# new line
timing["word"] = timing["word"].strip()
if (
len(subtitle) > 0
and max_line_count is not None
and (long_pause or line_count >= max_line_count)
or seg_break
):
# subtitle break
yield subtitle
subtitle = []
line_count = 1
elif line_len > 0:
# line break
line_count += 1
timing["word"] = "\n" + timing["word"]
line_len = len(timing["word"].strip())
subtitle.append(timing)
last = timing["start"]
chunk_index += max_words_per_line
if len(subtitle) > 0:
yield subtitle
data.append({
"index": index,
"timestamp": timestamp,
"sentence": sentence
})
return data
if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]:
for subtitle in iterate_subtitles():
subtitle_start = self.format_timestamp(subtitle[0]["start"])
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
subtitle_text = "".join([word["word"] for word in subtitle])
if highlight_words:
last = subtitle_start
all_words = [timing["word"] for timing in subtitle]
for i, this_word in enumerate(subtitle):
start = self.format_timestamp(this_word["start"])
end = self.format_timestamp(this_word["end"])
if last != start:
yield last, start, subtitle_text
yield start, end, "".join(
[
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
if j == i
else word
for j, word in enumerate(all_words)
]
)
last = end
if align_lrc_words:
lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle]
l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end'])
lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]"
lrc_aligned_words = ' '.join(lrc_aligned_words)
yield None, None, lrc_aligned_words
else:
yield subtitle_start, subtitle_end, subtitle_text
else:
for segment in result["segments"]:
segment_start = self.format_timestamp(segment["start"])
segment_end = self.format_timestamp(segment["end"])
segment_text = segment["text"].strip().replace("-->", "->")
yield segment_start, segment_end, segment_text
def format_timestamp(self, seconds: float):
return format_timestamp(
seconds=seconds,
always_include_hours=self.always_include_hours,
decimal_marker=self.decimal_marker,
)
def parse_vtt(file_path):
"""Reads WebVTT file and returns as dict"""
with open(file_path, 'r', encoding='utf-8') as file:
webvtt_data = file.read()
class WriteVTT(SubtitlesWriter):
extension: str = "vtt"
always_include_hours: bool = False
decimal_marker: str = "."
data = []
blocks = webvtt_data.split('\n\n')
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
for block in blocks:
if block.strip() != '' and not block.strip().startswith("WebVTT"):
lines = block.strip().split('\n')
index = lines[0]
timestamp = lines[1]
sentence = ' '.join(lines[2:])
def to_segments(self, file_path: str) -> List[Segment]:
segments = []
data.append({
"index": index,
"timestamp": timestamp,
"sentence": sentence
})
blocks = read_file(file_path).split('\n\n')
return data
for block in blocks:
if block.strip() != '' and not block.strip().startswith("WEBVTT"):
lines = block.strip().split('\n')
time_line = lines[0].split(" --> ")
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
sentence = ' '.join(lines[1:])
segments.append(Segment(
start=start,
end=end,
text=sentence
))
return segments
def get_serialized_srt(dicts):
output = ""
for dic in dicts:
output += f'{dic["index"]}\n'
output += f'{dic["timestamp"]}\n'
output += f'{dic["sentence"]}\n\n'
return output
class WriteSRT(SubtitlesWriter):
extension: str = "srt"
always_include_hours: bool = True
decimal_marker: str = ","
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
def to_segments(self, file_path: str) -> List[Segment]:
segments = []
blocks = read_file(file_path).split('\n\n')
for block in blocks:
if block.strip() != '':
lines = block.strip().split('\n')
index = lines[0]
time_line = lines[1].split(" --> ")
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
sentence = ' '.join(lines[2:])
segments.append(Segment(
start=start,
end=end,
text=sentence
))
return segments
def get_serialized_vtt(dicts):
output = "WebVTT\n\n"
for dic in dicts:
output += f'{dic["index"]}\n'
output += f'{dic["timestamp"]}\n'
output += f'{dic["sentence"]}\n\n'
return output
class WriteLRC(SubtitlesWriter):
extension: str = "lrc"
always_include_hours: bool = False
decimal_marker: str = "."
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]:
print(f"{text}\n", file=file, flush=True)
else:
print(f"[{start}]{text}[{end}]\n", file=file, flush=True)
def to_segments(self, file_path: str) -> List[Segment]:
segments = []
blocks = read_file(file_path).split('\n')
for block in blocks:
if block.strip() != '':
lines = block.strip()
pattern = r'(\[.*?\])'
parts = re.split(pattern, lines)
parts = [part.strip() for part in parts if part]
for i, part in enumerate(parts):
sentence_i = i%2
if sentence_i == 1:
start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1]
start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "")
start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker)
segments.append(Segment(
start=start,
end=end,
text=text,
))
return segments
class WriteTSV(ResultWriter):
"""
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
Using integer milliseconds as start and end times means there's no chance of interference from
an environment setting a language encoding that causes the decimal in a floating point number
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
"""
extension: str = "tsv"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
print(round(1000 * segment["end"]), file=file, end="\t")
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
class WriteJSON(ResultWriter):
extension: str = "json"
def write_result(
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
json.dump(result, file)
def get_writer(
output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
output_format = output_format.strip().lower().replace(".", "")
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
"srt": WriteSRT,
"tsv": WriteTSV,
"json": WriteJSON,
"lrc": WriteLRC
}
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
def write_all(
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
):
for writer in all_writers:
writer(result, file, options, **kwargs)
return write_all
return writers[output_format](output_dir)
def generate_file(
output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str,
add_timestamp: bool = True, **kwargs
) -> Tuple[str, str]:
output_format = output_format.strip().lower().replace(".", "")
output_format = "vtt" if output_format == "webvtt" else output_format
if add_timestamp:
timestamp = datetime.now().strftime("%m%d%H%M%S")
output_file_name += f"-{timestamp}"
file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
file_writer = get_writer(output_format=output_format, output_dir=output_dir)
if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False):
kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True
file_writer(result=result, output_file_name=output_file_name, **kwargs)
content = read_file(file_path)
return content, file_path
def safe_filename(name):

View File

@@ -5,7 +5,8 @@ import numpy as np
from typing import BinaryIO, Union, List, Optional, Tuple
import warnings
import faster_whisper
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
from modules.whisper.data_classes import *
from faster_whisper.transcribe import SpeechTimestampsMap
import gradio as gr
@@ -247,18 +248,18 @@ class SileroVAD:
def restore_speech_timestamps(
self,
segments: List[dict],
segments: List[Segment],
speech_chunks: List[dict],
sampling_rate: Optional[int] = None,
) -> List[dict]:
) -> List[Segment]:
if sampling_rate is None:
sampling_rate = self.sampling_rate
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
for segment in segments:
segment["start"] = ts_map.get_original_time(segment["start"])
segment["end"] = ts_map.get_original_time(segment["end"])
segment.start = ts_map.get_original_time(segment.start)
segment.end = ts_map.get_original_time(segment.end)
return segments

View File

@@ -1,5 +1,4 @@
import os
import torch
import whisper
import ctranslate2
import gradio as gr
@@ -9,21 +8,20 @@ from typing import BinaryIO, Union, Tuple, List
import numpy as np
from datetime import datetime
from faster_whisper.vad import VadOptions
from dataclasses import astuple
from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR)
from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
from modules.utils.constants import *
from modules.utils.subtitle_manager import *
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
from modules.whisper.whisper_parameter import *
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
from modules.whisper.data_classes import *
from modules.diarize.diarizer import Diarizer
from modules.vad.silero_vad import SileroVAD
class WhisperBase(ABC):
class BaseTranscriptionPipeline(ABC):
def __init__(self,
model_dir: str = WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -73,13 +71,15 @@ class WhisperBase(ABC):
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
file_format: str = "SRT",
add_timestamp: bool = True,
*whisper_params,
) -> Tuple[List[dict], float]:
*pipeline_params,
) -> Tuple[List[Segment], float]:
"""
Run transcription with conditional pre-processing and post-processing.
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
The diarization will be performed in post-processing, if enabled.
Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
Parameters
----------
@@ -87,40 +87,33 @@ class WhisperBase(ABC):
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
file_format: str
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
*pipeline_params: tuple
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class.
This must be provided as a List with * wildcard because of the integration with gradio.
See more info at : https://github.com/gradio-app/gradio/issues/2471
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for running
"""
params = WhisperParameters.as_value(*whisper_params)
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
params = self.validate_gradio_values(params)
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
self.cache_parameters(
whisper_params=params,
add_timestamp=add_timestamp
)
if params.lang is None:
pass
elif params.lang == AUTOMATIC_DETECTION:
params.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.lang = language_code_dict[params.lang]
if params.is_bgm_separate:
if bgm_params.is_separate_bgm:
music, audio, _ = self.music_separator.separate(
audio=audio,
model_name=params.uvr_model_size,
device=params.uvr_device,
segment_size=params.uvr_segment_size,
save_file=params.uvr_save_file,
model_name=bgm_params.model_size,
device=bgm_params.device,
segment_size=bgm_params.segment_size,
save_file=bgm_params.save_file,
progress=progress
)
@@ -132,47 +125,55 @@ class WhisperBase(ABC):
origin_sample_rate = self.music_separator.audio_info.sample_rate
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
if params.uvr_enable_offload:
if bgm_params.enable_offload:
self.music_separator.offload()
if params.vad_filter:
# Explicit value set for float('inf') from gr.Number()
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
params.max_speech_duration_s = float('inf')
if vad_params.vad_filter:
vad_options = VadOptions(
threshold=params.threshold,
min_speech_duration_ms=params.min_speech_duration_ms,
max_speech_duration_s=params.max_speech_duration_s,
min_silence_duration_ms=params.min_silence_duration_ms,
speech_pad_ms=params.speech_pad_ms
threshold=vad_params.threshold,
min_speech_duration_ms=vad_params.min_speech_duration_ms,
max_speech_duration_s=vad_params.max_speech_duration_s,
min_silence_duration_ms=vad_params.min_silence_duration_ms,
speech_pad_ms=vad_params.speech_pad_ms
)
audio, speech_chunks = self.vad.run(
vad_processed, speech_chunks = self.vad.run(
audio=audio,
vad_parameters=vad_options,
progress=progress
)
if vad_processed.size > 0:
audio = vad_processed
else:
vad_params.vad_filter = False
result, elapsed_time = self.transcribe(
audio,
progress,
*astuple(params)
*whisper_params.to_list()
)
if params.vad_filter:
if vad_params.vad_filter:
result = self.vad.restore_speech_timestamps(
segments=result,
speech_chunks=speech_chunks,
)
if params.is_diarize:
if diarization_params.is_diarize:
result, elapsed_time_diarization = self.diarizer.run(
audio=audio,
use_auth_token=params.hf_token,
use_auth_token=diarization_params.hf_token,
transcribed_result=result,
device=diarization_params.device
)
elapsed_time += elapsed_time_diarization
self.cache_parameters(
params=params,
file_format=file_format,
add_timestamp=add_timestamp
)
return result, elapsed_time
def transcribe_file(self,
@@ -181,8 +182,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
*pipeline_params,
) -> Tuple[str, List]:
"""
Write subtitle file from Files
@@ -199,8 +200,8 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
*pipeline_params: tuple
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
Returns
----------
@@ -210,6 +211,11 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
writer_options = {
"highlight_words": True if params.whisper.word_timestamps else False
}
if input_folder_path:
files = get_media_files(input_folder_path)
if isinstance(files, str):
@@ -222,19 +228,21 @@ class WhisperBase(ABC):
transcribed_segments, time_for_task = self.run(
file,
progress,
file_format,
add_timestamp,
*whisper_params,
*pipeline_params,
)
file_name, file_ext = os.path.splitext(os.path.basename(file))
subtitle, file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
subtitle, file_path = generate_file(
output_dir=self.output_dir,
output_file_name=file_name,
output_format=file_format,
result=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
**writer_options
)
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
total_result = ''
total_time = 0
@@ -247,10 +255,11 @@ class WhisperBase(ABC):
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
result_file_path = [info['path'] for info in files_info.values()]
return [result_str, result_file_path]
return result_str, result_file_path
except Exception as e:
print(f"Error transcribing file: {e}")
raise
finally:
self.release_cuda_memory()
@@ -259,8 +268,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
*pipeline_params,
) -> Tuple[str, str]:
"""
Write subtitle file from microphone
@@ -274,7 +283,7 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
*pipeline_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
@@ -285,27 +294,36 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
writer_options = {
"highlight_words": True if params.whisper.word_timestamps else False
}
progress(0, desc="Loading Audio..")
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
file_format,
add_timestamp,
*whisper_params,
*pipeline_params,
)
progress(1, desc="Completed!")
subtitle, result_file_path = self.generate_and_write_file(
file_name="Mic",
transcribed_segments=transcribed_segments,
file_name = "Mic"
subtitle, file_path = generate_file(
output_dir=self.output_dir,
output_file_name=file_name,
output_format=file_format,
result=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
**writer_options
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
return [result_str, result_file_path]
return result_str, file_path
except Exception as e:
print(f"Error transcribing file: {e}")
print(f"Error transcribing mic: {e}")
raise
finally:
self.release_cuda_memory()
@@ -314,8 +332,8 @@ class WhisperBase(ABC):
file_format: str = "SRT",
add_timestamp: bool = True,
progress=gr.Progress(),
*whisper_params,
) -> list:
*pipeline_params,
) -> Tuple[str, str]:
"""
Write subtitle file from Youtube
@@ -329,7 +347,7 @@ class WhisperBase(ABC):
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
progress: gr.Progress
Indicator to show progress directly in gradio.
*whisper_params: tuple
*pipeline_params: tuple
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
Returns
@@ -340,6 +358,11 @@ class WhisperBase(ABC):
Output file path to return to gr.Files()
"""
try:
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
writer_options = {
"highlight_words": True if params.whisper.word_timestamps else False
}
progress(0, desc="Loading Audio from Youtube..")
yt = get_ytdata(youtube_link)
audio = get_ytaudio(yt)
@@ -347,29 +370,33 @@ class WhisperBase(ABC):
transcribed_segments, time_for_task = self.run(
audio,
progress,
file_format,
add_timestamp,
*whisper_params,
*pipeline_params,
)
progress(1, desc="Completed!")
file_name = safe_filename(yt.title)
subtitle, result_file_path = self.generate_and_write_file(
file_name=file_name,
transcribed_segments=transcribed_segments,
subtitle, file_path = generate_file(
output_dir=self.output_dir,
output_file_name=file_name,
output_format=file_format,
result=transcribed_segments,
add_timestamp=add_timestamp,
file_format=file_format,
output_dir=self.output_dir
**writer_options
)
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
if os.path.exists(audio):
os.remove(audio)
return [result_str, result_file_path]
return result_str, file_path
except Exception as e:
print(f"Error transcribing file: {e}")
print(f"Error transcribing youtube: {e}")
raise
finally:
self.release_cuda_memory()
@@ -387,58 +414,6 @@ class WhisperBase(ABC):
else:
return list(ctranslate2.get_supported_compute_types("cpu"))
@staticmethod
def generate_and_write_file(file_name: str,
transcribed_segments: list,
add_timestamp: bool,
file_format: str,
output_dir: str
) -> str:
"""
Writes subtitle file
Parameters
----------
file_name: str
Output file name
transcribed_segments: list
Text segments transcribed from audio
add_timestamp: bool
Determines whether to add a timestamp to the end of the filename.
file_format: str
File format to write. Supported formats: [SRT, WebVTT, txt]
output_dir: str
Directory path of the output
Returns
----------
content: str
Result of the transcription
output_path: str
output file path
"""
if add_timestamp:
timestamp = datetime.now().strftime("%m%d%H%M%S")
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
else:
output_path = os.path.join(output_dir, f"{file_name}")
file_format = file_format.strip().lower()
if file_format == "srt":
content = get_srt(transcribed_segments)
output_path += '.srt'
elif file_format == "webvtt":
content = get_vtt(transcribed_segments)
output_path += '.vtt'
elif file_format == "txt":
content = get_txt(transcribed_segments)
output_path += '.txt'
write_file(content, output_path)
return content, output_path
@staticmethod
def format_time(elapsed_time: float) -> str:
"""
@@ -471,7 +446,7 @@ class WhisperBase(ABC):
if torch.cuda.is_available():
return "cuda"
elif torch.backends.mps.is_available():
if not WhisperBase.is_sparse_api_supported():
if not BaseTranscriptionPipeline.is_sparse_api_supported():
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
return "cpu"
return "mps"
@@ -513,17 +488,64 @@ class WhisperBase(ABC):
os.remove(file_path)
@staticmethod
def cache_parameters(
whisper_params: WhisperValues,
add_timestamp: bool
):
"""cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
cached_whisper_param = whisper_params.to_yaml()
cached_yaml = {**cached_params, **cached_whisper_param}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
def validate_gradio_values(params: TranscriptionPipelineParams):
"""
Validate gradio specific values that can't be displayed as None in the UI.
Related issue : https://github.com/gradio-app/gradio/issues/8723
"""
if params.whisper.lang is None:
pass
elif params.whisper.lang == AUTOMATIC_DETECTION:
params.whisper.lang = None
else:
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.whisper.lang = language_code_dict[params.whisper.lang]
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
if params.whisper.initial_prompt == GRADIO_NONE_STR:
params.whisper.initial_prompt = None
if params.whisper.prefix == GRADIO_NONE_STR:
params.whisper.prefix = None
if params.whisper.hotwords == GRADIO_NONE_STR:
params.whisper.hotwords = None
if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
params.whisper.max_new_tokens = None
if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
params.whisper.hallucination_silence_threshold = None
if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
params.whisper.language_detection_threshold = None
if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
params.vad.max_speech_duration_s = float('inf')
return params
@staticmethod
def cache_parameters(
params: TranscriptionPipelineParams,
file_format: str = "SRT",
add_timestamp: bool = True
):
"""Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
param_to_cache = params.to_dict()
cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
cached_yaml["whisper"]["file_format"] = file_format
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list):
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
if cached_yaml["whisper"].get("lang", None) is None:
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
else:
language_dict = whisper.tokenizer.LANGUAGES
cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
if cached_yaml is not None and cached_yaml:
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
@staticmethod
def resample_audio(audio: Union[str, np.ndarray],

View File

@@ -0,0 +1,608 @@
import faster_whisper.transcribe
import gradio as gr
import torch
from typing import Optional, Dict, List, Union, NamedTuple
from pydantic import BaseModel, Field, field_validator, ConfigDict
from gradio_i18n import Translate, gettext as _
from enum import Enum
from copy import deepcopy
import yaml
from modules.utils.constants import *
class WhisperImpl(Enum):
WHISPER = "whisper"
FASTER_WHISPER = "faster-whisper"
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
class Segment(BaseModel):
id: Optional[int] = Field(default=None, description="Incremental id for the segment")
seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio")
text: Optional[str] = Field(default=None, description="Transcription text of the segment")
start: Optional[float] = Field(default=None, description="Start time of the segment")
end: Optional[float] = Field(default=None, description="End time of the segment")
tokens: Optional[List[int]] = Field(default=None, description="List of token IDs")
temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process")
avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens")
compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment")
no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech")
words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment")
@classmethod
def from_faster_whisper(cls,
seg: faster_whisper.transcribe.Segment):
if seg.words is not None:
words = [
Word(
start=w.start,
end=w.end,
word=w.word,
probability=w.probability
) for w in seg.words
]
else:
words = None
return cls(
id=seg.id,
seek=seg.seek,
text=seg.text,
start=seg.start,
end=seg.end,
tokens=seg.tokens,
temperature=seg.temperature,
avg_logprob=seg.avg_logprob,
compression_ratio=seg.compression_ratio,
no_speech_prob=seg.no_speech_prob,
words=words
)
class Word(BaseModel):
start: Optional[float] = Field(default=None, description="Start time of the word")
end: Optional[float] = Field(default=None, description="Start time of the word")
word: Optional[str] = Field(default=None, description="Word text")
probability: Optional[float] = Field(default=None, description="Probability of the word")
class BaseParams(BaseModel):
model_config = ConfigDict(protected_namespaces=())
def to_dict(self) -> Dict:
return self.model_dump()
def to_list(self) -> List:
return list(self.model_dump().values())
@classmethod
def from_list(cls, data_list: List) -> 'BaseParams':
field_names = list(cls.model_fields.keys())
return cls(**dict(zip(field_names, data_list)))
class VadParams(BaseParams):
"""Voice Activity Detection parameters"""
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
threshold: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
)
min_speech_duration_ms: int = Field(
default=250,
ge=0,
description="Final speech chunks shorter than this are discarded"
)
max_speech_duration_s: float = Field(
default=float("inf"),
gt=0,
description="Maximum duration of speech chunks in seconds"
)
min_silence_duration_ms: int = Field(
default=2000,
ge=0,
description="Minimum silence duration between speech chunks"
)
speech_pad_ms: int = Field(
default=400,
ge=0,
description="Padding added to each side of speech chunks"
)
@classmethod
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Silero VAD Filter"),
value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
interactive=True,
info=_("Enable this to transcribe only detected voice")
),
gr.Slider(
minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
value=defaults.get("threshold", cls.__fields__["threshold"].default),
info="Lower it to be more sensitive to small sounds."
),
gr.Number(
label="Minimum Speech Duration (ms)", precision=0,
value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
info="Final speech chunks shorter than this time are thrown out"
),
gr.Number(
label="Maximum Speech Duration (s)",
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
info="Maximum duration of speech chunks in \"seconds\"."
),
gr.Number(
label="Minimum Silence Duration (ms)", precision=0,
value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
info="In the end of each speech chunk wait for this time before separating it"
),
gr.Number(
label="Speech Padding (ms)", precision=0,
value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
info="Final speech chunks are padded by this time each side"
)
]
class DiarizationParams(BaseParams):
"""Speaker diarization parameters"""
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
device: str = Field(default="cuda", description="Device to run Diarization model.")
hf_token: str = Field(
default="",
description="Hugging Face token for downloading diarization models"
)
@classmethod
def to_gradio_inputs(cls,
defaults: Optional[Dict] = None,
available_devices: Optional[List] = None,
device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Diarization"),
value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Textbox(
label=_("HuggingFace Token"),
value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
info=_("This is only needed the first time you download the model")
),
]
class BGMSeparationParams(BaseParams):
"""Background music separation parameters"""
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
model_size: str = Field(
default="UVR-MDX-NET-Inst_HQ_4",
description="UVR model size"
)
device: str = Field(default="cuda", description="Device to run UVR model.")
segment_size: int = Field(
default=256,
gt=0,
description="Segment size for UVR model"
)
save_file: bool = Field(
default=False,
description="Whether to save separated audio files"
)
enable_offload: bool = Field(
default=True,
description="Offload UVR model after transcription"
)
@classmethod
def to_gradio_input(cls,
defaults: Optional[Dict] = None,
available_devices: Optional[List] = None,
device: Optional[str] = None,
available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
return [
gr.Checkbox(
label=_("Enable Background Music Remover Filter"),
value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
interactive=True,
info=_("Enabling this will remove background music")
),
gr.Dropdown(
label=_("Model"),
choices=["UVR-MDX-NET-Inst_HQ_4",
"UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
value=defaults.get("model_size", cls.__fields__["model_size"].default),
),
gr.Dropdown(
label=_("Device"),
choices=["cpu", "cuda"] if available_devices is None else available_devices,
value=defaults.get("device", device),
),
gr.Number(
label="Segment Size",
value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
precision=0,
info="Segment size for UVR model"
),
gr.Checkbox(
label=_("Save separated files to output"),
value=defaults.get("save_file", cls.__fields__["save_file"].default),
),
gr.Checkbox(
label=_("Offload sub model after removing background music"),
value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
)
]
class WhisperParams(BaseParams):
"""Whisper parameters"""
model_size: str = Field(default="large-v2", description="Whisper model size")
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
log_prob_threshold: float = Field(
default=-1.0,
description="Threshold for average log probability of sampled tokens"
)
no_speech_threshold: float = Field(
default=0.6,
ge=0.0,
le=1.0,
description="Threshold for detecting silence"
)
compute_type: str = Field(default="float16", description="Computation type for transcription")
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
condition_on_previous_text: bool = Field(
default=True,
description="Use previous output as prompt for next window"
)
prompt_reset_on_temperature: float = Field(
default=0.5,
ge=0.0,
le=1.0,
description="Temperature threshold for resetting prompt"
)
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
temperature: float = Field(
default=0.0,
ge=0.0,
description="Temperature for sampling"
)
compression_ratio_threshold: float = Field(
default=2.4,
gt=0,
description="Threshold for gzip compression ratio"
)
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
suppress_blank: bool = Field(
default=True,
description="Suppress blank outputs at start of sampling"
)
suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
max_initial_timestamp: float = Field(
default=1.0,
ge=0.0,
description="Maximum initial timestamp"
)
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
prepend_punctuations: Optional[str] = Field(
default="\"'“¿([{-",
description="Punctuations to merge with next word"
)
append_punctuations: Optional[str] = Field(
default="\"'.。,!?::”)]}、",
description="Punctuations to merge with previous word"
)
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
hallucination_silence_threshold: Optional[float] = Field(
default=None,
description="Threshold for skipping silent periods in hallucination detection"
)
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
language_detection_threshold: Optional[float] = Field(
default=None,
description="Threshold for language detection probability"
)
language_detection_segments: int = Field(
default=1,
gt=0,
description="Number of segments for language detection"
)
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
@field_validator('lang')
def validate_lang(cls, v):
from modules.utils.constants import AUTOMATIC_DETECTION
return None if v == AUTOMATIC_DETECTION.unwrap() else v
@field_validator('suppress_tokens')
def validate_supress_tokens(cls, v):
import ast
try:
if isinstance(v, str):
suppress_tokens = ast.literal_eval(v)
if not isinstance(suppress_tokens, list):
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
return suppress_tokens
if isinstance(v, list):
return v
except Exception as e:
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
@classmethod
def to_gradio_inputs(cls,
defaults: Optional[Dict] = None,
only_advanced: Optional[bool] = True,
whisper_type: Optional[str] = None,
available_models: Optional[List] = None,
available_langs: Optional[List] = None,
available_compute_types: Optional[List] = None,
compute_type: Optional[str] = None):
whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
inputs = []
if not only_advanced:
inputs += [
gr.Dropdown(
label=_("Model"),
choices=available_models,
value=defaults.get("model_size", cls.__fields__["model_size"].default),
),
gr.Dropdown(
label=_("Language"),
choices=available_langs,
value=defaults.get("lang", AUTOMATIC_DETECTION),
),
gr.Checkbox(
label=_("Translate to English?"),
value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
),
]
inputs += [
gr.Number(
label="Beam Size",
value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
precision=0,
info="Beam size for decoding"
),
gr.Number(
label="Log Probability Threshold",
value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
info="Threshold for average log probability of sampled tokens"
),
gr.Number(
label="No Speech Threshold",
value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
info="Threshold for detecting silence"
),
gr.Dropdown(
label="Compute Type",
choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
value=defaults.get("compute_type", compute_type),
info="Computation type for transcription"
),
gr.Number(
label="Best Of",
value=defaults.get("best_of", cls.__fields__["best_of"].default),
precision=0,
info="Number of candidates when sampling"
),
gr.Number(
label="Patience",
value=defaults.get("patience", cls.__fields__["patience"].default),
info="Beam search patience factor"
),
gr.Checkbox(
label="Condition On Previous Text",
value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
info="Use previous output as prompt for next window"
),
gr.Slider(
label="Prompt Reset On Temperature",
value=defaults.get("prompt_reset_on_temperature",
cls.__fields__["prompt_reset_on_temperature"].default),
minimum=0,
maximum=1,
step=0.01,
info="Temperature threshold for resetting prompt"
),
gr.Textbox(
label="Initial Prompt",
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
info="Initial prompt for first window"
),
gr.Slider(
label="Temperature",
value=defaults.get("temperature", cls.__fields__["temperature"].default),
minimum=0.0,
step=0.01,
maximum=1.0,
info="Temperature for sampling"
),
gr.Number(
label="Compression Ratio Threshold",
value=defaults.get("compression_ratio_threshold",
cls.__fields__["compression_ratio_threshold"].default),
info="Threshold for gzip compression ratio"
)
]
faster_whisper_inputs = [
gr.Number(
label="Length Penalty",
value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
info="Exponential length penalty",
),
gr.Number(
label="Repetition Penalty",
value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
info="Penalty for repeated tokens"
),
gr.Number(
label="No Repeat N-gram Size",
value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
precision=0,
info="Size of n-grams to prevent repetition"
),
gr.Textbox(
label="Prefix",
value=defaults.get("prefix", GRADIO_NONE_STR),
info="Prefix text for first window"
),
gr.Checkbox(
label="Suppress Blank",
value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
info="Suppress blank outputs at start of sampling"
),
gr.Textbox(
label="Suppress Tokens",
value=defaults.get("suppress_tokens", "[-1]"),
info="Token IDs to suppress"
),
gr.Number(
label="Max Initial Timestamp",
value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
info="Maximum initial timestamp"
),
gr.Checkbox(
label="Word Timestamps",
value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
info="Extract word-level timestamps"
),
gr.Textbox(
label="Prepend Punctuations",
value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
info="Punctuations to merge with next word"
),
gr.Textbox(
label="Append Punctuations",
value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
info="Punctuations to merge with previous word"
),
gr.Number(
label="Max New Tokens",
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
precision=0,
info="Maximum number of new tokens per chunk"
),
gr.Number(
label="Chunk Length (s)",
value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
precision=0,
info="Length of audio segments in seconds"
),
gr.Number(
label="Hallucination Silence Threshold (sec)",
value=defaults.get("hallucination_silence_threshold",
GRADIO_NONE_NUMBER_MIN),
info="Threshold for skipping silent periods in hallucination detection"
),
gr.Textbox(
label="Hotwords",
value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
info="Hotwords/hint phrases for the model"
),
gr.Number(
label="Language Detection Threshold",
value=defaults.get("language_detection_threshold",
GRADIO_NONE_NUMBER_MIN),
info="Threshold for language detection probability"
),
gr.Number(
label="Language Detection Segments",
value=defaults.get("language_detection_segments",
cls.__fields__["language_detection_segments"].default),
precision=0,
info="Number of segments for language detection"
)
]
insanely_fast_whisper_inputs = [
gr.Number(
label="Batch Size",
value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
precision=0,
info="Batch size for processing"
)
]
if whisper_type != WhisperImpl.FASTER_WHISPER.value:
for input_component in faster_whisper_inputs:
input_component.visible = False
if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
for input_component in insanely_fast_whisper_inputs:
input_component.visible = False
inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
return inputs
class TranscriptionPipelineParams(BaseModel):
"""Transcription pipeline parameters"""
whisper: WhisperParams = Field(default_factory=WhisperParams)
vad: VadParams = Field(default_factory=VadParams)
diarization: DiarizationParams = Field(default_factory=DiarizationParams)
bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
def to_dict(self) -> Dict:
data = {
"whisper": self.whisper.to_dict(),
"vad": self.vad.to_dict(),
"diarization": self.diarization.to_dict(),
"bgm_separation": self.bgm_separation.to_dict()
}
return data
def to_list(self) -> List:
"""
Convert data class to the list because I have to pass the parameters as a list in the gradio.
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
See more about Gradio pre-processing: https://www.gradio.app/docs/components
"""
whisper_list = self.whisper.to_list()
vad_list = self.vad.to_list()
diarization_list = self.diarization.to_list()
bgm_sep_list = self.bgm_separation.to_list()
return whisper_list + vad_list + diarization_list + bgm_sep_list
@staticmethod
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
"""Convert list to the data class again to use it in a function."""
data_list = deepcopy(pipeline_list)
whisper_list = data_list[0:len(WhisperParams.__annotations__)]
data_list = data_list[len(WhisperParams.__annotations__):]
vad_list = data_list[0:len(VadParams.__annotations__)]
data_list = data_list[len(VadParams.__annotations__):]
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
data_list = data_list[len(DiarizationParams.__annotations__):]
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
return TranscriptionPipelineParams(
whisper=WhisperParams.from_list(whisper_list),
vad=VadParams.from_list(vad_list),
diarization=DiarizationParams.from_list(diarization_list),
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
)

View File

@@ -12,11 +12,11 @@ import gradio as gr
from argparse import Namespace
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
from modules.whisper.whisper_parameter import *
from modules.whisper.whisper_base import WhisperBase
from modules.whisper.data_classes import *
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
class FasterWhisperInference(WhisperBase):
class FasterWhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = FASTER_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParameters.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
if not params.initial_prompt:
params.initial_prompt = None
if not params.prefix:
params.prefix = None
if not params.hotwords:
params.hotwords = None
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
segments, info = self.model.transcribe(
audio=audio,
language=params.lang,
@@ -112,11 +102,7 @@ class FasterWhisperInference(WhisperBase):
segments_result = []
for segment in segments:
progress(segment.start / info.duration, desc="Transcribing..")
segments_result.append({
"start": segment.start,
"end": segment.end,
"text": segment.text
})
segments_result.append(Segment.from_faster_whisper(segment))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time

View File

@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
from argparse import Namespace
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
from modules.whisper.whisper_parameter import *
from modules.whisper.whisper_base import WhisperBase
from modules.whisper.data_classes import *
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
class InsanelyFastWhisperInference(WhisperBase):
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -32,15 +32,13 @@ class InsanelyFastWhisperInference(WhisperBase):
self.model_dir = model_dir
os.makedirs(self.model_dir, exist_ok=True)
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
self.available_models = openai_models + distil_models
self.available_models = self.get_model_paths()
def transcribe(self,
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -55,13 +53,13 @@ class InsanelyFastWhisperInference(WhisperBase):
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParameters.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
@@ -95,9 +93,17 @@ class InsanelyFastWhisperInference(WhisperBase):
generate_kwargs=kwargs
)
segments_result = self.format_result(
transcribed_result=segments,
)
segments_result = []
for item in segments["chunks"]:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
segments_result.append(Segment(
text=item["text"],
start=start,
end=end
))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
@@ -138,31 +144,26 @@ class InsanelyFastWhisperInference(WhisperBase):
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
)
@staticmethod
def format_result(
transcribed_result: dict
) -> List[dict]:
def get_model_paths(self):
"""
Format the transcription result of insanely_fast_whisper as the same with other implementation.
Parameters
----------
transcribed_result: dict
Transcription result of the insanely_fast_whisper
Get available models from models path including fine-tuned model.
Returns
----------
result: List[dict]
Formatted result as the same with other implementation
Name set of models
"""
result = transcribed_result["chunks"]
for item in result:
start, end = item["timestamp"][0], item["timestamp"][1]
if end is None:
end = start
item["start"] = start
item["end"] = end
return result
openai_models = whisper.available_models()
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
default_models = openai_models + distil_models
existing_models = os.listdir(self.model_dir)
wrong_dirs = [".locks"]
available_models = default_models + existing_models
available_models = [model for model in available_models if model not in wrong_dirs]
available_models = sorted(set(available_models), key=available_models.index)
return available_models
@staticmethod
def download_model(

View File

@@ -8,11 +8,11 @@ import os
from argparse import Namespace
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
from modules.whisper.whisper_base import WhisperBase
from modules.whisper.whisper_parameter import *
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
from modules.whisper.data_classes import *
class WhisperInference(WhisperBase):
class WhisperInference(BaseTranscriptionPipeline):
def __init__(self,
model_dir: str = WHISPER_MODELS_DIR,
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
audio: Union[str, np.ndarray, torch.Tensor],
progress: gr.Progress = gr.Progress(),
*whisper_params,
) -> Tuple[List[dict], float]:
) -> Tuple[List[Segment], float]:
"""
transcribe method for faster-whisper.
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
Returns
----------
segments_result: List[dict]
list of dicts that includes start, end timestamps and transcribed text
segments_result: List[Segment]
list of Segment that includes start, end timestamps and transcribed text
elapsed_time: float
elapsed time for transcription
"""
start_time = time.time()
params = WhisperParameters.as_value(*whisper_params)
params = WhisperParams.from_list(list(whisper_params))
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
self.update_model(params.model_size, params.compute_type, progress)
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
def progress_callback(progress_value):
progress(progress_value, desc="Transcribing..")
segments_result = self.model.transcribe(audio=audio,
language=params.lang,
verbose=False,
beam_size=params.beam_size,
logprob_threshold=params.log_prob_threshold,
no_speech_threshold=params.no_speech_threshold,
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
fp16=True if params.compute_type == "float16" else False,
best_of=params.best_of,
patience=params.patience,
temperature=params.temperature,
compression_ratio_threshold=params.compression_ratio_threshold,
progress_callback=progress_callback,)["segments"]
elapsed_time = time.time() - start_time
result = self.model.transcribe(audio=audio,
language=params.lang,
verbose=False,
beam_size=params.beam_size,
logprob_threshold=params.log_prob_threshold,
no_speech_threshold=params.no_speech_threshold,
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
fp16=True if params.compute_type == "float16" else False,
best_of=params.best_of,
patience=params.patience,
temperature=params.temperature,
compression_ratio_threshold=params.compression_ratio_threshold,
progress_callback=progress_callback,)["segments"]
segments_result = []
for segment in result:
segments_result.append(Segment(
start=segment["start"],
end=segment["end"],
text=segment["text"]
))
elapsed_time = time.time() - start_time
return segments_result, elapsed_time
def update_model(self,

View File

@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
from modules.whisper.faster_whisper_inference import FasterWhisperInference
from modules.whisper.whisper_Inference import WhisperInference
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
from modules.whisper.whisper_base import WhisperBase
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
from modules.whisper.data_classes import *
class WhisperFactory:
@@ -19,7 +20,7 @@ class WhisperFactory:
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
uvr_model_dir: str = UVR_MODELS_DIR,
output_dir: str = OUTPUT_DIR,
) -> "WhisperBase":
) -> "BaseTranscriptionPipeline":
"""
Create a whisper inference class based on the provided whisper_type.
@@ -45,36 +46,29 @@ class WhisperFactory:
Returns
-------
WhisperBase
BaseTranscriptionPipeline
An instance of the appropriate whisper inference class based on the whisper_type.
"""
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
whisper_type = whisper_type.lower().strip()
whisper_type = whisper_type.strip().lower()
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
whisper_typos = ["whisper"]
insanely_fast_whisper_typos = [
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
]
if whisper_type in faster_whisper_typos:
if whisper_type == WhisperImpl.FASTER_WHISPER.value:
return FasterWhisperInference(
model_dir=faster_whisper_model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
elif whisper_type in whisper_typos:
elif whisper_type == WhisperImpl.WHISPER.value:
return WhisperInference(
model_dir=whisper_model_dir,
output_dir=output_dir,
diarization_model_dir=diarization_model_dir,
uvr_model_dir=uvr_model_dir
)
elif whisper_type in insanely_fast_whisper_typos:
elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
return InsanelyFastWhisperInference(
model_dir=insanely_fast_whisper_model_dir,
output_dir=output_dir,

View File

@@ -1,371 +0,0 @@
from dataclasses import dataclass, fields
import gradio as gr
from typing import Optional, Dict
import yaml
from modules.utils.constants import AUTOMATIC_DETECTION
@dataclass
class WhisperParameters:
model_size: gr.Dropdown
lang: gr.Dropdown
is_translate: gr.Checkbox
beam_size: gr.Number
log_prob_threshold: gr.Number
no_speech_threshold: gr.Number
compute_type: gr.Dropdown
best_of: gr.Number
patience: gr.Number
condition_on_previous_text: gr.Checkbox
prompt_reset_on_temperature: gr.Slider
initial_prompt: gr.Textbox
temperature: gr.Slider
compression_ratio_threshold: gr.Number
vad_filter: gr.Checkbox
threshold: gr.Slider
min_speech_duration_ms: gr.Number
max_speech_duration_s: gr.Number
min_silence_duration_ms: gr.Number
speech_pad_ms: gr.Number
batch_size: gr.Number
is_diarize: gr.Checkbox
hf_token: gr.Textbox
diarization_device: gr.Dropdown
length_penalty: gr.Number
repetition_penalty: gr.Number
no_repeat_ngram_size: gr.Number
prefix: gr.Textbox
suppress_blank: gr.Checkbox
suppress_tokens: gr.Textbox
max_initial_timestamp: gr.Number
word_timestamps: gr.Checkbox
prepend_punctuations: gr.Textbox
append_punctuations: gr.Textbox
max_new_tokens: gr.Number
chunk_length: gr.Number
hallucination_silence_threshold: gr.Number
hotwords: gr.Textbox
language_detection_threshold: gr.Number
language_detection_segments: gr.Number
is_bgm_separate: gr.Checkbox
uvr_model_size: gr.Dropdown
uvr_device: gr.Dropdown
uvr_segment_size: gr.Number
uvr_save_file: gr.Checkbox
uvr_enable_offload: gr.Checkbox
"""
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
See more about Gradio pre-processing: https://www.gradio.app/docs/components
Attributes
----------
model_size: gr.Dropdown
Whisper model size.
lang: gr.Dropdown
Source language of the file to transcribe.
is_translate: gr.Checkbox
Boolean value that determines whether to translate to English.
It's Whisper's feature to translate speech from another language directly into English end-to-end.
beam_size: gr.Number
Int value that is used for decoding option.
log_prob_threshold: gr.Number
If the average log probability over sampled tokens is below this value, treat as failed.
no_speech_threshold: gr.Number
If the no_speech probability is higher than this value AND
the average log probability over sampled tokens is below `log_prob_threshold`,
consider the segment as silent.
compute_type: gr.Dropdown
compute type for transcription.
see more info : https://opennmt.net/CTranslate2/quantization.html
best_of: gr.Number
Number of candidates when sampling with non-zero temperature.
patience: gr.Number
Beam search patience factor.
condition_on_previous_text: gr.Checkbox
if True, the previous output of the model is provided as a prompt for the next window;
disabling may make the text inconsistent across windows, but the model becomes less prone to
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
initial_prompt: gr.Textbox
Optional text to provide as a prompt for the first window. This can be used to provide, or
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
to make it more likely to predict those word correctly.
temperature: gr.Slider
Temperature for sampling. It can be a tuple of temperatures,
which will be successively used upon failures according to either
`compression_ratio_threshold` or `log_prob_threshold`.
compression_ratio_threshold: gr.Number
If the gzip compression ratio is above this value, treat as failed
vad_filter: gr.Checkbox
Enable the voice activity detection (VAD) to filter out parts of the audio
without speech. This step is using the Silero VAD model
https://github.com/snakers4/silero-vad.
threshold: gr.Slider
This parameter is related with Silero VAD. Speech threshold.
Silero VAD outputs speech probabilities for each audio chunk,
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
min_speech_duration_ms: gr.Number
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
max_speech_duration_s: gr.Number
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
than max_speech_duration_s will be split at the timestamp of the last silence that
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
split aggressively just before max_speech_duration_s.
min_silence_duration_ms: gr.Number
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
before separating it
speech_pad_ms: gr.Number
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
batch_size: gr.Number
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
is_diarize: gr.Checkbox
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
hf_token: gr.Textbox
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
diarization_device: gr.Dropdown
This parameter is related with whisperx. Device to run diarization model
length_penalty: gr.Number
This parameter is related to faster-whisper. Exponential length penalty constant.
repetition_penalty: gr.Number
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
(set > 1 to penalize).
no_repeat_ngram_size: gr.Number
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
prefix: gr.Textbox
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
suppress_blank: gr.Checkbox
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
suppress_tokens: gr.Textbox
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
of symbols as defined in the model config.json file.
max_initial_timestamp: gr.Number
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
word_timestamps: gr.Checkbox
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
and dynamic time warping, and include the timestamps for each word in each segment.
prepend_punctuations: gr.Textbox
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
with the next word.
append_punctuations: gr.Textbox
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
with the previous word.
max_new_tokens: gr.Number
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
the maximum will be set by the default max_length.
chunk_length: gr.Number
This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
hallucination_silence_threshold: gr.Number
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
(in seconds) when a possible hallucination is detected.
hotwords: gr.Textbox
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
language_detection_threshold: gr.Number
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
language_detection_segments: gr.Number
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
is_separate_bgm: gr.Checkbox
This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
uvr_model_size: gr.Dropdown
This parameter is related to UVR. UVR model size.
uvr_device: gr.Dropdown
This parameter is related to UVR. Device to run UVR model.
uvr_segment_size: gr.Number
This parameter is related to UVR. Segment size for UVR model.
uvr_save_file: gr.Checkbox
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
uvr_enable_offload: gr.Checkbox
This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
after each transcription.
"""
def as_list(self) -> list:
"""
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
Returns
----------
A list of Gradio components
"""
return [getattr(self, f.name) for f in fields(self)]
@staticmethod
def as_value(*args) -> 'WhisperValues':
"""
To use Whisper parameters in function after Gradio post-processing.
See more about Gradio post-processing: : https://www.gradio.app/docs/components
Returns
----------
WhisperValues
Data class that has values of parameters
"""
return WhisperValues(*args)
@dataclass
class WhisperValues:
model_size: str = "large-v2"
lang: Optional[str] = None
is_translate: bool = False
beam_size: int = 5
log_prob_threshold: float = -1.0
no_speech_threshold: float = 0.6
compute_type: str = "float16"
best_of: int = 5
patience: float = 1.0
condition_on_previous_text: bool = True
prompt_reset_on_temperature: float = 0.5
initial_prompt: Optional[str] = None
temperature: float = 0.0
compression_ratio_threshold: float = 2.4
vad_filter: bool = False
threshold: float = 0.5
min_speech_duration_ms: int = 250
max_speech_duration_s: float = float("inf")
min_silence_duration_ms: int = 2000
speech_pad_ms: int = 400
batch_size: int = 24
is_diarize: bool = False
hf_token: str = ""
diarization_device: str = "cuda"
length_penalty: float = 1.0
repetition_penalty: float = 1.0
no_repeat_ngram_size: int = 0
prefix: Optional[str] = None
suppress_blank: bool = True
suppress_tokens: Optional[str] = "[-1]"
max_initial_timestamp: float = 0.0
word_timestamps: bool = False
prepend_punctuations: Optional[str] = "\"'“¿([{-"
append_punctuations: Optional[str] = "\"'.。,!?::”)]}、"
max_new_tokens: Optional[int] = None
chunk_length: Optional[int] = 30
hallucination_silence_threshold: Optional[float] = None
hotwords: Optional[str] = None
language_detection_threshold: Optional[float] = None
language_detection_segments: int = 1
is_bgm_separate: bool = False
uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
uvr_device: str = "cuda"
uvr_segment_size: int = 256
uvr_save_file: bool = False
uvr_enable_offload: bool = True
"""
A data class to use Whisper parameters.
"""
def to_yaml(self) -> Dict:
data = {
"whisper": {
"model_size": self.model_size,
"lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
"is_translate": self.is_translate,
"beam_size": self.beam_size,
"log_prob_threshold": self.log_prob_threshold,
"no_speech_threshold": self.no_speech_threshold,
"best_of": self.best_of,
"patience": self.patience,
"condition_on_previous_text": self.condition_on_previous_text,
"prompt_reset_on_temperature": self.prompt_reset_on_temperature,
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
"temperature": self.temperature,
"compression_ratio_threshold": self.compression_ratio_threshold,
"batch_size": self.batch_size,
"length_penalty": self.length_penalty,
"repetition_penalty": self.repetition_penalty,
"no_repeat_ngram_size": self.no_repeat_ngram_size,
"prefix": None if not self.prefix else self.prefix,
"suppress_blank": self.suppress_blank,
"suppress_tokens": self.suppress_tokens,
"max_initial_timestamp": self.max_initial_timestamp,
"word_timestamps": self.word_timestamps,
"prepend_punctuations": self.prepend_punctuations,
"append_punctuations": self.append_punctuations,
"max_new_tokens": self.max_new_tokens,
"chunk_length": self.chunk_length,
"hallucination_silence_threshold": self.hallucination_silence_threshold,
"hotwords": None if not self.hotwords else self.hotwords,
"language_detection_threshold": self.language_detection_threshold,
"language_detection_segments": self.language_detection_segments,
},
"vad": {
"vad_filter": self.vad_filter,
"threshold": self.threshold,
"min_speech_duration_ms": self.min_speech_duration_ms,
"max_speech_duration_s": self.max_speech_duration_s,
"min_silence_duration_ms": self.min_silence_duration_ms,
"speech_pad_ms": self.speech_pad_ms,
},
"diarization": {
"is_diarize": self.is_diarize,
"hf_token": self.hf_token
},
"bgm_separation": {
"is_separate_bgm": self.is_bgm_separate,
"model_size": self.uvr_model_size,
"segment_size": self.uvr_segment_size,
"save_file": self.uvr_save_file,
"enable_offload": self.uvr_enable_offload
},
}
return data
def as_list(self) -> list:
"""
Converts the data class attributes into a list
Returns
----------
A list of Whisper parameters
"""
return [getattr(self, f.name) for f in fields(self)]

View File

@@ -56,7 +56,7 @@
"!pip install faster-whisper==1.0.3\n",
"!pip install ctranslate2==4.4.0\n",
"!pip install gradio\n",
"!pip install git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error\n",
"!pip install gradio-i18n\n",
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n",
"!pip install git+https://github.com/JuanBindez/pytubefix.git\n",
"!pip install tokenizers==0.19.1\n",
@@ -99,7 +99,7 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {
"id": "PQroYRRZzQiN",
"cellView": "form"

View File

@@ -11,7 +11,7 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
faster-whisper==1.0.3
transformers
gradio
git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error
gradio-i18n
pytubefix
ruamel.yaml==0.18.6
pyannote.audio==3.3.1

View File

@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.whisper_parameter import WhisperValues
from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -17,9 +17,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
("whisper", False, True, False),
("faster-whisper", False, True, False),
("insanely_fast_whisper", False, True, False)
(WhisperImpl.WHISPER.value, False, True, False),
(WhisperImpl.FASTER_WHISPER.value, False, True, False),
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
]
)
def test_bgm_separation_pipeline(
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
("whisper", True, True, False),
("faster-whisper", True, True, False),
("insanely_fast_whisper", True, True, False)
(WhisperImpl.WHISPER.value, True, True, False),
(WhisperImpl.FASTER_WHISPER.value, True, True, False),
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
]
)
def test_bgm_separation_with_vad_pipeline(

View File

@@ -1,10 +1,14 @@
from modules.utils.paths import *
import functools
import jiwer
import os
import torch
from modules.utils.paths import *
from modules.utils.youtube_manager import *
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
TEST_WHISPER_MODEL = "tiny"
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
@@ -13,5 +17,24 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
@functools.lru_cache
def is_cuda_available():
return torch.cuda.is_available()
@functools.lru_cache
def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
try:
yt_temp_path = os.path.join("modules", "yt_tmp.wav")
if os.path.exists(yt_temp_path):
return False
yt = get_ytdata(url)
audio = get_ytaudio(yt)
return False
except Exception as e:
print(f"Pytube has detected as a bot: {e}")
return True
def calculate_wer(answer, prediction):
return jiwer.wer(answer, prediction)

View File

@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.whisper_parameter import WhisperValues
from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -16,9 +16,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
("whisper", False, False, True),
("faster-whisper", False, False, True),
("insanely_fast_whisper", False, False, True)
(WhisperImpl.WHISPER.value, False, False, True),
(WhisperImpl.FASTER_WHISPER.value, False, False, True),
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
]
)
def test_diarization_pipeline(

View File

@@ -1,5 +1,6 @@
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.whisper_parameter import WhisperValues
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *
@@ -12,9 +13,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
("whisper", False, False, False),
("faster-whisper", False, False, False),
("insanely_fast_whisper", False, False, False)
(WhisperImpl.WHISPER.value, False, False, False),
(WhisperImpl.FASTER_WHISPER.value, False, False, False),
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
]
)
def test_transcribe(
@@ -28,6 +29,10 @@ def test_transcribe(
if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
answer = TEST_ANSWER
if diarization:
answer = "SPEAKER_00|"+TEST_ANSWER
whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type,
)
@@ -37,16 +42,24 @@ def test_transcribe(
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
)
hparams = WhisperValues(
model_size=TEST_WHISPER_MODEL,
vad_filter=vad_filter,
is_bgm_separate=bgm_separation,
compute_type=whisper_inferencer.current_compute_type,
uvr_enable_offload=True,
is_diarize=diarization,
).as_list()
hparams = TranscriptionPipelineParams(
whisper=WhisperParams(
model_size=TEST_WHISPER_MODEL,
compute_type=whisper_inferencer.current_compute_type
),
vad=VadParams(
vad_filter=vad_filter
),
bgm_separation=BGMSeparationParams(
is_separate_bgm=bgm_separation,
enable_offload=True
),
diarization=DiarizationParams(
is_diarize=diarization
),
).to_list()
subtitle_str, file_path = whisper_inferencer.transcribe_file(
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path],
None,
"SRT",
@@ -54,29 +67,29 @@ def test_transcribe(
gr.Progress(),
*hparams,
)
subtitle = read_file(file_paths[0]).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
if not is_pytube_detected_bot():
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert os.path.exists(file_path)
whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
whisper_inferencer.transcribe_mic(
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
subtitle = read_file(file_path).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
def download_file(url, save_dir):

View File

@@ -1,6 +1,6 @@
from modules.utils.paths import *
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.whisper_parameter import WhisperValues
from modules.whisper.data_classes import *
from test_config import *
from test_transcription import download_file, test_transcribe
@@ -12,9 +12,9 @@ import os
@pytest.mark.parametrize(
"whisper_type,vad_filter,bgm_separation,diarization",
[
("whisper", True, False, False),
("faster-whisper", True, False, False),
("insanely_fast_whisper", True, False, False)
(WhisperImpl.WHISPER.value, True, False, False),
(WhisperImpl.FASTER_WHISPER.value, True, False, False),
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
]
)
def test_vad_pipeline(