From 8a0930d682fe3206e0b41c694fc03d7d10c7eed2 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 10 十月 2023 11:35:42 +0800
Subject: [PATCH] paraformer-speaker inference pipeline

---
 funasr/utils/modelscope_file.py    |  328 +++++++++++++
 funasr/utils/speaker_utils.py      |  592 +++++++++++++++++++++++
 funasr/utils/cluster_backend.py    |  191 +++++++
 funasr/bin/asr_inference_launch.py |  363 ++++++++++++++
 4 files changed, 1,474 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 50b9886..15dbdd4 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -5,6 +5,7 @@
 
 import argparse
 import logging
+from optparse import Option
 import os
 import sys
 import time
@@ -45,6 +46,15 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.vad_utils import slice_padding_fbank
+from funasr.utils.speaker_utils import (check_audio_list, 
+                                        sv_preprocess, 
+                                        sv_chunk, 
+                                        CAMPPlus, 
+                                        extract_feature, 
+                                        postprocess,
+                                        distribute_spk)
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.utils.cluster_backend import ClusterBackend
 from tqdm import tqdm
 
 def inference_asr(
@@ -739,6 +749,342 @@
 
             logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
         torch.cuda.empty_cache()
+        return asr_result_list
+
+    return _forward
+
+
+def inference_paraformer_vad_speaker(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        beam_size: int,
+        ngpu: int,
+        ctc_weight: float,
+        lm_weight: float,
+        penalty: float,
+        log_level: Union[int, str],
+        # data_path_and_name_and_type,
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        cmvn_file: Optional[str] = None,
+        lm_train_config: Optional[str] = None,
+        lm_file: Optional[str] = None,
+        token_type: Optional[str] = None,
+        key_file: Optional[str] = None,
+        word_lm_train_config: Optional[str] = None,
+        bpemodel: Optional[str] = None,
+        allow_variable_data_keys: bool = False,
+        output_dir: Optional[str] = None,
+        dtype: str = "float32",
+        seed: int = 0,
+        ngram_weight: float = 0.9,
+        nbest: int = 1,
+        num_workers: int = 1,
+        vad_infer_config: Optional[str] = None,
+        vad_model_file: Optional[str] = None,
+        vad_cmvn_file: Optional[str] = None,
+        time_stamp_writer: bool = True,
+        punc_infer_config: Optional[str] = None,
+        punc_model_file: Optional[str] = None,
+        sv_model_file: Optional[str] = None,
+        streaming: bool = False,
+        embedding_node: str = "resnet1_dense",
+        sv_threshold: float = 0.9465,
+        outputs_dict: Optional[bool] = True,
+        param_dict: dict = None,
+
+        **kwargs,
+):
+    ncpu = kwargs.get("ncpu", 1)
+    torch.set_num_threads(ncpu)
+
+    if word_lm_train_config is not None:
+        raise NotImplementedError("Word LM is not implemented")
+    if ngpu > 1:
+        raise NotImplementedError("only single GPU decoding is supported")
+
+    logging.basicConfig(
+        level=log_level,
+        format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+    )
+
+    if param_dict is not None:
+        hotword_list_or_file = param_dict.get('hotword')
+    else:
+        hotword_list_or_file = None
+
+    if ngpu >= 1 and torch.cuda.is_available():
+        device = "cuda"
+    else:
+        device = "cpu"
+
+    # 1. Set random-seed
+    set_all_random_seed(seed)
+
+    # 2. Build speech2vadsegment
+    speech2vadsegment_kwargs = dict(
+        vad_infer_config=vad_infer_config,
+        vad_model_file=vad_model_file,
+        vad_cmvn_file=vad_cmvn_file,
+        device=device,
+        dtype=dtype,
+    )
+    # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+    speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
+
+    # 3. Build speech2text
+    speech2text_kwargs = dict(
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        bpemodel=bpemodel,
+        device=device,
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        dtype=dtype,
+        beam_size=beam_size,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        ngram_weight=ngram_weight,
+        penalty=penalty,
+        nbest=nbest,
+        hotword_list_or_file=hotword_list_or_file,
+    )
+    speech2text = Speech2TextParaformer(**speech2text_kwargs)
+    text2punc = None
+    if punc_model_file is not None:
+        text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
+
+    if output_dir is not None:
+        writer = DatadirWriter(output_dir)
+        ibest_writer = writer[f"1best_recog"]
+        ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
+
+    def _forward(data_path_and_name_and_type,
+                 raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+                 output_dir_v2: Optional[str] = None,
+                 fs: dict = None,
+                 param_dict: dict = None,
+                 **kwargs,
+                 ):
+
+        hotword_list_or_file = None
+        if param_dict is not None:
+            hotword_list_or_file = param_dict.get('hotword')
+
+        if 'hotword' in kwargs:
+            hotword_list_or_file = kwargs['hotword']
+
+        speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
+        batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
+        batch_size_token = kwargs.get("batch_size_token", 6000)
+        print("batch_size_token: ", batch_size_token)
+
+        if speech2text.hotword_list is None:
+            speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+        loader = build_streaming_iterator(
+            task_name="asr",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
+            dtype=dtype,
+            fs=fs,
+            batch_size=1,
+            key_file=key_file,
+            num_workers=num_workers,
+        )
+
+        if param_dict is not None:
+            use_timestamp = param_dict.get('use_timestamp', True)
+        else:
+            use_timestamp = True
+
+        finish_count = 0
+        file_count = 1
+        lfr_factor = 6
+        # 7 .Start for-loop
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        writer = None
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+            ibest_writer = writer[f"1best_recog"]
+
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            beg_vad = time.time()
+            vad_results = speech2vadsegment(**batch)
+            end_vad = time.time()
+            print("time cost vad: ", end_vad - beg_vad)
+            _, vadsegments = vad_results[0], vad_results[1][0]
+            ##################################
+            #####  speaker_verification  #####
+            ##################################
+            # load sv model
+            sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
+            sv_model = CAMPPlus()
+            sv_model.load_state_dict(sv_model_dict)
+            sv_model.eval()
+            cb_model = ClusterBackend()
+            vad_segments = []
+            audio = batch['speech'].numpy().reshape(-1)
+            for vadsegment in vadsegments:
+                st = int(vadsegment[0]) / 1000
+                ed = int(vadsegment[1]) / 1000
+                vad_segments.append(
+                    [st, ed, audio[int(st * 16000):int(ed * 16000)]])
+            check_audio_list(vad_segments)
+            # sv pipeline
+            segments = sv_chunk(vad_segments)
+            embeddings = []
+            for s in segments:
+                #_, embs = self.sv_pipeline([s[2]], output_emb=True)
+                # embeddings.append(embs)
+                wavs = sv_preprocess([s[2]])
+                # embs = self.forward(wavs)
+                embs = []
+                for x in wavs:
+                    x = extract_feature([x])
+                    embs.append(sv_model(x))
+                embs = torch.cat(embs)
+                embeddings.append(embs.detach().numpy())
+            embeddings = np.concatenate(embeddings)
+            labels = cb_model(embeddings)
+            sv_output = postprocess(segments, vad_segments, labels, embeddings)
+
+            speech, speech_lengths = batch["speech"], batch["speech_lengths"]
+
+            n = len(vadsegments)
+            data_with_index = [(vadsegments[i], i) for i in range(n)]
+            sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+            results_sorted = []
+            
+            if not len(sorted_data):
+                key = keys[0]
+                # no active segments after VAD
+                if writer is not None:
+                    # Write empty results
+                    ibest_writer["token"][key] = ""
+                    ibest_writer["token_int"][key] = ""
+                    ibest_writer["vad"][key] = ""
+                    ibest_writer["text"][key] = ""
+                    ibest_writer["text_with_punc"][key] = ""
+                    if use_timestamp:
+                        ibest_writer["time_stamp"][key] = ""
+
+                logging.info("decoding, utt: {}, empty speech".format(key))
+                continue
+
+            batch_size_token_ms = batch_size_token*60
+            if speech2text.device == "cpu":
+                batch_size_token_ms = 0
+            if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+                batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+            
+            batch_size_token_ms_cum = 0
+            beg_idx = 0
+            beg_asr_total = time.time()
+            for j, _ in enumerate(tqdm(range(0, n))):
+                batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
+                if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
+                    continue
+                batch_size_token_ms_cum = 0
+                end_idx = j + 1
+                speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+                beg_idx = end_idx
+                batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
+                batch = to_device(batch, device=device)
+                # print("batch: ", speech_j.shape[0])
+                beg_asr = time.time()
+                results = speech2text(**batch)
+                end_asr = time.time()
+                # print("time cost asr: ", end_asr - beg_asr)
+
+                if len(results) < 1:
+                    results = [["", [], [], [], [], [], []]]
+                results_sorted.extend(results)
+            end_asr_total = time.time()
+            print("total time cost asr: ", end_asr_total-beg_asr_total)
+            restored_data = [0] * n
+            for j in range(n):
+                index = sorted_data[j][1]
+                restored_data[index] = results_sorted[j]
+            result = ["", [], [], [], [], [], []]
+            for j in range(n):
+                result[0] += restored_data[j][0]
+                result[1] += restored_data[j][1]
+                result[2] += restored_data[j][2]
+                if len(restored_data[j][4]) > 0:
+                    for t in restored_data[j][4]:
+                        t[0] += vadsegments[j][0]
+                        t[1] += vadsegments[j][0]
+                    result[4] += restored_data[j][4]
+                # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
+
+            key = keys[0]
+            # result = result_segments[0]
+            text, token, token_int = result[0], result[1], result[2]
+            time_stamp = result[4] if len(result[4]) > 0 else None
+
+            if use_timestamp and time_stamp is not None and len(time_stamp):
+                postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+            else:
+                postprocessed_result = postprocess_utils.sentence_postprocess(token)
+            text_postprocessed = ""
+            time_stamp_postprocessed = ""
+            text_postprocessed_punc = postprocessed_result
+            if len(postprocessed_result) == 3:
+                text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
+                                                                           postprocessed_result[1], \
+                                                                           postprocessed_result[2]
+            else:
+                text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+            text_postprocessed_punc = text_postprocessed
+            punc_id_list = []
+            if len(word_lists) > 0 and text2punc is not None:
+                beg_punc = time.time()
+                text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+                end_punc = time.time()
+                print("time cost punc: ", end_punc - beg_punc)
+
+            item = {'key': key, 'value': text_postprocessed_punc}
+            if text_postprocessed != "":
+                item['text_postprocessed'] = text_postprocessed
+            if time_stamp_postprocessed != "":
+                item['time_stamp'] = time_stamp_postprocessed
+
+            item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
+
+            asr_result_list.append(item)
+            finish_count += 1
+            # asr_utils.print_progress(finish_count / file_count)
+            if writer is not None:
+                # Write the result to each file
+                ibest_writer["token"][key] = " ".join(token)
+                ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                ibest_writer["vad"][key] = "{}".format(vadsegments)
+                ibest_writer["text"][key] = " ".join(word_lists)
+                ibest_writer["text_with_punc"][key] = text_postprocessed_punc
+                if time_stamp_postprocessed is not None:
+                    ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
+
+            logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
+        torch.cuda.empty_cache()
+        distribute_spk(asr_result_list[0]['sentences'], sv_output)
+        import pdb; pdb.set_trace()
         return asr_result_list
 
     return _forward
@@ -1684,6 +2030,8 @@
         return inference_paraformer(**kwargs)
     elif mode == "paraformer_streaming":
         return inference_paraformer_online(**kwargs)
+    elif mode == "paraformer_vad_speaker":
+        return inference_paraformer_vad_speaker(**kwargs)
     elif mode.startswith("paraformer_vad"):
         return inference_paraformer_vad_punc(**kwargs)
     elif mode == "mfcca":
@@ -1782,6 +2130,16 @@
         help="VAD model parameter file",
     )
     group.add_argument(
+        "--punc_infer_config",
+        type=str,
+        help="PUNC infer configuration",
+    )
+    group.add_argument(
+        "--punc_model_file",
+        type=str,
+        help="PUNC model parameter file",
+    )
+    group.add_argument(
         "--cmvn_file",
         type=str,
         help="Global CMVN file",
@@ -1797,6 +2155,11 @@
         help="ASR model parameter file",
     )
     group.add_argument(
+        "--sv_model_file",
+        type=str,
+        help="SV model parameter file",
+    )
+    group.add_argument(
         "--lm_train_config",
         type=str,
         help="LM training configuration",
diff --git a/funasr/utils/cluster_backend.py b/funasr/utils/cluster_backend.py
new file mode 100644
index 0000000..47b45d2
--- /dev/null
+++ b/funasr/utils/cluster_backend.py
@@ -0,0 +1,191 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from typing import Any, Dict, Union
+
+import hdbscan
+import numpy as np
+import scipy
+import sklearn
+import umap
+from sklearn.cluster._kmeans import k_means
+from torch import nn
+
+
+class SpectralCluster:
+    r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
+    This implementation is adapted from https://github.com/speechbrain/speechbrain.
+    """
+
+    def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
+        self.min_num_spks = min_num_spks
+        self.max_num_spks = max_num_spks
+        self.pval = pval
+
+    def __call__(self, X, oracle_num=None):
+        # Similarity matrix computation
+        sim_mat = self.get_sim_mat(X)
+
+        # Refining similarity matrix with pval
+        prunned_sim_mat = self.p_pruning(sim_mat)
+
+        # Symmetrization
+        sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
+
+        # Laplacian calculation
+        laplacian = self.get_laplacian(sym_prund_sim_mat)
+
+        # Get Spectral Embeddings
+        emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
+
+        # Perform clustering
+        labels = self.cluster_embs(emb, num_of_spk)
+
+        return labels
+
+    def get_sim_mat(self, X):
+        # Cosine similarities
+        M = sklearn.metrics.pairwise.cosine_similarity(X, X)
+        return M
+
+    def p_pruning(self, A):
+        if A.shape[0] * self.pval < 6:
+            pval = 6. / A.shape[0]
+        else:
+            pval = self.pval
+
+        n_elems = int((1 - pval) * A.shape[0])
+
+        # For each row in a affinity matrix
+        for i in range(A.shape[0]):
+            low_indexes = np.argsort(A[i, :])
+            low_indexes = low_indexes[0:n_elems]
+
+            # Replace smaller similarity values by 0s
+            A[i, low_indexes] = 0
+        return A
+
+    def get_laplacian(self, M):
+        M[np.diag_indices(M.shape[0])] = 0
+        D = np.sum(np.abs(M), axis=1)
+        D = np.diag(D)
+        L = D - M
+        return L
+
+    def get_spec_embs(self, L, k_oracle=None):
+        lambdas, eig_vecs = scipy.linalg.eigh(L)
+
+        if k_oracle is not None:
+            num_of_spk = k_oracle
+        else:
+            lambda_gap_list = self.getEigenGaps(
+                lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+            num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
+
+        emb = eig_vecs[:, :num_of_spk]
+        return emb, num_of_spk
+
+    def cluster_embs(self, emb, k):
+        _, labels, _ = k_means(emb, k)
+        return labels
+
+    def getEigenGaps(self, eig_vals):
+        eig_vals_gap_list = []
+        for i in range(len(eig_vals) - 1):
+            gap = float(eig_vals[i + 1]) - float(eig_vals[i])
+            eig_vals_gap_list.append(gap)
+        return eig_vals_gap_list
+
+
+class UmapHdbscan:
+    r"""
+    Reference:
+    - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
+      Emphasis On Topological Structure. ICASSP2022
+    """
+
+    def __init__(self,
+                 n_neighbors=20,
+                 n_components=60,
+                 min_samples=10,
+                 min_cluster_size=10,
+                 metric='cosine'):
+        self.n_neighbors = n_neighbors
+        self.n_components = n_components
+        self.min_samples = min_samples
+        self.min_cluster_size = min_cluster_size
+        self.metric = metric
+
+    def __call__(self, X):
+        umap_X = umap.UMAP(
+            n_neighbors=self.n_neighbors,
+            min_dist=0.0,
+            n_components=min(self.n_components, X.shape[0] - 2),
+            metric=self.metric,
+        ).fit_transform(X)
+        labels = hdbscan.HDBSCAN(
+            min_samples=self.min_samples,
+            min_cluster_size=self.min_cluster_size,
+            allow_single_cluster=True).fit_predict(umap_X)
+        return labels
+
+
+class ClusterBackend(nn.Module):
+    r"""Perfom clustering for input embeddings and output the labels.
+    Args:
+        model_dir: A model dir.
+        model_config: The model config.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.model_config = {'merge_thr':0.78}
+        # self.other_config = kwargs
+
+        self.spectral_cluster = SpectralCluster()
+        self.umap_hdbscan_cluster = UmapHdbscan()
+
+    def forward(self, X, **params):
+        # clustering and return the labels
+        k = params['oracle_num'] if 'oracle_num' in params else None
+        assert len(
+            X.shape
+        ) == 2, 'modelscope error: the shape of input should be [N, C]'
+        if X.shape[0] < 20:
+            return np.zeros(X.shape[0], dtype='int')
+        if X.shape[0] < 2048 or k is not None:
+            labels = self.spectral_cluster(X, k)
+        else:
+            labels = self.umap_hdbscan_cluster(X)
+
+        if k is None and 'merge_thr' in self.model_config:
+            labels = self.merge_by_cos(labels, X,
+                                       self.model_config['merge_thr'])
+
+        return labels
+
+    def merge_by_cos(self, labels, embs, cos_thr):
+        # merge the similar speakers by cosine similarity
+        assert cos_thr > 0 and cos_thr <= 1
+        while True:
+            spk_num = labels.max() + 1
+            if spk_num == 1:
+                break
+            spk_center = []
+            for i in range(spk_num):
+                spk_emb = embs[labels == i].mean(0)
+                spk_center.append(spk_emb)
+            assert len(spk_center) > 0
+            spk_center = np.stack(spk_center, axis=0)
+            norm_spk_center = spk_center / np.linalg.norm(
+                spk_center, axis=1, keepdims=True)
+            affinity = np.matmul(norm_spk_center, norm_spk_center.T)
+            affinity = np.triu(affinity, 1)
+            spks = np.unravel_index(np.argmax(affinity), affinity.shape)
+            if affinity[spks] < cos_thr:
+                break
+            for i in range(len(labels)):
+                if labels[i] == spks[1]:
+                    labels[i] = spks[0]
+                elif labels[i] > spks[1]:
+                    labels[i] -= 1
+        return labels
diff --git a/funasr/utils/modelscope_file.py b/funasr/utils/modelscope_file.py
new file mode 100644
index 0000000..d93f24c
--- /dev/null
+++ b/funasr/utils/modelscope_file.py
@@ -0,0 +1,328 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import contextlib
+import os
+import tempfile
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Generator, Union
+
+import requests
+
+
+class Storage(metaclass=ABCMeta):
+    """Abstract class of storage.
+
+    All backends need to implement two apis: ``read()`` and ``read_text()``.
+    ``read()`` reads the file as a byte stream and ``read_text()`` reads
+    the file as texts.
+    """
+
+    @abstractmethod
+    def read(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def read_text(self, filepath: str):
+        pass
+
+    @abstractmethod
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        pass
+
+    @abstractmethod
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        pass
+
+
+class LocalStorage(Storage):
+    """Local hard disk storage"""
+
+    def read(self, filepath: Union[str, Path]) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        with open(filepath, 'rb') as f:
+            content = f.read()
+        return content
+
+    def read_text(self,
+                  filepath: Union[str, Path],
+                  encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        with open(filepath, 'r', encoding=encoding) as f:
+            value_buf = f.read()
+        return value_buf
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'wb') as f:
+            f.write(obj)
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        dirname = os.path.dirname(filepath)
+        if dirname and not os.path.exists(dirname):
+            os.makedirs(dirname, exist_ok=True)
+
+        with open(filepath, 'w', encoding=encoding) as f:
+            f.write(obj)
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self,
+            filepath: Union[str,
+                            Path]) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        yield filepath
+
+
+class HTTPStorage(Storage):
+    """HTTP and HTTPS storage."""
+
+    def read(self, url):
+        # TODO @wenmeng.zwm add progress bar if file is too large
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.content
+
+    def read_text(self, url):
+        r = requests.get(url)
+        r.raise_for_status()
+        return r.text
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = HTTPStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, url: Union[str, Path]) -> None:
+        raise NotImplementedError('write is not supported by HTTP Storage')
+
+    def write_text(self,
+                   obj: str,
+                   url: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'write_text is not supported by HTTP Storage')
+
+
+class OSSStorage(Storage):
+    """OSS storage."""
+
+    def __init__(self, oss_config_file=None):
+        # read from config file or env var
+        raise NotImplementedError(
+            'OSSStorage.__init__ to be implemented in the future')
+
+    def read(self, filepath):
+        raise NotImplementedError(
+            'OSSStorage.read to be implemented in the future')
+
+    def read_text(self, filepath, encoding='utf-8'):
+        raise NotImplementedError(
+            'OSSStorage.read_text to be implemented in the future')
+
+    @contextlib.contextmanager
+    def as_local_path(
+            self, filepath: str) -> Generator[Union[str, Path], None, None]:
+        """Download a file from ``filepath``.
+
+        ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+        can be called with ``with`` statement, and when exists from the
+        ``with`` statement, the temporary path will be released.
+
+        Args:
+            filepath (str): Download a file from ``filepath``.
+
+        Examples:
+            >>> storage = OSSStorage()
+            >>> # After existing from the ``with`` clause,
+            >>> # the path will be removed
+            >>> with storage.get_local_path('http://path/to/file') as path:
+            ...     # do something here
+        """
+        try:
+            f = tempfile.NamedTemporaryFile(delete=False)
+            f.write(self.read(filepath))
+            f.close()
+            yield f.name
+        finally:
+            os.remove(f.name)
+
+    def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+        raise NotImplementedError(
+            'OSSStorage.write to be implemented in the future')
+
+    def write_text(self,
+                   obj: str,
+                   filepath: Union[str, Path],
+                   encoding: str = 'utf-8') -> None:
+        raise NotImplementedError(
+            'OSSStorage.write_text to be implemented in the future')
+
+
+G_STORAGES = {}
+
+
+class File(object):
+    _prefix_to_storage: dict = {
+        'oss': OSSStorage,
+        'http': HTTPStorage,
+        'https': HTTPStorage,
+        'local': LocalStorage,
+    }
+
+    @staticmethod
+    def _get_storage(uri):
+        assert isinstance(uri,
+                          str), f'uri should be str type, but got {type(uri)}'
+
+        if '://' not in uri:
+            # local path
+            storage_type = 'local'
+        else:
+            prefix, _ = uri.split('://')
+            storage_type = prefix
+
+        assert storage_type in File._prefix_to_storage, \
+            f'Unsupported uri {uri}, valid prefixs: '\
+            f'{list(File._prefix_to_storage.keys())}'
+
+        if storage_type not in G_STORAGES:
+            G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
+
+        return G_STORAGES[storage_type]
+
+    @staticmethod
+    def read(uri: str) -> bytes:
+        """Read data from a given ``filepath`` with 'rb' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+
+        Returns:
+            bytes: Expected bytes object.
+        """
+        storage = File._get_storage(uri)
+        return storage.read(uri)
+
+    @staticmethod
+    def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
+        """Read data from a given ``filepath`` with 'r' mode.
+
+        Args:
+            filepath (str or Path): Path to read data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+
+        Returns:
+            str: Expected text reading from ``filepath``.
+        """
+        storage = File._get_storage(uri)
+        return storage.read_text(uri)
+
+    @staticmethod
+    def write(obj: bytes, uri: Union[str, Path]) -> None:
+        """Write data to a given ``filepath`` with 'wb' mode.
+
+        Note:
+            ``write`` will create a directory if the directory of ``filepath``
+            does not exist.
+
+        Args:
+            obj (bytes): Data to be written.
+            filepath (str or Path): Path to write data.
+        """
+        storage = File._get_storage(uri)
+        return storage.write(obj, uri)
+
+    @staticmethod
+    def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
+        """Write data to a given ``filepath`` with 'w' mode.
+
+        Note:
+            ``write_text`` will create a directory if the directory of
+            ``filepath`` does not exist.
+
+        Args:
+            obj (str): Data to be written.
+            filepath (str or Path): Path to write data.
+            encoding (str): The encoding format used to open the ``filepath``.
+                Default: 'utf-8'.
+        """
+        storage = File._get_storage(uri)
+        return storage.write_text(obj, uri)
+
+    @contextlib.contextmanager
+    def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
+        """Only for unified API and do nothing."""
+        storage = File._get_storage(uri)
+        with storage.as_local_path(uri) as local_path:
+            yield local_path
diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
new file mode 100644
index 0000000..a1c610f
--- /dev/null
+++ b/funasr/utils/speaker_utils.py
@@ -0,0 +1,592 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
+"""
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import nn
+
+import io
+import os
+from typing import Any, Dict, List, Union
+
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+import logging
+from funasr.utils.modelscope_file import File
+from collections import OrderedDict
+import torchaudio.compliance.kaldi as Kaldi
+
+
+def check_audio_list(audio: list):
+    audio_dur = 0
+    for i in range(len(audio)):
+        seg = audio[i]
+        assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
+        assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
+        assert int(seg[1] * 16000) - int(
+            seg[0] * 16000
+        ) == seg[2].shape[
+            0], 'modelscope error: audio data in list is inconsistent with time length.'
+        if i > 0:
+            assert seg[0] >= audio[
+                i - 1][1], 'modelscope error: Wrong time stamps.'
+        audio_dur += seg[1] - seg[0]
+    assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
+
+
+def sv_preprocess(inputs: Union[np.ndarray, list]):
+        output = []
+        for i in range(len(inputs)):
+            if isinstance(inputs[i], str):
+                file_bytes = File.read(inputs[i])
+                data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
+                if len(data.shape) == 2:
+                    data = data[:, 0]
+                data = torch.from_numpy(data).unsqueeze(0)
+                data = data.squeeze(0)
+            elif isinstance(inputs[i], np.ndarray):
+                assert len(
+                    inputs[i].shape
+                ) == 1, 'modelscope error: Input array should be [N, T]'
+                data = inputs[i]
+                if data.dtype in ['int16', 'int32', 'int64']:
+                    data = (data / (1 << 15)).astype('float32')
+                else:
+                    data = data.astype('float32')
+                data = torch.from_numpy(data)
+            else:
+                raise ValueError(
+                    'modelscope error: The input type is restricted to audio address and nump array.'
+                )
+            output.append(data)
+        return output
+
+
+def sv_chunk(vad_segments: list, fs = 16000) -> list:
+    config = {
+            'seg_dur': 1.5,
+            'seg_shift': 0.75,
+        }
+    def seg_chunk(seg_data):
+        seg_st = seg_data[0]
+        data = seg_data[2]
+        chunk_len = int(config['seg_dur'] * fs)
+        chunk_shift = int(config['seg_shift'] * fs)
+        last_chunk_ed = 0
+        seg_res = []
+        for chunk_st in range(0, data.shape[0], chunk_shift):
+            chunk_ed = min(chunk_st + chunk_len, data.shape[0])
+            if chunk_ed <= last_chunk_ed:
+                break
+            last_chunk_ed = chunk_ed
+            chunk_st = max(0, chunk_ed - chunk_len)
+            chunk_data = data[chunk_st:chunk_ed]
+            if chunk_data.shape[0] < chunk_len:
+                chunk_data = np.pad(chunk_data,
+                                    (0, chunk_len - chunk_data.shape[0]),
+                                    'constant')
+            seg_res.append([
+                chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
+                chunk_data
+            ])
+        return seg_res
+
+    segs = []
+    for i, s in enumerate(vad_segments):
+        segs.extend(seg_chunk(s))
+
+    return segs
+
+
+class BasicResBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, in_planes, planes, stride=1):
+        super(BasicResBlock, self).__init__()
+        self.conv1 = nn.Conv2d(
+            in_planes,
+            planes,
+            kernel_size=3,
+            stride=(stride, 1),
+            padding=1,
+            bias=False)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = nn.Conv2d(
+            planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn2 = nn.BatchNorm2d(planes)
+
+        self.shortcut = nn.Sequential()
+        if stride != 1 or in_planes != self.expansion * planes:
+            self.shortcut = nn.Sequential(
+                nn.Conv2d(
+                    in_planes,
+                    self.expansion * planes,
+                    kernel_size=1,
+                    stride=(stride, 1),
+                    bias=False), nn.BatchNorm2d(self.expansion * planes))
+
+    def forward(self, x):
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.bn2(self.conv2(out))
+        out += self.shortcut(x)
+        out = F.relu(out)
+        return out
+
+
+class FCM(nn.Module):
+
+    def __init__(self,
+                 block=BasicResBlock,
+                 num_blocks=[2, 2],
+                 m_channels=32,
+                 feat_dim=80):
+        super(FCM, self).__init__()
+        self.in_planes = m_channels
+        self.conv1 = nn.Conv2d(
+            1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(m_channels)
+
+        self.layer1 = self._make_layer(
+            block, m_channels, num_blocks[0], stride=2)
+        self.layer2 = self._make_layer(
+            block, m_channels, num_blocks[0], stride=2)
+
+        self.conv2 = nn.Conv2d(
+            m_channels,
+            m_channels,
+            kernel_size=3,
+            stride=(2, 1),
+            padding=1,
+            bias=False)
+        self.bn2 = nn.BatchNorm2d(m_channels)
+        self.out_channels = m_channels * (feat_dim // 8)
+
+    def _make_layer(self, block, planes, num_blocks, stride):
+        strides = [stride] + [1] * (num_blocks - 1)
+        layers = []
+        for stride in strides:
+            layers.append(block(self.in_planes, planes, stride))
+            self.in_planes = planes * block.expansion
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        x = x.unsqueeze(1)
+        out = F.relu(self.bn1(self.conv1(x)))
+        out = self.layer1(out)
+        out = self.layer2(out)
+        out = F.relu(self.bn2(self.conv2(out)))
+
+        shape = out.shape
+        out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
+        return out
+
+
+class CAMPPlus(nn.Module):
+
+    def __init__(self,
+                 feat_dim=80,
+                 embedding_size=192,
+                 growth_rate=32,
+                 bn_size=4,
+                 init_channels=128,
+                 config_str='batchnorm-relu',
+                 memory_efficient=True,
+                 output_level='segment'):
+        super(CAMPPlus, self).__init__()
+
+        self.head = FCM(feat_dim=feat_dim)
+        channels = self.head.out_channels
+        self.output_level = output_level
+
+        self.xvector = nn.Sequential(
+            OrderedDict([
+                ('tdnn',
+                 TDNNLayer(
+                     channels,
+                     init_channels,
+                     5,
+                     stride=2,
+                     dilation=1,
+                     padding=-1,
+                     config_str=config_str)),
+            ]))
+        channels = init_channels
+        for i, (num_layers, kernel_size, dilation) in enumerate(
+                zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
+            block = CAMDenseTDNNBlock(
+                num_layers=num_layers,
+                in_channels=channels,
+                out_channels=growth_rate,
+                bn_channels=bn_size * growth_rate,
+                kernel_size=kernel_size,
+                dilation=dilation,
+                config_str=config_str,
+                memory_efficient=memory_efficient)
+            self.xvector.add_module('block%d' % (i + 1), block)
+            channels = channels + num_layers * growth_rate
+            self.xvector.add_module(
+                'transit%d' % (i + 1),
+                TransitLayer(
+                    channels, channels // 2, bias=False,
+                    config_str=config_str))
+            channels //= 2
+
+        self.xvector.add_module('out_nonlinear',
+                                get_nonlinear(config_str, channels))
+
+        if self.output_level == 'segment':
+            self.xvector.add_module('stats', StatsPool())
+            self.xvector.add_module(
+                'dense',
+                DenseLayer(
+                    channels * 2, embedding_size, config_str='batchnorm_'))
+        else:
+            assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
+
+        for m in self.modules():
+            if isinstance(m, (nn.Conv1d, nn.Linear)):
+                nn.init.kaiming_normal_(m.weight.data)
+                if m.bias is not None:
+                    nn.init.zeros_(m.bias)
+
+    def forward(self, x):
+        x = x.permute(0, 2, 1)  # (B,T,F) => (B,F,T)
+        x = self.head(x)
+        x = self.xvector(x)
+        if self.output_level == 'frame':
+            x = x.transpose(1, 2)
+        return x
+
+
+def get_nonlinear(config_str, channels):
+    nonlinear = nn.Sequential()
+    for name in config_str.split('-'):
+        if name == 'relu':
+            nonlinear.add_module('relu', nn.ReLU(inplace=True))
+        elif name == 'prelu':
+            nonlinear.add_module('prelu', nn.PReLU(channels))
+        elif name == 'batchnorm':
+            nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
+        elif name == 'batchnorm_':
+            nonlinear.add_module('batchnorm',
+                                 nn.BatchNorm1d(channels, affine=False))
+        else:
+            raise ValueError('Unexpected module ({}).'.format(name))
+    return nonlinear
+
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
+    mean = x.mean(dim=dim)
+    std = x.std(dim=dim, unbiased=unbiased)
+    stats = torch.cat([mean, std], dim=-1)
+    if keepdim:
+        stats = stats.unsqueeze(dim=dim)
+    return stats
+
+
+class StatsPool(nn.Module):
+
+    def forward(self, x):
+        return statistics_pooling(x)
+
+
+class TDNNLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 kernel_size,
+                 stride=1,
+                 padding=0,
+                 dilation=1,
+                 bias=False,
+                 config_str='batchnorm-relu'):
+        super(TDNNLayer, self).__init__()
+        if padding < 0:
+            assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+                kernel_size)
+            padding = (kernel_size - 1) // 2 * dilation
+        self.linear = nn.Conv1d(
+            in_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias)
+        self.nonlinear = get_nonlinear(config_str, out_channels)
+
+    def forward(self, x):
+        x = self.linear(x)
+        x = self.nonlinear(x)
+        return x
+
+
+def extract_feature(audio):
+    features = []
+    for au in audio:
+        feature = Kaldi.fbank(
+            au.unsqueeze(0), num_mel_bins=80)
+        feature = feature - feature.mean(dim=0, keepdim=True)
+        features.append(feature.unsqueeze(0))
+    features = torch.cat(features)
+    return features
+
+
+class CAMLayer(nn.Module):
+
+    def __init__(self,
+                 bn_channels,
+                 out_channels,
+                 kernel_size,
+                 stride,
+                 padding,
+                 dilation,
+                 bias,
+                 reduction=2):
+        super(CAMLayer, self).__init__()
+        self.linear_local = nn.Conv1d(
+            bn_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias)
+        self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+        self.relu = nn.ReLU(inplace=True)
+        self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
+        self.sigmoid = nn.Sigmoid()
+
+    def forward(self, x):
+        y = self.linear_local(x)
+        context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
+        context = self.relu(self.linear1(context))
+        m = self.sigmoid(self.linear2(context))
+        return y * m
+
+    def seg_pooling(self, x, seg_len=100, stype='avg'):
+        if stype == 'avg':
+            seg = F.avg_pool1d(
+                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+        elif stype == 'max':
+            seg = F.max_pool1d(
+                x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+        else:
+            raise ValueError('Wrong segment pooling type.')
+        shape = seg.shape
+        seg = seg.unsqueeze(-1).expand(*shape,
+                                       seg_len).reshape(*shape[:-1], -1)
+        seg = seg[..., :x.shape[-1]]
+        return seg
+
+
+class CAMDenseTDNNLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 bn_channels,
+                 kernel_size,
+                 stride=1,
+                 dilation=1,
+                 bias=False,
+                 config_str='batchnorm-relu',
+                 memory_efficient=False):
+        super(CAMDenseTDNNLayer, self).__init__()
+        assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+            kernel_size)
+        padding = (kernel_size - 1) // 2 * dilation
+        self.memory_efficient = memory_efficient
+        self.nonlinear1 = get_nonlinear(config_str, in_channels)
+        self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+        self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+        self.cam_layer = CAMLayer(
+            bn_channels,
+            out_channels,
+            kernel_size,
+            stride=stride,
+            padding=padding,
+            dilation=dilation,
+            bias=bias)
+
+    def bn_function(self, x):
+        return self.linear1(self.nonlinear1(x))
+
+    def forward(self, x):
+        if self.training and self.memory_efficient:
+            x = cp.checkpoint(self.bn_function, x)
+        else:
+            x = self.bn_function(x)
+        x = self.cam_layer(self.nonlinear2(x))
+        return x
+
+
+class CAMDenseTDNNBlock(nn.ModuleList):
+
+    def __init__(self,
+                 num_layers,
+                 in_channels,
+                 out_channels,
+                 bn_channels,
+                 kernel_size,
+                 stride=1,
+                 dilation=1,
+                 bias=False,
+                 config_str='batchnorm-relu',
+                 memory_efficient=False):
+        super(CAMDenseTDNNBlock, self).__init__()
+        for i in range(num_layers):
+            layer = CAMDenseTDNNLayer(
+                in_channels=in_channels + i * out_channels,
+                out_channels=out_channels,
+                bn_channels=bn_channels,
+                kernel_size=kernel_size,
+                stride=stride,
+                dilation=dilation,
+                bias=bias,
+                config_str=config_str,
+                memory_efficient=memory_efficient)
+            self.add_module('tdnnd%d' % (i + 1), layer)
+
+    def forward(self, x):
+        for layer in self:
+            x = torch.cat([x, layer(x)], dim=1)
+        return x
+
+
+class TransitLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 bias=True,
+                 config_str='batchnorm-relu'):
+        super(TransitLayer, self).__init__()
+        self.nonlinear = get_nonlinear(config_str, in_channels)
+        self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+    def forward(self, x):
+        x = self.nonlinear(x)
+        x = self.linear(x)
+        return x
+
+
+class DenseLayer(nn.Module):
+
+    def __init__(self,
+                 in_channels,
+                 out_channels,
+                 bias=False,
+                 config_str='batchnorm-relu'):
+        super(DenseLayer, self).__init__()
+        self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+        self.nonlinear = get_nonlinear(config_str, out_channels)
+
+    def forward(self, x):
+        if len(x.shape) == 2:
+            x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+        else:
+            x = self.linear(x)
+        x = self.nonlinear(x)
+        return x
+
+def postprocess(segments: list, vad_segments: list,
+                labels: np.ndarray, embeddings: np.ndarray) -> list:
+    assert len(segments) == len(labels)
+    labels = correct_labels(labels)
+    distribute_res = []
+    for i in range(len(segments)):
+        distribute_res.append([segments[i][0], segments[i][1], labels[i]])
+    # merge the same speakers chronologically
+    distribute_res = merge_seque(distribute_res)
+
+    # accquire speaker center
+    spk_embs = []
+    for i in range(labels.max() + 1):
+        spk_emb = embeddings[labels == i].mean(0)
+        spk_embs.append(spk_emb)
+    spk_embs = np.stack(spk_embs)
+
+    def is_overlapped(t1, t2):
+        if t1 > t2 + 1e-4:
+            return True
+        return False
+
+    # distribute the overlap region
+    for i in range(1, len(distribute_res)):
+        if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
+            p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
+            distribute_res[i][0] = p
+            distribute_res[i - 1][1] = p
+
+    # smooth the result
+    distribute_res = smooth(distribute_res)
+
+    return distribute_res
+
+
+def correct_labels(labels):
+    labels_id = 0
+    id2id = {}
+    new_labels = []
+    for i in labels:
+        if i not in id2id:
+            id2id[i] = labels_id
+            labels_id += 1
+        new_labels.append(id2id[i])
+    return np.array(new_labels)
+
+def merge_seque(distribute_res):
+    res = [distribute_res[0]]
+    for i in range(1, len(distribute_res)):
+        if distribute_res[i][2] != res[-1][2] or distribute_res[i][
+                0] > res[-1][1]:
+            res.append(distribute_res[i])
+        else:
+            res[-1][1] = distribute_res[i][1]
+    return res
+
+def smooth(res, mindur=1):
+    # short segments are assigned to nearest speakers.
+    for i in range(len(res)):
+        res[i][0] = round(res[i][0], 2)
+        res[i][1] = round(res[i][1], 2)
+        if res[i][1] - res[i][0] < mindur:
+            if i == 0:
+                res[i][2] = res[i + 1][2]
+            elif i == len(res) - 1:
+                res[i][2] = res[i - 1][2]
+            elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
+                res[i][2] = res[i - 1][2]
+            else:
+                res[i][2] = res[i + 1][2]
+    # merge the speakers
+    res = merge_seque(res)
+
+    return res
+
+
+def distribute_spk(sentence_list, sd_time_list):
+    sd_sentence_list = []
+    for d in sentence_list:
+        sentence_start = d['ts_list'][0][0]
+        sentence_end = d['ts_list'][-1][1]
+        sentence_spk = 0
+        max_overlap = 0
+        for sd_time in sd_time_list:
+            spk_st, spk_ed, spk = sd_time
+            spk_st = spk_st*1000
+            spk_ed = spk_ed*1000
+            overlap = max(
+                min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
+            if overlap > max_overlap:
+                max_overlap = overlap
+                sentence_spk = spk
+        d['spk'] = sentence_spk
+        sd_sentence_list.append(d)
+    return sd_sentence_list
\ No newline at end of file

--
Gitblit v1.9.1