From e2844449722beee014637e807d3426a7a8983f50 Mon Sep 17 00:00:00 2001 From: jhj0517 <97279763+jhj0517@users.noreply.github.com> Date: Mon, 4 Nov 2024 23:21:57 +0900 Subject: [PATCH] Add gradio parameter `file_format` to cache --- modules/whisper/base_transcription_pipeline.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/modules/whisper/base_transcription_pipeline.py b/modules/whisper/base_transcription_pipeline.py index 4882abd..2791dc6 100644 --- a/modules/whisper/base_transcription_pipeline.py +++ b/modules/whisper/base_transcription_pipeline.py @@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC): def run(self, audio: Union[str, BinaryIO, np.ndarray], progress: gr.Progress = gr.Progress(), + file_format: str = "SRT", add_timestamp: bool = True, *pipeline_params, ) -> Tuple[List[Segment], float]: @@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC): Audio input. This can be file path or binary type. progress: gr.Progress Indicator to show progress directly in gradio. + file_format: str + Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"] add_timestamp: bool Whether to add a timestamp at the end of the filename. *pipeline_params: tuple @@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC): self.cache_parameters( params=params, + file_format=file_format, add_timestamp=add_timestamp ) return result, elapsed_time @@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC): transcribed_segments, time_for_task = self.run( file, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC): transcribed_segments, time_for_task = self.run( mic_audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC): transcribed_segments, time_for_task = self.run( audio, progress, + file_format, add_timestamp, *pipeline_params, ) @@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC): @staticmethod def cache_parameters( params: TranscriptionPipelineParams, - add_timestamp: bool + file_format: str = "SRT", + add_timestamp: bool = True ): """Cache parameters to the yaml file""" cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) @@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC): cached_yaml = {**cached_params, **param_to_cache} cached_yaml["whisper"]["add_timestamp"] = add_timestamp + cached_yaml["whisper"]["file_format"] = file_format supress_token = cached_yaml["whisper"].get("suppress_tokens", None) if supress_token and isinstance(supress_token, list):