Calculate WER between gen result & answer

This commit is contained in:
jhj0517
2024-11-02 16:35:46 +09:00
parent 5fee9a3edb
commit c7bfcf2316

View File

@@ -1,5 +1,6 @@
from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR
from test_config import *
@@ -28,6 +29,10 @@ def test_transcribe(
if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
answer = TEST_ANSWER
if diarization:
answer = "SPEAKER_00|"+TEST_ANSWER
whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type,
)
@@ -54,7 +59,7 @@ def test_transcribe(
),
).to_list()
subtitle_str, file_path = whisper_inferencer.transcribe_file(
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path],
None,
"SRT",
@@ -62,12 +67,11 @@ def test_transcribe(
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
subtitle = read_file(file_paths[0]).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
if not is_pytube_detected_bot():
whisper_inferencer.transcribe_youtube(
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL,
"SRT",
False,
@@ -75,17 +79,17 @@ def test_transcribe(
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
assert os.path.exists(file_path)
whisper_inferencer.transcribe_mic(
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path,
"SRT",
False,
gr.Progress(),
*hparams,
)
assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path
subtitle = read_file(file_path).split("\n")
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
def download_file(url, save_dir):