From cc2c1d1d53dea5d2c45f858d1baa5bd279f47987 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期三, 31 五月 2023 14:39:25 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR

---
 funasr/bin/diar_infer.py |  350 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 350 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/diar_infer.py b/funasr/bin/diar_infer.py
new file mode 100755
index 0000000..4460e3d
--- /dev/null
+++ b/funasr/bin/diar_infer.py
@@ -0,0 +1,350 @@
+# -*- encoding: utf-8 -*-
+#!/usr/bin/env python3
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import argparse
+import logging
+import os
+import sys
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+
+from collections import OrderedDict
+import numpy as np
+import soundfile
+import torch
+from torch.nn import functional as F
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.tasks.diar import DiarTask
+from funasr.tasks.diar import EENDOLADiarTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from scipy.ndimage import median_filter
+from funasr.utils.misc import statistic_model_parameters
+from funasr.datasets.iterable_dataset import load_bytes
+from funasr.models.frontend.wav_frontend import WavFrontendMel23
+
+class Speech2DiarizationEEND:
+    """Speech2Diarlization class
+
+    Examples:
+        >>> import soundfile
+        >>> import numpy as np
+        >>> speech2diar = Speech2DiarizationEEND("diar_sond_config.yml", "diar_sond.pb")
+        >>> profile = np.load("profiles.npy")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2diar(audio, profile)
+        {"spk1": [(int, int), ...], ...}
+
+    """
+
+    def __init__(
+            self,
+            diar_train_config: Union[Path, str] = None,
+            diar_model_file: Union[Path, str] = None,
+            device: str = "cpu",
+            dtype: str = "float32",
+    ):
+        assert check_argument_types()
+
+        # 1. Build Diarization model
+        diar_model, diar_train_args = EENDOLADiarTask.build_model_from_file(
+            config_file=diar_train_config,
+            model_file=diar_model_file,
+            device=device
+        )
+        frontend = None
+        if diar_train_args.frontend is not None and diar_train_args.frontend_conf is not None:
+            frontend = WavFrontendMel23(**diar_train_args.frontend_conf)
+
+        # set up seed for eda
+        np.random.seed(diar_train_args.seed)
+        torch.manual_seed(diar_train_args.seed)
+        torch.cuda.manual_seed(diar_train_args.seed)
+        os.environ['PYTORCH_SEED'] = str(diar_train_args.seed)
+        logging.info("diar_model: {}".format(diar_model))
+        logging.info("diar_train_args: {}".format(diar_train_args))
+        diar_model.to(dtype=getattr(torch, dtype)).eval()
+
+        self.diar_model = diar_model
+        self.diar_train_args = diar_train_args
+        self.device = device
+        self.dtype = dtype
+        self.frontend = frontend
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            speech: Union[torch.Tensor, np.ndarray],
+            speech_lengths: Union[torch.Tensor, np.ndarray] = None
+    ):
+        """Inference
+
+        Args:
+            speech: Input speech data
+        Returns:
+            diarization results
+
+        """
+        assert check_argument_types()
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.diar_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        batch = {"speech": feats, "speech_lengths": feats_len}
+        batch = to_device(batch, device=self.device)
+        results = self.diar_model.estimate_sequential(**batch)
+
+        return results
+
+    @staticmethod
+    def from_pretrained(
+            model_tag: Optional[str] = None,
+            **kwargs: Optional[Any],
+    ):
+        """Build Speech2Diarization instance from the pretrained model.
+
+        Args:
+            model_tag (Optional[str]): Model tag of the pretrained models.
+                Currently, the tags of espnet_model_zoo are supported.
+
+        Returns:
+            Speech2Diarization: Speech2Diarization instance.
+
+        """
+        if model_tag is not None:
+            try:
+                from espnet_model_zoo.downloader import ModelDownloader
+
+            except ImportError:
+                logging.error(
+                    "`espnet_model_zoo` is not installed. "
+                    "Please install via `pip install -U espnet_model_zoo`."
+                )
+                raise
+            d = ModelDownloader()
+            kwargs.update(**d.download_and_unpack(model_tag))
+
+        return Speech2DiarizationEEND(**kwargs)
+
+
+class Speech2DiarizationSOND:
+    """Speech2Xvector class
+
+    Examples:
+        >>> import soundfile
+        >>> import numpy as np
+        >>> speech2diar = Speech2DiarizationSOND("diar_sond_config.yml", "diar_sond.pb")
+        >>> profile = np.load("profiles.npy")
+        >>> audio, rate = soundfile.read("speech.wav")
+        >>> speech2diar(audio, profile)
+        {"spk1": [(int, int), ...], ...}
+
+    """
+
+    def __init__(
+            self,
+            diar_train_config: Union[Path, str] = None,
+            diar_model_file: Union[Path, str] = None,
+            device: Union[str, torch.device] = "cpu",
+            batch_size: int = 1,
+            dtype: str = "float32",
+            streaming: bool = False,
+            smooth_size: int = 83,
+            dur_threshold: float = 10,
+    ):
+        assert check_argument_types()
+
+        # TODO: 1. Build Diarization model
+        diar_model, diar_train_args = DiarTask.build_model_from_file(
+            config_file=diar_train_config,
+            model_file=diar_model_file,
+            device=device
+        )
+        logging.info("diar_model: {}".format(diar_model))
+        logging.info("model parameter number: {}".format(statistic_model_parameters(diar_model)))
+        logging.info("diar_train_args: {}".format(diar_train_args))
+        diar_model.to(dtype=getattr(torch, dtype)).eval()
+
+        self.diar_model = diar_model
+        self.diar_train_args = diar_train_args
+        self.token_list = diar_train_args.token_list
+        self.smooth_size = smooth_size
+        self.dur_threshold = dur_threshold
+        self.device = device
+        self.dtype = dtype
+
+    def smooth_multi_labels(self, multi_label):
+        multi_label = median_filter(multi_label, (self.smooth_size, 1), mode="constant", cval=0.0).astype(int)
+        return multi_label
+
+    @staticmethod
+    def calc_spk_turns(label_arr, spk_list):
+        turn_list = []
+        length = label_arr.shape[0]
+        n_spk = label_arr.shape[1]
+        for k in range(n_spk):
+            if spk_list[k] == "None":
+                continue
+            in_utt = False
+            start = 0
+            for i in range(length):
+                if label_arr[i, k] == 1 and in_utt is False:
+                    start = i
+                    in_utt = True
+                if label_arr[i, k] == 0 and in_utt is True:
+                    turn_list.append([spk_list[k], start, i - start])
+                    in_utt = False
+            if in_utt:
+                turn_list.append([spk_list[k], start, length - start])
+        return turn_list
+
+    @staticmethod
+    def seq2arr(seq, vec_dim=8):
+        def int2vec(x, vec_dim=8, dtype=np.int):
+            b = ('{:0' + str(vec_dim) + 'b}').format(x)
+            # little-endian order: lower bit first
+            return (np.array(list(b)[::-1]) == '1').astype(dtype)
+
+        # process oov
+        seq = np.array([int(x) for x in seq])
+        new_seq = []
+        for i, x in enumerate(seq):
+            if x < 2 ** vec_dim:
+                new_seq.append(x)
+            else:
+                idx_list = np.where(seq < 2 ** vec_dim)[0]
+                if len(idx_list) > 0:
+                    idx = np.abs(idx_list - i).argmin()
+                    new_seq.append(seq[idx_list[idx]])
+                else:
+                    new_seq.append(0)
+        return np.row_stack([int2vec(x, vec_dim) for x in new_seq])
+
+    def post_processing(self, raw_logits: torch.Tensor, spk_num: int, output_format: str = "speaker_turn"):
+        logits_idx = raw_logits.argmax(-1)  # B, T, vocab_size -> B, T
+        # upsampling outputs to match inputs
+        ut = logits_idx.shape[1] * self.diar_model.encoder.time_ds_ratio
+        logits_idx = F.upsample(
+            logits_idx.unsqueeze(1).float(),
+            size=(ut, ),
+            mode="nearest",
+        ).squeeze(1).long()
+        logits_idx = logits_idx[0].tolist()
+        pse_labels = [self.token_list[x] for x in logits_idx]
+        if output_format == "pse_labels":
+            return pse_labels, None
+
+        multi_labels = self.seq2arr(pse_labels, spk_num)[:, :spk_num]  # remove padding speakers
+        multi_labels = self.smooth_multi_labels(multi_labels)
+        if output_format == "binary_labels":
+            return multi_labels, None
+
+        spk_list = ["spk{}".format(i + 1) for i in range(spk_num)]
+        spk_turns = self.calc_spk_turns(multi_labels, spk_list)
+        results = OrderedDict()
+        for spk, st, dur in spk_turns:
+            if spk not in results:
+                results[spk] = []
+            if dur > self.dur_threshold:
+                results[spk].append((st, st+dur))
+
+        # sort segments in start time ascending
+        for spk in results:
+            results[spk] = sorted(results[spk], key=lambda x: x[0])
+
+        return results, pse_labels
+
+    @torch.no_grad()
+    def __call__(
+            self,
+            speech: Union[torch.Tensor, np.ndarray],
+            profile: Union[torch.Tensor, np.ndarray],
+            output_format: str = "speaker_turn"
+    ):
+        """Inference
+
+        Args:
+            speech: Input speech data
+            profile: Speaker profiles
+        Returns:
+            diarization results for each speaker
+
+        """
+        assert check_argument_types()
+        # Input as audio signal
+        if isinstance(speech, np.ndarray):
+            speech = torch.tensor(speech)
+        if isinstance(profile, np.ndarray):
+            profile = torch.tensor(profile)
+
+        # data: (Nsamples,) -> (1, Nsamples)
+        speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        profile = profile.unsqueeze(0).to(getattr(torch, self.dtype))
+        # lengths: (1,)
+        speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
+        profile_lengths = profile.new_full([1], dtype=torch.long, fill_value=profile.size(1))
+        batch = {"speech": speech, "speech_lengths": speech_lengths,
+                 "profile": profile, "profile_lengths": profile_lengths}
+        # a. To device
+        batch = to_device(batch, device=self.device)
+
+        logits = self.diar_model.prediction_forward(**batch)
+        results, pse_labels = self.post_processing(logits, profile.shape[1], output_format)
+
+        return results, pse_labels
+
+    @staticmethod
+    def from_pretrained(
+            model_tag: Optional[str] = None,
+            **kwargs: Optional[Any],
+    ):
+        """Build Speech2Xvector instance from the pretrained model.
+
+        Args:
+            model_tag (Optional[str]): Model tag of the pretrained models.
+                Currently, the tags of espnet_model_zoo are supported.
+
+        Returns:
+            Speech2Xvector: Speech2Xvector instance.
+
+        """
+        if model_tag is not None:
+            try:
+                from espnet_model_zoo.downloader import ModelDownloader
+
+            except ImportError:
+                logging.error(
+                    "`espnet_model_zoo` is not installed. "
+                    "Please install via `pip install -U espnet_model_zoo`."
+                )
+                raise
+            d = ModelDownloader()
+            kwargs.update(**d.download_and_unpack(model_tag))
+
+        return Speech2DiarizationSOND(**kwargs)
+
+
+
+

--
Gitblit v1.9.1