Merge pull request #380 from linuxlurak/master-mod

Include loading of default value of file_format from config file.
This commit is contained in:
jhj0517
2024-11-04 23:57:24 +09:00
committed by GitHub
3 changed files with 12 additions and 2 deletions

2
app.py
View File

@@ -53,7 +53,7 @@ class App:
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
else whisper_params["lang"], label=_("Language"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format"))
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
with gr.Row():
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
interactive=True)

View File

@@ -1,5 +1,6 @@
whisper:
model_size: "large-v2"
file_format: "SRT"
lang: "Automatic Detection"
is_translate: false
beam_size: 5

View File

@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
file_format: str = "SRT",
add_timestamp: bool = True,
*pipeline_params,
) -> Tuple[List[Segment], float]:
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
file_format: str
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*pipeline_params: tuple
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
self.cache_parameters(
params=params,
file_format=file_format,
add_timestamp=add_timestamp
)
return result, elapsed_time
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
file,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
@staticmethod
def cache_parameters(
params: TranscriptionPipelineParams,
add_timestamp: bool
file_format: str = "SRT",
add_timestamp: bool = True
):
"""Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
cached_yaml["whisper"]["file_format"] = file_format
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list):