Add gradio parameter file_format to cache
This commit is contained in:
@@ -71,6 +71,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
def run(self,
|
def run(self,
|
||||||
audio: Union[str, BinaryIO, np.ndarray],
|
audio: Union[str, BinaryIO, np.ndarray],
|
||||||
progress: gr.Progress = gr.Progress(),
|
progress: gr.Progress = gr.Progress(),
|
||||||
|
file_format: str = "SRT",
|
||||||
add_timestamp: bool = True,
|
add_timestamp: bool = True,
|
||||||
*pipeline_params,
|
*pipeline_params,
|
||||||
) -> Tuple[List[Segment], float]:
|
) -> Tuple[List[Segment], float]:
|
||||||
@@ -86,6 +87,8 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
Audio input. This can be file path or binary type.
|
Audio input. This can be file path or binary type.
|
||||||
progress: gr.Progress
|
progress: gr.Progress
|
||||||
Indicator to show progress directly in gradio.
|
Indicator to show progress directly in gradio.
|
||||||
|
file_format: str
|
||||||
|
Subtitle file format between ["SRT", "WebVTT", "txt", "lrc"]
|
||||||
add_timestamp: bool
|
add_timestamp: bool
|
||||||
Whether to add a timestamp at the end of the filename.
|
Whether to add a timestamp at the end of the filename.
|
||||||
*pipeline_params: tuple
|
*pipeline_params: tuple
|
||||||
@@ -168,6 +171,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
|
|
||||||
self.cache_parameters(
|
self.cache_parameters(
|
||||||
params=params,
|
params=params,
|
||||||
|
file_format=file_format,
|
||||||
add_timestamp=add_timestamp
|
add_timestamp=add_timestamp
|
||||||
)
|
)
|
||||||
return result, elapsed_time
|
return result, elapsed_time
|
||||||
@@ -224,6 +228,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
file,
|
file,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*pipeline_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
@@ -298,6 +303,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
mic_audio,
|
mic_audio,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*pipeline_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
@@ -364,6 +370,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
transcribed_segments, time_for_task = self.run(
|
transcribed_segments, time_for_task = self.run(
|
||||||
audio,
|
audio,
|
||||||
progress,
|
progress,
|
||||||
|
file_format,
|
||||||
add_timestamp,
|
add_timestamp,
|
||||||
*pipeline_params,
|
*pipeline_params,
|
||||||
)
|
)
|
||||||
@@ -513,7 +520,8 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def cache_parameters(
|
def cache_parameters(
|
||||||
params: TranscriptionPipelineParams,
|
params: TranscriptionPipelineParams,
|
||||||
add_timestamp: bool
|
file_format: str = "SRT",
|
||||||
|
add_timestamp: bool = True
|
||||||
):
|
):
|
||||||
"""Cache parameters to the yaml file"""
|
"""Cache parameters to the yaml file"""
|
||||||
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
|
||||||
@@ -521,6 +529,7 @@ class BaseTranscriptionPipeline(ABC):
|
|||||||
|
|
||||||
cached_yaml = {**cached_params, **param_to_cache}
|
cached_yaml = {**cached_params, **param_to_cache}
|
||||||
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
cached_yaml["whisper"]["add_timestamp"] = add_timestamp
|
||||||
|
cached_yaml["whisper"]["file_format"] = file_format
|
||||||
|
|
||||||
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
supress_token = cached_yaml["whisper"].get("suppress_tokens", None)
|
||||||
if supress_token and isinstance(supress_token, list):
|
if supress_token and isinstance(supress_token, list):
|
||||||
|
|||||||
Reference in New Issue
Block a user