Merge pull request #244 from jhj0517/fix/device

Fix device bug
This commit is contained in:
jhj0517
2024-08-30 15:21:31 +09:00
committed by GitHub
2 changed files with 8 additions and 6 deletions

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import List, Union, BinaryIO
from typing import List, Union, BinaryIO, Optional
import numpy as np
import time
import logging
@@ -24,7 +24,7 @@ class Diarizer:
audio: Union[str, BinaryIO, np.ndarray],
transcribed_result: List[dict],
use_auth_token: str,
device: str
device: Optional[str] = None
):
"""
Diarize transcribed result as a post-processing
@@ -38,7 +38,7 @@ class Diarizer:
use_auth_token: str
Huggingface token with READ permission. This is only needed the first time you download the model.
You must manually go to the website https://huggingface.co/pyannote/speaker-diarization-3.1 and agree to their TOS to download the model.
device: str
device: Optional[str]
Device for diarization.
Returns
@@ -50,8 +50,10 @@ class Diarizer:
"""
start_time = time.time()
if (device != self.device
or self.pipe is None):
if device is None:
device = self.device
if device != self.device or self.pipe is None:
self.update_pipe(
device=device,
use_auth_token=use_auth_token
@@ -89,6 +91,7 @@ class Diarizer:
device: str
Device for diarization.
"""
self.device = device
os.makedirs(self.model_dir, exist_ok=True)

View File

@@ -130,7 +130,6 @@ class WhisperBase(ABC):
audio=audio,
use_auth_token=params.hf_token,
transcribed_result=result,
device=self.device
)
elapsed_time += elapsed_time_diarization
return result, elapsed_time