Add gradio parameter file_format to cache

This commit is contained in:
jhj0517
2024-11-04 23:21:57 +09:00
parent 8a4343101e
commit e284444972

View File

@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
def run(self,
audio: Union[str, BinaryIO, np.ndarray],
progress: gr.Progress = gr.Progress(),
file_format: str = "SRT",
add_timestamp: bool = True,
*pipeline_params,
) -> Tuple[List[Segment], float]:
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
Audio input. This can be file path or binary type.
progress: gr.Progress
Indicator to show progress directly in gradio.
file_format: str
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
add_timestamp: bool
Whether to add a timestamp at the end of the filename.
*pipeline_params: tuple
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
self.cache_parameters(
params=params,
file_format=file_format,
add_timestamp=add_timestamp
)
return result, elapsed_time
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
file,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
mic_audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
transcribed_segments, time_for_task = self.run(
audio,
progress,
file_format,
add_timestamp,
*pipeline_params,
)
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
@staticmethod
def cache_parameters(
params: TranscriptionPipelineParams,
add_timestamp: bool
file_format: str = "SRT",
add_timestamp: bool = True
):
"""Cache parameters to the yaml file"""
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
cached_yaml = {**cached_params, **param_to_cache}
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
cached_yaml["whisper"]["file_format"] = file_format
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
if supress_token and isinstance(supress_token, list):