Merge pull request #380 from linuxlurak/master-mod
Include loading of default value of file_format from config file.
This commit is contained in:
2
app.py
2
app.py
@@ -53,7 +53,7 @@ class App:
|
||||
dd_lang = gr.Dropdown(choices=self.whisper_inf.available_langs + [AUTOMATIC_DETECTION],
|
||||
value=AUTOMATIC_DETECTION if whisper_params["lang"] == AUTOMATIC_DETECTION.unwrap()
|
||||
else whisper_params["lang"], label=_("Language"))
|
||||
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value="SRT", label=_("File Format"))
|
||||
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt", "LRC"], value=whisper_params["file_format"], label=_("File Format"))
|
||||
with gr.Row():
|
||||
cb_translate = gr.Checkbox(value=whisper_params["is_translate"], label=_("Translate to English?"),
|
||||
interactive=True)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
whisper:
|
||||
model_size: "large-v2"
|
||||
file_format: "SRT"
|
||||
lang: "Automatic Detection"
|
||||
is_translate: false
|
||||
beam_size: 5
|
||||
|
||||
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
def run(self,
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
*pipeline_params,
|
||||
) -> Tuple[List[Segment], float]:
|
||||
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
|
||||
Audio input. This can be file path or binary type.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
file_format: str
|
||||
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
||||
add_timestamp: bool
|
||||
Whether to add a timestamp at the end of the filename.
|
||||
*pipeline_params: tuple
|
||||
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
|
||||
self.cache_parameters(
|
||||
params=params,
|
||||
file_format=file_format,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
return result, elapsed_time
|
||||
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
file,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
mic_audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
|
||||
@staticmethod
|
||||
def cache_parameters(
|
||||
params: TranscriptionPipelineParams,
|
||||
add_timestamp: bool
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True
|
||||
):
|
||||
"""Cache parameters to the yaml file"""
|
||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
|
||||
cached_yaml = {**cached_params, **param_to_cache}
|
||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||
cached_yaml["whisper"]["file_format"] = file_format
|
||||
|
||||
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
||||
if supress_token and isinstance(supress_token, list):
|
||||
|
||||
Reference in New Issue
Block a user