Merge branch 'master' into salfter
This commit is contained in:
2
.github/workflows/ci.yml
vendored
2
.github/workflows/ci.yml
vendored
@@ -37,7 +37,7 @@ jobs:
|
|||||||
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
|
run: sudo apt-get update && sudo apt-get install -y git ffmpeg
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: pip install -r requirements.txt pytest
|
run: pip install -r requirements.txt pytest jiwer
|
||||||
|
|
||||||
- name: Run test
|
- name: Run test
|
||||||
run: python -m pytest -rs tests
|
run: python -m pytest -rs tests
|
||||||
@@ -128,3 +128,6 @@ This is Whisper's original VRAM usage table for models.
|
|||||||
- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
|
- [x] Add background music separation pre-processing with [UVR](https://github.com/Anjok07/ultimatevocalremovergui)
|
||||||
- [ ] Add fast api script
|
- [ ] Add fast api script
|
||||||
- [ ] Support real-time transcription for microphone
|
- [ ] Support real-time transcription for microphone
|
||||||
|
|
||||||
|
### Translation 🌐
|
||||||
|
Any PRs translating Japanese, Spanish, French, German, Chinese, or any other language into [translation.yaml](https://github.com/jhj0517/Whisper-WebUI/blob/master/configs/translation.yaml) would be greatly appreciated!
|
||||||
|
|||||||
183
app.py
183
app.py
@@ -7,17 +7,14 @@ import yaml
|
|||||||
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, WHISPER_MODELS_DIR,
|
||||||
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
INSANELY_FAST_WHISPER_MODELS_DIR, NLLB_MODELS_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
||||||
UVR_MODELS_DIR, I18N_YAML_PATH)
|
UVR_MODELS_DIR, I18N_YAML_PATH)
|
||||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
|
||||||
from modules.utils.files_manager import load_yaml
|
from modules.utils.files_manager import load_yaml
|
||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
|
||||||
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
|
||||||
from modules.translation.nllb_inference import NLLBInference
|
from modules.translation.nllb_inference import NLLBInference
|
||||||
from modules.ui.htmls import *
|
from modules.ui.htmls import *
|
||||||
from modules.utils.cli_manager import str2bool
|
from modules.utils.cli_manager import str2bool
|
||||||
from modules.utils.youtube_manager import get_ytmetas
|
from modules.utils.youtube_manager import get_ytmetas
|
||||||
from modules.translation.deepl_api import DeepLAPI
|
from modules.translation.deepl_api import DeepLAPI
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
|
|
||||||
|
|
||||||
class App:
|
class App:
|
||||||
@@ -44,7 +41,7 @@ class App:
|
|||||||
print(f"Use \"{self.args.whisper_type}\" implementation\n"
|
print(f"Use \"{self.args.whisper_type}\" implementation\n"
|
||||||
f"Device \"{self.whisper_inf.device}\" is detected")
|
f"Device \"{self.whisper_inf.device}\" is detected")
|
||||||
|
|
||||||
def create_whisper_parameters(self):
|
def create_pipeline_inputs(self):
|
||||||
whisper_params = self.default_params["whisper"]
|
whisper_params = self.default_params["whisper"]
|
||||||
vad_params = self.default_params["vad"]
|
vad_params = self.default_params["vad"]
|
||||||
diarization_params = self.default_params["diarization"]
|
diarization_params = self.default_params["diarization"]
|
||||||
@@ -56,7 +53,7 @@ class App:
|
|||||||
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
|
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
|
||||||
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
|
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
|
||||||
else whisper_params["lang"], label=_("Language"))
|
else whisper_params["lang"], label=_("Language"))
|
||||||
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label=_("File Format"))
|
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
|
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
|
||||||
interactive=True)
|
interactive=True)
|
||||||
@@ -66,158 +63,31 @@ class App:
|
|||||||
interactive=True)
|
interactive=True)
|
||||||
|
|
||||||
with gr.Accordion(_("Advanced Parameters"), open=False):
|
with gr.Accordion(_("Advanced Parameters"), open=False):
|
||||||
nb_beam_size = gr.Number(label="Beam Size", value=whisper_params["beam_size"], precision=0,
|
whisper_inputs = WhisperParams.to_gradio_inputs(defaults=whisper_params, only_advanced=True,
|
||||||
interactive=True,
|
whisper_type=self.args.whisper_type,
|
||||||
info="Beam size to use for decoding.")
|
available_compute_types=self.whisper_inf.available_compute_types,
|
||||||
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold",
|
compute_type=self.whisper_inf.current_compute_type)
|
||||||
value=whisper_params["log_prob_threshold"], interactive=True,
|
|
||||||
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
|
||||||
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=whisper_params["no_speech_threshold"],
|
|
||||||
interactive=True,
|
|
||||||
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
|
||||||
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
|
||||||
value=self.whisper_inf.current_compute_type, interactive=True,
|
|
||||||
allow_custom_value=True,
|
|
||||||
info="Select the type of computation to perform.")
|
|
||||||
nb_best_of = gr.Number(label="Best Of", value=whisper_params["best_of"], interactive=True,
|
|
||||||
info="Number of candidates when sampling with non-zero temperature.")
|
|
||||||
nb_patience = gr.Number(label="Patience", value=whisper_params["patience"], interactive=True,
|
|
||||||
info="Beam search patience factor.")
|
|
||||||
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text",
|
|
||||||
value=whisper_params["condition_on_previous_text"],
|
|
||||||
interactive=True,
|
|
||||||
info="Condition on previous text during decoding.")
|
|
||||||
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature",
|
|
||||||
value=whisper_params["prompt_reset_on_temperature"],
|
|
||||||
minimum=0, maximum=1, step=0.01, interactive=True,
|
|
||||||
info="Resets prompt if temperature is above this value."
|
|
||||||
" Arg has effect only if 'Condition On Previous Text' is True.")
|
|
||||||
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
|
||||||
info="Initial prompt to use for decoding.")
|
|
||||||
sd_temperature = gr.Slider(label="Temperature", value=whisper_params["temperature"], minimum=0.0,
|
|
||||||
step=0.01, maximum=1.0, interactive=True,
|
|
||||||
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
|
||||||
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold",
|
|
||||||
value=whisper_params["compression_ratio_threshold"],
|
|
||||||
interactive=True,
|
|
||||||
info="If the gzip compression ratio is above this value, treat as failed.")
|
|
||||||
nb_chunk_length = gr.Number(label="Chunk Length (s)", value=lambda: whisper_params["chunk_length"],
|
|
||||||
precision=0,
|
|
||||||
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
|
||||||
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
|
||||||
nb_length_penalty = gr.Number(label="Length Penalty", value=whisper_params["length_penalty"],
|
|
||||||
info="Exponential length penalty constant.")
|
|
||||||
nb_repetition_penalty = gr.Number(label="Repetition Penalty",
|
|
||||||
value=whisper_params["repetition_penalty"],
|
|
||||||
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
|
||||||
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size",
|
|
||||||
value=whisper_params["no_repeat_ngram_size"],
|
|
||||||
precision=0,
|
|
||||||
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
|
||||||
tb_prefix = gr.Textbox(label="Prefix", value=lambda: whisper_params["prefix"],
|
|
||||||
info="Optional text to provide as a prefix for the first window.")
|
|
||||||
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=whisper_params["suppress_blank"],
|
|
||||||
info="Suppress blank outputs at the beginning of the sampling.")
|
|
||||||
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value=whisper_params["suppress_tokens"],
|
|
||||||
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
|
||||||
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp",
|
|
||||||
value=whisper_params["max_initial_timestamp"],
|
|
||||||
info="The initial timestamp cannot be later than this.")
|
|
||||||
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=whisper_params["word_timestamps"],
|
|
||||||
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
|
||||||
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations",
|
|
||||||
value=whisper_params["prepend_punctuations"],
|
|
||||||
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
|
||||||
tb_append_punctuations = gr.Textbox(label="Append Punctuations",
|
|
||||||
value=whisper_params["append_punctuations"],
|
|
||||||
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
|
||||||
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: whisper_params["max_new_tokens"],
|
|
||||||
precision=0,
|
|
||||||
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
|
||||||
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
|
||||||
value=lambda: whisper_params[
|
|
||||||
"hallucination_silence_threshold"],
|
|
||||||
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
|
||||||
tb_hotwords = gr.Textbox(label="Hotwords", value=lambda: whisper_params["hotwords"],
|
|
||||||
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
|
||||||
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold",
|
|
||||||
value=lambda: whisper_params[
|
|
||||||
"language_detection_threshold"],
|
|
||||||
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
|
||||||
nb_language_detection_segments = gr.Number(label="Language Detection Segments",
|
|
||||||
value=lambda: whisper_params["language_detection_segments"],
|
|
||||||
precision=0,
|
|
||||||
info="Number of segments to consider for the language detection.")
|
|
||||||
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
|
||||||
nb_batch_size = gr.Number(label="Batch Size", value=whisper_params["batch_size"], precision=0)
|
|
||||||
|
|
||||||
with gr.Accordion(_("Background Music Remover Filter"), open=False):
|
with gr.Accordion(_("Background Music Remover Filter"), open=False):
|
||||||
cb_bgm_separation = gr.Checkbox(label=_("Enable Background Music Remover Filter"),
|
uvr_inputs = BGMSeparationParams.to_gradio_input(defaults=uvr_params,
|
||||||
value=uvr_params["is_separate_bgm"],
|
available_models=self.whisper_inf.music_separator.available_models,
|
||||||
interactive=True,
|
available_devices=self.whisper_inf.music_separator.available_devices,
|
||||||
info=_("Enabling this will remove background music"))
|
device=self.whisper_inf.music_separator.device)
|
||||||
dd_uvr_device = gr.Dropdown(label=_("Device"), value=self.whisper_inf.music_separator.device,
|
|
||||||
choices=self.whisper_inf.music_separator.available_devices)
|
|
||||||
dd_uvr_model_size = gr.Dropdown(label=_("Model"), value=uvr_params["model_size"],
|
|
||||||
choices=self.whisper_inf.music_separator.available_models)
|
|
||||||
nb_uvr_segment_size = gr.Number(label="Segment Size", value=uvr_params["segment_size"], precision=0)
|
|
||||||
cb_uvr_save_file = gr.Checkbox(label=_("Save separated files to output"), value=uvr_params["save_file"])
|
|
||||||
cb_uvr_enable_offload = gr.Checkbox(label=_("Offload sub model after removing background music"),
|
|
||||||
value=uvr_params["enable_offload"])
|
|
||||||
|
|
||||||
with gr.Accordion(_("Voice Detection Filter"), open=False):
|
with gr.Accordion(_("Voice Detection Filter"), open=False):
|
||||||
cb_vad_filter = gr.Checkbox(label=_("Enable Silero VAD Filter"), value=vad_params["vad_filter"],
|
vad_inputs = VadParams.to_gradio_inputs(defaults=vad_params)
|
||||||
interactive=True,
|
|
||||||
info=_("Enable this to transcribe only detected voice"))
|
|
||||||
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
|
||||||
value=vad_params["threshold"],
|
|
||||||
info="Lower it to be more sensitive to small sounds.")
|
|
||||||
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0,
|
|
||||||
value=vad_params["min_speech_duration_ms"],
|
|
||||||
info="Final speech chunks shorter than this time are thrown out")
|
|
||||||
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)",
|
|
||||||
value=vad_params["max_speech_duration_s"],
|
|
||||||
info="Maximum duration of speech chunks in \"seconds\".")
|
|
||||||
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0,
|
|
||||||
value=vad_params["min_silence_duration_ms"],
|
|
||||||
info="In the end of each speech chunk wait for this time"
|
|
||||||
" before separating it")
|
|
||||||
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=vad_params["speech_pad_ms"],
|
|
||||||
info="Final speech chunks are padded by this time each side")
|
|
||||||
|
|
||||||
with gr.Accordion(_("Diarization"), open=False):
|
with gr.Accordion(_("Diarization"), open=False):
|
||||||
cb_diarize = gr.Checkbox(label=_("Enable Diarization"), value=diarization_params["is_diarize"])
|
diarization_inputs = DiarizationParams.to_gradio_inputs(defaults=diarization_params,
|
||||||
tb_hf_token = gr.Text(label=_("HuggingFace Token"), value=diarization_params["hf_token"],
|
available_devices=self.whisper_inf.diarizer.available_device,
|
||||||
info=_("This is only needed the first time you download the model"))
|
device=self.whisper_inf.diarizer.device)
|
||||||
dd_diarization_device = gr.Dropdown(label=_("Device"),
|
|
||||||
choices=self.whisper_inf.diarizer.get_available_device(),
|
|
||||||
value=self.whisper_inf.diarizer.get_device())
|
|
||||||
|
|
||||||
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
||||||
|
|
||||||
|
pipeline_inputs = [dd_model, dd_lang, cb_translate] + whisper_inputs + vad_inputs + diarization_inputs + uvr_inputs
|
||||||
|
|
||||||
return (
|
return (
|
||||||
WhisperParameters(
|
pipeline_inputs,
|
||||||
model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
|
|
||||||
log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
|
|
||||||
compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
|
|
||||||
condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
|
|
||||||
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
|
||||||
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
|
||||||
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
|
||||||
speech_pad_ms=nb_speech_pad_ms, chunk_length=nb_chunk_length, batch_size=nb_batch_size,
|
|
||||||
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
|
||||||
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
|
||||||
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
|
||||||
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
|
||||||
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
|
||||||
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens,
|
|
||||||
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
|
||||||
language_detection_threshold=nb_language_detection_threshold,
|
|
||||||
language_detection_segments=nb_language_detection_segments,
|
|
||||||
prompt_reset_on_temperature=sld_prompt_reset_on_temperature, is_bgm_separate=cb_bgm_separation,
|
|
||||||
uvr_device=dd_uvr_device, uvr_model_size=dd_uvr_model_size, uvr_segment_size=nb_uvr_segment_size,
|
|
||||||
uvr_save_file=cb_uvr_save_file, uvr_enable_offload=cb_uvr_enable_offload
|
|
||||||
),
|
|
||||||
dd_file_format,
|
dd_file_format,
|
||||||
cb_timestamp
|
cb_timestamp
|
||||||
)
|
)
|
||||||
@@ -243,7 +113,7 @@ class App:
|
|||||||
visible=self.args.colab,
|
visible=self.args.colab,
|
||||||
value="")
|
value="")
|
||||||
|
|
||||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||||
@@ -254,7 +124,7 @@ class App:
|
|||||||
|
|
||||||
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
||||||
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
||||||
inputs=params + whisper_params.as_list(),
|
inputs=params + pipeline_params,
|
||||||
outputs=[tb_indicator, files_subtitles])
|
outputs=[tb_indicator, files_subtitles])
|
||||||
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
||||||
|
|
||||||
@@ -268,7 +138,7 @@ class App:
|
|||||||
tb_title = gr.Label(label=_("Youtube Title"))
|
tb_title = gr.Label(label=_("Youtube Title"))
|
||||||
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
|
tb_description = gr.Textbox(label=_("Youtube Description"), max_lines=15)
|
||||||
|
|
||||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||||
@@ -280,7 +150,7 @@ class App:
|
|||||||
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
||||||
|
|
||||||
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
||||||
inputs=params + whisper_params.as_list(),
|
inputs=params + pipeline_params,
|
||||||
outputs=[tb_indicator, files_subtitles])
|
outputs=[tb_indicator, files_subtitles])
|
||||||
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
||||||
outputs=[img_thumbnail, tb_title, tb_description])
|
outputs=[img_thumbnail, tb_title, tb_description])
|
||||||
@@ -290,7 +160,7 @@ class App:
|
|||||||
with gr.Row():
|
with gr.Row():
|
||||||
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
|
mic_input = gr.Microphone(label=_("Record with Mic"), type="filepath", interactive=True)
|
||||||
|
|
||||||
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
pipeline_params, dd_file_format, cb_timestamp = self.create_pipeline_inputs()
|
||||||
|
|
||||||
with gr.Row():
|
with gr.Row():
|
||||||
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
btn_run = gr.Button(_("GENERATE SUBTITLE FILE"), variant="primary")
|
||||||
@@ -302,7 +172,7 @@ class App:
|
|||||||
params = [mic_input, dd_file_format, cb_timestamp]
|
params = [mic_input, dd_file_format, cb_timestamp]
|
||||||
|
|
||||||
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
||||||
inputs=params + whisper_params.as_list(),
|
inputs=params + pipeline_params,
|
||||||
outputs=[tb_indicator, files_subtitles])
|
outputs=[tb_indicator, files_subtitles])
|
||||||
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
||||||
|
|
||||||
@@ -417,7 +287,6 @@ class App:
|
|||||||
|
|
||||||
# Launch the app with optional gradio settings
|
# Launch the app with optional gradio settings
|
||||||
args = self.args
|
args = self.args
|
||||||
|
|
||||||
self.app.queue(
|
self.app.queue(
|
||||||
api_open=args.api_open
|
api_open=args.api_open
|
||||||
).launch(
|
).launch(
|
||||||
@@ -447,8 +316,8 @@ class App:
|
|||||||
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument('--whisper_type', type=str, default="faster-whisper",
|
parser.add_argument('--whisper_type', type=str, default=WhisperImpl.FASTER_WHISPER.value,
|
||||||
choices=["whisper", "faster-whisper", "insanely-fast-whisper"],
|
choices=[item.value for item in WhisperImpl],
|
||||||
help='A type of the whisper implementation (Github repo name)')
|
help='A type of the whisper implementation (Github repo name)')
|
||||||
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
|
parser.add_argument('--share', type=str2bool, default=False, nargs='?', const=True, help='Gradio share value')
|
||||||
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
whisper:
|
whisper:
|
||||||
model_size: "medium.en"
|
model_size: "medium.en"
|
||||||
|
file_format: "SRT"
|
||||||
lang: "english"
|
lang: "english"
|
||||||
is_translate: false
|
is_translate: false
|
||||||
beam_size: 5
|
beam_size: 5
|
||||||
|
|||||||
@@ -411,3 +411,49 @@ ru: # Russian
|
|||||||
Instrumental: Инструментал
|
Instrumental: Инструментал
|
||||||
Vocals: Вокал
|
Vocals: Вокал
|
||||||
SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
|
SEPARATE BACKGROUND MUSIC: РАЗДЕЛИТЬ ФОНОВУЮ МУЗЫКУ
|
||||||
|
|
||||||
|
tr: # Turkish
|
||||||
|
Language: Dil
|
||||||
|
File: Dosya
|
||||||
|
Youtube: Youtube
|
||||||
|
Mic: Mikrofon
|
||||||
|
T2T Translation: T2T Çeviri
|
||||||
|
BGM Separation: Arka Plan Müziği Ayırma
|
||||||
|
GENERATE SUBTITLE FILE: ALTYAZI DOSYASI OLUŞTUR
|
||||||
|
Output: Çıktı
|
||||||
|
Downloadable output file: İndirilebilir çıktı dosyası
|
||||||
|
Upload File here: Dosya Yükle
|
||||||
|
Model: Model
|
||||||
|
Automatic Detection: Otomatik Algılama
|
||||||
|
File Format: Dosya Formatı
|
||||||
|
Translate to English?: İngilizceye Çevir?
|
||||||
|
Add a timestamp to the end of the filename: Dosya adının sonuna zaman damgası ekle
|
||||||
|
Advanced Parameters: Gelişmiş Parametreler
|
||||||
|
Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresi
|
||||||
|
Enabling this will remove background music: Bunu etkinleştirmek, arka plan müziğini alt model tarafından transkripsiyondan önce kaldıracaktır
|
||||||
|
Enable Background Music Remover Filter: Arka Plan Müziği Kaldırma Filtresini Etkinleştir
|
||||||
|
Save separated files to output: Ayrılmış dosyaları çıktıya kaydet
|
||||||
|
Offload sub model after removing background music: Arka plan müziği kaldırıldıktan sonra alt modeli devre dışı bırak
|
||||||
|
Voice Detection Filter: Ses Algılama Filtresi
|
||||||
|
Enable this to transcribe only detected voice: Bunu etkinleştirerek yalnızca alt model tarafından algılanan ses kısımlarını transkribe et
|
||||||
|
Enable Silero VAD Filter: Silero VAD Filtresini Etkinleştir
|
||||||
|
Diarization: Konuşmacı Ayrımı
|
||||||
|
Enable Diarization: Konuşmacı Ayrımını Etkinleştir
|
||||||
|
HuggingFace Token: HuggingFace Anahtarı
|
||||||
|
This is only needed the first time you download the model: Bu, modeli ilk kez indirirken gereklidir. Zaten modelleriniz varsa girmenize gerek yok. Modeli indirmek için "https://huggingface.co/pyannote/speaker-diarization-3.1" ve "https://huggingface.co/pyannote/segmentation-3.0" adreslerine gidip gereksinimlerini kabul etmeniz gerekiyor
|
||||||
|
Device: Cihaz
|
||||||
|
Youtube Link: Youtube Bağlantısı
|
||||||
|
Youtube Thumbnail: Youtube Küçük Resmi
|
||||||
|
Youtube Title: Youtube Başlığı
|
||||||
|
Youtube Description: Youtube Açıklaması
|
||||||
|
Record with Mic: Mikrofonla Kaydet
|
||||||
|
Upload Subtitle Files to translate here: Çeviri için altyazı dosyalarını buraya yükle
|
||||||
|
Your Auth Key (API KEY): Yetki Anahtarınız (API ANAHTARI)
|
||||||
|
Source Language: Kaynak Dil
|
||||||
|
Target Language: Hedef Dil
|
||||||
|
Pro User?: Pro Kullanıcı?
|
||||||
|
TRANSLATE SUBTITLE FILE: ALTYAZI DOSYASINI ÇEVİR
|
||||||
|
Upload Audio Files to separate background music: Arka plan müziğini ayırmak için ses dosyalarını yükle
|
||||||
|
Instrumental: Enstrümantal
|
||||||
|
Vocals: Vokal
|
||||||
|
SEPARATE BACKGROUND MUSIC: ARKA PLAN MÜZİĞİNİ AYIR
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from pyannote.audio import Pipeline
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from modules.whisper.data_classes import *
|
||||||
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
||||||
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
from modules.diarize.audio_loader import load_audio, SAMPLE_RATE
|
||||||
|
|
||||||
@@ -43,6 +44,8 @@ class DiarizationPipeline:
|
|||||||
|
|
||||||
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
||||||
transcript_segments = transcript_result["segments"]
|
transcript_segments = transcript_result["segments"]
|
||||||
|
if transcript_segments and isinstance(transcript_segments[0], Segment):
|
||||||
|
transcript_segments = [seg.model_dump() for seg in transcript_segments]
|
||||||
for seg in transcript_segments:
|
for seg in transcript_segments:
|
||||||
# assign speaker to segment (if any)
|
# assign speaker to segment (if any)
|
||||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], seg['end']) - np.maximum(diarize_df['start'],
|
||||||
@@ -63,7 +66,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
|||||||
seg["speaker"] = speaker
|
seg["speaker"] = speaker
|
||||||
|
|
||||||
# assign speaker to words
|
# assign speaker to words
|
||||||
if 'words' in seg:
|
if 'words' in seg and seg['words'] is not None:
|
||||||
for word in seg['words']:
|
for word in seg['words']:
|
||||||
if 'start' in word:
|
if 'start' in word:
|
||||||
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
diarize_df['intersection'] = np.minimum(diarize_df['end'], word['end']) - np.maximum(
|
||||||
@@ -85,10 +88,10 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
|
|||||||
if word_speaker is not None:
|
if word_speaker is not None:
|
||||||
word["speaker"] = word_speaker
|
word["speaker"] = word_speaker
|
||||||
|
|
||||||
return transcript_result
|
return {"segments": transcript_segments}
|
||||||
|
|
||||||
|
|
||||||
class Segment:
|
class DiarizationSegment:
|
||||||
def __init__(self, start, end, speaker=None):
|
def __init__(self, start, end, speaker=None):
|
||||||
self.start = start
|
self.start = start
|
||||||
self.end = end
|
self.end = end
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
from typing import List, Union, BinaryIO, Optional
|
from typing import List, Union, BinaryIO, Optional, Tuple
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import time
|
import time
|
||||||
import logging
|
import logging
|
||||||
@@ -8,6 +8,7 @@ import logging
|
|||||||
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
||||||
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
||||||
from modules.diarize.audio_loader import load_audio
|
from modules.diarize.audio_loader import load_audio
|
||||||
|
from modules.whisper.data_classes import *
|
||||||
|
|
||||||
|
|
||||||
class Diarizer:
|
class Diarizer:
|
||||||
@@ -23,10 +24,10 @@ class Diarizer:
|
|||||||
|
|
||||||
def run(self,
|
def run(self,
|
||||||
audio: Union[str, BinaryIO, np.ndarray],
|
audio: Union[str, BinaryIO, np.ndarray],
|
||||||
transcribed_result: List[dict],
|
transcribed_result: List[Segment],
|
||||||
use_auth_token: str,
|
use_auth_token: str,
|
||||||
device: Optional[str] = None
|
device: Optional[str] = None
|
||||||
):
|
) -> Tuple[List[Segment], float]:
|
||||||
"""
|
"""
|
||||||
Diarize transcribed result as a post-processing
|
Diarize transcribed result as a post-processing
|
||||||
|
|
||||||
@@ -34,7 +35,7 @@ class Diarizer:
|
|||||||
----------
|
----------
|
||||||
audio: Union[str, BinaryIO, np.ndarray]
|
audio: Union[str, BinaryIO, np.ndarray]
|
||||||
Audio input. This can be file path or binary type.
|
Audio input. This can be file path or binary type.
|
||||||
transcribed_result: List[dict]
|
transcribed_result: List[Segment]
|
||||||
transcribed result through whisper.
|
transcribed result through whisper.
|
||||||
use_auth_token: str
|
use_auth_token: str
|
||||||
Huggingface token with READ permission. This is only needed the first time you download the model.
|
Huggingface token with READ permission. This is only needed the first time you download the model.
|
||||||
@@ -44,8 +45,8 @@ class Diarizer:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
segments_result: List[dict]
|
segments_result: List[Segment]
|
||||||
list of dicts that includes start, end timestamps and transcribed text
|
list of Segment that includes start, end timestamps and transcribed text
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
elapsed time for running
|
elapsed time for running
|
||||||
"""
|
"""
|
||||||
@@ -68,14 +69,20 @@ class Diarizer:
|
|||||||
{"segments": transcribed_result}
|
{"segments": transcribed_result}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
segments_result = []
|
||||||
for segment in diarized_result["segments"]:
|
for segment in diarized_result["segments"]:
|
||||||
speaker = "None"
|
speaker = "None"
|
||||||
if "speaker" in segment:
|
if "speaker" in segment:
|
||||||
speaker = segment["speaker"]
|
speaker = segment["speaker"]
|
||||||
segment["text"] = speaker + "|" + segment["text"].strip()
|
diarized_text = speaker + "|" + segment["text"].strip()
|
||||||
|
segments_result.append(Segment(
|
||||||
|
start=segment["start"],
|
||||||
|
end=segment["end"],
|
||||||
|
text=diarized_text
|
||||||
|
))
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
return diarized_result["segments"], elapsed_time
|
return segments_result, elapsed_time
|
||||||
|
|
||||||
def update_pipe(self,
|
def update_pipe(self,
|
||||||
use_auth_token: str,
|
use_auth_token: str,
|
||||||
|
|||||||
@@ -139,37 +139,27 @@ class DeepLAPI:
|
|||||||
)
|
)
|
||||||
|
|
||||||
files_info = {}
|
files_info = {}
|
||||||
for fileobj in fileobjs:
|
for file_path in fileobjs:
|
||||||
file_path = fileobj
|
file_name, file_ext = os.path.splitext(os.path.basename(file_path))
|
||||||
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
writer = get_writer(file_ext, self.output_dir)
|
||||||
|
segments = writer.to_segments(file_path)
|
||||||
if file_ext == ".srt":
|
|
||||||
parsed_dicts = parse_srt(file_path=file_path)
|
|
||||||
|
|
||||||
elif file_ext == ".vtt":
|
|
||||||
parsed_dicts = parse_vtt(file_path=file_path)
|
|
||||||
|
|
||||||
batch_size = self.max_text_batch_size
|
batch_size = self.max_text_batch_size
|
||||||
for batch_start in range(0, len(parsed_dicts), batch_size):
|
for batch_start in range(0, len(segments), batch_size):
|
||||||
batch_end = min(batch_start + batch_size, len(parsed_dicts))
|
progress(batch_start / len(segments), desc="Translating..")
|
||||||
sentences_to_translate = [dic["sentence"] for dic in parsed_dicts[batch_start:batch_end]]
|
sentences_to_translate = [seg.text for seg in segments[batch_start:batch_start+batch_size]]
|
||||||
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
translated_texts = self.request_deepl_translate(auth_key, sentences_to_translate, source_lang,
|
||||||
target_lang, is_pro)
|
target_lang, is_pro)
|
||||||
for i, translated_text in enumerate(translated_texts):
|
for i, translated_text in enumerate(translated_texts):
|
||||||
parsed_dicts[batch_start + i]["sentence"] = translated_text["text"]
|
segments[batch_start + i].text = translated_text["text"]
|
||||||
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
|
||||||
|
|
||||||
if file_ext == ".srt":
|
subtitle, output_path = generate_file(
|
||||||
subtitle = get_serialized_srt(parsed_dicts)
|
output_dir=self.output_dir,
|
||||||
elif file_ext == ".vtt":
|
output_file_name=file_name,
|
||||||
subtitle = get_serialized_vtt(parsed_dicts)
|
output_format=file_ext,
|
||||||
|
result=segments,
|
||||||
if add_timestamp:
|
add_timestamp=add_timestamp
|
||||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
)
|
||||||
file_name += f"-{timestamp}"
|
|
||||||
|
|
||||||
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
|
||||||
write_file(subtitle, output_path)
|
|
||||||
|
|
||||||
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
||||||
|
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from typing import List
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
import modules.translation.nllb_inference as nllb
|
import modules.translation.nllb_inference as nllb
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
from modules.utils.subtitle_manager import *
|
from modules.utils.subtitle_manager import *
|
||||||
from modules.utils.files_manager import load_yaml, save_yaml
|
from modules.utils.files_manager import load_yaml, save_yaml
|
||||||
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
from modules.utils.paths import DEFAULT_PARAMETERS_CONFIG_PATH, NLLB_MODELS_DIR, TRANSLATION_OUTPUT_DIR
|
||||||
@@ -95,32 +95,22 @@ class TranslationBase(ABC):
|
|||||||
files_info = {}
|
files_info = {}
|
||||||
for fileobj in fileobjs:
|
for fileobj in fileobjs:
|
||||||
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
|
||||||
if file_ext == ".srt":
|
writer = get_writer(file_ext, self.output_dir)
|
||||||
parsed_dicts = parse_srt(file_path=fileobj)
|
segments = writer.to_segments(fileobj)
|
||||||
total_progress = len(parsed_dicts)
|
for i, segment in enumerate(segments):
|
||||||
for index, dic in enumerate(parsed_dicts):
|
progress(i / len(segments), desc="Translating..")
|
||||||
progress(index / total_progress, desc="Translating..")
|
translated_text = self.translate(segment.text, max_length=max_length)
|
||||||
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
segment.text = translated_text
|
||||||
dic["sentence"] = translated_text
|
|
||||||
subtitle = get_serialized_srt(parsed_dicts)
|
|
||||||
|
|
||||||
elif file_ext == ".vtt":
|
subtitle, file_path = generate_file(
|
||||||
parsed_dicts = parse_vtt(file_path=fileobj)
|
output_dir=self.output_dir,
|
||||||
total_progress = len(parsed_dicts)
|
output_file_name=file_name,
|
||||||
for index, dic in enumerate(parsed_dicts):
|
output_format=file_ext,
|
||||||
progress(index / total_progress, desc="Translating..")
|
result=segments,
|
||||||
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
add_timestamp=add_timestamp
|
||||||
dic["sentence"] = translated_text
|
)
|
||||||
subtitle = get_serialized_vtt(parsed_dicts)
|
|
||||||
|
|
||||||
if add_timestamp:
|
files_info[file_name] = {"subtitle": subtitle, "path": file_path}
|
||||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
|
||||||
file_name += f"-{timestamp}"
|
|
||||||
|
|
||||||
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
|
||||||
write_file(subtitle, output_path)
|
|
||||||
|
|
||||||
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
|
||||||
|
|
||||||
total_result = ''
|
total_result = ''
|
||||||
for file_name, info in files_info.items():
|
for file_name, info in files_info.items():
|
||||||
@@ -133,7 +123,8 @@ class TranslationBase(ABC):
|
|||||||
return [gr_str, output_file_paths]
|
return [gr_str, output_file_paths]
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error: {str(e)}")
|
print(f"Error translating file: {e}")
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.release_cuda_memory()
|
self.release_cuda_memory()
|
||||||
|
|
||||||
|
|||||||
@@ -1,3 +1,6 @@
|
|||||||
from gradio_i18n import Translate, gettext as _
|
from gradio_i18n import Translate, gettext as _
|
||||||
|
|
||||||
AUTOMATIC_DETECTION = _("Automatic Detection")
|
AUTOMATIC_DETECTION = _("Automatic Detection")
|
||||||
|
GRADIO_NONE_STR = ""
|
||||||
|
GRADIO_NONE_NUMBER_MAX = 9999
|
||||||
|
GRADIO_NONE_NUMBER_MIN = 0
|
||||||
|
|||||||
@@ -67,3 +67,9 @@ def is_video(file_path):
|
|||||||
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
|
video_extensions = ['.mp4', '.mkv', '.avi', '.mov', '.flv', '.wmv', '.webm', '.m4v', '.mpeg', '.mpg', '.3gp']
|
||||||
extension = os.path.splitext(file_path)[1].lower()
|
extension = os.path.splitext(file_path)[1].lower()
|
||||||
return extension in video_extensions
|
return extension in video_extensions
|
||||||
|
|
||||||
|
|
||||||
|
def read_file(file_path):
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
subtitle_content = f.read()
|
||||||
|
return subtitle_content
|
||||||
|
|||||||
@@ -1,121 +1,425 @@
|
|||||||
|
# Ported from https://github.com/openai/whisper/blob/main/whisper/utils.py
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import zlib
|
||||||
|
from typing import Callable, List, Optional, TextIO, Union, Dict, Tuple
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from modules.whisper.data_classes import Segment, Word
|
||||||
|
from .files_manager import read_file
|
||||||
|
|
||||||
|
|
||||||
def timeformat_srt(time):
|
def format_timestamp(
|
||||||
hours = time // 3600
|
seconds: float, always_include_hours: bool = True, decimal_marker: str = ","
|
||||||
minutes = (time - hours * 3600) // 60
|
) -> str:
|
||||||
seconds = time - hours * 3600 - minutes * 60
|
assert seconds >= 0, "non-negative timestamp expected"
|
||||||
milliseconds = (time - int(time)) * 1000
|
milliseconds = round(seconds * 1000.0)
|
||||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d},{int(milliseconds):03d}"
|
|
||||||
|
hours = milliseconds // 3_600_000
|
||||||
|
milliseconds -= hours * 3_600_000
|
||||||
|
|
||||||
|
minutes = milliseconds // 60_000
|
||||||
|
milliseconds -= minutes * 60_000
|
||||||
|
|
||||||
|
seconds = milliseconds // 1_000
|
||||||
|
milliseconds -= seconds * 1_000
|
||||||
|
|
||||||
|
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
||||||
|
return (
|
||||||
|
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def timeformat_vtt(time):
|
def time_str_to_seconds(time_str: str, decimal_marker: str = ",") -> float:
|
||||||
hours = time // 3600
|
times = time_str.split(":")
|
||||||
minutes = (time - hours * 3600) // 60
|
|
||||||
seconds = time - hours * 3600 - minutes * 60
|
if len(times) == 3:
|
||||||
milliseconds = (time - int(time)) * 1000
|
hours, minutes, rest = times
|
||||||
return f"{int(hours):02d}:{int(minutes):02d}:{int(seconds):02d}.{int(milliseconds):03d}"
|
hours = int(hours)
|
||||||
|
else:
|
||||||
|
hours = 0
|
||||||
|
minutes, rest = times
|
||||||
|
|
||||||
|
seconds, fractional = rest.split(decimal_marker)
|
||||||
|
|
||||||
|
minutes = int(minutes)
|
||||||
|
seconds = int(seconds)
|
||||||
|
fractional_seconds = float("0." + fractional)
|
||||||
|
|
||||||
|
return hours * 3600 + minutes * 60 + seconds + fractional_seconds
|
||||||
|
|
||||||
|
|
||||||
def write_file(subtitle, output_file):
|
def get_start(segments: List[dict]) -> Optional[float]:
|
||||||
with open(output_file, 'w', encoding='utf-8') as f:
|
return next(
|
||||||
f.write(subtitle)
|
(w["start"] for s in segments for w in s["words"]),
|
||||||
|
segments[0]["start"] if segments else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_srt(segments):
|
def get_end(segments: List[dict]) -> Optional[float]:
|
||||||
output = ""
|
return next(
|
||||||
for i, segment in enumerate(segments):
|
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
||||||
output += f"{i + 1}\n"
|
segments[-1]["end"] if segments else None,
|
||||||
output += f"{timeformat_srt(segment['start'])} --> {timeformat_srt(segment['end'])}\n"
|
)
|
||||||
if segment['text'].startswith(' '):
|
|
||||||
segment['text'] = segment['text'][1:]
|
|
||||||
output += f"{segment['text']}\n\n"
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
def get_vtt(segments):
|
class ResultWriter:
|
||||||
output = "WebVTT\n\n"
|
extension: str
|
||||||
for i, segment in enumerate(segments):
|
|
||||||
output += f"{i + 1}\n"
|
def __init__(self, output_dir: str):
|
||||||
output += f"{timeformat_vtt(segment['start'])} --> {timeformat_vtt(segment['end'])}\n"
|
self.output_dir = output_dir
|
||||||
if segment['text'].startswith(' '):
|
|
||||||
segment['text'] = segment['text'][1:]
|
def __call__(
|
||||||
output += f"{segment['text']}\n\n"
|
self, result: Union[dict, List[Segment]], output_file_name: str,
|
||||||
return output
|
options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
if isinstance(result, List) and result and isinstance(result[0], Segment):
|
||||||
|
result = {"segments": [seg.model_dump() for seg in result]}
|
||||||
|
|
||||||
|
output_path = os.path.join(
|
||||||
|
self.output_dir, output_file_name + "." + self.extension
|
||||||
|
)
|
||||||
|
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
self.write_result(result, file=f, options=options, **kwargs)
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def get_txt(segments):
|
class WriteTXT(ResultWriter):
|
||||||
output = ""
|
extension: str = "txt"
|
||||||
for i, segment in enumerate(segments):
|
|
||||||
if segment['text'].startswith(' '):
|
def write_result(
|
||||||
segment['text'] = segment['text'][1:]
|
self, result: Union[Dict, List[Segment]], file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
output += f"{segment['text']}\n"
|
):
|
||||||
return output
|
for segment in result["segments"]:
|
||||||
|
print(segment["text"].strip(), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
def parse_srt(file_path):
|
class SubtitlesWriter(ResultWriter):
|
||||||
"""Reads SRT file and returns as dict"""
|
always_include_hours: bool
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
decimal_marker: str
|
||||||
srt_data = file.read()
|
|
||||||
|
|
||||||
data = []
|
def iterate_result(
|
||||||
blocks = srt_data.split('\n\n')
|
self,
|
||||||
|
result: dict,
|
||||||
|
options: Optional[dict] = None,
|
||||||
|
*,
|
||||||
|
max_line_width: Optional[int] = None,
|
||||||
|
max_line_count: Optional[int] = None,
|
||||||
|
highlight_words: bool = False,
|
||||||
|
align_lrc_words: bool = False,
|
||||||
|
max_words_per_line: Optional[int] = None,
|
||||||
|
):
|
||||||
|
options = options or {}
|
||||||
|
max_line_width = max_line_width or options.get("max_line_width")
|
||||||
|
max_line_count = max_line_count or options.get("max_line_count")
|
||||||
|
highlight_words = highlight_words or options.get("highlight_words", False)
|
||||||
|
align_lrc_words = align_lrc_words or options.get("align_lrc_words", False)
|
||||||
|
max_words_per_line = max_words_per_line or options.get("max_words_per_line")
|
||||||
|
preserve_segments = max_line_count is None or max_line_width is None
|
||||||
|
max_line_width = max_line_width or 1000
|
||||||
|
max_words_per_line = max_words_per_line or 1000
|
||||||
|
|
||||||
for block in blocks:
|
def iterate_subtitles():
|
||||||
if block.strip() != '':
|
line_len = 0
|
||||||
lines = block.strip().split('\n')
|
line_count = 1
|
||||||
index = lines[0]
|
# the next subtitle to yield (a list of word timings with whitespace)
|
||||||
timestamp = lines[1]
|
subtitle: List[dict] = []
|
||||||
sentence = ' '.join(lines[2:])
|
last: float = get_start(result["segments"]) or 0.0
|
||||||
|
for segment in result["segments"]:
|
||||||
|
chunk_index = 0
|
||||||
|
words_count = max_words_per_line
|
||||||
|
while chunk_index < len(segment["words"]):
|
||||||
|
remaining_words = len(segment["words"]) - chunk_index
|
||||||
|
if max_words_per_line > len(segment["words"]) - chunk_index:
|
||||||
|
words_count = remaining_words
|
||||||
|
for i, original_timing in enumerate(
|
||||||
|
segment["words"][chunk_index : chunk_index + words_count]
|
||||||
|
):
|
||||||
|
timing = original_timing.copy()
|
||||||
|
long_pause = (
|
||||||
|
not preserve_segments and timing["start"] - last > 3.0
|
||||||
|
)
|
||||||
|
has_room = line_len + len(timing["word"]) <= max_line_width
|
||||||
|
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
|
||||||
|
if (
|
||||||
|
line_len > 0
|
||||||
|
and has_room
|
||||||
|
and not long_pause
|
||||||
|
and not seg_break
|
||||||
|
):
|
||||||
|
# line continuation
|
||||||
|
line_len += len(timing["word"])
|
||||||
|
else:
|
||||||
|
# new line
|
||||||
|
timing["word"] = timing["word"].strip()
|
||||||
|
if (
|
||||||
|
len(subtitle) > 0
|
||||||
|
and max_line_count is not None
|
||||||
|
and (long_pause or line_count >= max_line_count)
|
||||||
|
or seg_break
|
||||||
|
):
|
||||||
|
# subtitle break
|
||||||
|
yield subtitle
|
||||||
|
subtitle = []
|
||||||
|
line_count = 1
|
||||||
|
elif line_len > 0:
|
||||||
|
# line break
|
||||||
|
line_count += 1
|
||||||
|
timing["word"] = "\n" + timing["word"]
|
||||||
|
line_len = len(timing["word"].strip())
|
||||||
|
subtitle.append(timing)
|
||||||
|
last = timing["start"]
|
||||||
|
chunk_index += max_words_per_line
|
||||||
|
if len(subtitle) > 0:
|
||||||
|
yield subtitle
|
||||||
|
|
||||||
data.append({
|
if len(result["segments"]) > 0 and "words" in result["segments"][0] and result["segments"][0]["words"]:
|
||||||
"index": index,
|
for subtitle in iterate_subtitles():
|
||||||
"timestamp": timestamp,
|
subtitle_start = self.format_timestamp(subtitle[0]["start"])
|
||||||
"sentence": sentence
|
subtitle_end = self.format_timestamp(subtitle[-1]["end"])
|
||||||
})
|
subtitle_text = "".join([word["word"] for word in subtitle])
|
||||||
return data
|
if highlight_words:
|
||||||
|
last = subtitle_start
|
||||||
|
all_words = [timing["word"] for timing in subtitle]
|
||||||
|
for i, this_word in enumerate(subtitle):
|
||||||
|
start = self.format_timestamp(this_word["start"])
|
||||||
|
end = self.format_timestamp(this_word["end"])
|
||||||
|
if last != start:
|
||||||
|
yield last, start, subtitle_text
|
||||||
|
|
||||||
|
yield start, end, "".join(
|
||||||
|
[
|
||||||
|
re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
||||||
|
if j == i
|
||||||
|
else word
|
||||||
|
for j, word in enumerate(all_words)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
last = end
|
||||||
|
|
||||||
|
if align_lrc_words:
|
||||||
|
lrc_aligned_words = [f"[{self.format_timestamp(sub['start'])}]{sub['word']}" for sub in subtitle]
|
||||||
|
l_start, l_end = self.format_timestamp(subtitle[-1]['start']), self.format_timestamp(subtitle[-1]['end'])
|
||||||
|
lrc_aligned_words[-1] = f"[{l_start}]{subtitle[-1]['word']}[{l_end}]"
|
||||||
|
lrc_aligned_words = ' '.join(lrc_aligned_words)
|
||||||
|
yield None, None, lrc_aligned_words
|
||||||
|
|
||||||
|
else:
|
||||||
|
yield subtitle_start, subtitle_end, subtitle_text
|
||||||
|
else:
|
||||||
|
for segment in result["segments"]:
|
||||||
|
segment_start = self.format_timestamp(segment["start"])
|
||||||
|
segment_end = self.format_timestamp(segment["end"])
|
||||||
|
segment_text = segment["text"].strip().replace("-->", "->")
|
||||||
|
yield segment_start, segment_end, segment_text
|
||||||
|
|
||||||
|
def format_timestamp(self, seconds: float):
|
||||||
|
return format_timestamp(
|
||||||
|
seconds=seconds,
|
||||||
|
always_include_hours=self.always_include_hours,
|
||||||
|
decimal_marker=self.decimal_marker,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_vtt(file_path):
|
class WriteVTT(SubtitlesWriter):
|
||||||
"""Reads WebVTT file and returns as dict"""
|
extension: str = "vtt"
|
||||||
with open(file_path, 'r', encoding='utf-8') as file:
|
always_include_hours: bool = False
|
||||||
webvtt_data = file.read()
|
decimal_marker: str = "."
|
||||||
|
|
||||||
data = []
|
def write_result(
|
||||||
blocks = webvtt_data.split('\n\n')
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
print("WEBVTT\n", file=file)
|
||||||
|
for start, end, text in self.iterate_result(result, options, **kwargs):
|
||||||
|
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
for block in blocks:
|
def to_segments(self, file_path: str) -> List[Segment]:
|
||||||
if block.strip() != '' and not block.strip().startswith("WebVTT"):
|
segments = []
|
||||||
lines = block.strip().split('\n')
|
|
||||||
index = lines[0]
|
|
||||||
timestamp = lines[1]
|
|
||||||
sentence = ' '.join(lines[2:])
|
|
||||||
|
|
||||||
data.append({
|
blocks = read_file(file_path).split('\n\n')
|
||||||
"index": index,
|
|
||||||
"timestamp": timestamp,
|
|
||||||
"sentence": sentence
|
|
||||||
})
|
|
||||||
|
|
||||||
return data
|
for block in blocks:
|
||||||
|
if block.strip() != '' and not block.strip().startswith("WEBVTT"):
|
||||||
|
lines = block.strip().split('\n')
|
||||||
|
time_line = lines[0].split(" --> ")
|
||||||
|
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
|
||||||
|
sentence = ' '.join(lines[1:])
|
||||||
|
|
||||||
|
segments.append(Segment(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=sentence
|
||||||
|
))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def get_serialized_srt(dicts):
|
class WriteSRT(SubtitlesWriter):
|
||||||
output = ""
|
extension: str = "srt"
|
||||||
for dic in dicts:
|
always_include_hours: bool = True
|
||||||
output += f'{dic["index"]}\n'
|
decimal_marker: str = ","
|
||||||
output += f'{dic["timestamp"]}\n'
|
|
||||||
output += f'{dic["sentence"]}\n\n'
|
def write_result(
|
||||||
return output
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for i, (start, end, text) in enumerate(
|
||||||
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
|
):
|
||||||
|
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
||||||
|
|
||||||
|
def to_segments(self, file_path: str) -> List[Segment]:
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
blocks = read_file(file_path).split('\n\n')
|
||||||
|
|
||||||
|
for block in blocks:
|
||||||
|
if block.strip() != '':
|
||||||
|
lines = block.strip().split('\n')
|
||||||
|
index = lines[0]
|
||||||
|
time_line = lines[1].split(" --> ")
|
||||||
|
start, end = time_str_to_seconds(time_line[0], self.decimal_marker), time_str_to_seconds(time_line[1], self.decimal_marker)
|
||||||
|
sentence = ' '.join(lines[2:])
|
||||||
|
|
||||||
|
segments.append(Segment(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=sentence
|
||||||
|
))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
def get_serialized_vtt(dicts):
|
class WriteLRC(SubtitlesWriter):
|
||||||
output = "WebVTT\n\n"
|
extension: str = "lrc"
|
||||||
for dic in dicts:
|
always_include_hours: bool = False
|
||||||
output += f'{dic["index"]}\n'
|
decimal_marker: str = "."
|
||||||
output += f'{dic["timestamp"]}\n'
|
|
||||||
output += f'{dic["sentence"]}\n\n'
|
def write_result(
|
||||||
return output
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for i, (start, end, text) in enumerate(
|
||||||
|
self.iterate_result(result, options, **kwargs), start=1
|
||||||
|
):
|
||||||
|
if "align_lrc_words" in kwargs and kwargs["align_lrc_words"]:
|
||||||
|
print(f"{text}\n", file=file, flush=True)
|
||||||
|
else:
|
||||||
|
print(f"[{start}]{text}[{end}]\n", file=file, flush=True)
|
||||||
|
|
||||||
|
def to_segments(self, file_path: str) -> List[Segment]:
|
||||||
|
segments = []
|
||||||
|
|
||||||
|
blocks = read_file(file_path).split('\n')
|
||||||
|
|
||||||
|
for block in blocks:
|
||||||
|
if block.strip() != '':
|
||||||
|
lines = block.strip()
|
||||||
|
pattern = r'(\[.*?\])'
|
||||||
|
parts = re.split(pattern, lines)
|
||||||
|
parts = [part.strip() for part in parts if part]
|
||||||
|
|
||||||
|
for i, part in enumerate(parts):
|
||||||
|
sentence_i = i%2
|
||||||
|
if sentence_i == 1:
|
||||||
|
start_str, text, end_str = parts[sentence_i-1], parts[sentence_i], parts[sentence_i+1]
|
||||||
|
start_str, end_str = start_str.replace("[", "").replace("]", ""), end_str.replace("[", "").replace("]", "")
|
||||||
|
start, end = time_str_to_seconds(start_str, self.decimal_marker), time_str_to_seconds(end_str, self.decimal_marker)
|
||||||
|
|
||||||
|
segments.append(Segment(
|
||||||
|
start=start,
|
||||||
|
end=end,
|
||||||
|
text=text,
|
||||||
|
))
|
||||||
|
|
||||||
|
return segments
|
||||||
|
|
||||||
|
|
||||||
|
class WriteTSV(ResultWriter):
|
||||||
|
"""
|
||||||
|
Write a transcript to a file in TSV (tab-separated values) format containing lines like:
|
||||||
|
<start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
|
||||||
|
|
||||||
|
Using integer milliseconds as start and end times means there's no chance of interference from
|
||||||
|
an environment setting a language encoding that causes the decimal in a floating point number
|
||||||
|
to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
|
||||||
|
"""
|
||||||
|
|
||||||
|
extension: str = "tsv"
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
print("start", "end", "text", sep="\t", file=file)
|
||||||
|
for segment in result["segments"]:
|
||||||
|
print(round(1000 * segment["start"]), file=file, end="\t")
|
||||||
|
print(round(1000 * segment["end"]), file=file, end="\t")
|
||||||
|
print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
|
||||||
|
|
||||||
|
|
||||||
|
class WriteJSON(ResultWriter):
|
||||||
|
extension: str = "json"
|
||||||
|
|
||||||
|
def write_result(
|
||||||
|
self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
json.dump(result, file)
|
||||||
|
|
||||||
|
|
||||||
|
def get_writer(
|
||||||
|
output_format: str, output_dir: str
|
||||||
|
) -> Callable[[dict, TextIO, dict], None]:
|
||||||
|
output_format = output_format.strip().lower().replace(".", "")
|
||||||
|
|
||||||
|
writers = {
|
||||||
|
"txt": WriteTXT,
|
||||||
|
"vtt": WriteVTT,
|
||||||
|
"srt": WriteSRT,
|
||||||
|
"tsv": WriteTSV,
|
||||||
|
"json": WriteJSON,
|
||||||
|
"lrc": WriteLRC
|
||||||
|
}
|
||||||
|
|
||||||
|
if output_format == "all":
|
||||||
|
all_writers = [writer(output_dir) for writer in writers.values()]
|
||||||
|
|
||||||
|
def write_all(
|
||||||
|
result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
|
||||||
|
):
|
||||||
|
for writer in all_writers:
|
||||||
|
writer(result, file, options, **kwargs)
|
||||||
|
|
||||||
|
return write_all
|
||||||
|
|
||||||
|
return writers[output_format](output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
def generate_file(
|
||||||
|
output_format: str, output_dir: str, result: Union[dict, List[Segment]], output_file_name: str,
|
||||||
|
add_timestamp: bool = True, **kwargs
|
||||||
|
) -> Tuple[str, str]:
|
||||||
|
output_format = output_format.strip().lower().replace(".", "")
|
||||||
|
output_format = "vtt" if output_format == "webvtt" else output_format
|
||||||
|
|
||||||
|
if add_timestamp:
|
||||||
|
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
||||||
|
output_file_name += f"-{timestamp}"
|
||||||
|
|
||||||
|
file_path = os.path.join(output_dir, f"{output_file_name}.{output_format}")
|
||||||
|
file_writer = get_writer(output_format=output_format, output_dir=output_dir)
|
||||||
|
|
||||||
|
if isinstance(file_writer, WriteLRC) and kwargs.get("highlight_words", False):
|
||||||
|
kwargs["highlight_words"], kwargs["align_lrc_words"] = False, True
|
||||||
|
|
||||||
|
file_writer(result=result, output_file_name=output_file_name, **kwargs)
|
||||||
|
content = read_file(file_path)
|
||||||
|
return content, file_path
|
||||||
|
|
||||||
|
|
||||||
def safe_filename(name):
|
def safe_filename(name):
|
||||||
|
|||||||
@@ -5,7 +5,8 @@ import numpy as np
|
|||||||
from typing import BinaryIO, Union, List, Optional, Tuple
|
from typing import BinaryIO, Union, List, Optional, Tuple
|
||||||
import warnings
|
import warnings
|
||||||
import faster_whisper
|
import faster_whisper
|
||||||
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
|
from modules.whisper.data_classes import *
|
||||||
|
from faster_whisper.transcribe import SpeechTimestampsMap
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
@@ -247,18 +248,18 @@ class SileroVAD:
|
|||||||
|
|
||||||
def restore_speech_timestamps(
|
def restore_speech_timestamps(
|
||||||
self,
|
self,
|
||||||
segments: List[dict],
|
segments: List[Segment],
|
||||||
speech_chunks: List[dict],
|
speech_chunks: List[dict],
|
||||||
sampling_rate: Optional[int] = None,
|
sampling_rate: Optional[int] = None,
|
||||||
) -> List[dict]:
|
) -> List[Segment]:
|
||||||
if sampling_rate is None:
|
if sampling_rate is None:
|
||||||
sampling_rate = self.sampling_rate
|
sampling_rate = self.sampling_rate
|
||||||
|
|
||||||
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
||||||
|
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
segment["start"] = ts_map.get_original_time(segment["start"])
|
segment.start = ts_map.get_original_time(segment.start)
|
||||||
segment["end"] = ts_map.get_original_time(segment["end"])
|
segment.end = ts_map.get_original_time(segment.end)
|
||||||
|
|
||||||
return segments
|
return segments
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
import os
|
import os
|
||||||
import torch
|
|
||||||
import whisper
|
import whisper
|
||||||
import ctranslate2
|
import ctranslate2
|
||||||
import gradio as gr
|
import gradio as gr
|
||||||
@@ -9,21 +8,20 @@ from typing import BinaryIO, Union, Tuple, List
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from faster_whisper.vad import VadOptions
|
from faster_whisper.vad import VadOptions
|
||||||
from dataclasses import astuple
|
|
||||||
|
|
||||||
from modules.uvr.music_separator import MusicSeparator
|
from modules.uvr.music_separator import MusicSeparator
|
||||||
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
|
||||||
UVR_MODELS_DIR)
|
UVR_MODELS_DIR)
|
||||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
from modules.utils.constants import *
|
||||||
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
from modules.utils.subtitle_manager import *
|
||||||
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
||||||
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
|
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml, read_file
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
from modules.diarize.diarizer import Diarizer
|
from modules.diarize.diarizer import Diarizer
|
||||||
from modules.vad.silero_vad import SileroVAD
|
from modules.vad.silero_vad import SileroVAD
|
||||||
|
|
||||||
|
|
||||||
class WhisperBase(ABC):
|
class BaseTranscriptionPipeline(ABC):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dir: str = WHISPER_MODELS_DIR,
|
model_dir: str = WHISPER_MODELS_DIR,
|
||||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||||
@@ -73,13 +71,15 @@ class WhisperBase(ABC):
|
|||||||
def run(self,
|
def run(self,
|
||||||
audio: Union[str, BinaryIO, np.ndarray],
|
audio: Union[str, BinaryIO, np.ndarray],
|
||||||
progress: gr.Progress = gr.Progress(),
|
progress: gr.Progress = gr.Progress(),
|
||||||
|
file_format: str = "SRT",
|
||||||
add_timestamp: bool = True,
|
add_timestamp: bool = True,
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
) -> Tuple[List[dict], float]:
|
) -> Tuple[List[Segment], float]:
|
||||||
"""
|
"""
|
||||||
Run transcription with conditional pre-processing and post-processing.
|
Run transcription with conditional pre-processing and post-processing.
|
||||||
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
|
||||||
The diarization will be performed in post-processing, if enabled.
|
The diarization will be performed in post-processing, if enabled.
|
||||||
|
Due to the integration with gradio, the parameters have to be specified with a `*` wildcard.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -87,40 +87,33 @@ class WhisperBase(ABC):
|
|||||||
Audio input. This can be file path or binary type.
|
Audio input. This can be file path or binary type.
|
||||||
progress: gr.Progress
|
progress: gr.Progress
|
||||||
Indicator to show progress directly in gradio.
|
Indicator to show progress directly in gradio.
|
||||||
|
file_format: str
|
||||||
|
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
||||||
add_timestamp: bool
|
add_timestamp: bool
|
||||||
Whether to add a timestamp at the end of the filename.
|
Whether to add a timestamp at the end of the filename.
|
||||||
*whisper_params: tuple
|
*pipeline_params: tuple
|
||||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class.
|
||||||
|
This must be provided as a List with * wildcard because of the integration with gradio.
|
||||||
|
See more info at : https://github.com/gradio-app/gradio/issues/2471
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
segments_result: List[dict]
|
segments_result: List[Segment]
|
||||||
list of dicts that includes start, end timestamps and transcribed text
|
list of Segment that includes start, end timestamps and transcribed text
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
elapsed time for running
|
elapsed time for running
|
||||||
"""
|
"""
|
||||||
params = WhisperParameters.as_value(*whisper_params)
|
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||||
|
params = self.validate_gradio_values(params)
|
||||||
|
bgm_params, vad_params, whisper_params, diarization_params = params.bgm_separation, params.vad, params.whisper, params.diarization
|
||||||
|
|
||||||
self.cache_parameters(
|
if bgm_params.is_separate_bgm:
|
||||||
whisper_params=params,
|
|
||||||
add_timestamp=add_timestamp
|
|
||||||
)
|
|
||||||
|
|
||||||
if params.lang is None:
|
|
||||||
pass
|
|
||||||
elif params.lang == AUTOMATIC_DETECTION:
|
|
||||||
params.lang = None
|
|
||||||
else:
|
|
||||||
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
|
||||||
params.lang = language_code_dict[params.lang]
|
|
||||||
|
|
||||||
if params.is_bgm_separate:
|
|
||||||
music, audio, _ = self.music_separator.separate(
|
music, audio, _ = self.music_separator.separate(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
model_name=params.uvr_model_size,
|
model_name=bgm_params.model_size,
|
||||||
device=params.uvr_device,
|
device=bgm_params.device,
|
||||||
segment_size=params.uvr_segment_size,
|
segment_size=bgm_params.segment_size,
|
||||||
save_file=params.uvr_save_file,
|
save_file=bgm_params.save_file,
|
||||||
progress=progress
|
progress=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -132,47 +125,55 @@ class WhisperBase(ABC):
|
|||||||
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
origin_sample_rate = self.music_separator.audio_info.sample_rate
|
||||||
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
|
||||||
|
|
||||||
if params.uvr_enable_offload:
|
if bgm_params.enable_offload:
|
||||||
self.music_separator.offload()
|
self.music_separator.offload()
|
||||||
|
|
||||||
if params.vad_filter:
|
if vad_params.vad_filter:
|
||||||
# Explicit value set for float('inf') from gr.Number()
|
|
||||||
if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
|
|
||||||
params.max_speech_duration_s = float('inf')
|
|
||||||
|
|
||||||
vad_options = VadOptions(
|
vad_options = VadOptions(
|
||||||
threshold=params.threshold,
|
threshold=vad_params.threshold,
|
||||||
min_speech_duration_ms=params.min_speech_duration_ms,
|
min_speech_duration_ms=vad_params.min_speech_duration_ms,
|
||||||
max_speech_duration_s=params.max_speech_duration_s,
|
max_speech_duration_s=vad_params.max_speech_duration_s,
|
||||||
min_silence_duration_ms=params.min_silence_duration_ms,
|
min_silence_duration_ms=vad_params.min_silence_duration_ms,
|
||||||
speech_pad_ms=params.speech_pad_ms
|
speech_pad_ms=vad_params.speech_pad_ms
|
||||||
)
|
)
|
||||||
|
|
||||||
audio, speech_chunks = self.vad.run(
|
vad_processed, speech_chunks = self.vad.run(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
vad_parameters=vad_options,
|
vad_parameters=vad_options,
|
||||||
progress=progress
|
progress=progress
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if vad_processed.size > 0:
|
||||||
|
audio = vad_processed
|
||||||
|
else:
|
||||||
|
vad_params.vad_filter = False
|
||||||
|
|
||||||
result, elapsed_time = self.transcribe(
|
result, elapsed_time = self.transcribe(
|
||||||
audio,
|
audio,
|
||||||
progress,
|
progress,
|
||||||
*astuple(params)
|
*whisper_params.to_list()
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.vad_filter:
|
if vad_params.vad_filter:
|
||||||
result = self.vad.restore_speech_timestamps(
|
result = self.vad.restore_speech_timestamps(
|
||||||
segments=result,
|
segments=result,
|
||||||
speech_chunks=speech_chunks,
|
speech_chunks=speech_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
if params.is_diarize:
|
if diarization_params.is_diarize:
|
||||||
result, elapsed_time_diarization = self.diarizer.run(
|
result, elapsed_time_diarization = self.diarizer.run(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
use_auth_token=params.hf_token,
|
use_auth_token=diarization_params.hf_token,
|
||||||
transcribed_result=result,
|
transcribed_result=result,
|
||||||
|
device=diarization_params.device
|
||||||
)
|
)
|
||||||
elapsed_time += elapsed_time_diarization
|
elapsed_time += elapsed_time_diarization
|
||||||
|
|
||||||
|
self.cache_parameters(
|
||||||
|
params=params,
|
||||||
|
file_format=file_format,
|
||||||
|
add_timestamp=add_timestamp
|
||||||
|
)
|
||||||
return result, elapsed_time
|
return result, elapsed_time
|
||||||
|
|
||||||
def transcribe_file(self,
|
def transcribe_file(self,
|
||||||
@@ -181,8 +182,8 @@ class WhisperBase(ABC):
|
|||||||
file_format: str = "SRT",
|
file_format: str = "SRT",
|
||||||
add_timestamp: bool = True,
|
add_timestamp: bool = True,
|
||||||
progress=gr.Progress(),
|
progress=gr.Progress(),
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
) -> list:
|
) -> Tuple[str, List]:
|
||||||
"""
|
"""
|
||||||
Write subtitle file from Files
|
Write subtitle file from Files
|
||||||
|
|
||||||
@@ -199,8 +200,8 @@ class WhisperBase(ABC):
|
|||||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
|
||||||
progress: gr.Progress
|
progress: gr.Progress
|
||||||
Indicator to show progress directly in gradio.
|
Indicator to show progress directly in gradio.
|
||||||
*whisper_params: tuple
|
*pipeline_params: tuple
|
||||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
Parameters for the transcription pipeline. This will be dealt with "TranscriptionPipelineParams" data class
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
@@ -210,6 +211,11 @@ class WhisperBase(ABC):
|
|||||||
Output file path to return to gr.Files()
|
Output file path to return to gr.Files()
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||||
|
writer_options = {
|
||||||
|
"highlight_words": True if params.whisper.word_timestamps else False
|
||||||
|
}
|
||||||
|
|
||||||
if input_folder_path:
|
if input_folder_path:
|
||||||
files = get_media_files(input_folder_path)
|
files = get_media_files(input_folder_path)
|
||||||
if isinstance(files, str):
|
if isinstance(files, str):
|
||||||
@@ -222,19 +228,21 @@ class WhisperBase(ABC):
|
|||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
file,
|
file,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
file_name, file_ext = os.path.splitext(os.path.basename(file))
|
||||||
subtitle, file_path = self.generate_and_write_file(
|
subtitle, file_path = generate_file(
|
||||||
file_name=file_name,
|
output_dir=self.output_dir,
|
||||||
transcribed_segments=transcribed_segments,
|
output_file_name=file_name,
|
||||||
|
output_format=file_format,
|
||||||
|
result=transcribed_segments,
|
||||||
add_timestamp=add_timestamp,
|
add_timestamp=add_timestamp,
|
||||||
file_format=file_format,
|
**writer_options
|
||||||
output_dir=self.output_dir
|
|
||||||
)
|
)
|
||||||
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
files_info[file_name] = {"subtitle": read_file(file_path), "time_for_task": time_for_task, "path": file_path}
|
||||||
|
|
||||||
total_result = ''
|
total_result = ''
|
||||||
total_time = 0
|
total_time = 0
|
||||||
@@ -247,10 +255,11 @@ class WhisperBase(ABC):
|
|||||||
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
result_str = f"Done in {self.format_time(total_time)}! Subtitle is in the outputs folder.\n\n{total_result}"
|
||||||
result_file_path = [info['path'] for info in files_info.values()]
|
result_file_path = [info['path'] for info in files_info.values()]
|
||||||
|
|
||||||
return [result_str, result_file_path]
|
return result_str, result_file_path
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error transcribing file: {e}")
|
print(f"Error transcribing file: {e}")
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.release_cuda_memory()
|
self.release_cuda_memory()
|
||||||
|
|
||||||
@@ -259,8 +268,8 @@ class WhisperBase(ABC):
|
|||||||
file_format: str = "SRT",
|
file_format: str = "SRT",
|
||||||
add_timestamp: bool = True,
|
add_timestamp: bool = True,
|
||||||
progress=gr.Progress(),
|
progress=gr.Progress(),
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
) -> list:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Write subtitle file from microphone
|
Write subtitle file from microphone
|
||||||
|
|
||||||
@@ -274,7 +283,7 @@ class WhisperBase(ABC):
|
|||||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
||||||
progress: gr.Progress
|
progress: gr.Progress
|
||||||
Indicator to show progress directly in gradio.
|
Indicator to show progress directly in gradio.
|
||||||
*whisper_params: tuple
|
*pipeline_params: tuple
|
||||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -285,27 +294,36 @@ class WhisperBase(ABC):
|
|||||||
Output file path to return to gr.Files()
|
Output file path to return to gr.Files()
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||||
|
writer_options = {
|
||||||
|
"highlight_words": True if params.whisper.word_timestamps else False
|
||||||
|
}
|
||||||
|
|
||||||
progress(0, desc="Loading Audio..")
|
progress(0, desc="Loading Audio..")
|
||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
mic_audio,
|
mic_audio,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
progress(1, desc="Completed!")
|
progress(1, desc="Completed!")
|
||||||
|
|
||||||
subtitle, result_file_path = self.generate_and_write_file(
|
file_name = "Mic"
|
||||||
file_name="Mic",
|
subtitle, file_path = generate_file(
|
||||||
transcribed_segments=transcribed_segments,
|
output_dir=self.output_dir,
|
||||||
|
output_file_name=file_name,
|
||||||
|
output_format=file_format,
|
||||||
|
result=transcribed_segments,
|
||||||
add_timestamp=add_timestamp,
|
add_timestamp=add_timestamp,
|
||||||
file_format=file_format,
|
**writer_options
|
||||||
output_dir=self.output_dir
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
||||||
return [result_str, result_file_path]
|
return result_str, file_path
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error transcribing file: {e}")
|
print(f"Error transcribing mic: {e}")
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.release_cuda_memory()
|
self.release_cuda_memory()
|
||||||
|
|
||||||
@@ -314,8 +332,8 @@ class WhisperBase(ABC):
|
|||||||
file_format: str = "SRT",
|
file_format: str = "SRT",
|
||||||
add_timestamp: bool = True,
|
add_timestamp: bool = True,
|
||||||
progress=gr.Progress(),
|
progress=gr.Progress(),
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
) -> list:
|
) -> Tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Write subtitle file from Youtube
|
Write subtitle file from Youtube
|
||||||
|
|
||||||
@@ -329,7 +347,7 @@ class WhisperBase(ABC):
|
|||||||
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
||||||
progress: gr.Progress
|
progress: gr.Progress
|
||||||
Indicator to show progress directly in gradio.
|
Indicator to show progress directly in gradio.
|
||||||
*whisper_params: tuple
|
*pipeline_params: tuple
|
||||||
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
Parameters related with whisper. This will be dealt with "WhisperParameters" data class
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
@@ -340,6 +358,11 @@ class WhisperBase(ABC):
|
|||||||
Output file path to return to gr.Files()
|
Output file path to return to gr.Files()
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
params = TranscriptionPipelineParams.from_list(list(pipeline_params))
|
||||||
|
writer_options = {
|
||||||
|
"highlight_words": True if params.whisper.word_timestamps else False
|
||||||
|
}
|
||||||
|
|
||||||
progress(0, desc="Loading Audio from Youtube..")
|
progress(0, desc="Loading Audio from Youtube..")
|
||||||
yt = get_ytdata(youtube_link)
|
yt = get_ytdata(youtube_link)
|
||||||
audio = get_ytaudio(yt)
|
audio = get_ytaudio(yt)
|
||||||
@@ -347,29 +370,33 @@ class WhisperBase(ABC):
|
|||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
audio,
|
audio,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*whisper_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
|
|
||||||
progress(1, desc="Completed!")
|
progress(1, desc="Completed!")
|
||||||
|
|
||||||
file_name = safe_filename(yt.title)
|
file_name = safe_filename(yt.title)
|
||||||
subtitle, result_file_path = self.generate_and_write_file(
|
subtitle, file_path = generate_file(
|
||||||
file_name=file_name,
|
output_dir=self.output_dir,
|
||||||
transcribed_segments=transcribed_segments,
|
output_file_name=file_name,
|
||||||
|
output_format=file_format,
|
||||||
|
result=transcribed_segments,
|
||||||
add_timestamp=add_timestamp,
|
add_timestamp=add_timestamp,
|
||||||
file_format=file_format,
|
**writer_options
|
||||||
output_dir=self.output_dir
|
|
||||||
)
|
)
|
||||||
|
|
||||||
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
|
||||||
|
|
||||||
if os.path.exists(audio):
|
if os.path.exists(audio):
|
||||||
os.remove(audio)
|
os.remove(audio)
|
||||||
|
|
||||||
return [result_str, result_file_path]
|
return result_str, file_path
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"Error transcribing file: {e}")
|
print(f"Error transcribing youtube: {e}")
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
self.release_cuda_memory()
|
self.release_cuda_memory()
|
||||||
|
|
||||||
@@ -387,58 +414,6 @@ class WhisperBase(ABC):
|
|||||||
else:
|
else:
|
||||||
return list(ctranslate2.get_supported_compute_types("cpu"))
|
return list(ctranslate2.get_supported_compute_types("cpu"))
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def generate_and_write_file(file_name: str,
|
|
||||||
transcribed_segments: list,
|
|
||||||
add_timestamp: bool,
|
|
||||||
file_format: str,
|
|
||||||
output_dir: str
|
|
||||||
) -> str:
|
|
||||||
"""
|
|
||||||
Writes subtitle file
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
file_name: str
|
|
||||||
Output file name
|
|
||||||
transcribed_segments: list
|
|
||||||
Text segments transcribed from audio
|
|
||||||
add_timestamp: bool
|
|
||||||
Determines whether to add a timestamp to the end of the filename.
|
|
||||||
file_format: str
|
|
||||||
File format to write. Supported formats: [SRT, WebVTT, txt]
|
|
||||||
output_dir: str
|
|
||||||
Directory path of the output
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
content: str
|
|
||||||
Result of the transcription
|
|
||||||
output_path: str
|
|
||||||
output file path
|
|
||||||
"""
|
|
||||||
if add_timestamp:
|
|
||||||
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
|
||||||
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
|
||||||
else:
|
|
||||||
output_path = os.path.join(output_dir, f"{file_name}")
|
|
||||||
|
|
||||||
file_format = file_format.strip().lower()
|
|
||||||
if file_format == "srt":
|
|
||||||
content = get_srt(transcribed_segments)
|
|
||||||
output_path += '.srt'
|
|
||||||
|
|
||||||
elif file_format == "webvtt":
|
|
||||||
content = get_vtt(transcribed_segments)
|
|
||||||
output_path += '.vtt'
|
|
||||||
|
|
||||||
elif file_format == "txt":
|
|
||||||
content = get_txt(transcribed_segments)
|
|
||||||
output_path += '.txt'
|
|
||||||
|
|
||||||
write_file(content, output_path)
|
|
||||||
return content, output_path
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def format_time(elapsed_time: float) -> str:
|
def format_time(elapsed_time: float) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -471,7 +446,7 @@ class WhisperBase(ABC):
|
|||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
return "cuda"
|
return "cuda"
|
||||||
elif torch.backends.mps.is_available():
|
elif torch.backends.mps.is_available():
|
||||||
if not WhisperBase.is_sparse_api_supported():
|
if not BaseTranscriptionPipeline.is_sparse_api_supported():
|
||||||
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
# Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
|
||||||
return "cpu"
|
return "cpu"
|
||||||
return "mps"
|
return "mps"
|
||||||
@@ -513,17 +488,64 @@ class WhisperBase(ABC):
|
|||||||
os.remove(file_path)
|
os.remove(file_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def cache_parameters(
|
def validate_gradio_values(params: TranscriptionPipelineParams):
|
||||||
whisper_params: WhisperValues,
|
"""
|
||||||
add_timestamp: bool
|
Validate gradio specific values that can't be displayed as None in the UI.
|
||||||
):
|
Related issue : https://github.com/gradio-app/gradio/issues/8723
|
||||||
"""cache parameters to the yaml file"""
|
"""
|
||||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
if params.whisper.lang is None:
|
||||||
cached_whisper_param = whisper_params.to_yaml()
|
pass
|
||||||
cached_yaml = {**cached_params, **cached_whisper_param}
|
elif params.whisper.lang == AUTOMATIC_DETECTION:
|
||||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
params.whisper.lang = None
|
||||||
|
else:
|
||||||
|
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
||||||
|
params.whisper.lang = language_code_dict[params.whisper.lang]
|
||||||
|
|
||||||
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
if params.whisper.initial_prompt == GRADIO_NONE_STR:
|
||||||
|
params.whisper.initial_prompt = None
|
||||||
|
if params.whisper.prefix == GRADIO_NONE_STR:
|
||||||
|
params.whisper.prefix = None
|
||||||
|
if params.whisper.hotwords == GRADIO_NONE_STR:
|
||||||
|
params.whisper.hotwords = None
|
||||||
|
if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
|
||||||
|
params.whisper.max_new_tokens = None
|
||||||
|
if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
|
||||||
|
params.whisper.hallucination_silence_threshold = None
|
||||||
|
if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
|
||||||
|
params.whisper.language_detection_threshold = None
|
||||||
|
if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
|
||||||
|
params.vad.max_speech_duration_s = float('inf')
|
||||||
|
return params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def cache_parameters(
|
||||||
|
params: TranscriptionPipelineParams,
|
||||||
|
file_format: str = "SRT",
|
||||||
|
add_timestamp: bool = True
|
||||||
|
):
|
||||||
|
"""Cache parameters to the yaml file"""
|
||||||
|
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||||
|
param_to_cache = params.to_dict()
|
||||||
|
|
||||||
|
cached_yaml = {**cached_params, **param_to_cache}
|
||||||
|
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||||
|
cached_yaml["whisper"]["file_format"] = file_format
|
||||||
|
|
||||||
|
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
||||||
|
if supress_token and isinstance(supress_token, list):
|
||||||
|
cached_yaml["whisper"]["suppress_tokens"] = str(supress_token)
|
||||||
|
|
||||||
|
if cached_yaml["whisper"].get("lang", None) is None:
|
||||||
|
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
|
||||||
|
else:
|
||||||
|
language_dict = whisper.tokenizer.LANGUAGES
|
||||||
|
cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
|
||||||
|
|
||||||
|
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
|
||||||
|
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
|
||||||
|
|
||||||
|
if cached_yaml is not None and cached_yaml:
|
||||||
|
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def resample_audio(audio: Union[str, np.ndarray],
|
def resample_audio(audio: Union[str, np.ndarray],
|
||||||
608
modules/whisper/data_classes.py
Normal file
608
modules/whisper/data_classes.py
Normal file
@@ -0,0 +1,608 @@
|
|||||||
|
import faster_whisper.transcribe
|
||||||
|
import gradio as gr
|
||||||
|
import torch
|
||||||
|
from typing import Optional, Dict, List, Union, NamedTuple
|
||||||
|
from pydantic import BaseModel, Field, field_validator, ConfigDict
|
||||||
|
from gradio_i18n import Translate, gettext as _
|
||||||
|
from enum import Enum
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from modules.utils.constants import *
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperImpl(Enum):
|
||||||
|
WHISPER = "whisper"
|
||||||
|
FASTER_WHISPER = "faster-whisper"
|
||||||
|
INSANELY_FAST_WHISPER = "insanely_fast_whisper"
|
||||||
|
|
||||||
|
|
||||||
|
class Segment(BaseModel):
|
||||||
|
id: Optional[int] = Field(default=None, description="Incremental id for the segment")
|
||||||
|
seek: Optional[int] = Field(default=None, description="Seek of the segment from chunked audio")
|
||||||
|
text: Optional[str] = Field(default=None, description="Transcription text of the segment")
|
||||||
|
start: Optional[float] = Field(default=None, description="Start time of the segment")
|
||||||
|
end: Optional[float] = Field(default=None, description="End time of the segment")
|
||||||
|
tokens: Optional[List[int]] = Field(default=None, description="List of token IDs")
|
||||||
|
temperature: Optional[float] = Field(default=None, description="Temperature used during the decoding process")
|
||||||
|
avg_logprob: Optional[float] = Field(default=None, description="Average log probability of the tokens")
|
||||||
|
compression_ratio: Optional[float] = Field(default=None, description="Compression ratio of the segment")
|
||||||
|
no_speech_prob: Optional[float] = Field(default=None, description="Probability that it's not speech")
|
||||||
|
words: Optional[List['Word']] = Field(default=None, description="List of words contained in the segment")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_faster_whisper(cls,
|
||||||
|
seg: faster_whisper.transcribe.Segment):
|
||||||
|
if seg.words is not None:
|
||||||
|
words = [
|
||||||
|
Word(
|
||||||
|
start=w.start,
|
||||||
|
end=w.end,
|
||||||
|
word=w.word,
|
||||||
|
probability=w.probability
|
||||||
|
) for w in seg.words
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
words = None
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
id=seg.id,
|
||||||
|
seek=seg.seek,
|
||||||
|
text=seg.text,
|
||||||
|
start=seg.start,
|
||||||
|
end=seg.end,
|
||||||
|
tokens=seg.tokens,
|
||||||
|
temperature=seg.temperature,
|
||||||
|
avg_logprob=seg.avg_logprob,
|
||||||
|
compression_ratio=seg.compression_ratio,
|
||||||
|
no_speech_prob=seg.no_speech_prob,
|
||||||
|
words=words
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Word(BaseModel):
|
||||||
|
start: Optional[float] = Field(default=None, description="Start time of the word")
|
||||||
|
end: Optional[float] = Field(default=None, description="Start time of the word")
|
||||||
|
word: Optional[str] = Field(default=None, description="Word text")
|
||||||
|
probability: Optional[float] = Field(default=None, description="Probability of the word")
|
||||||
|
|
||||||
|
|
||||||
|
class BaseParams(BaseModel):
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
return self.model_dump()
|
||||||
|
|
||||||
|
def to_list(self) -> List:
|
||||||
|
return list(self.model_dump().values())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_list(cls, data_list: List) -> 'BaseParams':
|
||||||
|
field_names = list(cls.model_fields.keys())
|
||||||
|
return cls(**dict(zip(field_names, data_list)))
|
||||||
|
|
||||||
|
|
||||||
|
class VadParams(BaseParams):
|
||||||
|
"""Voice Activity Detection parameters"""
|
||||||
|
vad_filter: bool = Field(default=False, description="Enable voice activity detection to filter out non-speech parts")
|
||||||
|
threshold: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Speech threshold for Silero VAD. Probabilities above this value are considered speech"
|
||||||
|
)
|
||||||
|
min_speech_duration_ms: int = Field(
|
||||||
|
default=250,
|
||||||
|
ge=0,
|
||||||
|
description="Final speech chunks shorter than this are discarded"
|
||||||
|
)
|
||||||
|
max_speech_duration_s: float = Field(
|
||||||
|
default=float("inf"),
|
||||||
|
gt=0,
|
||||||
|
description="Maximum duration of speech chunks in seconds"
|
||||||
|
)
|
||||||
|
min_silence_duration_ms: int = Field(
|
||||||
|
default=2000,
|
||||||
|
ge=0,
|
||||||
|
description="Minimum silence duration between speech chunks"
|
||||||
|
)
|
||||||
|
speech_pad_ms: int = Field(
|
||||||
|
default=400,
|
||||||
|
ge=0,
|
||||||
|
description="Padding added to each side of speech chunks"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_gradio_inputs(cls, defaults: Optional[Dict] = None) -> List[gr.components.base.FormComponent]:
|
||||||
|
return [
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Enable Silero VAD Filter"),
|
||||||
|
value=defaults.get("vad_filter", cls.__fields__["vad_filter"].default),
|
||||||
|
interactive=True,
|
||||||
|
info=_("Enable this to transcribe only detected voice")
|
||||||
|
),
|
||||||
|
gr.Slider(
|
||||||
|
minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold",
|
||||||
|
value=defaults.get("threshold", cls.__fields__["threshold"].default),
|
||||||
|
info="Lower it to be more sensitive to small sounds."
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Minimum Speech Duration (ms)", precision=0,
|
||||||
|
value=defaults.get("min_speech_duration_ms", cls.__fields__["min_speech_duration_ms"].default),
|
||||||
|
info="Final speech chunks shorter than this time are thrown out"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Maximum Speech Duration (s)",
|
||||||
|
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
|
||||||
|
info="Maximum duration of speech chunks in \"seconds\"."
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Minimum Silence Duration (ms)", precision=0,
|
||||||
|
value=defaults.get("min_silence_duration_ms", cls.__fields__["min_silence_duration_ms"].default),
|
||||||
|
info="In the end of each speech chunk wait for this time before separating it"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Speech Padding (ms)", precision=0,
|
||||||
|
value=defaults.get("speech_pad_ms", cls.__fields__["speech_pad_ms"].default),
|
||||||
|
info="Final speech chunks are padded by this time each side"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DiarizationParams(BaseParams):
|
||||||
|
"""Speaker diarization parameters"""
|
||||||
|
is_diarize: bool = Field(default=False, description="Enable speaker diarization")
|
||||||
|
device: str = Field(default="cuda", description="Device to run Diarization model.")
|
||||||
|
hf_token: str = Field(
|
||||||
|
default="",
|
||||||
|
description="Hugging Face token for downloading diarization models"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_gradio_inputs(cls,
|
||||||
|
defaults: Optional[Dict] = None,
|
||||||
|
available_devices: Optional[List] = None,
|
||||||
|
device: Optional[str] = None) -> List[gr.components.base.FormComponent]:
|
||||||
|
return [
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Enable Diarization"),
|
||||||
|
value=defaults.get("is_diarize", cls.__fields__["is_diarize"].default),
|
||||||
|
),
|
||||||
|
gr.Dropdown(
|
||||||
|
label=_("Device"),
|
||||||
|
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
||||||
|
value=defaults.get("device", device),
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label=_("HuggingFace Token"),
|
||||||
|
value=defaults.get("hf_token", cls.__fields__["hf_token"].default),
|
||||||
|
info=_("This is only needed the first time you download the model")
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class BGMSeparationParams(BaseParams):
|
||||||
|
"""Background music separation parameters"""
|
||||||
|
is_separate_bgm: bool = Field(default=False, description="Enable background music separation")
|
||||||
|
model_size: str = Field(
|
||||||
|
default="UVR-MDX-NET-Inst_HQ_4",
|
||||||
|
description="UVR model size"
|
||||||
|
)
|
||||||
|
device: str = Field(default="cuda", description="Device to run UVR model.")
|
||||||
|
segment_size: int = Field(
|
||||||
|
default=256,
|
||||||
|
gt=0,
|
||||||
|
description="Segment size for UVR model"
|
||||||
|
)
|
||||||
|
save_file: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Whether to save separated audio files"
|
||||||
|
)
|
||||||
|
enable_offload: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Offload UVR model after transcription"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_gradio_input(cls,
|
||||||
|
defaults: Optional[Dict] = None,
|
||||||
|
available_devices: Optional[List] = None,
|
||||||
|
device: Optional[str] = None,
|
||||||
|
available_models: Optional[List] = None) -> List[gr.components.base.FormComponent]:
|
||||||
|
return [
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Enable Background Music Remover Filter"),
|
||||||
|
value=defaults.get("is_separate_bgm", cls.__fields__["is_separate_bgm"].default),
|
||||||
|
interactive=True,
|
||||||
|
info=_("Enabling this will remove background music")
|
||||||
|
),
|
||||||
|
gr.Dropdown(
|
||||||
|
label=_("Model"),
|
||||||
|
choices=["UVR-MDX-NET-Inst_HQ_4",
|
||||||
|
"UVR-MDX-NET-Inst_3"] if available_models is None else available_models,
|
||||||
|
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
||||||
|
),
|
||||||
|
gr.Dropdown(
|
||||||
|
label=_("Device"),
|
||||||
|
choices=["cpu", "cuda"] if available_devices is None else available_devices,
|
||||||
|
value=defaults.get("device", device),
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Segment Size",
|
||||||
|
value=defaults.get("segment_size", cls.__fields__["segment_size"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Segment size for UVR model"
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Save separated files to output"),
|
||||||
|
value=defaults.get("save_file", cls.__fields__["save_file"].default),
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Offload sub model after removing background music"),
|
||||||
|
value=defaults.get("enable_offload", cls.__fields__["enable_offload"].default),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class WhisperParams(BaseParams):
|
||||||
|
"""Whisper parameters"""
|
||||||
|
model_size: str = Field(default="large-v2", description="Whisper model size")
|
||||||
|
lang: Optional[str] = Field(default=None, description="Source language of the file to transcribe")
|
||||||
|
is_translate: bool = Field(default=False, description="Translate speech to English end-to-end")
|
||||||
|
beam_size: int = Field(default=5, ge=1, description="Beam size for decoding")
|
||||||
|
log_prob_threshold: float = Field(
|
||||||
|
default=-1.0,
|
||||||
|
description="Threshold for average log probability of sampled tokens"
|
||||||
|
)
|
||||||
|
no_speech_threshold: float = Field(
|
||||||
|
default=0.6,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Threshold for detecting silence"
|
||||||
|
)
|
||||||
|
compute_type: str = Field(default="float16", description="Computation type for transcription")
|
||||||
|
best_of: int = Field(default=5, ge=1, description="Number of candidates when sampling")
|
||||||
|
patience: float = Field(default=1.0, gt=0, description="Beam search patience factor")
|
||||||
|
condition_on_previous_text: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Use previous output as prompt for next window"
|
||||||
|
)
|
||||||
|
prompt_reset_on_temperature: float = Field(
|
||||||
|
default=0.5,
|
||||||
|
ge=0.0,
|
||||||
|
le=1.0,
|
||||||
|
description="Temperature threshold for resetting prompt"
|
||||||
|
)
|
||||||
|
initial_prompt: Optional[str] = Field(default=None, description="Initial prompt for first window")
|
||||||
|
temperature: float = Field(
|
||||||
|
default=0.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="Temperature for sampling"
|
||||||
|
)
|
||||||
|
compression_ratio_threshold: float = Field(
|
||||||
|
default=2.4,
|
||||||
|
gt=0,
|
||||||
|
description="Threshold for gzip compression ratio"
|
||||||
|
)
|
||||||
|
length_penalty: float = Field(default=1.0, gt=0, description="Exponential length penalty")
|
||||||
|
repetition_penalty: float = Field(default=1.0, gt=0, description="Penalty for repeated tokens")
|
||||||
|
no_repeat_ngram_size: int = Field(default=0, ge=0, description="Size of n-grams to prevent repetition")
|
||||||
|
prefix: Optional[str] = Field(default=None, description="Prefix text for first window")
|
||||||
|
suppress_blank: bool = Field(
|
||||||
|
default=True,
|
||||||
|
description="Suppress blank outputs at start of sampling"
|
||||||
|
)
|
||||||
|
suppress_tokens: Optional[Union[List[int], str]] = Field(default=[-1], description="Token IDs to suppress")
|
||||||
|
max_initial_timestamp: float = Field(
|
||||||
|
default=1.0,
|
||||||
|
ge=0.0,
|
||||||
|
description="Maximum initial timestamp"
|
||||||
|
)
|
||||||
|
word_timestamps: bool = Field(default=False, description="Extract word-level timestamps")
|
||||||
|
prepend_punctuations: Optional[str] = Field(
|
||||||
|
default="\"'“¿([{-",
|
||||||
|
description="Punctuations to merge with next word"
|
||||||
|
)
|
||||||
|
append_punctuations: Optional[str] = Field(
|
||||||
|
default="\"'.。,,!!??::”)]}、",
|
||||||
|
description="Punctuations to merge with previous word"
|
||||||
|
)
|
||||||
|
max_new_tokens: Optional[int] = Field(default=None, description="Maximum number of new tokens per chunk")
|
||||||
|
chunk_length: Optional[int] = Field(default=30, description="Length of audio segments in seconds")
|
||||||
|
hallucination_silence_threshold: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Threshold for skipping silent periods in hallucination detection"
|
||||||
|
)
|
||||||
|
hotwords: Optional[str] = Field(default=None, description="Hotwords/hint phrases for the model")
|
||||||
|
language_detection_threshold: Optional[float] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Threshold for language detection probability"
|
||||||
|
)
|
||||||
|
language_detection_segments: int = Field(
|
||||||
|
default=1,
|
||||||
|
gt=0,
|
||||||
|
description="Number of segments for language detection"
|
||||||
|
)
|
||||||
|
batch_size: int = Field(default=24, gt=0, description="Batch size for processing")
|
||||||
|
|
||||||
|
@field_validator('lang')
|
||||||
|
def validate_lang(cls, v):
|
||||||
|
from modules.utils.constants import AUTOMATIC_DETECTION
|
||||||
|
return None if v == AUTOMATIC_DETECTION.unwrap() else v
|
||||||
|
|
||||||
|
@field_validator('suppress_tokens')
|
||||||
|
def validate_supress_tokens(cls, v):
|
||||||
|
import ast
|
||||||
|
try:
|
||||||
|
if isinstance(v, str):
|
||||||
|
suppress_tokens = ast.literal_eval(v)
|
||||||
|
if not isinstance(suppress_tokens, list):
|
||||||
|
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
||||||
|
return suppress_tokens
|
||||||
|
if isinstance(v, list):
|
||||||
|
return v
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"Invalid Suppress Tokens. The value must be type of List[int]: {e}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def to_gradio_inputs(cls,
|
||||||
|
defaults: Optional[Dict] = None,
|
||||||
|
only_advanced: Optional[bool] = True,
|
||||||
|
whisper_type: Optional[str] = None,
|
||||||
|
available_models: Optional[List] = None,
|
||||||
|
available_langs: Optional[List] = None,
|
||||||
|
available_compute_types: Optional[List] = None,
|
||||||
|
compute_type: Optional[str] = None):
|
||||||
|
whisper_type = WhisperImpl.FASTER_WHISPER.value if whisper_type is None else whisper_type.strip().lower()
|
||||||
|
|
||||||
|
inputs = []
|
||||||
|
if not only_advanced:
|
||||||
|
inputs += [
|
||||||
|
gr.Dropdown(
|
||||||
|
label=_("Model"),
|
||||||
|
choices=available_models,
|
||||||
|
value=defaults.get("model_size", cls.__fields__["model_size"].default),
|
||||||
|
),
|
||||||
|
gr.Dropdown(
|
||||||
|
label=_("Language"),
|
||||||
|
choices=available_langs,
|
||||||
|
value=defaults.get("lang", AUTOMATIC_DETECTION),
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label=_("Translate to English?"),
|
||||||
|
value=defaults.get("is_translate", cls.__fields__["is_translate"].default),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
inputs += [
|
||||||
|
gr.Number(
|
||||||
|
label="Beam Size",
|
||||||
|
value=defaults.get("beam_size", cls.__fields__["beam_size"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Beam size for decoding"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Log Probability Threshold",
|
||||||
|
value=defaults.get("log_prob_threshold", cls.__fields__["log_prob_threshold"].default),
|
||||||
|
info="Threshold for average log probability of sampled tokens"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="No Speech Threshold",
|
||||||
|
value=defaults.get("no_speech_threshold", cls.__fields__["no_speech_threshold"].default),
|
||||||
|
info="Threshold for detecting silence"
|
||||||
|
),
|
||||||
|
gr.Dropdown(
|
||||||
|
label="Compute Type",
|
||||||
|
choices=["float16", "int8", "int16"] if available_compute_types is None else available_compute_types,
|
||||||
|
value=defaults.get("compute_type", compute_type),
|
||||||
|
info="Computation type for transcription"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Best Of",
|
||||||
|
value=defaults.get("best_of", cls.__fields__["best_of"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Number of candidates when sampling"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Patience",
|
||||||
|
value=defaults.get("patience", cls.__fields__["patience"].default),
|
||||||
|
info="Beam search patience factor"
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label="Condition On Previous Text",
|
||||||
|
value=defaults.get("condition_on_previous_text", cls.__fields__["condition_on_previous_text"].default),
|
||||||
|
info="Use previous output as prompt for next window"
|
||||||
|
),
|
||||||
|
gr.Slider(
|
||||||
|
label="Prompt Reset On Temperature",
|
||||||
|
value=defaults.get("prompt_reset_on_temperature",
|
||||||
|
cls.__fields__["prompt_reset_on_temperature"].default),
|
||||||
|
minimum=0,
|
||||||
|
maximum=1,
|
||||||
|
step=0.01,
|
||||||
|
info="Temperature threshold for resetting prompt"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Initial Prompt",
|
||||||
|
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
|
||||||
|
info="Initial prompt for first window"
|
||||||
|
),
|
||||||
|
gr.Slider(
|
||||||
|
label="Temperature",
|
||||||
|
value=defaults.get("temperature", cls.__fields__["temperature"].default),
|
||||||
|
minimum=0.0,
|
||||||
|
step=0.01,
|
||||||
|
maximum=1.0,
|
||||||
|
info="Temperature for sampling"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Compression Ratio Threshold",
|
||||||
|
value=defaults.get("compression_ratio_threshold",
|
||||||
|
cls.__fields__["compression_ratio_threshold"].default),
|
||||||
|
info="Threshold for gzip compression ratio"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
faster_whisper_inputs = [
|
||||||
|
gr.Number(
|
||||||
|
label="Length Penalty",
|
||||||
|
value=defaults.get("length_penalty", cls.__fields__["length_penalty"].default),
|
||||||
|
info="Exponential length penalty",
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Repetition Penalty",
|
||||||
|
value=defaults.get("repetition_penalty", cls.__fields__["repetition_penalty"].default),
|
||||||
|
info="Penalty for repeated tokens"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="No Repeat N-gram Size",
|
||||||
|
value=defaults.get("no_repeat_ngram_size", cls.__fields__["no_repeat_ngram_size"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Size of n-grams to prevent repetition"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Prefix",
|
||||||
|
value=defaults.get("prefix", GRADIO_NONE_STR),
|
||||||
|
info="Prefix text for first window"
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label="Suppress Blank",
|
||||||
|
value=defaults.get("suppress_blank", cls.__fields__["suppress_blank"].default),
|
||||||
|
info="Suppress blank outputs at start of sampling"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Suppress Tokens",
|
||||||
|
value=defaults.get("suppress_tokens", "[-1]"),
|
||||||
|
info="Token IDs to suppress"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Max Initial Timestamp",
|
||||||
|
value=defaults.get("max_initial_timestamp", cls.__fields__["max_initial_timestamp"].default),
|
||||||
|
info="Maximum initial timestamp"
|
||||||
|
),
|
||||||
|
gr.Checkbox(
|
||||||
|
label="Word Timestamps",
|
||||||
|
value=defaults.get("word_timestamps", cls.__fields__["word_timestamps"].default),
|
||||||
|
info="Extract word-level timestamps"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Prepend Punctuations",
|
||||||
|
value=defaults.get("prepend_punctuations", cls.__fields__["prepend_punctuations"].default),
|
||||||
|
info="Punctuations to merge with next word"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Append Punctuations",
|
||||||
|
value=defaults.get("append_punctuations", cls.__fields__["append_punctuations"].default),
|
||||||
|
info="Punctuations to merge with previous word"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Max New Tokens",
|
||||||
|
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
|
||||||
|
precision=0,
|
||||||
|
info="Maximum number of new tokens per chunk"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Chunk Length (s)",
|
||||||
|
value=defaults.get("chunk_length", cls.__fields__["chunk_length"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Length of audio segments in seconds"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Hallucination Silence Threshold (sec)",
|
||||||
|
value=defaults.get("hallucination_silence_threshold",
|
||||||
|
GRADIO_NONE_NUMBER_MIN),
|
||||||
|
info="Threshold for skipping silent periods in hallucination detection"
|
||||||
|
),
|
||||||
|
gr.Textbox(
|
||||||
|
label="Hotwords",
|
||||||
|
value=defaults.get("hotwords", cls.__fields__["hotwords"].default),
|
||||||
|
info="Hotwords/hint phrases for the model"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Language Detection Threshold",
|
||||||
|
value=defaults.get("language_detection_threshold",
|
||||||
|
GRADIO_NONE_NUMBER_MIN),
|
||||||
|
info="Threshold for language detection probability"
|
||||||
|
),
|
||||||
|
gr.Number(
|
||||||
|
label="Language Detection Segments",
|
||||||
|
value=defaults.get("language_detection_segments",
|
||||||
|
cls.__fields__["language_detection_segments"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Number of segments for language detection"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
insanely_fast_whisper_inputs = [
|
||||||
|
gr.Number(
|
||||||
|
label="Batch Size",
|
||||||
|
value=defaults.get("batch_size", cls.__fields__["batch_size"].default),
|
||||||
|
precision=0,
|
||||||
|
info="Batch size for processing"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
if whisper_type != WhisperImpl.FASTER_WHISPER.value:
|
||||||
|
for input_component in faster_whisper_inputs:
|
||||||
|
input_component.visible = False
|
||||||
|
|
||||||
|
if whisper_type != WhisperImpl.INSANELY_FAST_WHISPER.value:
|
||||||
|
for input_component in insanely_fast_whisper_inputs:
|
||||||
|
input_component.visible = False
|
||||||
|
|
||||||
|
inputs += faster_whisper_inputs + insanely_fast_whisper_inputs
|
||||||
|
|
||||||
|
return inputs
|
||||||
|
|
||||||
|
|
||||||
|
class TranscriptionPipelineParams(BaseModel):
|
||||||
|
"""Transcription pipeline parameters"""
|
||||||
|
whisper: WhisperParams = Field(default_factory=WhisperParams)
|
||||||
|
vad: VadParams = Field(default_factory=VadParams)
|
||||||
|
diarization: DiarizationParams = Field(default_factory=DiarizationParams)
|
||||||
|
bgm_separation: BGMSeparationParams = Field(default_factory=BGMSeparationParams)
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict:
|
||||||
|
data = {
|
||||||
|
"whisper": self.whisper.to_dict(),
|
||||||
|
"vad": self.vad.to_dict(),
|
||||||
|
"diarization": self.diarization.to_dict(),
|
||||||
|
"bgm_separation": self.bgm_separation.to_dict()
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
|
||||||
|
def to_list(self) -> List:
|
||||||
|
"""
|
||||||
|
Convert data class to the list because I have to pass the parameters as a list in the gradio.
|
||||||
|
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
||||||
|
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
||||||
|
"""
|
||||||
|
whisper_list = self.whisper.to_list()
|
||||||
|
vad_list = self.vad.to_list()
|
||||||
|
diarization_list = self.diarization.to_list()
|
||||||
|
bgm_sep_list = self.bgm_separation.to_list()
|
||||||
|
return whisper_list + vad_list + diarization_list + bgm_sep_list
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def from_list(pipeline_list: List) -> 'TranscriptionPipelineParams':
|
||||||
|
"""Convert list to the data class again to use it in a function."""
|
||||||
|
data_list = deepcopy(pipeline_list)
|
||||||
|
|
||||||
|
whisper_list = data_list[0:len(WhisperParams.__annotations__)]
|
||||||
|
data_list = data_list[len(WhisperParams.__annotations__):]
|
||||||
|
|
||||||
|
vad_list = data_list[0:len(VadParams.__annotations__)]
|
||||||
|
data_list = data_list[len(VadParams.__annotations__):]
|
||||||
|
|
||||||
|
diarization_list = data_list[0:len(DiarizationParams.__annotations__)]
|
||||||
|
data_list = data_list[len(DiarizationParams.__annotations__):]
|
||||||
|
|
||||||
|
bgm_sep_list = data_list[0:len(BGMSeparationParams.__annotations__)]
|
||||||
|
|
||||||
|
return TranscriptionPipelineParams(
|
||||||
|
whisper=WhisperParams.from_list(whisper_list),
|
||||||
|
vad=VadParams.from_list(vad_list),
|
||||||
|
diarization=DiarizationParams.from_list(diarization_list),
|
||||||
|
bgm_separation=BGMSeparationParams.from_list(bgm_sep_list)
|
||||||
|
)
|
||||||
@@ -12,11 +12,11 @@ import gradio as gr
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
from modules.whisper.whisper_base import WhisperBase
|
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||||
|
|
||||||
|
|
||||||
class FasterWhisperInference(WhisperBase):
|
class FasterWhisperInference(BaseTranscriptionPipeline):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
model_dir: str = FASTER_WHISPER_MODELS_DIR,
|
||||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||||
@@ -40,7 +40,7 @@ class FasterWhisperInference(WhisperBase):
|
|||||||
audio: Union[str, BinaryIO, np.ndarray],
|
audio: Union[str, BinaryIO, np.ndarray],
|
||||||
progress: gr.Progress = gr.Progress(),
|
progress: gr.Progress = gr.Progress(),
|
||||||
*whisper_params,
|
*whisper_params,
|
||||||
) -> Tuple[List[dict], float]:
|
) -> Tuple[List[Segment], float]:
|
||||||
"""
|
"""
|
||||||
transcribe method for faster-whisper.
|
transcribe method for faster-whisper.
|
||||||
|
|
||||||
@@ -55,28 +55,18 @@ class FasterWhisperInference(WhisperBase):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
segments_result: List[dict]
|
segments_result: List[Segment]
|
||||||
list of dicts that includes start, end timestamps and transcribed text
|
list of Segment that includes start, end timestamps and transcribed text
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
elapsed time for transcription
|
elapsed time for transcription
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
params = WhisperParameters.as_value(*whisper_params)
|
params = WhisperParams.from_list(list(whisper_params))
|
||||||
|
|
||||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||||
self.update_model(params.model_size, params.compute_type, progress)
|
self.update_model(params.model_size, params.compute_type, progress)
|
||||||
|
|
||||||
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
|
||||||
if not params.initial_prompt:
|
|
||||||
params.initial_prompt = None
|
|
||||||
if not params.prefix:
|
|
||||||
params.prefix = None
|
|
||||||
if not params.hotwords:
|
|
||||||
params.hotwords = None
|
|
||||||
|
|
||||||
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
|
||||||
|
|
||||||
segments, info = self.model.transcribe(
|
segments, info = self.model.transcribe(
|
||||||
audio=audio,
|
audio=audio,
|
||||||
language=params.lang,
|
language=params.lang,
|
||||||
@@ -112,11 +102,7 @@ class FasterWhisperInference(WhisperBase):
|
|||||||
segments_result = []
|
segments_result = []
|
||||||
for segment in segments:
|
for segment in segments:
|
||||||
progress(segment.start / info.duration, desc="Transcribing..")
|
progress(segment.start / info.duration, desc="Transcribing..")
|
||||||
segments_result.append({
|
segments_result.append(Segment.from_faster_whisper(segment))
|
||||||
"start": segment.start,
|
|
||||||
"end": segment.end,
|
|
||||||
"text": segment.text
|
|
||||||
})
|
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
return segments_result, elapsed_time
|
return segments_result, elapsed_time
|
||||||
|
|||||||
@@ -12,11 +12,11 @@ from rich.progress import Progress, TimeElapsedColumn, BarColumn, TextColumn
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
from modules.utils.paths import (INSANELY_FAST_WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, UVR_MODELS_DIR, OUTPUT_DIR)
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
from modules.whisper.whisper_base import WhisperBase
|
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||||
|
|
||||||
|
|
||||||
class InsanelyFastWhisperInference(WhisperBase):
|
class InsanelyFastWhisperInference(BaseTranscriptionPipeline):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
model_dir: str = INSANELY_FAST_WHISPER_MODELS_DIR,
|
||||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||||
@@ -32,15 +32,13 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|||||||
self.model_dir = model_dir
|
self.model_dir = model_dir
|
||||||
os.makedirs(self.model_dir, exist_ok=True)
|
os.makedirs(self.model_dir, exist_ok=True)
|
||||||
|
|
||||||
openai_models = whisper.available_models()
|
self.available_models = self.get_model_paths()
|
||||||
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
|
||||||
self.available_models = openai_models + distil_models
|
|
||||||
|
|
||||||
def transcribe(self,
|
def transcribe(self,
|
||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
progress: gr.Progress = gr.Progress(),
|
progress: gr.Progress = gr.Progress(),
|
||||||
*whisper_params,
|
*whisper_params,
|
||||||
) -> Tuple[List[dict], float]:
|
) -> Tuple[List[Segment], float]:
|
||||||
"""
|
"""
|
||||||
transcribe method for faster-whisper.
|
transcribe method for faster-whisper.
|
||||||
|
|
||||||
@@ -55,13 +53,13 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
segments_result: List[dict]
|
segments_result: List[Segment]
|
||||||
list of dicts that includes start, end timestamps and transcribed text
|
list of Segment that includes start, end timestamps and transcribed text
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
elapsed time for transcription
|
elapsed time for transcription
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
params = WhisperParameters.as_value(*whisper_params)
|
params = WhisperParams.from_list(list(whisper_params))
|
||||||
|
|
||||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||||
self.update_model(params.model_size, params.compute_type, progress)
|
self.update_model(params.model_size, params.compute_type, progress)
|
||||||
@@ -95,9 +93,17 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|||||||
generate_kwargs=kwargs
|
generate_kwargs=kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
segments_result = self.format_result(
|
segments_result = []
|
||||||
transcribed_result=segments,
|
for item in segments["chunks"]:
|
||||||
)
|
start, end = item["timestamp"][0], item["timestamp"][1]
|
||||||
|
if end is None:
|
||||||
|
end = start
|
||||||
|
segments_result.append(Segment(
|
||||||
|
text=item["text"],
|
||||||
|
start=start,
|
||||||
|
end=end
|
||||||
|
))
|
||||||
|
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
return segments_result, elapsed_time
|
return segments_result, elapsed_time
|
||||||
|
|
||||||
@@ -138,31 +144,26 @@ class InsanelyFastWhisperInference(WhisperBase):
|
|||||||
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
|
model_kwargs={"attn_implementation": "flash_attention_2"} if is_flash_attn_2_available() else {"attn_implementation": "sdpa"},
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
def get_model_paths(self):
|
||||||
def format_result(
|
|
||||||
transcribed_result: dict
|
|
||||||
) -> List[dict]:
|
|
||||||
"""
|
"""
|
||||||
Format the transcription result of insanely_fast_whisper as the same with other implementation.
|
Get available models from models path including fine-tuned model.
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
transcribed_result: dict
|
|
||||||
Transcription result of the insanely_fast_whisper
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
result: List[dict]
|
Name set of models
|
||||||
Formatted result as the same with other implementation
|
|
||||||
"""
|
"""
|
||||||
result = transcribed_result["chunks"]
|
openai_models = whisper.available_models()
|
||||||
for item in result:
|
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
||||||
start, end = item["timestamp"][0], item["timestamp"][1]
|
default_models = openai_models + distil_models
|
||||||
if end is None:
|
|
||||||
end = start
|
existing_models = os.listdir(self.model_dir)
|
||||||
item["start"] = start
|
wrong_dirs = [".locks"]
|
||||||
item["end"] = end
|
|
||||||
return result
|
available_models = default_models + existing_models
|
||||||
|
available_models = [model for model in available_models if model not in wrong_dirs]
|
||||||
|
available_models = sorted(set(available_models), key=available_models.index)
|
||||||
|
|
||||||
|
return available_models
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def download_model(
|
def download_model(
|
||||||
|
|||||||
@@ -8,11 +8,11 @@ import os
|
|||||||
from argparse import Namespace
|
from argparse import Namespace
|
||||||
|
|
||||||
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, UVR_MODELS_DIR)
|
||||||
from modules.whisper.whisper_base import WhisperBase
|
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||||
from modules.whisper.whisper_parameter import *
|
from modules.whisper.data_classes import *
|
||||||
|
|
||||||
|
|
||||||
class WhisperInference(WhisperBase):
|
class WhisperInference(BaseTranscriptionPipeline):
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
model_dir: str = WHISPER_MODELS_DIR,
|
model_dir: str = WHISPER_MODELS_DIR,
|
||||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||||
@@ -30,7 +30,7 @@ class WhisperInference(WhisperBase):
|
|||||||
audio: Union[str, np.ndarray, torch.Tensor],
|
audio: Union[str, np.ndarray, torch.Tensor],
|
||||||
progress: gr.Progress = gr.Progress(),
|
progress: gr.Progress = gr.Progress(),
|
||||||
*whisper_params,
|
*whisper_params,
|
||||||
) -> Tuple[List[dict], float]:
|
) -> Tuple[List[Segment], float]:
|
||||||
"""
|
"""
|
||||||
transcribe method for faster-whisper.
|
transcribe method for faster-whisper.
|
||||||
|
|
||||||
@@ -45,13 +45,13 @@ class WhisperInference(WhisperBase):
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
----------
|
----------
|
||||||
segments_result: List[dict]
|
segments_result: List[Segment]
|
||||||
list of dicts that includes start, end timestamps and transcribed text
|
list of Segment that includes start, end timestamps and transcribed text
|
||||||
elapsed_time: float
|
elapsed_time: float
|
||||||
elapsed time for transcription
|
elapsed time for transcription
|
||||||
"""
|
"""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
params = WhisperParameters.as_value(*whisper_params)
|
params = WhisperParams.from_list(list(whisper_params))
|
||||||
|
|
||||||
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
||||||
self.update_model(params.model_size, params.compute_type, progress)
|
self.update_model(params.model_size, params.compute_type, progress)
|
||||||
@@ -59,21 +59,28 @@ class WhisperInference(WhisperBase):
|
|||||||
def progress_callback(progress_value):
|
def progress_callback(progress_value):
|
||||||
progress(progress_value, desc="Transcribing..")
|
progress(progress_value, desc="Transcribing..")
|
||||||
|
|
||||||
segments_result = self.model.transcribe(audio=audio,
|
result = self.model.transcribe(audio=audio,
|
||||||
language=params.lang,
|
language=params.lang,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
beam_size=params.beam_size,
|
beam_size=params.beam_size,
|
||||||
logprob_threshold=params.log_prob_threshold,
|
logprob_threshold=params.log_prob_threshold,
|
||||||
no_speech_threshold=params.no_speech_threshold,
|
no_speech_threshold=params.no_speech_threshold,
|
||||||
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
task="translate" if params.is_translate and self.current_model_size in self.translatable_models else "transcribe",
|
||||||
fp16=True if params.compute_type == "float16" else False,
|
fp16=True if params.compute_type == "float16" else False,
|
||||||
best_of=params.best_of,
|
best_of=params.best_of,
|
||||||
patience=params.patience,
|
patience=params.patience,
|
||||||
temperature=params.temperature,
|
temperature=params.temperature,
|
||||||
compression_ratio_threshold=params.compression_ratio_threshold,
|
compression_ratio_threshold=params.compression_ratio_threshold,
|
||||||
progress_callback=progress_callback,)["segments"]
|
progress_callback=progress_callback,)["segments"]
|
||||||
elapsed_time = time.time() - start_time
|
segments_result = []
|
||||||
|
for segment in result:
|
||||||
|
segments_result.append(Segment(
|
||||||
|
start=segment["start"],
|
||||||
|
end=segment["end"],
|
||||||
|
text=segment["text"]
|
||||||
|
))
|
||||||
|
|
||||||
|
elapsed_time = time.time() - start_time
|
||||||
return segments_result, elapsed_time
|
return segments_result, elapsed_time
|
||||||
|
|
||||||
def update_model(self,
|
def update_model(self,
|
||||||
|
|||||||
@@ -6,7 +6,8 @@ from modules.utils.paths import (FASTER_WHISPER_MODELS_DIR, DIARIZATION_MODELS_D
|
|||||||
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
||||||
from modules.whisper.whisper_Inference import WhisperInference
|
from modules.whisper.whisper_Inference import WhisperInference
|
||||||
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
||||||
from modules.whisper.whisper_base import WhisperBase
|
from modules.whisper.base_transcription_pipeline import BaseTranscriptionPipeline
|
||||||
|
from modules.whisper.data_classes import *
|
||||||
|
|
||||||
|
|
||||||
class WhisperFactory:
|
class WhisperFactory:
|
||||||
@@ -19,7 +20,7 @@ class WhisperFactory:
|
|||||||
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
diarization_model_dir: str = DIARIZATION_MODELS_DIR,
|
||||||
uvr_model_dir: str = UVR_MODELS_DIR,
|
uvr_model_dir: str = UVR_MODELS_DIR,
|
||||||
output_dir: str = OUTPUT_DIR,
|
output_dir: str = OUTPUT_DIR,
|
||||||
) -> "WhisperBase":
|
) -> "BaseTranscriptionPipeline":
|
||||||
"""
|
"""
|
||||||
Create a whisper inference class based on the provided whisper_type.
|
Create a whisper inference class based on the provided whisper_type.
|
||||||
|
|
||||||
@@ -45,36 +46,29 @@ class WhisperFactory:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
WhisperBase
|
BaseTranscriptionPipeline
|
||||||
An instance of the appropriate whisper inference class based on the whisper_type.
|
An instance of the appropriate whisper inference class based on the whisper_type.
|
||||||
"""
|
"""
|
||||||
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
||||||
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
||||||
|
|
||||||
whisper_type = whisper_type.lower().strip()
|
whisper_type = whisper_type.strip().lower()
|
||||||
|
|
||||||
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
|
if whisper_type == WhisperImpl.FASTER_WHISPER.value:
|
||||||
whisper_typos = ["whisper"]
|
|
||||||
insanely_fast_whisper_typos = [
|
|
||||||
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
|
||||||
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
|
||||||
]
|
|
||||||
|
|
||||||
if whisper_type in faster_whisper_typos:
|
|
||||||
return FasterWhisperInference(
|
return FasterWhisperInference(
|
||||||
model_dir=faster_whisper_model_dir,
|
model_dir=faster_whisper_model_dir,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
diarization_model_dir=diarization_model_dir,
|
diarization_model_dir=diarization_model_dir,
|
||||||
uvr_model_dir=uvr_model_dir
|
uvr_model_dir=uvr_model_dir
|
||||||
)
|
)
|
||||||
elif whisper_type in whisper_typos:
|
elif whisper_type == WhisperImpl.WHISPER.value:
|
||||||
return WhisperInference(
|
return WhisperInference(
|
||||||
model_dir=whisper_model_dir,
|
model_dir=whisper_model_dir,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
diarization_model_dir=diarization_model_dir,
|
diarization_model_dir=diarization_model_dir,
|
||||||
uvr_model_dir=uvr_model_dir
|
uvr_model_dir=uvr_model_dir
|
||||||
)
|
)
|
||||||
elif whisper_type in insanely_fast_whisper_typos:
|
elif whisper_type == WhisperImpl.INSANELY_FAST_WHISPER.value:
|
||||||
return InsanelyFastWhisperInference(
|
return InsanelyFastWhisperInference(
|
||||||
model_dir=insanely_fast_whisper_model_dir,
|
model_dir=insanely_fast_whisper_model_dir,
|
||||||
output_dir=output_dir,
|
output_dir=output_dir,
|
||||||
|
|||||||
@@ -1,371 +0,0 @@
|
|||||||
from dataclasses import dataclass, fields
|
|
||||||
import gradio as gr
|
|
||||||
from typing import Optional, Dict
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from modules.utils.constants import AUTOMATIC_DETECTION
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WhisperParameters:
|
|
||||||
model_size: gr.Dropdown
|
|
||||||
lang: gr.Dropdown
|
|
||||||
is_translate: gr.Checkbox
|
|
||||||
beam_size: gr.Number
|
|
||||||
log_prob_threshold: gr.Number
|
|
||||||
no_speech_threshold: gr.Number
|
|
||||||
compute_type: gr.Dropdown
|
|
||||||
best_of: gr.Number
|
|
||||||
patience: gr.Number
|
|
||||||
condition_on_previous_text: gr.Checkbox
|
|
||||||
prompt_reset_on_temperature: gr.Slider
|
|
||||||
initial_prompt: gr.Textbox
|
|
||||||
temperature: gr.Slider
|
|
||||||
compression_ratio_threshold: gr.Number
|
|
||||||
vad_filter: gr.Checkbox
|
|
||||||
threshold: gr.Slider
|
|
||||||
min_speech_duration_ms: gr.Number
|
|
||||||
max_speech_duration_s: gr.Number
|
|
||||||
min_silence_duration_ms: gr.Number
|
|
||||||
speech_pad_ms: gr.Number
|
|
||||||
batch_size: gr.Number
|
|
||||||
is_diarize: gr.Checkbox
|
|
||||||
hf_token: gr.Textbox
|
|
||||||
diarization_device: gr.Dropdown
|
|
||||||
length_penalty: gr.Number
|
|
||||||
repetition_penalty: gr.Number
|
|
||||||
no_repeat_ngram_size: gr.Number
|
|
||||||
prefix: gr.Textbox
|
|
||||||
suppress_blank: gr.Checkbox
|
|
||||||
suppress_tokens: gr.Textbox
|
|
||||||
max_initial_timestamp: gr.Number
|
|
||||||
word_timestamps: gr.Checkbox
|
|
||||||
prepend_punctuations: gr.Textbox
|
|
||||||
append_punctuations: gr.Textbox
|
|
||||||
max_new_tokens: gr.Number
|
|
||||||
chunk_length: gr.Number
|
|
||||||
hallucination_silence_threshold: gr.Number
|
|
||||||
hotwords: gr.Textbox
|
|
||||||
language_detection_threshold: gr.Number
|
|
||||||
language_detection_segments: gr.Number
|
|
||||||
is_bgm_separate: gr.Checkbox
|
|
||||||
uvr_model_size: gr.Dropdown
|
|
||||||
uvr_device: gr.Dropdown
|
|
||||||
uvr_segment_size: gr.Number
|
|
||||||
uvr_save_file: gr.Checkbox
|
|
||||||
uvr_enable_offload: gr.Checkbox
|
|
||||||
"""
|
|
||||||
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
|
||||||
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
||||||
Related Gradio issue: https://github.com/gradio-app/gradio/issues/2471
|
|
||||||
See more about Gradio pre-processing: https://www.gradio.app/docs/components
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
model_size: gr.Dropdown
|
|
||||||
Whisper model size.
|
|
||||||
|
|
||||||
lang: gr.Dropdown
|
|
||||||
Source language of the file to transcribe.
|
|
||||||
|
|
||||||
is_translate: gr.Checkbox
|
|
||||||
Boolean value that determines whether to translate to English.
|
|
||||||
It's Whisper's feature to translate speech from another language directly into English end-to-end.
|
|
||||||
|
|
||||||
beam_size: gr.Number
|
|
||||||
Int value that is used for decoding option.
|
|
||||||
|
|
||||||
log_prob_threshold: gr.Number
|
|
||||||
If the average log probability over sampled tokens is below this value, treat as failed.
|
|
||||||
|
|
||||||
no_speech_threshold: gr.Number
|
|
||||||
If the no_speech probability is higher than this value AND
|
|
||||||
the average log probability over sampled tokens is below `log_prob_threshold`,
|
|
||||||
consider the segment as silent.
|
|
||||||
|
|
||||||
compute_type: gr.Dropdown
|
|
||||||
compute type for transcription.
|
|
||||||
see more info : https://opennmt.net/CTranslate2/quantization.html
|
|
||||||
|
|
||||||
best_of: gr.Number
|
|
||||||
Number of candidates when sampling with non-zero temperature.
|
|
||||||
|
|
||||||
patience: gr.Number
|
|
||||||
Beam search patience factor.
|
|
||||||
|
|
||||||
condition_on_previous_text: gr.Checkbox
|
|
||||||
if True, the previous output of the model is provided as a prompt for the next window;
|
|
||||||
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
|
||||||
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
|
||||||
|
|
||||||
initial_prompt: gr.Textbox
|
|
||||||
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
|
||||||
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
|
||||||
to make it more likely to predict those word correctly.
|
|
||||||
|
|
||||||
temperature: gr.Slider
|
|
||||||
Temperature for sampling. It can be a tuple of temperatures,
|
|
||||||
which will be successively used upon failures according to either
|
|
||||||
`compression_ratio_threshold` or `log_prob_threshold`.
|
|
||||||
|
|
||||||
compression_ratio_threshold: gr.Number
|
|
||||||
If the gzip compression ratio is above this value, treat as failed
|
|
||||||
|
|
||||||
vad_filter: gr.Checkbox
|
|
||||||
Enable the voice activity detection (VAD) to filter out parts of the audio
|
|
||||||
without speech. This step is using the Silero VAD model
|
|
||||||
https://github.com/snakers4/silero-vad.
|
|
||||||
|
|
||||||
threshold: gr.Slider
|
|
||||||
This parameter is related with Silero VAD. Speech threshold.
|
|
||||||
Silero VAD outputs speech probabilities for each audio chunk,
|
|
||||||
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
|
||||||
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
|
||||||
|
|
||||||
min_speech_duration_ms: gr.Number
|
|
||||||
This parameter is related with Silero VAD. Final speech chunks shorter min_speech_duration_ms are thrown out.
|
|
||||||
|
|
||||||
max_speech_duration_s: gr.Number
|
|
||||||
This parameter is related with Silero VAD. Maximum duration of speech chunks in seconds. Chunks longer
|
|
||||||
than max_speech_duration_s will be split at the timestamp of the last silence that
|
|
||||||
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
|
||||||
split aggressively just before max_speech_duration_s.
|
|
||||||
|
|
||||||
min_silence_duration_ms: gr.Number
|
|
||||||
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
|
||||||
before separating it
|
|
||||||
|
|
||||||
speech_pad_ms: gr.Number
|
|
||||||
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
|
||||||
|
|
||||||
batch_size: gr.Number
|
|
||||||
This parameter is related with insanely-fast-whisper pipe. Batch size to pass to the pipe
|
|
||||||
|
|
||||||
is_diarize: gr.Checkbox
|
|
||||||
This parameter is related with whisperx. Boolean value that determines whether to diarize or not.
|
|
||||||
|
|
||||||
hf_token: gr.Textbox
|
|
||||||
This parameter is related with whisperx. Huggingface token is needed to download diarization models.
|
|
||||||
Read more about : https://huggingface.co/pyannote/speaker-diarization-3.1#requirements
|
|
||||||
|
|
||||||
diarization_device: gr.Dropdown
|
|
||||||
This parameter is related with whisperx. Device to run diarization model
|
|
||||||
|
|
||||||
length_penalty: gr.Number
|
|
||||||
This parameter is related to faster-whisper. Exponential length penalty constant.
|
|
||||||
|
|
||||||
repetition_penalty: gr.Number
|
|
||||||
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
|
||||||
(set > 1 to penalize).
|
|
||||||
|
|
||||||
no_repeat_ngram_size: gr.Number
|
|
||||||
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
|
||||||
|
|
||||||
prefix: gr.Textbox
|
|
||||||
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
|
||||||
|
|
||||||
suppress_blank: gr.Checkbox
|
|
||||||
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
|
||||||
|
|
||||||
suppress_tokens: gr.Textbox
|
|
||||||
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
|
||||||
of symbols as defined in the model config.json file.
|
|
||||||
|
|
||||||
max_initial_timestamp: gr.Number
|
|
||||||
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
|
||||||
|
|
||||||
word_timestamps: gr.Checkbox
|
|
||||||
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
|
||||||
and dynamic time warping, and include the timestamps for each word in each segment.
|
|
||||||
|
|
||||||
prepend_punctuations: gr.Textbox
|
|
||||||
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
|
||||||
with the next word.
|
|
||||||
|
|
||||||
append_punctuations: gr.Textbox
|
|
||||||
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
|
||||||
with the previous word.
|
|
||||||
|
|
||||||
max_new_tokens: gr.Number
|
|
||||||
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
|
||||||
the maximum will be set by the default max_length.
|
|
||||||
|
|
||||||
chunk_length: gr.Number
|
|
||||||
This parameter is related to faster-whisper and insanely-fast-whisper. The length of audio segments in seconds.
|
|
||||||
If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.
|
|
||||||
|
|
||||||
hallucination_silence_threshold: gr.Number
|
|
||||||
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
|
||||||
(in seconds) when a possible hallucination is detected.
|
|
||||||
|
|
||||||
hotwords: gr.Textbox
|
|
||||||
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
|
||||||
|
|
||||||
language_detection_threshold: gr.Number
|
|
||||||
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
|
||||||
|
|
||||||
language_detection_segments: gr.Number
|
|
||||||
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
|
||||||
|
|
||||||
is_separate_bgm: gr.Checkbox
|
|
||||||
This parameter is related to UVR. Boolean value that determines whether to separate bgm or not.
|
|
||||||
|
|
||||||
uvr_model_size: gr.Dropdown
|
|
||||||
This parameter is related to UVR. UVR model size.
|
|
||||||
|
|
||||||
uvr_device: gr.Dropdown
|
|
||||||
This parameter is related to UVR. Device to run UVR model.
|
|
||||||
|
|
||||||
uvr_segment_size: gr.Number
|
|
||||||
This parameter is related to UVR. Segment size for UVR model.
|
|
||||||
|
|
||||||
uvr_save_file: gr.Checkbox
|
|
||||||
This parameter is related to UVR. Boolean value that determines whether to save the file or not.
|
|
||||||
|
|
||||||
uvr_enable_offload: gr.Checkbox
|
|
||||||
This parameter is related to UVR. Boolean value that determines whether to offload the UVR model or not
|
|
||||||
after each transcription.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def as_list(self) -> list:
|
|
||||||
"""
|
|
||||||
Converts the data class attributes into a list, Use in Gradio UI before Gradio pre-processing.
|
|
||||||
See more about Gradio pre-processing: : https://www.gradio.app/docs/components
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
A list of Gradio components
|
|
||||||
"""
|
|
||||||
return [getattr(self, f.name) for f in fields(self)]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def as_value(*args) -> 'WhisperValues':
|
|
||||||
"""
|
|
||||||
To use Whisper parameters in function after Gradio post-processing.
|
|
||||||
See more about Gradio post-processing: : https://www.gradio.app/docs/components
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
WhisperValues
|
|
||||||
Data class that has values of parameters
|
|
||||||
"""
|
|
||||||
return WhisperValues(*args)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class WhisperValues:
|
|
||||||
model_size: str = "large-v2"
|
|
||||||
lang: Optional[str] = None
|
|
||||||
is_translate: bool = False
|
|
||||||
beam_size: int = 5
|
|
||||||
log_prob_threshold: float = -1.0
|
|
||||||
no_speech_threshold: float = 0.6
|
|
||||||
compute_type: str = "float16"
|
|
||||||
best_of: int = 5
|
|
||||||
patience: float = 1.0
|
|
||||||
condition_on_previous_text: bool = True
|
|
||||||
prompt_reset_on_temperature: float = 0.5
|
|
||||||
initial_prompt: Optional[str] = None
|
|
||||||
temperature: float = 0.0
|
|
||||||
compression_ratio_threshold: float = 2.4
|
|
||||||
vad_filter: bool = False
|
|
||||||
threshold: float = 0.5
|
|
||||||
min_speech_duration_ms: int = 250
|
|
||||||
max_speech_duration_s: float = float("inf")
|
|
||||||
min_silence_duration_ms: int = 2000
|
|
||||||
speech_pad_ms: int = 400
|
|
||||||
batch_size: int = 24
|
|
||||||
is_diarize: bool = False
|
|
||||||
hf_token: str = ""
|
|
||||||
diarization_device: str = "cuda"
|
|
||||||
length_penalty: float = 1.0
|
|
||||||
repetition_penalty: float = 1.0
|
|
||||||
no_repeat_ngram_size: int = 0
|
|
||||||
prefix: Optional[str] = None
|
|
||||||
suppress_blank: bool = True
|
|
||||||
suppress_tokens: Optional[str] = "[-1]"
|
|
||||||
max_initial_timestamp: float = 0.0
|
|
||||||
word_timestamps: bool = False
|
|
||||||
prepend_punctuations: Optional[str] = "\"'“¿([{-"
|
|
||||||
append_punctuations: Optional[str] = "\"'.。,,!!??::”)]}、"
|
|
||||||
max_new_tokens: Optional[int] = None
|
|
||||||
chunk_length: Optional[int] = 30
|
|
||||||
hallucination_silence_threshold: Optional[float] = None
|
|
||||||
hotwords: Optional[str] = None
|
|
||||||
language_detection_threshold: Optional[float] = None
|
|
||||||
language_detection_segments: int = 1
|
|
||||||
is_bgm_separate: bool = False
|
|
||||||
uvr_model_size: str = "UVR-MDX-NET-Inst_HQ_4"
|
|
||||||
uvr_device: str = "cuda"
|
|
||||||
uvr_segment_size: int = 256
|
|
||||||
uvr_save_file: bool = False
|
|
||||||
uvr_enable_offload: bool = True
|
|
||||||
"""
|
|
||||||
A data class to use Whisper parameters.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def to_yaml(self) -> Dict:
|
|
||||||
data = {
|
|
||||||
"whisper": {
|
|
||||||
"model_size": self.model_size,
|
|
||||||
"lang": AUTOMATIC_DETECTION.unwrap() if self.lang is None else self.lang,
|
|
||||||
"is_translate": self.is_translate,
|
|
||||||
"beam_size": self.beam_size,
|
|
||||||
"log_prob_threshold": self.log_prob_threshold,
|
|
||||||
"no_speech_threshold": self.no_speech_threshold,
|
|
||||||
"best_of": self.best_of,
|
|
||||||
"patience": self.patience,
|
|
||||||
"condition_on_previous_text": self.condition_on_previous_text,
|
|
||||||
"prompt_reset_on_temperature": self.prompt_reset_on_temperature,
|
|
||||||
"initial_prompt": None if not self.initial_prompt else self.initial_prompt,
|
|
||||||
"temperature": self.temperature,
|
|
||||||
"compression_ratio_threshold": self.compression_ratio_threshold,
|
|
||||||
"batch_size": self.batch_size,
|
|
||||||
"length_penalty": self.length_penalty,
|
|
||||||
"repetition_penalty": self.repetition_penalty,
|
|
||||||
"no_repeat_ngram_size": self.no_repeat_ngram_size,
|
|
||||||
"prefix": None if not self.prefix else self.prefix,
|
|
||||||
"suppress_blank": self.suppress_blank,
|
|
||||||
"suppress_tokens": self.suppress_tokens,
|
|
||||||
"max_initial_timestamp": self.max_initial_timestamp,
|
|
||||||
"word_timestamps": self.word_timestamps,
|
|
||||||
"prepend_punctuations": self.prepend_punctuations,
|
|
||||||
"append_punctuations": self.append_punctuations,
|
|
||||||
"max_new_tokens": self.max_new_tokens,
|
|
||||||
"chunk_length": self.chunk_length,
|
|
||||||
"hallucination_silence_threshold": self.hallucination_silence_threshold,
|
|
||||||
"hotwords": None if not self.hotwords else self.hotwords,
|
|
||||||
"language_detection_threshold": self.language_detection_threshold,
|
|
||||||
"language_detection_segments": self.language_detection_segments,
|
|
||||||
},
|
|
||||||
"vad": {
|
|
||||||
"vad_filter": self.vad_filter,
|
|
||||||
"threshold": self.threshold,
|
|
||||||
"min_speech_duration_ms": self.min_speech_duration_ms,
|
|
||||||
"max_speech_duration_s": self.max_speech_duration_s,
|
|
||||||
"min_silence_duration_ms": self.min_silence_duration_ms,
|
|
||||||
"speech_pad_ms": self.speech_pad_ms,
|
|
||||||
},
|
|
||||||
"diarization": {
|
|
||||||
"is_diarize": self.is_diarize,
|
|
||||||
"hf_token": self.hf_token
|
|
||||||
},
|
|
||||||
"bgm_separation": {
|
|
||||||
"is_separate_bgm": self.is_bgm_separate,
|
|
||||||
"model_size": self.uvr_model_size,
|
|
||||||
"segment_size": self.uvr_segment_size,
|
|
||||||
"save_file": self.uvr_save_file,
|
|
||||||
"enable_offload": self.uvr_enable_offload
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return data
|
|
||||||
|
|
||||||
def as_list(self) -> list:
|
|
||||||
"""
|
|
||||||
Converts the data class attributes into a list
|
|
||||||
|
|
||||||
Returns
|
|
||||||
----------
|
|
||||||
A list of Whisper parameters
|
|
||||||
"""
|
|
||||||
return [getattr(self, f.name) for f in fields(self)]
|
|
||||||
@@ -56,7 +56,7 @@
|
|||||||
"!pip install faster-whisper==1.0.3\n",
|
"!pip install faster-whisper==1.0.3\n",
|
||||||
"!pip install ctranslate2==4.4.0\n",
|
"!pip install ctranslate2==4.4.0\n",
|
||||||
"!pip install gradio\n",
|
"!pip install gradio\n",
|
||||||
"!pip install git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error\n",
|
"!pip install gradio-i18n\n",
|
||||||
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n",
|
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/256\n",
|
||||||
"!pip install git+https://github.com/JuanBindez/pytubefix.git\n",
|
"!pip install git+https://github.com/JuanBindez/pytubefix.git\n",
|
||||||
"!pip install tokenizers==0.19.1\n",
|
"!pip install tokenizers==0.19.1\n",
|
||||||
@@ -99,7 +99,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 2,
|
"execution_count": 3,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"id": "PQroYRRZzQiN",
|
"id": "PQroYRRZzQiN",
|
||||||
"cellView": "form"
|
"cellView": "form"
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
|
|||||||
faster-whisper==1.0.3
|
faster-whisper==1.0.3
|
||||||
transformers
|
transformers
|
||||||
gradio
|
gradio
|
||||||
git+https://github.com/jhj0517/gradio-i18n.git@fix/encoding-error
|
gradio-i18n
|
||||||
pytubefix
|
pytubefix
|
||||||
ruamel.yaml==0.18.6
|
ruamel.yaml==0.18.6
|
||||||
pyannote.audio==3.3.1
|
pyannote.audio==3.3.1
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from modules.utils.paths import *
|
from modules.utils.paths import *
|
||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.whisper_parameter import WhisperValues
|
from modules.whisper.data_classes import *
|
||||||
from test_config import *
|
from test_config import *
|
||||||
from test_transcription import download_file, test_transcribe
|
from test_transcription import download_file, test_transcribe
|
||||||
|
|
||||||
@@ -17,9 +17,9 @@ import os
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||||
[
|
[
|
||||||
("whisper", False, True, False),
|
(WhisperImpl.WHISPER.value, False, True, False),
|
||||||
("faster-whisper", False, True, False),
|
(WhisperImpl.FASTER_WHISPER.value, False, True, False),
|
||||||
("insanely_fast_whisper", False, True, False)
|
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, True, False)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_bgm_separation_pipeline(
|
def test_bgm_separation_pipeline(
|
||||||
@@ -38,9 +38,9 @@ def test_bgm_separation_pipeline(
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||||
[
|
[
|
||||||
("whisper", True, True, False),
|
(WhisperImpl.WHISPER.value, True, True, False),
|
||||||
("faster-whisper", True, True, False),
|
(WhisperImpl.FASTER_WHISPER.value, True, True, False),
|
||||||
("insanely_fast_whisper", True, True, False)
|
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, True, False)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_bgm_separation_with_vad_pipeline(
|
def test_bgm_separation_with_vad_pipeline(
|
||||||
|
|||||||
@@ -1,10 +1,14 @@
|
|||||||
from modules.utils.paths import *
|
import functools
|
||||||
|
import jiwer
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from modules.utils.paths import *
|
||||||
|
from modules.utils.youtube_manager import *
|
||||||
|
|
||||||
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
TEST_FILE_DOWNLOAD_URL = "https://github.com/jhj0517/whisper_flutter_new/raw/main/example/assets/jfk.wav"
|
||||||
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
TEST_FILE_PATH = os.path.join(WEBUI_DIR, "tests", "jfk.wav")
|
||||||
|
TEST_ANSWER = "And so my fellow Americans ask not what your country can do for you ask what you can do for your country"
|
||||||
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
TEST_YOUTUBE_URL = "https://www.youtube.com/watch?v=4WEQtgnBu0I&ab_channel=AndriaFitzer"
|
||||||
TEST_WHISPER_MODEL = "tiny"
|
TEST_WHISPER_MODEL = "tiny"
|
||||||
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
TEST_UVR_MODEL = "UVR-MDX-NET-Inst_HQ_4"
|
||||||
@@ -13,5 +17,24 @@ TEST_SUBTITLE_SRT_PATH = os.path.join(WEBUI_DIR, "tests", "test_srt.srt")
|
|||||||
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
TEST_SUBTITLE_VTT_PATH = os.path.join(WEBUI_DIR, "tests", "test_vtt.vtt")
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
def is_cuda_available():
|
def is_cuda_available():
|
||||||
return torch.cuda.is_available()
|
return torch.cuda.is_available()
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache
|
||||||
|
def is_pytube_detected_bot(url: str = TEST_YOUTUBE_URL):
|
||||||
|
try:
|
||||||
|
yt_temp_path = os.path.join("modules", "yt_tmp.wav")
|
||||||
|
if os.path.exists(yt_temp_path):
|
||||||
|
return False
|
||||||
|
yt = get_ytdata(url)
|
||||||
|
audio = get_ytaudio(yt)
|
||||||
|
return False
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Pytube has detected as a bot: {e}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_wer(answer, prediction):
|
||||||
|
return jiwer.wer(answer, prediction)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from modules.utils.paths import *
|
from modules.utils.paths import *
|
||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.whisper_parameter import WhisperValues
|
from modules.whisper.data_classes import *
|
||||||
from test_config import *
|
from test_config import *
|
||||||
from test_transcription import download_file, test_transcribe
|
from test_transcription import download_file, test_transcribe
|
||||||
|
|
||||||
@@ -16,9 +16,9 @@ import os
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||||
[
|
[
|
||||||
("whisper", False, False, True),
|
(WhisperImpl.WHISPER.value, False, False, True),
|
||||||
("faster-whisper", False, False, True),
|
(WhisperImpl.FASTER_WHISPER.value, False, False, True),
|
||||||
("insanely_fast_whisper", False, False, True)
|
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, True)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_diarization_pipeline(
|
def test_diarization_pipeline(
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.whisper_parameter import WhisperValues
|
from modules.whisper.data_classes import *
|
||||||
|
from modules.utils.subtitle_manager import read_file
|
||||||
from modules.utils.paths import WEBUI_DIR
|
from modules.utils.paths import WEBUI_DIR
|
||||||
from test_config import *
|
from test_config import *
|
||||||
|
|
||||||
@@ -12,9 +13,9 @@ import os
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||||
[
|
[
|
||||||
("whisper", False, False, False),
|
(WhisperImpl.WHISPER.value, False, False, False),
|
||||||
("faster-whisper", False, False, False),
|
(WhisperImpl.FASTER_WHISPER.value, False, False, False),
|
||||||
("insanely_fast_whisper", False, False, False)
|
(WhisperImpl.INSANELY_FAST_WHISPER.value, False, False, False)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_transcribe(
|
def test_transcribe(
|
||||||
@@ -28,6 +29,10 @@ def test_transcribe(
|
|||||||
if not os.path.exists(audio_path):
|
if not os.path.exists(audio_path):
|
||||||
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
||||||
|
|
||||||
|
answer = TEST_ANSWER
|
||||||
|
if diarization:
|
||||||
|
answer = "SPEAKER_00|"+TEST_ANSWER
|
||||||
|
|
||||||
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
||||||
whisper_type=whisper_type,
|
whisper_type=whisper_type,
|
||||||
)
|
)
|
||||||
@@ -37,16 +42,24 @@ def test_transcribe(
|
|||||||
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
f"""Diarization Device: {whisper_inferencer.diarizer.device}"""
|
||||||
)
|
)
|
||||||
|
|
||||||
hparams = WhisperValues(
|
hparams = TranscriptionPipelineParams(
|
||||||
model_size=TEST_WHISPER_MODEL,
|
whisper=WhisperParams(
|
||||||
vad_filter=vad_filter,
|
model_size=TEST_WHISPER_MODEL,
|
||||||
is_bgm_separate=bgm_separation,
|
compute_type=whisper_inferencer.current_compute_type
|
||||||
compute_type=whisper_inferencer.current_compute_type,
|
),
|
||||||
uvr_enable_offload=True,
|
vad=VadParams(
|
||||||
is_diarize=diarization,
|
vad_filter=vad_filter
|
||||||
).as_list()
|
),
|
||||||
|
bgm_separation=BGMSeparationParams(
|
||||||
|
is_separate_bgm=bgm_separation,
|
||||||
|
enable_offload=True
|
||||||
|
),
|
||||||
|
diarization=DiarizationParams(
|
||||||
|
is_diarize=diarization
|
||||||
|
),
|
||||||
|
).to_list()
|
||||||
|
|
||||||
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
|
||||||
[audio_path],
|
[audio_path],
|
||||||
None,
|
None,
|
||||||
"SRT",
|
"SRT",
|
||||||
@@ -54,29 +67,29 @@ def test_transcribe(
|
|||||||
gr.Progress(),
|
gr.Progress(),
|
||||||
*hparams,
|
*hparams,
|
||||||
)
|
)
|
||||||
|
subtitle = read_file(file_paths[0]).split("\n")
|
||||||
|
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||||
|
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
if not is_pytube_detected_bot():
|
||||||
assert isinstance(file_path[0], str) and file_path
|
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
|
||||||
|
TEST_YOUTUBE_URL,
|
||||||
|
"SRT",
|
||||||
|
False,
|
||||||
|
gr.Progress(),
|
||||||
|
*hparams,
|
||||||
|
)
|
||||||
|
assert isinstance(subtitle_str, str) and subtitle_str
|
||||||
|
assert os.path.exists(file_path)
|
||||||
|
|
||||||
whisper_inferencer.transcribe_youtube(
|
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
|
||||||
TEST_YOUTUBE_URL,
|
|
||||||
"SRT",
|
|
||||||
False,
|
|
||||||
gr.Progress(),
|
|
||||||
*hparams,
|
|
||||||
)
|
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
|
||||||
assert isinstance(file_path[0], str) and file_path
|
|
||||||
|
|
||||||
whisper_inferencer.transcribe_mic(
|
|
||||||
audio_path,
|
audio_path,
|
||||||
"SRT",
|
"SRT",
|
||||||
False,
|
False,
|
||||||
gr.Progress(),
|
gr.Progress(),
|
||||||
*hparams,
|
*hparams,
|
||||||
)
|
)
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
subtitle = read_file(file_path).split("\n")
|
||||||
assert isinstance(file_path[0], str) and file_path
|
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||||
|
|
||||||
|
|
||||||
def download_file(url, save_dir):
|
def download_file(url, save_dir):
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
from modules.utils.paths import *
|
from modules.utils.paths import *
|
||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.whisper_parameter import WhisperValues
|
from modules.whisper.data_classes import *
|
||||||
from test_config import *
|
from test_config import *
|
||||||
from test_transcription import download_file, test_transcribe
|
from test_transcription import download_file, test_transcribe
|
||||||
|
|
||||||
@@ -12,9 +12,9 @@ import os
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"whisper_type,vad_filter,bgm_separation,diarization",
|
"whisper_type,vad_filter,bgm_separation,diarization",
|
||||||
[
|
[
|
||||||
("whisper", True, False, False),
|
(WhisperImpl.WHISPER.value, True, False, False),
|
||||||
("faster-whisper", True, False, False),
|
(WhisperImpl.FASTER_WHISPER.value, True, False, False),
|
||||||
("insanely_fast_whisper", True, False, False)
|
(WhisperImpl.INSANELY_FAST_WHISPER.value, True, False, False)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_vad_pipeline(
|
def test_vad_pipeline(
|
||||||
|
|||||||
Reference in New Issue
Block a user