Calculate WER between gen result & answer
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from modules.whisper.whisper_factory import WhisperFactory
|
||||
from modules.whisper.data_classes import *
|
||||
from modules.utils.subtitle_manager import read_file
|
||||
from modules.utils.paths import WEBUI_DIR
|
||||
from test_config import *
|
||||
|
||||
@@ -28,6 +29,10 @@ def test_transcribe(
|
||||
if not os.path.exists(audio_path):
|
||||
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
||||
|
||||
answer = TEST_ANSWER
|
||||
if diarization:
|
||||
answer = "SPEAKER_00|"+TEST_ANSWER
|
||||
|
||||
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
||||
whisper_type=whisper_type,
|
||||
)
|
||||
@@ -54,7 +59,7 @@ def test_transcribe(
|
||||
),
|
||||
).to_list()
|
||||
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
||||
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
|
||||
[audio_path],
|
||||
None,
|
||||
"SRT",
|
||||
@@ -62,12 +67,11 @@ def test_transcribe(
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
subtitle = read_file(file_paths[0]).split("\n")
|
||||
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||
|
||||
if not is_pytube_detected_bot():
|
||||
whisper_inferencer.transcribe_youtube(
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
|
||||
TEST_YOUTUBE_URL,
|
||||
"SRT",
|
||||
False,
|
||||
@@ -75,17 +79,17 @@ def test_transcribe(
|
||||
*hparams,
|
||||
)
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
assert os.path.exists(file_path)
|
||||
|
||||
whisper_inferencer.transcribe_mic(
|
||||
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
|
||||
audio_path,
|
||||
"SRT",
|
||||
False,
|
||||
gr.Progress(),
|
||||
*hparams,
|
||||
)
|
||||
assert isinstance(subtitle_str, str) and subtitle_str
|
||||
assert isinstance(file_path[0], str) and file_path
|
||||
subtitle = read_file(file_path).split("\n")
|
||||
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||
|
||||
|
||||
def download_file(url, save_dir):
|
||||
|
||||
Reference in New Issue
Block a user