Calculate WER between gen result & answer

This commit is contained in:
jhj0517
2024-11-02 16:35:46 +09:00
parent 5fee9a3edb
commit c7bfcf2316

View File

@@ -1,5 +1,6 @@
from modules.whisper.whisper_factory import WhisperFactory from modules.whisper.whisper_factory import WhisperFactory
from modules.whisper.data_classes import * from modules.whisper.data_classes import *
from modules.utils.subtitle_manager import read_file
from modules.utils.paths import WEBUI_DIR from modules.utils.paths import WEBUI_DIR
from test_config import * from test_config import *
@@ -28,6 +29,10 @@ def test_transcribe(
if not os.path.exists(audio_path): if not os.path.exists(audio_path):
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir) download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
answer = TEST_ANSWER
if diarization:
answer = "SPEAKER_00|"+TEST_ANSWER
whisper_inferencer = WhisperFactory.create_whisper_inference( whisper_inferencer = WhisperFactory.create_whisper_inference(
whisper_type=whisper_type, whisper_type=whisper_type,
) )
@@ -54,7 +59,7 @@ def test_transcribe(
), ),
).to_list() ).to_list()
subtitle_str, file_path = whisper_inferencer.transcribe_file( subtitle_str, file_paths = whisper_inferencer.transcribe_file(
[audio_path], [audio_path],
None, None,
"SRT", "SRT",
@@ -62,12 +67,11 @@ def test_transcribe(
gr.Progress(), gr.Progress(),
*hparams, *hparams,
) )
subtitle = read_file(file_paths[0]).split("\n")
assert isinstance(subtitle_str, str) and subtitle_str assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
assert isinstance(file_path[0], str) and file_path
if not is_pytube_detected_bot(): if not is_pytube_detected_bot():
whisper_inferencer.transcribe_youtube( subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
TEST_YOUTUBE_URL, TEST_YOUTUBE_URL,
"SRT", "SRT",
False, False,
@@ -75,17 +79,17 @@ def test_transcribe(
*hparams, *hparams,
) )
assert isinstance(subtitle_str, str) and subtitle_str assert isinstance(subtitle_str, str) and subtitle_str
assert isinstance(file_path[0], str) and file_path assert os.path.exists(file_path)
whisper_inferencer.transcribe_mic( subtitle_str, file_path = whisper_inferencer.transcribe_mic(
audio_path, audio_path,
"SRT", "SRT",
False, False,
gr.Progress(), gr.Progress(),
*hparams, *hparams,
) )
assert isinstance(subtitle_str, str) and subtitle_str subtitle = read_file(file_path).split("\n")
assert isinstance(file_path[0], str) and file_path assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
def download_file(url, save_dir): def download_file(url, save_dir):