Add gradio parameter file_format to cache
This commit is contained in:
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
def run(self,
|
||||
audio: Union[str, BinaryIO, np.ndarray],
|
||||
progress: gr.Progress = gr.Progress(),
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True,
|
||||
*pipeline_params,
|
||||
) -> Tuple[List[Segment], float]:
|
||||
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
|
||||
Audio input. This can be file path or binary type.
|
||||
progress: gr.Progress
|
||||
Indicator to show progress directly in gradio.
|
||||
file_format: str
|
||||
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
||||
add_timestamp: bool
|
||||
Whether to add a timestamp at the end of the filename.
|
||||
*pipeline_params: tuple
|
||||
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
|
||||
self.cache_parameters(
|
||||
params=params,
|
||||
file_format=file_format,
|
||||
add_timestamp=add_timestamp
|
||||
)
|
||||
return result, elapsed_time
|
||||
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
file,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
mic_audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
transcribed_segments, time_for_task = self.run(
|
||||
audio,
|
||||
progress,
|
||||
file_format,
|
||||
add_timestamp,
|
||||
*pipeline_params,
|
||||
)
|
||||
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
|
||||
@staticmethod
|
||||
def cache_parameters(
|
||||
params: TranscriptionPipelineParams,
|
||||
add_timestamp: bool
|
||||
file_format: str = "SRT",
|
||||
add_timestamp: bool = True
|
||||
):
|
||||
"""Cache parameters to the yaml file"""
|
||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
|
||||
|
||||
cached_yaml = {**cached_params, **param_to_cache}
|
||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||
cached_yaml["whisper"]["file_format"] = file_format
|
||||
|
||||
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
||||
if supress_token and isinstance(supress_token, list):
|
||||
|
||||
Reference in New Issue
Block a user