140 lines
4.7 KiB
Python
140 lines
4.7 KiB
Python
import os
|
|
import torch
|
|
from typing import List, Union, BinaryIO, Optional, Tuple
|
|
import numpy as np
|
|
import time
|
|
import logging
|
|
|
|
from modules.utils.paths import DIARIZATION_MODELS_DIR
|
|
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
|
from modules.diarize.audio_loader import load_audio
|
|
from modules.whisper.data_classes import *
|
|
|
|
|
|
class Diarizer:
|
|
def __init__(self,
|
|
model_dir: str = DIARIZATION_MODELS_DIR
|
|
):
|
|
self.device = self.get_device()
|
|
self.available_device = self.get_available_device()
|
|
self.compute_type = "float16"
|
|
self.model_dir = model_dir
|
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
self.pipe = None
|
|
|
|
def run(self,
|
|
audio: Union[str, BinaryIO, np.ndarray],
|
|
transcribed_result: List[Segment],
|
|
use_auth_token: str,
|
|
device: Optional[str] = None
|
|
) -> Tuple[List[Segment], float]:
|
|
"""
|
|
Diarize transcribed result as a post-processing
|
|
|
|
Parameters
|
|
----------
|
|
audio: Union[str, BinaryIO, np.ndarray]
|
|
Audio input. This can be file path or binary type.
|
|
transcribed_result: List[Segment]
|
|
transcribed result through whisper.
|
|
use_auth_token: str
|
|
Huggingface token with READ permission. This is only needed the first time you download the model.
|
|
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
|
device: Optional[str]
|
|
Device for diarization.
|
|
|
|
Returns
|
|
----------
|
|
segments_result: List[Segment]
|
|
list of Segment that includes start, end timestamps and transcribed text
|
|
elapsed_time: float
|
|
elapsed time for running
|
|
"""
|
|
start_time = time.time()
|
|
|
|
if device is None:
|
|
device = self.device
|
|
|
|
if device != self.device or self.pipe is None:
|
|
self.update_pipe(
|
|
device=device,
|
|
use_auth_token=use_auth_token
|
|
)
|
|
|
|
audio = load_audio(audio)
|
|
|
|
diarization_segments = self.pipe(audio)
|
|
diarized_result = assign_word_speakers(
|
|
diarization_segments,
|
|
{"segments": transcribed_result}
|
|
)
|
|
|
|
segments_result = []
|
|
for segment in diarized_result["segments"]:
|
|
speaker = "None"
|
|
if "speaker" in segment:
|
|
speaker = segment["speaker"]
|
|
diarized_text = speaker + "|" + segment["text"].strip()
|
|
segments_result.append(Segment(
|
|
start=segment["start"],
|
|
end=segment["end"],
|
|
text=diarized_text
|
|
))
|
|
|
|
elapsed_time = time.time() - start_time
|
|
return segments_result, elapsed_time
|
|
|
|
def update_pipe(self,
|
|
use_auth_token: str,
|
|
device: str
|
|
):
|
|
"""
|
|
Set pipeline for diarization
|
|
|
|
Parameters
|
|
----------
|
|
use_auth_token: str
|
|
Huggingface token with READ permission. This is only needed the first time you download the model.
|
|
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
|
|
device: str
|
|
Device for diarization.
|
|
"""
|
|
self.device = device
|
|
|
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
|
|
if (not os.listdir(self.model_dir) and
|
|
not use_auth_token):
|
|
print(
|
|
"\nFailed to diarize. You need huggingface token and agree to their requirements to download the diarization model.\n"
|
|
"Go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and follow their instructions to download the model.\n"
|
|
)
|
|
return
|
|
|
|
logger = logging.getLogger("speechbrain.utils.train_logger")
|
|
# Disable redundant torchvision warning message
|
|
logger.disabled = True
|
|
self.pipe = DiarizationPipeline(
|
|
use_auth_token=use_auth_token,
|
|
device=device,
|
|
cache_dir=self.model_dir
|
|
)
|
|
logger.disabled = False
|
|
|
|
@staticmethod
|
|
def get_device():
|
|
if torch.cuda.is_available():
|
|
return "cuda"
|
|
elif torch.backends.mps.is_available():
|
|
return "mps"
|
|
else:
|
|
return "cpu"
|
|
|
|
@staticmethod
|
|
def get_available_device():
|
|
devices = ["cpu"]
|
|
if torch.cuda.is_available():
|
|
devices.append("cuda")
|
|
elif torch.backends.mps.is_available():
|
|
devices.append("mps")
|
|
return devices |