Calculate WER between gen result & answer
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from modules.whisper.whisper_factory import WhisperFactory
|
from modules.whisper.whisper_factory import WhisperFactory
|
||||||
from modules.whisper.data_classes import *
|
from modules.whisper.data_classes import *
|
||||||
|
from modules.utils.subtitle_manager import read_file
|
||||||
from modules.utils.paths import WEBUI_DIR
|
from modules.utils.paths import WEBUI_DIR
|
||||||
from test_config import *
|
from test_config import *
|
||||||
|
|
||||||
@@ -28,6 +29,10 @@ def test_transcribe(
|
|||||||
if not os.path.exists(audio_path):
|
if not os.path.exists(audio_path):
|
||||||
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
download_file(TEST_FILE_DOWNLOAD_URL, audio_path_dir)
|
||||||
|
|
||||||
|
answer = TEST_ANSWER
|
||||||
|
if diarization:
|
||||||
|
answer = "SPEAKER_00|"+TEST_ANSWER
|
||||||
|
|
||||||
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
whisper_inferencer = WhisperFactory.create_whisper_inference(
|
||||||
whisper_type=whisper_type,
|
whisper_type=whisper_type,
|
||||||
)
|
)
|
||||||
@@ -54,7 +59,7 @@ def test_transcribe(
|
|||||||
),
|
),
|
||||||
).to_list()
|
).to_list()
|
||||||
|
|
||||||
subtitle_str, file_path = whisper_inferencer.transcribe_file(
|
subtitle_str, file_paths = whisper_inferencer.transcribe_file(
|
||||||
[audio_path],
|
[audio_path],
|
||||||
None,
|
None,
|
||||||
"SRT",
|
"SRT",
|
||||||
@@ -62,12 +67,11 @@ def test_transcribe(
|
|||||||
gr.Progress(),
|
gr.Progress(),
|
||||||
*hparams,
|
*hparams,
|
||||||
)
|
)
|
||||||
|
subtitle = read_file(file_paths[0]).split("\n")
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||||
assert isinstance(file_path[0], str) and file_path
|
|
||||||
|
|
||||||
if not is_pytube_detected_bot():
|
if not is_pytube_detected_bot():
|
||||||
whisper_inferencer.transcribe_youtube(
|
subtitle_str, file_path = whisper_inferencer.transcribe_youtube(
|
||||||
TEST_YOUTUBE_URL,
|
TEST_YOUTUBE_URL,
|
||||||
"SRT",
|
"SRT",
|
||||||
False,
|
False,
|
||||||
@@ -75,17 +79,17 @@ def test_transcribe(
|
|||||||
*hparams,
|
*hparams,
|
||||||
)
|
)
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
assert isinstance(subtitle_str, str) and subtitle_str
|
||||||
assert isinstance(file_path[0], str) and file_path
|
assert os.path.exists(file_path)
|
||||||
|
|
||||||
whisper_inferencer.transcribe_mic(
|
subtitle_str, file_path = whisper_inferencer.transcribe_mic(
|
||||||
audio_path,
|
audio_path,
|
||||||
"SRT",
|
"SRT",
|
||||||
False,
|
False,
|
||||||
gr.Progress(),
|
gr.Progress(),
|
||||||
*hparams,
|
*hparams,
|
||||||
)
|
)
|
||||||
assert isinstance(subtitle_str, str) and subtitle_str
|
subtitle = read_file(file_path).split("\n")
|
||||||
assert isinstance(file_path[0], str) and file_path
|
assert calculate_wer(answer, subtitle[2].strip().replace(",", "").replace(".", "")) < 0.1
|
||||||
|
|
||||||
|
|
||||||
def download_file(url, save_dir):
|
def download_file(url, save_dir):
|
||||||
|
|||||||
Reference in New Issue
Block a user