Add gradio parameter file_format to cache

This commit is contained in:
jhj0517
2024-11-04 23:21:57 +09:00
parent 8a4343101e
commit e284444972

View File

@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
def run(self, def run(self,
audio: Union[str, BinaryIO, np.ndarray], audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(), progress: gr.Progress = gr.Progress(),
file_format: str = "SRT",
add_timestamp: bool = True, add_timestamp: bool = True,
*pipeline_params, *pipeline_params,
) -> Tuple[List[Segment], float]: ) -> Tuple[List[Segment], float]:
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
Audio input. This can be file path or binary type. Audio input. This can be file path or binary type.
progress: gr.Progress progress: gr.Progress
Indicator to show progress directly in gradio. Indicator to show progress directly in gradio.
file_format: str
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool add_timestamp: bool
Whether to add a timestamp at the end of the filename. Whether to add a timestamp at the end of the filename.
*pipeline_params: tuple *pipeline_params: tuple
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
self.cache_parameters( self.cache_parameters(
params=params, params=params,
file_format=file_format,
add_timestamp=add_timestamp add_timestamp=add_timestamp
) )
return result, elapsed_time return result, elapsed_time
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run( transcribed_segments, time_for_task = self.run(
file, file,
progress, progress,
file_format,
add_timestamp, add_timestamp,
*pipeline_params, *pipeline_params,
) )
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run( transcribed_segments, time_for_task = self.run(
mic_audio, mic_audio,
progress, progress,
file_format,
add_timestamp, add_timestamp,
*pipeline_params, *pipeline_params,
) )
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run( transcribed_segments, time_for_task = self.run(
audio, audio,
progress, progress,
file_format,
add_timestamp, add_timestamp,
*pipeline_params, *pipeline_params,
) )
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
@staticmethod @staticmethod
def cache_parameters( def cache_parameters(
params: TranscriptionPipelineParams, params: TranscriptionPipelineParams,
add_timestamp: bool file_format: str = "SRT",
add_timestamp: bool = True
): ):
"""Cache parameters to the yaml file""" """Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH) cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
cached_yaml = {**cached_params, **param_to_cache} cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp cached_yaml["whisper"]["add_timestamp"] = add_timestamp
cached_yaml["whisper"]["file_format"] = file_format
supress_token = cached_yaml["whisper"].get("suppress_tokens", None) supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list): if supress_token and isinstance(supress_token, list):