shixian.shi
2023-10-10 8a0930d682fe3206e0b41c694fc03d7d10c7eed2
paraformer-speaker inference pipeline
1个文件已修改
3个文件已添加
1474 ■■■■■ 已修改文件
funasr/bin/asr_inference_launch.py 363 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/cluster_backend.py 191 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/modelscope_file.py 328 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/speaker_utils.py 592 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py
@@ -5,6 +5,7 @@
import argparse
import logging
from optparse import Option
import os
import sys
import time
@@ -45,6 +46,15 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.vad_utils import slice_padding_fbank
from funasr.utils.speaker_utils import (check_audio_list,
                                        sv_preprocess,
                                        sv_chunk,
                                        CAMPPlus,
                                        extract_feature,
                                        postprocess,
                                        distribute_spk)
from funasr.build_utils.build_model_from_file import build_model_from_file
from funasr.utils.cluster_backend import ClusterBackend
from tqdm import tqdm
def inference_asr(
@@ -739,6 +749,342 @@
            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        torch.cuda.empty_cache()
        return asr_result_list
    return _forward
def inference_paraformer_vad_speaker(
        maxlenratio: float,
        minlenratio: float,
        batch_size: int,
        beam_size: int,
        ngpu: int,
        ctc_weight: float,
        lm_weight: float,
        penalty: float,
        log_level: Union[int, str],
        # data_path_and_name_and_type,
        asr_train_config: Optional[str],
        asr_model_file: Optional[str],
        cmvn_file: Optional[str] = None,
        lm_train_config: Optional[str] = None,
        lm_file: Optional[str] = None,
        token_type: Optional[str] = None,
        key_file: Optional[str] = None,
        word_lm_train_config: Optional[str] = None,
        bpemodel: Optional[str] = None,
        allow_variable_data_keys: bool = False,
        output_dir: Optional[str] = None,
        dtype: str = "float32",
        seed: int = 0,
        ngram_weight: float = 0.9,
        nbest: int = 1,
        num_workers: int = 1,
        vad_infer_config: Optional[str] = None,
        vad_model_file: Optional[str] = None,
        vad_cmvn_file: Optional[str] = None,
        time_stamp_writer: bool = True,
        punc_infer_config: Optional[str] = None,
        punc_model_file: Optional[str] = None,
        sv_model_file: Optional[str] = None,
        streaming: bool = False,
        embedding_node: str = "resnet1_dense",
        sv_threshold: float = 0.9465,
        outputs_dict: Optional[bool] = True,
        param_dict: dict = None,
        **kwargs,
):
    ncpu = kwargs.get("ncpu", 1)
    torch.set_num_threads(ncpu)
    if word_lm_train_config is not None:
        raise NotImplementedError("Word LM is not implemented")
    if ngpu > 1:
        raise NotImplementedError("only single GPU decoding is supported")
    logging.basicConfig(
        level=log_level,
        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
    )
    if param_dict is not None:
        hotword_list_or_file = param_dict.get('hotword')
    else:
        hotword_list_or_file = None
    if ngpu >= 1 and torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"
    # 1. Set random-seed
    set_all_random_seed(seed)
    # 2. Build speech2vadsegment
    speech2vadsegment_kwargs = dict(
        vad_infer_config=vad_infer_config,
        vad_model_file=vad_model_file,
        vad_cmvn_file=vad_cmvn_file,
        device=device,
        dtype=dtype,
    )
    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
    # 3. Build speech2text
    speech2text_kwargs = dict(
        asr_train_config=asr_train_config,
        asr_model_file=asr_model_file,
        cmvn_file=cmvn_file,
        lm_train_config=lm_train_config,
        lm_file=lm_file,
        token_type=token_type,
        bpemodel=bpemodel,
        device=device,
        maxlenratio=maxlenratio,
        minlenratio=minlenratio,
        dtype=dtype,
        beam_size=beam_size,
        ctc_weight=ctc_weight,
        lm_weight=lm_weight,
        ngram_weight=ngram_weight,
        penalty=penalty,
        nbest=nbest,
        hotword_list_or_file=hotword_list_or_file,
    )
    speech2text = Speech2TextParaformer(**speech2text_kwargs)
    text2punc = None
    if punc_model_file is not None:
        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
    if output_dir is not None:
        writer = DatadirWriter(output_dir)
        ibest_writer = writer[f"1best_recog"]
        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
    def _forward(data_path_and_name_and_type,
                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
                 output_dir_v2: Optional[str] = None,
                 fs: dict = None,
                 param_dict: dict = None,
                 **kwargs,
                 ):
        hotword_list_or_file = None
        if param_dict is not None:
            hotword_list_or_file = param_dict.get('hotword')
        if 'hotword' in kwargs:
            hotword_list_or_file = kwargs['hotword']
        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
        batch_size_token = kwargs.get("batch_size_token", 6000)
        print("batch_size_token: ", batch_size_token)
        if speech2text.hotword_list is None:
            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
        # 3. Build data-iterator
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, torch.Tensor):
                raw_inputs = raw_inputs.numpy()
            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
        loader = build_streaming_iterator(
            task_name="asr",
            preprocess_args=None,
            data_path_and_name_and_type=data_path_and_name_and_type,
            dtype=dtype,
            fs=fs,
            batch_size=1,
            key_file=key_file,
            num_workers=num_workers,
        )
        if param_dict is not None:
            use_timestamp = param_dict.get('use_timestamp', True)
        else:
            use_timestamp = True
        finish_count = 0
        file_count = 1
        lfr_factor = 6
        # 7 .Start for-loop
        asr_result_list = []
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        writer = None
        if output_path is not None:
            writer = DatadirWriter(output_path)
            ibest_writer = writer[f"1best_recog"]
        for keys, batch in loader:
            assert isinstance(batch, dict), type(batch)
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            beg_vad = time.time()
            vad_results = speech2vadsegment(**batch)
            end_vad = time.time()
            print("time cost vad: ", end_vad - beg_vad)
            _, vadsegments = vad_results[0], vad_results[1][0]
            ##################################
            #####  speaker_verification  #####
            ##################################
            # load sv model
            sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
            sv_model = CAMPPlus()
            sv_model.load_state_dict(sv_model_dict)
            sv_model.eval()
            cb_model = ClusterBackend()
            vad_segments = []
            audio = batch['speech'].numpy().reshape(-1)
            for vadsegment in vadsegments:
                st = int(vadsegment[0]) / 1000
                ed = int(vadsegment[1]) / 1000
                vad_segments.append(
                    [st, ed, audio[int(st * 16000):int(ed * 16000)]])
            check_audio_list(vad_segments)
            # sv pipeline
            segments = sv_chunk(vad_segments)
            embeddings = []
            for s in segments:
                #_, embs = self.sv_pipeline([s[2]], output_emb=True)
                # embeddings.append(embs)
                wavs = sv_preprocess([s[2]])
                # embs = self.forward(wavs)
                embs = []
                for x in wavs:
                    x = extract_feature([x])
                    embs.append(sv_model(x))
                embs = torch.cat(embs)
                embeddings.append(embs.detach().numpy())
            embeddings = np.concatenate(embeddings)
            labels = cb_model(embeddings)
            sv_output = postprocess(segments, vad_segments, labels, embeddings)
            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
            n = len(vadsegments)
            data_with_index = [(vadsegments[i], i) for i in range(n)]
            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
            results_sorted = []
            if not len(sorted_data):
                key = keys[0]
                # no active segments after VAD
                if writer is not None:
                    # Write empty results
                    ibest_writer["token"][key] = ""
                    ibest_writer["token_int"][key] = ""
                    ibest_writer["vad"][key] = ""
                    ibest_writer["text"][key] = ""
                    ibest_writer["text_with_punc"][key] = ""
                    if use_timestamp:
                        ibest_writer["time_stamp"][key] = ""
                logging.info("decoding, utt: {}, empty speech".format(key))
                continue
            batch_size_token_ms = batch_size_token*60
            if speech2text.device == "cpu":
                batch_size_token_ms = 0
            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
                batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
            batch_size_token_ms_cum = 0
            beg_idx = 0
            beg_asr_total = time.time()
            for j, _ in enumerate(tqdm(range(0, n))):
                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
                    continue
                batch_size_token_ms_cum = 0
                end_idx = j + 1
                speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
                beg_idx = end_idx
                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                batch = to_device(batch, device=device)
                # print("batch: ", speech_j.shape[0])
                beg_asr = time.time()
                results = speech2text(**batch)
                end_asr = time.time()
                # print("time cost asr: ", end_asr - beg_asr)
                if len(results) < 1:
                    results = [["", [], [], [], [], [], []]]
                results_sorted.extend(results)
            end_asr_total = time.time()
            print("total time cost asr: ", end_asr_total-beg_asr_total)
            restored_data = [0] * n
            for j in range(n):
                index = sorted_data[j][1]
                restored_data[index] = results_sorted[j]
            result = ["", [], [], [], [], [], []]
            for j in range(n):
                result[0] += restored_data[j][0]
                result[1] += restored_data[j][1]
                result[2] += restored_data[j][2]
                if len(restored_data[j][4]) > 0:
                    for t in restored_data[j][4]:
                        t[0] += vadsegments[j][0]
                        t[1] += vadsegments[j][0]
                    result[4] += restored_data[j][4]
                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
            key = keys[0]
            # result = result_segments[0]
            text, token, token_int = result[0], result[1], result[2]
            time_stamp = result[4] if len(result[4]) > 0 else None
            if use_timestamp and time_stamp is not None and len(time_stamp):
                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
            else:
                postprocessed_result = postprocess_utils.sentence_postprocess(token)
            text_postprocessed = ""
            time_stamp_postprocessed = ""
            text_postprocessed_punc = postprocessed_result
            if len(postprocessed_result) == 3:
                text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
                                                                           postprocessed_result[1], \
                                                                           postprocessed_result[2]
            else:
                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
            text_postprocessed_punc = text_postprocessed
            punc_id_list = []
            if len(word_lists) > 0 and text2punc is not None:
                beg_punc = time.time()
                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
                end_punc = time.time()
                print("time cost punc: ", end_punc - beg_punc)
            item = {'key': key, 'value': text_postprocessed_punc}
            if text_postprocessed != "":
                item['text_postprocessed'] = text_postprocessed
            if time_stamp_postprocessed != "":
                item['time_stamp'] = time_stamp_postprocessed
            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
            asr_result_list.append(item)
            finish_count += 1
            # asr_utils.print_progress(finish_count / file_count)
            if writer is not None:
                # Write the result to each file
                ibest_writer["token"][key] = " ".join(token)
                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                ibest_writer["vad"][key] = "{}".format(vadsegments)
                ibest_writer["text"][key] = " ".join(word_lists)
                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
                if time_stamp_postprocessed is not None:
                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
        torch.cuda.empty_cache()
        distribute_spk(asr_result_list[0]['sentences'], sv_output)
        import pdb; pdb.set_trace()
        return asr_result_list
    return _forward
@@ -1684,6 +2030,8 @@
        return inference_paraformer(**kwargs)
    elif mode == "paraformer_streaming":
        return inference_paraformer_online(**kwargs)
    elif mode == "paraformer_vad_speaker":
        return inference_paraformer_vad_speaker(**kwargs)
    elif mode.startswith("paraformer_vad"):
        return inference_paraformer_vad_punc(**kwargs)
    elif mode == "mfcca":
@@ -1782,6 +2130,16 @@
        help="VAD model parameter file",
    )
    group.add_argument(
        "--punc_infer_config",
        type=str,
        help="PUNC infer configuration",
    )
    group.add_argument(
        "--punc_model_file",
        type=str,
        help="PUNC model parameter file",
    )
    group.add_argument(
        "--cmvn_file",
        type=str,
        help="Global CMVN file",
@@ -1797,6 +2155,11 @@
        help="ASR model parameter file",
    )
    group.add_argument(
        "--sv_model_file",
        type=str,
        help="SV model parameter file",
    )
    group.add_argument(
        "--lm_train_config",
        type=str,
        help="LM training configuration",
funasr/utils/cluster_backend.py
New file
@@ -0,0 +1,191 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import hdbscan
import numpy as np
import scipy
import sklearn
import umap
from sklearn.cluster._kmeans import k_means
from torch import nn
class SpectralCluster:
    r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
    This implementation is adapted from https://github.com/speechbrain/speechbrain.
    """
    def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
        self.min_num_spks = min_num_spks
        self.max_num_spks = max_num_spks
        self.pval = pval
    def __call__(self, X, oracle_num=None):
        # Similarity matrix computation
        sim_mat = self.get_sim_mat(X)
        # Refining similarity matrix with pval
        prunned_sim_mat = self.p_pruning(sim_mat)
        # Symmetrization
        sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
        # Laplacian calculation
        laplacian = self.get_laplacian(sym_prund_sim_mat)
        # Get Spectral Embeddings
        emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
        # Perform clustering
        labels = self.cluster_embs(emb, num_of_spk)
        return labels
    def get_sim_mat(self, X):
        # Cosine similarities
        M = sklearn.metrics.pairwise.cosine_similarity(X, X)
        return M
    def p_pruning(self, A):
        if A.shape[0] * self.pval < 6:
            pval = 6. / A.shape[0]
        else:
            pval = self.pval
        n_elems = int((1 - pval) * A.shape[0])
        # For each row in a affinity matrix
        for i in range(A.shape[0]):
            low_indexes = np.argsort(A[i, :])
            low_indexes = low_indexes[0:n_elems]
            # Replace smaller similarity values by 0s
            A[i, low_indexes] = 0
        return A
    def get_laplacian(self, M):
        M[np.diag_indices(M.shape[0])] = 0
        D = np.sum(np.abs(M), axis=1)
        D = np.diag(D)
        L = D - M
        return L
    def get_spec_embs(self, L, k_oracle=None):
        lambdas, eig_vecs = scipy.linalg.eigh(L)
        if k_oracle is not None:
            num_of_spk = k_oracle
        else:
            lambda_gap_list = self.getEigenGaps(
                lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
            num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
        emb = eig_vecs[:, :num_of_spk]
        return emb, num_of_spk
    def cluster_embs(self, emb, k):
        _, labels, _ = k_means(emb, k)
        return labels
    def getEigenGaps(self, eig_vals):
        eig_vals_gap_list = []
        for i in range(len(eig_vals) - 1):
            gap = float(eig_vals[i + 1]) - float(eig_vals[i])
            eig_vals_gap_list.append(gap)
        return eig_vals_gap_list
class UmapHdbscan:
    r"""
    Reference:
    - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
      Emphasis On Topological Structure. ICASSP2022
    """
    def __init__(self,
                 n_neighbors=20,
                 n_components=60,
                 min_samples=10,
                 min_cluster_size=10,
                 metric='cosine'):
        self.n_neighbors = n_neighbors
        self.n_components = n_components
        self.min_samples = min_samples
        self.min_cluster_size = min_cluster_size
        self.metric = metric
    def __call__(self, X):
        umap_X = umap.UMAP(
            n_neighbors=self.n_neighbors,
            min_dist=0.0,
            n_components=min(self.n_components, X.shape[0] - 2),
            metric=self.metric,
        ).fit_transform(X)
        labels = hdbscan.HDBSCAN(
            min_samples=self.min_samples,
            min_cluster_size=self.min_cluster_size,
            allow_single_cluster=True).fit_predict(umap_X)
        return labels
class ClusterBackend(nn.Module):
    r"""Perfom clustering for input embeddings and output the labels.
    Args:
        model_dir: A model dir.
        model_config: The model config.
    """
    def __init__(self):
        super().__init__()
        self.model_config = {'merge_thr':0.78}
        # self.other_config = kwargs
        self.spectral_cluster = SpectralCluster()
        self.umap_hdbscan_cluster = UmapHdbscan()
    def forward(self, X, **params):
        # clustering and return the labels
        k = params['oracle_num'] if 'oracle_num' in params else None
        assert len(
            X.shape
        ) == 2, 'modelscope error: the shape of input should be [N, C]'
        if X.shape[0] < 20:
            return np.zeros(X.shape[0], dtype='int')
        if X.shape[0] < 2048 or k is not None:
            labels = self.spectral_cluster(X, k)
        else:
            labels = self.umap_hdbscan_cluster(X)
        if k is None and 'merge_thr' in self.model_config:
            labels = self.merge_by_cos(labels, X,
                                       self.model_config['merge_thr'])
        return labels
    def merge_by_cos(self, labels, embs, cos_thr):
        # merge the similar speakers by cosine similarity
        assert cos_thr > 0 and cos_thr <= 1
        while True:
            spk_num = labels.max() + 1
            if spk_num == 1:
                break
            spk_center = []
            for i in range(spk_num):
                spk_emb = embs[labels == i].mean(0)
                spk_center.append(spk_emb)
            assert len(spk_center) > 0
            spk_center = np.stack(spk_center, axis=0)
            norm_spk_center = spk_center / np.linalg.norm(
                spk_center, axis=1, keepdims=True)
            affinity = np.matmul(norm_spk_center, norm_spk_center.T)
            affinity = np.triu(affinity, 1)
            spks = np.unravel_index(np.argmax(affinity), affinity.shape)
            if affinity[spks] < cos_thr:
                break
            for i in range(len(labels)):
                if labels[i] == spks[1]:
                    labels[i] = spks[0]
                elif labels[i] > spks[1]:
                    labels[i] -= 1
        return labels
funasr/utils/modelscope_file.py
New file
@@ -0,0 +1,328 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
import contextlib
import os
import tempfile
from abc import ABCMeta, abstractmethod
from pathlib import Path
from typing import Generator, Union
import requests
class Storage(metaclass=ABCMeta):
    """Abstract class of storage.
    All backends need to implement two apis: ``read()`` and ``read_text()``.
    ``read()`` reads the file as a byte stream and ``read_text()`` reads
    the file as texts.
    """
    @abstractmethod
    def read(self, filepath: str):
        pass
    @abstractmethod
    def read_text(self, filepath: str):
        pass
    @abstractmethod
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        pass
    @abstractmethod
    def write_text(self,
                   obj: str,
                   filepath: Union[str, Path],
                   encoding: str = 'utf-8') -> None:
        pass
class LocalStorage(Storage):
    """Local hard disk storage"""
    def read(self, filepath: Union[str, Path]) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.
        Args:
            filepath (str or Path): Path to read data.
        Returns:
            bytes: Expected bytes object.
        """
        with open(filepath, 'rb') as f:
            content = f.read()
        return content
    def read_text(self,
                  filepath: Union[str, Path],
                  encoding: str = 'utf-8') -> str:
        """Read data from a given ``filepath`` with 'r' mode.
        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        Returns:
            str: Expected text reading from ``filepath``.
        """
        with open(filepath, 'r', encoding=encoding) as f:
            value_buf = f.read()
        return value_buf
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.
        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.
        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)
        with open(filepath, 'wb') as f:
            f.write(obj)
    def write_text(self,
                   obj: str,
                   filepath: Union[str, Path],
                   encoding: str = 'utf-8') -> None:
        """Write data to a given ``filepath`` with 'w' mode.
        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.
        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        dirname = os.path.dirname(filepath)
        if dirname and not os.path.exists(dirname):
            os.makedirs(dirname, exist_ok=True)
        with open(filepath, 'w', encoding=encoding) as f:
            f.write(obj)
    @contextlib.contextmanager
    def as_local_path(
            self,
            filepath: Union[str,
                            Path]) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        yield filepath
class HTTPStorage(Storage):
    """HTTP and HTTPS storage."""
    def read(self, url):
        # TODO @wenmeng.zwm add progress bar if file is too large
        r = requests.get(url)
        r.raise_for_status()
        return r.content
    def read_text(self, url):
        r = requests.get(url)
        r.raise_for_status()
        return r.text
    @contextlib.contextmanager
    def as_local_path(
            self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.
        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.
        Args:
            filepath (str): Download a file from ``filepath``.
        Examples:
            >>> storage = HTTPStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
    def write(self, obj: bytes, url: Union[str, Path]) -> None:
        raise NotImplementedError('write is not supported by HTTP Storage')
    def write_text(self,
                   obj: str,
                   url: Union[str, Path],
                   encoding: str = 'utf-8') -> None:
        raise NotImplementedError(
            'write_text is not supported by HTTP Storage')
class OSSStorage(Storage):
    """OSS storage."""
    def __init__(self, oss_config_file=None):
        # read from config file or env var
        raise NotImplementedError(
            'OSSStorage.__init__ to be implemented in the future')
    def read(self, filepath):
        raise NotImplementedError(
            'OSSStorage.read to be implemented in the future')
    def read_text(self, filepath, encoding='utf-8'):
        raise NotImplementedError(
            'OSSStorage.read_text to be implemented in the future')
    @contextlib.contextmanager
    def as_local_path(
            self, filepath: str) -> Generator[Union[str, Path], None, None]:
        """Download a file from ``filepath``.
        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
        can be called with ``with`` statement, and when exists from the
        ``with`` statement, the temporary path will be released.
        Args:
            filepath (str): Download a file from ``filepath``.
        Examples:
            >>> storage = OSSStorage()
            >>> # After existing from the ``with`` clause,
            >>> # the path will be removed
            >>> with storage.get_local_path('http://path/to/file') as path:
            ...     # do something here
        """
        try:
            f = tempfile.NamedTemporaryFile(delete=False)
            f.write(self.read(filepath))
            f.close()
            yield f.name
        finally:
            os.remove(f.name)
    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
        raise NotImplementedError(
            'OSSStorage.write to be implemented in the future')
    def write_text(self,
                   obj: str,
                   filepath: Union[str, Path],
                   encoding: str = 'utf-8') -> None:
        raise NotImplementedError(
            'OSSStorage.write_text to be implemented in the future')
G_STORAGES = {}
class File(object):
    _prefix_to_storage: dict = {
        'oss': OSSStorage,
        'http': HTTPStorage,
        'https': HTTPStorage,
        'local': LocalStorage,
    }
    @staticmethod
    def _get_storage(uri):
        assert isinstance(uri,
                          str), f'uri should be str type, but got {type(uri)}'
        if '://' not in uri:
            # local path
            storage_type = 'local'
        else:
            prefix, _ = uri.split('://')
            storage_type = prefix
        assert storage_type in File._prefix_to_storage, \
            f'Unsupported uri {uri}, valid prefixs: '\
            f'{list(File._prefix_to_storage.keys())}'
        if storage_type not in G_STORAGES:
            G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
        return G_STORAGES[storage_type]
    @staticmethod
    def read(uri: str) -> bytes:
        """Read data from a given ``filepath`` with 'rb' mode.
        Args:
            filepath (str or Path): Path to read data.
        Returns:
            bytes: Expected bytes object.
        """
        storage = File._get_storage(uri)
        return storage.read(uri)
    @staticmethod
    def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
        """Read data from a given ``filepath`` with 'r' mode.
        Args:
            filepath (str or Path): Path to read data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        Returns:
            str: Expected text reading from ``filepath``.
        """
        storage = File._get_storage(uri)
        return storage.read_text(uri)
    @staticmethod
    def write(obj: bytes, uri: Union[str, Path]) -> None:
        """Write data to a given ``filepath`` with 'wb' mode.
        Note:
            ``write`` will create a directory if the directory of ``filepath``
            does not exist.
        Args:
            obj (bytes): Data to be written.
            filepath (str or Path): Path to write data.
        """
        storage = File._get_storage(uri)
        return storage.write(obj, uri)
    @staticmethod
    def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
        """Write data to a given ``filepath`` with 'w' mode.
        Note:
            ``write_text`` will create a directory if the directory of
            ``filepath`` does not exist.
        Args:
            obj (str): Data to be written.
            filepath (str or Path): Path to write data.
            encoding (str): The encoding format used to open the ``filepath``.
                Default: 'utf-8'.
        """
        storage = File._get_storage(uri)
        return storage.write_text(obj, uri)
    @contextlib.contextmanager
    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
        """Only for unified API and do nothing."""
        storage = File._get_storage(uri)
        with storage.as_local_path(uri) as local_path:
            yield local_path
funasr/utils/speaker_utils.py
New file
@@ -0,0 +1,592 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
"""
import torch
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import nn
import io
import os
from typing import Any, Dict, List, Union
import numpy as np
import soundfile as sf
import torch
import torchaudio
import logging
from funasr.utils.modelscope_file import File
from collections import OrderedDict
import torchaudio.compliance.kaldi as Kaldi
def check_audio_list(audio: list):
    audio_dur = 0
    for i in range(len(audio)):
        seg = audio[i]
        assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
        assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
        assert int(seg[1] * 16000) - int(
            seg[0] * 16000
        ) == seg[2].shape[
            0], 'modelscope error: audio data in list is inconsistent with time length.'
        if i > 0:
            assert seg[0] >= audio[
                i - 1][1], 'modelscope error: Wrong time stamps.'
        audio_dur += seg[1] - seg[0]
    assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
def sv_preprocess(inputs: Union[np.ndarray, list]):
        output = []
        for i in range(len(inputs)):
            if isinstance(inputs[i], str):
                file_bytes = File.read(inputs[i])
                data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
                if len(data.shape) == 2:
                    data = data[:, 0]
                data = torch.from_numpy(data).unsqueeze(0)
                data = data.squeeze(0)
            elif isinstance(inputs[i], np.ndarray):
                assert len(
                    inputs[i].shape
                ) == 1, 'modelscope error: Input array should be [N, T]'
                data = inputs[i]
                if data.dtype in ['int16', 'int32', 'int64']:
                    data = (data / (1 << 15)).astype('float32')
                else:
                    data = data.astype('float32')
                data = torch.from_numpy(data)
            else:
                raise ValueError(
                    'modelscope error: The input type is restricted to audio address and nump array.'
                )
            output.append(data)
        return output
def sv_chunk(vad_segments: list, fs = 16000) -> list:
    config = {
            'seg_dur': 1.5,
            'seg_shift': 0.75,
        }
    def seg_chunk(seg_data):
        seg_st = seg_data[0]
        data = seg_data[2]
        chunk_len = int(config['seg_dur'] * fs)
        chunk_shift = int(config['seg_shift'] * fs)
        last_chunk_ed = 0
        seg_res = []
        for chunk_st in range(0, data.shape[0], chunk_shift):
            chunk_ed = min(chunk_st + chunk_len, data.shape[0])
            if chunk_ed <= last_chunk_ed:
                break
            last_chunk_ed = chunk_ed
            chunk_st = max(0, chunk_ed - chunk_len)
            chunk_data = data[chunk_st:chunk_ed]
            if chunk_data.shape[0] < chunk_len:
                chunk_data = np.pad(chunk_data,
                                    (0, chunk_len - chunk_data.shape[0]),
                                    'constant')
            seg_res.append([
                chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
                chunk_data
            ])
        return seg_res
    segs = []
    for i, s in enumerate(vad_segments):
        segs.extend(seg_chunk(s))
    return segs
class BasicResBlock(nn.Module):
    expansion = 1
    def __init__(self, in_planes, planes, stride=1):
        super(BasicResBlock, self).__init__()
        self.conv1 = nn.Conv2d(
            in_planes,
            planes,
            kernel_size=3,
            stride=(stride, 1),
            padding=1,
            bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(
            planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(
                    in_planes,
                    self.expansion * planes,
                    kernel_size=1,
                    stride=(stride, 1),
                    bias=False), nn.BatchNorm2d(self.expansion * planes))
    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out
class FCM(nn.Module):
    def __init__(self,
                 block=BasicResBlock,
                 num_blocks=[2, 2],
                 m_channels=32,
                 feat_dim=80):
        super(FCM, self).__init__()
        self.in_planes = m_channels
        self.conv1 = nn.Conv2d(
            1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(m_channels)
        self.layer1 = self._make_layer(
            block, m_channels, num_blocks[0], stride=2)
        self.layer2 = self._make_layer(
            block, m_channels, num_blocks[0], stride=2)
        self.conv2 = nn.Conv2d(
            m_channels,
            m_channels,
            kernel_size=3,
            stride=(2, 1),
            padding=1,
            bias=False)
        self.bn2 = nn.BatchNorm2d(m_channels)
        self.out_channels = m_channels * (feat_dim // 8)
    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_planes, planes, stride))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)
    def forward(self, x):
        x = x.unsqueeze(1)
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = F.relu(self.bn2(self.conv2(out)))
        shape = out.shape
        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
        return out
class CAMPPlus(nn.Module):
    def __init__(self,
                 feat_dim=80,
                 embedding_size=192,
                 growth_rate=32,
                 bn_size=4,
                 init_channels=128,
                 config_str='batchnorm-relu',
                 memory_efficient=True,
                 output_level='segment'):
        super(CAMPPlus, self).__init__()
        self.head = FCM(feat_dim=feat_dim)
        channels = self.head.out_channels
        self.output_level = output_level
        self.xvector = nn.Sequential(
            OrderedDict([
                ('tdnn',
                 TDNNLayer(
                     channels,
                     init_channels,
                     5,
                     stride=2,
                     dilation=1,
                     padding=-1,
                     config_str=config_str)),
            ]))
        channels = init_channels
        for i, (num_layers, kernel_size, dilation) in enumerate(
                zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
            block = CAMDenseTDNNBlock(
                num_layers=num_layers,
                in_channels=channels,
                out_channels=growth_rate,
                bn_channels=bn_size * growth_rate,
                kernel_size=kernel_size,
                dilation=dilation,
                config_str=config_str,
                memory_efficient=memory_efficient)
            self.xvector.add_module('block%d' % (i + 1), block)
            channels = channels + num_layers * growth_rate
            self.xvector.add_module(
                'transit%d' % (i + 1),
                TransitLayer(
                    channels, channels // 2, bias=False,
                    config_str=config_str))
            channels //= 2
        self.xvector.add_module('out_nonlinear',
                                get_nonlinear(config_str, channels))
        if self.output_level == 'segment':
            self.xvector.add_module('stats', StatsPool())
            self.xvector.add_module(
                'dense',
                DenseLayer(
                    channels * 2, embedding_size, config_str='batchnorm_'))
        else:
            assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
        for m in self.modules():
            if isinstance(m, (nn.Conv1d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight.data)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
    def forward(self, x):
        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
        x = self.head(x)
        x = self.xvector(x)
        if self.output_level == 'frame':
            x = x.transpose(1, 2)
        return x
def get_nonlinear(config_str, channels):
    nonlinear = nn.Sequential()
    for name in config_str.split('-'):
        if name == 'relu':
            nonlinear.add_module('relu', nn.ReLU(inplace=True))
        elif name == 'prelu':
            nonlinear.add_module('prelu', nn.PReLU(channels))
        elif name == 'batchnorm':
            nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
        elif name == 'batchnorm_':
            nonlinear.add_module('batchnorm',
                                 nn.BatchNorm1d(channels, affine=False))
        else:
            raise ValueError('Unexpected module ({}).'.format(name))
    return nonlinear
def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
    mean = x.mean(dim=dim)
    std = x.std(dim=dim, unbiased=unbiased)
    stats = torch.cat([mean, std], dim=-1)
    if keepdim:
        stats = stats.unsqueeze(dim=dim)
    return stats
class StatsPool(nn.Module):
    def forward(self, x):
        return statistics_pooling(x)
class TDNNLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 kernel_size,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 config_str='batchnorm-relu'):
        super(TDNNLayer, self).__init__()
        if padding < 0:
            assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
                kernel_size)
            padding = (kernel_size - 1) // 2 * dilation
        self.linear = nn.Conv1d(
            in_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias)
        self.nonlinear = get_nonlinear(config_str, out_channels)
    def forward(self, x):
        x = self.linear(x)
        x = self.nonlinear(x)
        return x
def extract_feature(audio):
    features = []
    for au in audio:
        feature = Kaldi.fbank(
            au.unsqueeze(0), num_mel_bins=80)
        feature = feature - feature.mean(dim=0, keepdim=True)
        features.append(feature.unsqueeze(0))
    features = torch.cat(features)
    return features
class CAMLayer(nn.Module):
    def __init__(self,
                 bn_channels,
                 out_channels,
                 kernel_size,
                 stride,
                 padding,
                 dilation,
                 bias,
                 reduction=2):
        super(CAMLayer, self).__init__()
        self.linear_local = nn.Conv1d(
            bn_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias)
        self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
        self.relu = nn.ReLU(inplace=True)
        self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
        self.sigmoid = nn.Sigmoid()
    def forward(self, x):
        y = self.linear_local(x)
        context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
        context = self.relu(self.linear1(context))
        m = self.sigmoid(self.linear2(context))
        return y * m
    def seg_pooling(self, x, seg_len=100, stype='avg'):
        if stype == 'avg':
            seg = F.avg_pool1d(
                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
        elif stype == 'max':
            seg = F.max_pool1d(
                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
        else:
            raise ValueError('Wrong segment pooling type.')
        shape = seg.shape
        seg = seg.unsqueeze(-1).expand(*shape,
                                       seg_len).reshape(*shape[:-1], -1)
        seg = seg[..., :x.shape[-1]]
        return seg
class CAMDenseTDNNLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 bn_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 bias=False,
                 config_str='batchnorm-relu',
                 memory_efficient=False):
        super(CAMDenseTDNNLayer, self).__init__()
        assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
            kernel_size)
        padding = (kernel_size - 1) // 2 * dilation
        self.memory_efficient = memory_efficient
        self.nonlinear1 = get_nonlinear(config_str, in_channels)
        self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
        self.nonlinear2 = get_nonlinear(config_str, bn_channels)
        self.cam_layer = CAMLayer(
            bn_channels,
            out_channels,
            kernel_size,
            stride=stride,
            padding=padding,
            dilation=dilation,
            bias=bias)
    def bn_function(self, x):
        return self.linear1(self.nonlinear1(x))
    def forward(self, x):
        if self.training and self.memory_efficient:
            x = cp.checkpoint(self.bn_function, x)
        else:
            x = self.bn_function(x)
        x = self.cam_layer(self.nonlinear2(x))
        return x
class CAMDenseTDNNBlock(nn.ModuleList):
    def __init__(self,
                 num_layers,
                 in_channels,
                 out_channels,
                 bn_channels,
                 kernel_size,
                 stride=1,
                 dilation=1,
                 bias=False,
                 config_str='batchnorm-relu',
                 memory_efficient=False):
        super(CAMDenseTDNNBlock, self).__init__()
        for i in range(num_layers):
            layer = CAMDenseTDNNLayer(
                in_channels=in_channels + i * out_channels,
                out_channels=out_channels,
                bn_channels=bn_channels,
                kernel_size=kernel_size,
                stride=stride,
                dilation=dilation,
                bias=bias,
                config_str=config_str,
                memory_efficient=memory_efficient)
            self.add_module('tdnnd%d' % (i + 1), layer)
    def forward(self, x):
        for layer in self:
            x = torch.cat([x, layer(x)], dim=1)
        return x
class TransitLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 bias=True,
                 config_str='batchnorm-relu'):
        super(TransitLayer, self).__init__()
        self.nonlinear = get_nonlinear(config_str, in_channels)
        self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
    def forward(self, x):
        x = self.nonlinear(x)
        x = self.linear(x)
        return x
class DenseLayer(nn.Module):
    def __init__(self,
                 in_channels,
                 out_channels,
                 bias=False,
                 config_str='batchnorm-relu'):
        super(DenseLayer, self).__init__()
        self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
        self.nonlinear = get_nonlinear(config_str, out_channels)
    def forward(self, x):
        if len(x.shape) == 2:
            x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
        else:
            x = self.linear(x)
        x = self.nonlinear(x)
        return x
def postprocess(segments: list, vad_segments: list,
                labels: np.ndarray, embeddings: np.ndarray) -> list:
    assert len(segments) == len(labels)
    labels = correct_labels(labels)
    distribute_res = []
    for i in range(len(segments)):
        distribute_res.append([segments[i][0], segments[i][1], labels[i]])
    # merge the same speakers chronologically
    distribute_res = merge_seque(distribute_res)
    # accquire speaker center
    spk_embs = []
    for i in range(labels.max() + 1):
        spk_emb = embeddings[labels == i].mean(0)
        spk_embs.append(spk_emb)
    spk_embs = np.stack(spk_embs)
    def is_overlapped(t1, t2):
        if t1 > t2 + 1e-4:
            return True
        return False
    # distribute the overlap region
    for i in range(1, len(distribute_res)):
        if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
            p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
            distribute_res[i][0] = p
            distribute_res[i - 1][1] = p
    # smooth the result
    distribute_res = smooth(distribute_res)
    return distribute_res
def correct_labels(labels):
    labels_id = 0
    id2id = {}
    new_labels = []
    for i in labels:
        if i not in id2id:
            id2id[i] = labels_id
            labels_id += 1
        new_labels.append(id2id[i])
    return np.array(new_labels)
def merge_seque(distribute_res):
    res = [distribute_res[0]]
    for i in range(1, len(distribute_res)):
        if distribute_res[i][2] != res[-1][2] or distribute_res[i][
                0] > res[-1][1]:
            res.append(distribute_res[i])
        else:
            res[-1][1] = distribute_res[i][1]
    return res
def smooth(res, mindur=1):
    # short segments are assigned to nearest speakers.
    for i in range(len(res)):
        res[i][0] = round(res[i][0], 2)
        res[i][1] = round(res[i][1], 2)
        if res[i][1] - res[i][0] < mindur:
            if i == 0:
                res[i][2] = res[i + 1][2]
            elif i == len(res) - 1:
                res[i][2] = res[i - 1][2]
            elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
                res[i][2] = res[i - 1][2]
            else:
                res[i][2] = res[i + 1][2]
    # merge the speakers
    res = merge_seque(res)
    return res
def distribute_spk(sentence_list, sd_time_list):
    sd_sentence_list = []
    for d in sentence_list:
        sentence_start = d['ts_list'][0][0]
        sentence_end = d['ts_list'][-1][1]
        sentence_spk = 0
        max_overlap = 0
        for sd_time in sd_time_list:
            spk_st, spk_ed, spk = sd_time
            spk_st = spk_st*1000
            spk_ed = spk_ed*1000
            overlap = max(
                min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
            if overlap > max_overlap:
                max_overlap = overlap
                sentence_spk = spk
        d['spk'] = sentence_spk
        sd_sentence_list.append(d)
    return sd_sentence_list