Use constant for gradio none validation values

This commit is contained in:
jhj0517
2024-10-29 00:15:41 +09:00
parent 19e342ad3b
commit 2a2f7c60fa
3 changed files with 19 additions and 16 deletions

View File

@@ -1,3 +1,6 @@
from gradio_i18n import Translate, gettext as _
AUTOMATIC_DETECTION = _("Automatic Detection")
GRADIO_NONE_STR = ""
GRADIO_NONE_NUMBER_MAX = 9999
GRADIO_NONE_NUMBER_MIN = 0

View File

@@ -15,7 +15,7 @@ from dataclasses import astuple
from modules.uvr.music_separator import MusicSeparator
from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
UVR_MODELS_DIR)
from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.constants import *
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
@@ -519,19 +519,19 @@ class BaseTranscriptionPipeline(ABC):
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
params.whisper.lang = language_code_dict[params.lang]
if not params.whisper.initial_prompt:
if params.whisper.initial_prompt == GRADIO_NONE_STR:
params.whisper.initial_prompt = None
if not params.whisper.prefix:
if params.whisper.prefix == GRADIO_NONE_STR:
params.whisper.prefix = None
if not params.whisper.hotwords:
if params.whisper.hotwords == GRADIO_NONE_STR:
params.whisper.hotwords = None
if params.whisper.max_new_tokens == 0:
if params.whisper.max_new_tokens == GRADIO_NONE_NUMBER_MIN:
params.whisper.max_new_tokens = None
if params.whisper.hallucination_silence_threshold == 0:
if params.whisper.hallucination_silence_threshold == GRADIO_NONE_NUMBER_MIN:
params.whisper.hallucination_silence_threshold = None
if params.whisper.language_detection_threshold == 0:
if params.whisper.language_detection_threshold == GRADIO_NONE_NUMBER_MIN:
params.whisper.language_detection_threshold = None
if params.vad.max_speech_duration_s >= 9999:
if params.vad.max_speech_duration_s == GRADIO_NONE_NUMBER_MAX:
params.vad.max_speech_duration_s = float('inf')
return params
@@ -555,7 +555,7 @@ class BaseTranscriptionPipeline(ABC):
cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
cached_yaml["vad"]["max_speech_duration_s"] = 9999
cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
if cached_yaml is not None and cached_yaml:
save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)

View File

@@ -7,7 +7,7 @@ from enum import Enum
from copy import deepcopy
import yaml
from modules.utils.constants import AUTOMATIC_DETECTION
from modules.utils.constants import *
class WhisperImpl(Enum):
@@ -82,7 +82,7 @@ class VadParams(BaseParams):
),
gr.Number(
label="Maximum Speech Duration (s)",
value=defaults.get("max_speech_duration_s", cls.__fields__["max_speech_duration_s"].default),
value=defaults.get("max_speech_duration_s", GRADIO_NONE_NUMBER_MAX),
info="Maximum duration of speech chunks in \"seconds\"."
),
gr.Number(
@@ -373,7 +373,7 @@ class WhisperParams(BaseParams):
),
gr.Textbox(
label="Initial Prompt",
value=defaults.get("initial_prompt", cls.__fields__["initial_prompt"].default),
value=defaults.get("initial_prompt", GRADIO_NONE_STR),
info="Initial prompt for first window"
),
gr.Slider(
@@ -411,7 +411,7 @@ class WhisperParams(BaseParams):
),
gr.Textbox(
label="Prefix",
value=defaults.get("prefix", cls.__fields__["prefix"].default),
value=defaults.get("prefix", GRADIO_NONE_STR),
info="Prefix text for first window"
),
gr.Checkbox(
@@ -446,7 +446,7 @@ class WhisperParams(BaseParams):
),
gr.Number(
label="Max New Tokens",
value=defaults.get("max_new_tokens", cls.__fields__["max_new_tokens"].default),
value=defaults.get("max_new_tokens", GRADIO_NONE_NUMBER_MIN),
precision=0,
info="Maximum number of new tokens per chunk"
),
@@ -459,7 +459,7 @@ class WhisperParams(BaseParams):
gr.Number(
label="Hallucination Silence Threshold (sec)",
value=defaults.get("hallucination_silence_threshold",
cls.__fields__["hallucination_silence_threshold"].default),
GRADIO_NONE_NUMBER_MIN),
info="Threshold for skipping silent periods in hallucination detection"
),
gr.Textbox(
@@ -470,7 +470,7 @@ class WhisperParams(BaseParams):
gr.Number(
label="Language Detection Threshold",
value=defaults.get("language_detection_threshold",
cls.__fields__["language_detection_threshold"].default),
GRADIO_NONE_NUMBER_MIN),
info="Threshold for language detection probability"
),
gr.Number(