huangmingming
2023-03-16 0f3b189d3563e9b048ad9642d700e74dc2f73b1e
funasr/bin/eend_ola_inference.py
@@ -16,6 +16,7 @@
import numpy as np
import torch
from scipy.signal import medfilt
from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
@@ -34,7 +35,7 @@
    Examples:
        >>> import soundfile
        >>> import numpy as np
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pth")
        >>> speech2diar = Speech2Diarization("diar_sond_config.yml", "diar_sond.pb")
        >>> profile = np.load("profiles.npy")
        >>> audio, rate = soundfile.read("speech.wav")
        >>> speech2diar(audio, profile)
@@ -146,7 +147,7 @@
        output_dir: Optional[str] = None,
        batch_size: int = 1,
        dtype: str = "float32",
        ngpu: int = 0,
        ngpu: int = 1,
        num_workers: int = 0,
        log_level: Union[int, str] = "INFO",
        key_file: Optional[str] = None,
@@ -179,7 +180,6 @@
        diar_model_file=diar_model_file,
        device=device,
        dtype=dtype,
        streaming=streaming,
    )
    logging.info("speech2diarization_kwargs: {}".format(speech2diar_kwargs))
    speech2diar = Speech2Diarization.from_pretrained(
@@ -209,7 +209,7 @@
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
            data_path_and_name_and_type = [raw_inputs[0], "speech", "sound"]
        loader = EENDOLADiarTask.build_streaming_iterator(
            data_path_and_name_and_type,
            dtype=dtype,
@@ -236,9 +236,23 @@
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            results = speech2diar(**batch)
            # post process
            a = results[0][0].cpu().numpy()
            a = medfilt(a, (11, 1))
            rst = []
            for spkid, frames in enumerate(a.T):
                frames = np.pad(frames, (1, 1), 'constant')
                changes, = np.where(np.diff(frames, axis=0) != 0)
                fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
                for s, e in zip(changes[::2], changes[1::2]):
                    st = s / 10.
                    dur = (e - s) / 10.
                    rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
            # Only supporting batch_size==1
            key, value = keys[0], output_results_str(results, keys[0])
            item = {"key": key, "value": value}
            value = "\n".join(rst)
            item = {"key": keys[0], "value": value}
            result_list.append(item)
            if output_path is not None:
                output_writer.write(value)