zhuyunfeng
2023-05-09 b15db52e4e67da8a133a67e8ffa415386de48b40
funasr/bin/eend_ola_inference.py
@@ -16,8 +16,8 @@
import numpy as np
import torch
from typeguard import check_argument_types
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import EENDOLADiarTask
@@ -28,13 +28,14 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
class Speech2Diarization:
    """Speech2Diarlization class
    Examples:
        >>> import soundfile
        >>> import numpy as np
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
        >>> profile = np.load("profiles.npy")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2diar(audio, profile)
@@ -157,6 +158,8 @@
        **kwargs,
):
    assert check_argument_types()
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if batch_size > 1:
        raise NotImplementedError("batch decoding is not implemented")
    if ngpu > 1:
@@ -208,7 +211,7 @@
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs[0], "speech", "bytes"]
            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -237,7 +240,7 @@
            results = speech2diar(**batch)
            # post process
            a = results[0].cpu().numpy()
            a = results[0][0].cpu().numpy()
            a = medfilt(a, (11, 1))
            rst = []
            for spkid, frames in enumerate(a.T):
@@ -246,8 +249,8 @@
                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
                for s, e in zip(changes[::2], changes[1::2]):
                    st = s / 10.
                    ed = e / 10.
                    rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
                    dur = (e - s) / 10.
                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
            # Only supporting batch_size==1
            value = "\n".join(rst)