From 8a0930d682fe3206e0b41c694fc03d7d10c7eed2 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期二, 10 十月 2023 11:35:42 +0800
Subject: [PATCH] paraformer-speaker inference pipeline
---
funasr/utils/modelscope_file.py | 328 +++++++++++++
funasr/utils/speaker_utils.py | 592 +++++++++++++++++++++++
funasr/utils/cluster_backend.py | 191 +++++++
funasr/bin/asr_inference_launch.py | 363 ++++++++++++++
4 files changed, 1,474 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 50b9886..15dbdd4 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -5,6 +5,7 @@
import argparse
import logging
+from optparse import Option
import os
import sys
import time
@@ -45,6 +46,15 @@
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
from funasr.utils.vad_utils import slice_padding_fbank
+from funasr.utils.speaker_utils import (check_audio_list,
+ sv_preprocess,
+ sv_chunk,
+ CAMPPlus,
+ extract_feature,
+ postprocess,
+ distribute_spk)
+from funasr.build_utils.build_model_from_file import build_model_from_file
+from funasr.utils.cluster_backend import ClusterBackend
from tqdm import tqdm
def inference_asr(
@@ -739,6 +749,342 @@
logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
torch.cuda.empty_cache()
+ return asr_result_list
+
+ return _forward
+
+
+def inference_paraformer_vad_speaker(
+ maxlenratio: float,
+ minlenratio: float,
+ batch_size: int,
+ beam_size: int,
+ ngpu: int,
+ ctc_weight: float,
+ lm_weight: float,
+ penalty: float,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ asr_train_config: Optional[str],
+ asr_model_file: Optional[str],
+ cmvn_file: Optional[str] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ token_type: Optional[str] = None,
+ key_file: Optional[str] = None,
+ word_lm_train_config: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ ngram_weight: float = 0.9,
+ nbest: int = 1,
+ num_workers: int = 1,
+ vad_infer_config: Optional[str] = None,
+ vad_model_file: Optional[str] = None,
+ vad_cmvn_file: Optional[str] = None,
+ time_stamp_writer: bool = True,
+ punc_infer_config: Optional[str] = None,
+ punc_model_file: Optional[str] = None,
+ sv_model_file: Optional[str] = None,
+ streaming: bool = False,
+ embedding_node: str = "resnet1_dense",
+ sv_threshold: float = 0.9465,
+ outputs_dict: Optional[bool] = True,
+ param_dict: dict = None,
+
+ **kwargs,
+):
+ ncpu = kwargs.get("ncpu", 1)
+ torch.set_num_threads(ncpu)
+
+ if word_lm_train_config is not None:
+ raise NotImplementedError("Word LM is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+ else:
+ hotword_list_or_file = None
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speech2vadsegment_kwargs = dict(
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ # logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+ speech2vadsegment = Speech2VadSegment(**speech2vadsegment_kwargs)
+
+ # 3. Build speech2text
+ speech2text_kwargs = dict(
+ asr_train_config=asr_train_config,
+ asr_model_file=asr_model_file,
+ cmvn_file=cmvn_file,
+ lm_train_config=lm_train_config,
+ lm_file=lm_file,
+ token_type=token_type,
+ bpemodel=bpemodel,
+ device=device,
+ maxlenratio=maxlenratio,
+ minlenratio=minlenratio,
+ dtype=dtype,
+ beam_size=beam_size,
+ ctc_weight=ctc_weight,
+ lm_weight=lm_weight,
+ ngram_weight=ngram_weight,
+ penalty=penalty,
+ nbest=nbest,
+ hotword_list_or_file=hotword_list_or_file,
+ )
+ speech2text = Speech2TextParaformer(**speech2text_kwargs)
+ text2punc = None
+ if punc_model_file is not None:
+ text2punc = Text2Punc(punc_infer_config, punc_model_file, device=device, dtype=dtype)
+
+ if output_dir is not None:
+ writer = DatadirWriter(output_dir)
+ ibest_writer = writer[f"1best_recog"]
+ ibest_writer["token_list"][""] = " ".join(speech2text.asr_train_args.token_list)
+
+ def _forward(data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ **kwargs,
+ ):
+
+ hotword_list_or_file = None
+ if param_dict is not None:
+ hotword_list_or_file = param_dict.get('hotword')
+
+ if 'hotword' in kwargs:
+ hotword_list_or_file = kwargs['hotword']
+
+ speech2vadsegment.vad_model.vad_opts.max_single_segment_time = kwargs.get("max_single_segment_time", 60000)
+ batch_size_token_threshold_s = kwargs.get("batch_size_token_threshold_s", int(speech2vadsegment.vad_model.vad_opts.max_single_segment_time*0.67/1000)) * 1000
+ batch_size_token = kwargs.get("batch_size_token", 6000)
+ print("batch_size_token: ", batch_size_token)
+
+ if speech2text.hotword_list is None:
+ speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file)
+
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = build_streaming_iterator(
+ task_name="asr",
+ preprocess_args=None,
+ data_path_and_name_and_type=data_path_and_name_and_type,
+ dtype=dtype,
+ fs=fs,
+ batch_size=1,
+ key_file=key_file,
+ num_workers=num_workers,
+ )
+
+ if param_dict is not None:
+ use_timestamp = param_dict.get('use_timestamp', True)
+ else:
+ use_timestamp = True
+
+ finish_count = 0
+ file_count = 1
+ lfr_factor = 6
+ # 7 .Start for-loop
+ asr_result_list = []
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ writer = None
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ ibest_writer = writer[f"1best_recog"]
+
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ beg_vad = time.time()
+ vad_results = speech2vadsegment(**batch)
+ end_vad = time.time()
+ print("time cost vad: ", end_vad - beg_vad)
+ _, vadsegments = vad_results[0], vad_results[1][0]
+ ##################################
+ ##### speaker_verification #####
+ ##################################
+ # load sv model
+ sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
+ sv_model = CAMPPlus()
+ sv_model.load_state_dict(sv_model_dict)
+ sv_model.eval()
+ cb_model = ClusterBackend()
+ vad_segments = []
+ audio = batch['speech'].numpy().reshape(-1)
+ for vadsegment in vadsegments:
+ st = int(vadsegment[0]) / 1000
+ ed = int(vadsegment[1]) / 1000
+ vad_segments.append(
+ [st, ed, audio[int(st * 16000):int(ed * 16000)]])
+ check_audio_list(vad_segments)
+ # sv pipeline
+ segments = sv_chunk(vad_segments)
+ embeddings = []
+ for s in segments:
+ #_, embs = self.sv_pipeline([s[2]], output_emb=True)
+ # embeddings.append(embs)
+ wavs = sv_preprocess([s[2]])
+ # embs = self.forward(wavs)
+ embs = []
+ for x in wavs:
+ x = extract_feature([x])
+ embs.append(sv_model(x))
+ embs = torch.cat(embs)
+ embeddings.append(embs.detach().numpy())
+ embeddings = np.concatenate(embeddings)
+ labels = cb_model(embeddings)
+ sv_output = postprocess(segments, vad_segments, labels, embeddings)
+
+ speech, speech_lengths = batch["speech"], batch["speech_lengths"]
+
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ key = keys[0]
+ # no active segments after VAD
+ if writer is not None:
+ # Write empty results
+ ibest_writer["token"][key] = ""
+ ibest_writer["token_int"][key] = ""
+ ibest_writer["vad"][key] = ""
+ ibest_writer["text"][key] = ""
+ ibest_writer["text_with_punc"][key] = ""
+ if use_timestamp:
+ ibest_writer["time_stamp"][key] = ""
+
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+ batch_size_token_ms = batch_size_token*60
+ if speech2text.device == "cpu":
+ batch_size_token_ms = 0
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size_token_ms = max(batch_size_token_ms, sorted_data[0][0][1] - sorted_data[0][0][0])
+
+ batch_size_token_ms_cum = 0
+ beg_idx = 0
+ beg_asr_total = time.time()
+ for j, _ in enumerate(tqdm(range(0, n))):
+ batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
+ if j < n - 1 and (batch_size_token_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_ms and (sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_token_threshold_s:
+ continue
+ batch_size_token_ms_cum = 0
+ end_idx = j + 1
+ speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+ beg_idx = end_idx
+ batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
+ batch = to_device(batch, device=device)
+ # print("batch: ", speech_j.shape[0])
+ beg_asr = time.time()
+ results = speech2text(**batch)
+ end_asr = time.time()
+ # print("time cost asr: ", end_asr - beg_asr)
+
+ if len(results) < 1:
+ results = [["", [], [], [], [], [], []]]
+ results_sorted.extend(results)
+ end_asr_total = time.time()
+ print("total time cost asr: ", end_asr_total-beg_asr_total)
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ restored_data[index] = results_sorted[j]
+ result = ["", [], [], [], [], [], []]
+ for j in range(n):
+ result[0] += restored_data[j][0]
+ result[1] += restored_data[j][1]
+ result[2] += restored_data[j][2]
+ if len(restored_data[j][4]) > 0:
+ for t in restored_data[j][4]:
+ t[0] += vadsegments[j][0]
+ t[1] += vadsegments[j][0]
+ result[4] += restored_data[j][4]
+ # result = [result[k]+restored_data[j][k] for k in range(len(result[:-2]))]
+
+ key = keys[0]
+ # result = result_segments[0]
+ text, token, token_int = result[0], result[1], result[2]
+ time_stamp = result[4] if len(result[4]) > 0 else None
+
+ if use_timestamp and time_stamp is not None and len(time_stamp):
+ postprocessed_result = postprocess_utils.sentence_postprocess(token, time_stamp)
+ else:
+ postprocessed_result = postprocess_utils.sentence_postprocess(token)
+ text_postprocessed = ""
+ time_stamp_postprocessed = ""
+ text_postprocessed_punc = postprocessed_result
+ if len(postprocessed_result) == 3:
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocessed_result[0], \
+ postprocessed_result[1], \
+ postprocessed_result[2]
+ else:
+ text_postprocessed, word_lists = postprocessed_result[0], postprocessed_result[1]
+
+ text_postprocessed_punc = text_postprocessed
+ punc_id_list = []
+ if len(word_lists) > 0 and text2punc is not None:
+ beg_punc = time.time()
+ text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20)
+ end_punc = time.time()
+ print("time cost punc: ", end_punc - beg_punc)
+
+ item = {'key': key, 'value': text_postprocessed_punc}
+ if text_postprocessed != "":
+ item['text_postprocessed'] = text_postprocessed
+ if time_stamp_postprocessed != "":
+ item['time_stamp'] = time_stamp_postprocessed
+
+ item['sentences'] = time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed)
+
+ asr_result_list.append(item)
+ finish_count += 1
+ # asr_utils.print_progress(finish_count / file_count)
+ if writer is not None:
+ # Write the result to each file
+ ibest_writer["token"][key] = " ".join(token)
+ ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+ ibest_writer["vad"][key] = "{}".format(vadsegments)
+ ibest_writer["text"][key] = " ".join(word_lists)
+ ibest_writer["text_with_punc"][key] = text_postprocessed_punc
+ if time_stamp_postprocessed is not None:
+ ibest_writer["time_stamp"][key] = "{}".format(time_stamp_postprocessed)
+
+ logging.info("decoding, utt: {}, predictions: {}".format(key, text_postprocessed_punc))
+ torch.cuda.empty_cache()
+ distribute_spk(asr_result_list[0]['sentences'], sv_output)
+ import pdb; pdb.set_trace()
return asr_result_list
return _forward
@@ -1684,6 +2030,8 @@
return inference_paraformer(**kwargs)
elif mode == "paraformer_streaming":
return inference_paraformer_online(**kwargs)
+ elif mode == "paraformer_vad_speaker":
+ return inference_paraformer_vad_speaker(**kwargs)
elif mode.startswith("paraformer_vad"):
return inference_paraformer_vad_punc(**kwargs)
elif mode == "mfcca":
@@ -1782,6 +2130,16 @@
help="VAD model parameter file",
)
group.add_argument(
+ "--punc_infer_config",
+ type=str,
+ help="PUNC infer configuration",
+ )
+ group.add_argument(
+ "--punc_model_file",
+ type=str,
+ help="PUNC model parameter file",
+ )
+ group.add_argument(
"--cmvn_file",
type=str,
help="Global CMVN file",
@@ -1797,6 +2155,11 @@
help="ASR model parameter file",
)
group.add_argument(
+ "--sv_model_file",
+ type=str,
+ help="SV model parameter file",
+ )
+ group.add_argument(
"--lm_train_config",
type=str,
help="LM training configuration",
diff --git a/funasr/utils/cluster_backend.py b/funasr/utils/cluster_backend.py
new file mode 100644
index 0000000..47b45d2
--- /dev/null
+++ b/funasr/utils/cluster_backend.py
@@ -0,0 +1,191 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+from typing import Any, Dict, Union
+
+import hdbscan
+import numpy as np
+import scipy
+import sklearn
+import umap
+from sklearn.cluster._kmeans import k_means
+from torch import nn
+
+
+class SpectralCluster:
+ r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
+ This implementation is adapted from https://github.com/speechbrain/speechbrain.
+ """
+
+ def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
+ self.min_num_spks = min_num_spks
+ self.max_num_spks = max_num_spks
+ self.pval = pval
+
+ def __call__(self, X, oracle_num=None):
+ # Similarity matrix computation
+ sim_mat = self.get_sim_mat(X)
+
+ # Refining similarity matrix with pval
+ prunned_sim_mat = self.p_pruning(sim_mat)
+
+ # Symmetrization
+ sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
+
+ # Laplacian calculation
+ laplacian = self.get_laplacian(sym_prund_sim_mat)
+
+ # Get Spectral Embeddings
+ emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
+
+ # Perform clustering
+ labels = self.cluster_embs(emb, num_of_spk)
+
+ return labels
+
+ def get_sim_mat(self, X):
+ # Cosine similarities
+ M = sklearn.metrics.pairwise.cosine_similarity(X, X)
+ return M
+
+ def p_pruning(self, A):
+ if A.shape[0] * self.pval < 6:
+ pval = 6. / A.shape[0]
+ else:
+ pval = self.pval
+
+ n_elems = int((1 - pval) * A.shape[0])
+
+ # For each row in a affinity matrix
+ for i in range(A.shape[0]):
+ low_indexes = np.argsort(A[i, :])
+ low_indexes = low_indexes[0:n_elems]
+
+ # Replace smaller similarity values by 0s
+ A[i, low_indexes] = 0
+ return A
+
+ def get_laplacian(self, M):
+ M[np.diag_indices(M.shape[0])] = 0
+ D = np.sum(np.abs(M), axis=1)
+ D = np.diag(D)
+ L = D - M
+ return L
+
+ def get_spec_embs(self, L, k_oracle=None):
+ lambdas, eig_vecs = scipy.linalg.eigh(L)
+
+ if k_oracle is not None:
+ num_of_spk = k_oracle
+ else:
+ lambda_gap_list = self.getEigenGaps(
+ lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
+ num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
+
+ emb = eig_vecs[:, :num_of_spk]
+ return emb, num_of_spk
+
+ def cluster_embs(self, emb, k):
+ _, labels, _ = k_means(emb, k)
+ return labels
+
+ def getEigenGaps(self, eig_vals):
+ eig_vals_gap_list = []
+ for i in range(len(eig_vals) - 1):
+ gap = float(eig_vals[i + 1]) - float(eig_vals[i])
+ eig_vals_gap_list.append(gap)
+ return eig_vals_gap_list
+
+
+class UmapHdbscan:
+ r"""
+ Reference:
+ - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
+ Emphasis On Topological Structure. ICASSP2022
+ """
+
+ def __init__(self,
+ n_neighbors=20,
+ n_components=60,
+ min_samples=10,
+ min_cluster_size=10,
+ metric='cosine'):
+ self.n_neighbors = n_neighbors
+ self.n_components = n_components
+ self.min_samples = min_samples
+ self.min_cluster_size = min_cluster_size
+ self.metric = metric
+
+ def __call__(self, X):
+ umap_X = umap.UMAP(
+ n_neighbors=self.n_neighbors,
+ min_dist=0.0,
+ n_components=min(self.n_components, X.shape[0] - 2),
+ metric=self.metric,
+ ).fit_transform(X)
+ labels = hdbscan.HDBSCAN(
+ min_samples=self.min_samples,
+ min_cluster_size=self.min_cluster_size,
+ allow_single_cluster=True).fit_predict(umap_X)
+ return labels
+
+
+class ClusterBackend(nn.Module):
+ r"""Perfom clustering for input embeddings and output the labels.
+ Args:
+ model_dir: A model dir.
+ model_config: The model config.
+ """
+
+ def __init__(self):
+ super().__init__()
+ self.model_config = {'merge_thr':0.78}
+ # self.other_config = kwargs
+
+ self.spectral_cluster = SpectralCluster()
+ self.umap_hdbscan_cluster = UmapHdbscan()
+
+ def forward(self, X, **params):
+ # clustering and return the labels
+ k = params['oracle_num'] if 'oracle_num' in params else None
+ assert len(
+ X.shape
+ ) == 2, 'modelscope error: the shape of input should be [N, C]'
+ if X.shape[0] < 20:
+ return np.zeros(X.shape[0], dtype='int')
+ if X.shape[0] < 2048 or k is not None:
+ labels = self.spectral_cluster(X, k)
+ else:
+ labels = self.umap_hdbscan_cluster(X)
+
+ if k is None and 'merge_thr' in self.model_config:
+ labels = self.merge_by_cos(labels, X,
+ self.model_config['merge_thr'])
+
+ return labels
+
+ def merge_by_cos(self, labels, embs, cos_thr):
+ # merge the similar speakers by cosine similarity
+ assert cos_thr > 0 and cos_thr <= 1
+ while True:
+ spk_num = labels.max() + 1
+ if spk_num == 1:
+ break
+ spk_center = []
+ for i in range(spk_num):
+ spk_emb = embs[labels == i].mean(0)
+ spk_center.append(spk_emb)
+ assert len(spk_center) > 0
+ spk_center = np.stack(spk_center, axis=0)
+ norm_spk_center = spk_center / np.linalg.norm(
+ spk_center, axis=1, keepdims=True)
+ affinity = np.matmul(norm_spk_center, norm_spk_center.T)
+ affinity = np.triu(affinity, 1)
+ spks = np.unravel_index(np.argmax(affinity), affinity.shape)
+ if affinity[spks] < cos_thr:
+ break
+ for i in range(len(labels)):
+ if labels[i] == spks[1]:
+ labels[i] = spks[0]
+ elif labels[i] > spks[1]:
+ labels[i] -= 1
+ return labels
diff --git a/funasr/utils/modelscope_file.py b/funasr/utils/modelscope_file.py
new file mode 100644
index 0000000..d93f24c
--- /dev/null
+++ b/funasr/utils/modelscope_file.py
@@ -0,0 +1,328 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+
+import contextlib
+import os
+import tempfile
+from abc import ABCMeta, abstractmethod
+from pathlib import Path
+from typing import Generator, Union
+
+import requests
+
+
+class Storage(metaclass=ABCMeta):
+ """Abstract class of storage.
+
+ All backends need to implement two apis: ``read()`` and ``read_text()``.
+ ``read()`` reads the file as a byte stream and ``read_text()`` reads
+ the file as texts.
+ """
+
+ @abstractmethod
+ def read(self, filepath: str):
+ pass
+
+ @abstractmethod
+ def read_text(self, filepath: str):
+ pass
+
+ @abstractmethod
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ pass
+
+ @abstractmethod
+ def write_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ pass
+
+
+class LocalStorage(Storage):
+ """Local hard disk storage"""
+
+ def read(self, filepath: Union[str, Path]) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ with open(filepath, 'rb') as f:
+ content = f.read()
+ return content
+
+ def read_text(self,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ with open(filepath, 'r', encoding=encoding) as f:
+ value_buf = f.read()
+ return value_buf
+
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``write`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ dirname = os.path.dirname(filepath)
+ if dirname and not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ with open(filepath, 'wb') as f:
+ f.write(obj)
+
+ def write_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``write_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ dirname = os.path.dirname(filepath)
+ if dirname and not os.path.exists(dirname):
+ os.makedirs(dirname, exist_ok=True)
+
+ with open(filepath, 'w', encoding=encoding) as f:
+ f.write(obj)
+
+ @contextlib.contextmanager
+ def as_local_path(
+ self,
+ filepath: Union[str,
+ Path]) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ yield filepath
+
+
+class HTTPStorage(Storage):
+ """HTTP and HTTPS storage."""
+
+ def read(self, url):
+ # TODO @wenmeng.zwm add progress bar if file is too large
+ r = requests.get(url)
+ r.raise_for_status()
+ return r.content
+
+ def read_text(self, url):
+ r = requests.get(url)
+ r.raise_for_status()
+ return r.text
+
+ @contextlib.contextmanager
+ def as_local_path(
+ self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> storage = HTTPStorage()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with storage.get_local_path('http://path/to/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.read(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def write(self, obj: bytes, url: Union[str, Path]) -> None:
+ raise NotImplementedError('write is not supported by HTTP Storage')
+
+ def write_text(self,
+ obj: str,
+ url: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ raise NotImplementedError(
+ 'write_text is not supported by HTTP Storage')
+
+
+class OSSStorage(Storage):
+ """OSS storage."""
+
+ def __init__(self, oss_config_file=None):
+ # read from config file or env var
+ raise NotImplementedError(
+ 'OSSStorage.__init__ to be implemented in the future')
+
+ def read(self, filepath):
+ raise NotImplementedError(
+ 'OSSStorage.read to be implemented in the future')
+
+ def read_text(self, filepath, encoding='utf-8'):
+ raise NotImplementedError(
+ 'OSSStorage.read_text to be implemented in the future')
+
+ @contextlib.contextmanager
+ def as_local_path(
+ self, filepath: str) -> Generator[Union[str, Path], None, None]:
+ """Download a file from ``filepath``.
+
+ ``as_local_path`` is decorated by :meth:`contextlib.contextmanager`. It
+ can be called with ``with`` statement, and when exists from the
+ ``with`` statement, the temporary path will be released.
+
+ Args:
+ filepath (str): Download a file from ``filepath``.
+
+ Examples:
+ >>> storage = OSSStorage()
+ >>> # After existing from the ``with`` clause,
+ >>> # the path will be removed
+ >>> with storage.get_local_path('http://path/to/file') as path:
+ ... # do something here
+ """
+ try:
+ f = tempfile.NamedTemporaryFile(delete=False)
+ f.write(self.read(filepath))
+ f.close()
+ yield f.name
+ finally:
+ os.remove(f.name)
+
+ def write(self, obj: bytes, filepath: Union[str, Path]) -> None:
+ raise NotImplementedError(
+ 'OSSStorage.write to be implemented in the future')
+
+ def write_text(self,
+ obj: str,
+ filepath: Union[str, Path],
+ encoding: str = 'utf-8') -> None:
+ raise NotImplementedError(
+ 'OSSStorage.write_text to be implemented in the future')
+
+
+G_STORAGES = {}
+
+
+class File(object):
+ _prefix_to_storage: dict = {
+ 'oss': OSSStorage,
+ 'http': HTTPStorage,
+ 'https': HTTPStorage,
+ 'local': LocalStorage,
+ }
+
+ @staticmethod
+ def _get_storage(uri):
+ assert isinstance(uri,
+ str), f'uri should be str type, but got {type(uri)}'
+
+ if '://' not in uri:
+ # local path
+ storage_type = 'local'
+ else:
+ prefix, _ = uri.split('://')
+ storage_type = prefix
+
+ assert storage_type in File._prefix_to_storage, \
+ f'Unsupported uri {uri}, valid prefixs: '\
+ f'{list(File._prefix_to_storage.keys())}'
+
+ if storage_type not in G_STORAGES:
+ G_STORAGES[storage_type] = File._prefix_to_storage[storage_type]()
+
+ return G_STORAGES[storage_type]
+
+ @staticmethod
+ def read(uri: str) -> bytes:
+ """Read data from a given ``filepath`` with 'rb' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+
+ Returns:
+ bytes: Expected bytes object.
+ """
+ storage = File._get_storage(uri)
+ return storage.read(uri)
+
+ @staticmethod
+ def read_text(uri: Union[str, Path], encoding: str = 'utf-8') -> str:
+ """Read data from a given ``filepath`` with 'r' mode.
+
+ Args:
+ filepath (str or Path): Path to read data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+
+ Returns:
+ str: Expected text reading from ``filepath``.
+ """
+ storage = File._get_storage(uri)
+ return storage.read_text(uri)
+
+ @staticmethod
+ def write(obj: bytes, uri: Union[str, Path]) -> None:
+ """Write data to a given ``filepath`` with 'wb' mode.
+
+ Note:
+ ``write`` will create a directory if the directory of ``filepath``
+ does not exist.
+
+ Args:
+ obj (bytes): Data to be written.
+ filepath (str or Path): Path to write data.
+ """
+ storage = File._get_storage(uri)
+ return storage.write(obj, uri)
+
+ @staticmethod
+ def write_text(obj: str, uri: str, encoding: str = 'utf-8') -> None:
+ """Write data to a given ``filepath`` with 'w' mode.
+
+ Note:
+ ``write_text`` will create a directory if the directory of
+ ``filepath`` does not exist.
+
+ Args:
+ obj (str): Data to be written.
+ filepath (str or Path): Path to write data.
+ encoding (str): The encoding format used to open the ``filepath``.
+ Default: 'utf-8'.
+ """
+ storage = File._get_storage(uri)
+ return storage.write_text(obj, uri)
+
+ @contextlib.contextmanager
+ def as_local_path(uri: str) -> Generator[Union[str, Path], None, None]:
+ """Only for unified API and do nothing."""
+ storage = File._get_storage(uri)
+ with storage.as_local_path(uri) as local_path:
+ yield local_path
diff --git a/funasr/utils/speaker_utils.py b/funasr/utils/speaker_utils.py
new file mode 100644
index 0000000..a1c610f
--- /dev/null
+++ b/funasr/utils/speaker_utils.py
@@ -0,0 +1,592 @@
+# Copyright (c) Alibaba, Inc. and its affiliates.
+""" Some implementations are adapted from https://github.com/yuyq96/D-TDNN
+"""
+
+import torch
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from torch import nn
+
+import io
+import os
+from typing import Any, Dict, List, Union
+
+import numpy as np
+import soundfile as sf
+import torch
+import torchaudio
+import logging
+from funasr.utils.modelscope_file import File
+from collections import OrderedDict
+import torchaudio.compliance.kaldi as Kaldi
+
+
+def check_audio_list(audio: list):
+ audio_dur = 0
+ for i in range(len(audio)):
+ seg = audio[i]
+ assert seg[1] >= seg[0], 'modelscope error: Wrong time stamps.'
+ assert isinstance(seg[2], np.ndarray), 'modelscope error: Wrong data type.'
+ assert int(seg[1] * 16000) - int(
+ seg[0] * 16000
+ ) == seg[2].shape[
+ 0], 'modelscope error: audio data in list is inconsistent with time length.'
+ if i > 0:
+ assert seg[0] >= audio[
+ i - 1][1], 'modelscope error: Wrong time stamps.'
+ audio_dur += seg[1] - seg[0]
+ assert audio_dur > 5, 'modelscope error: The effective audio duration is too short.'
+
+
+def sv_preprocess(inputs: Union[np.ndarray, list]):
+ output = []
+ for i in range(len(inputs)):
+ if isinstance(inputs[i], str):
+ file_bytes = File.read(inputs[i])
+ data, fs = sf.read(io.BytesIO(file_bytes), dtype='float32')
+ if len(data.shape) == 2:
+ data = data[:, 0]
+ data = torch.from_numpy(data).unsqueeze(0)
+ data = data.squeeze(0)
+ elif isinstance(inputs[i], np.ndarray):
+ assert len(
+ inputs[i].shape
+ ) == 1, 'modelscope error: Input array should be [N, T]'
+ data = inputs[i]
+ if data.dtype in ['int16', 'int32', 'int64']:
+ data = (data / (1 << 15)).astype('float32')
+ else:
+ data = data.astype('float32')
+ data = torch.from_numpy(data)
+ else:
+ raise ValueError(
+ 'modelscope error: The input type is restricted to audio address and nump array.'
+ )
+ output.append(data)
+ return output
+
+
+def sv_chunk(vad_segments: list, fs = 16000) -> list:
+ config = {
+ 'seg_dur': 1.5,
+ 'seg_shift': 0.75,
+ }
+ def seg_chunk(seg_data):
+ seg_st = seg_data[0]
+ data = seg_data[2]
+ chunk_len = int(config['seg_dur'] * fs)
+ chunk_shift = int(config['seg_shift'] * fs)
+ last_chunk_ed = 0
+ seg_res = []
+ for chunk_st in range(0, data.shape[0], chunk_shift):
+ chunk_ed = min(chunk_st + chunk_len, data.shape[0])
+ if chunk_ed <= last_chunk_ed:
+ break
+ last_chunk_ed = chunk_ed
+ chunk_st = max(0, chunk_ed - chunk_len)
+ chunk_data = data[chunk_st:chunk_ed]
+ if chunk_data.shape[0] < chunk_len:
+ chunk_data = np.pad(chunk_data,
+ (0, chunk_len - chunk_data.shape[0]),
+ 'constant')
+ seg_res.append([
+ chunk_st / fs + seg_st, chunk_ed / fs + seg_st,
+ chunk_data
+ ])
+ return seg_res
+
+ segs = []
+ for i, s in enumerate(vad_segments):
+ segs.extend(seg_chunk(s))
+
+ return segs
+
+
+class BasicResBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self, in_planes, planes, stride=1):
+ super(BasicResBlock, self).__init__()
+ self.conv1 = nn.Conv2d(
+ in_planes,
+ planes,
+ kernel_size=3,
+ stride=(stride, 1),
+ padding=1,
+ bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+ self.conv2 = nn.Conv2d(
+ planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.shortcut = nn.Sequential()
+ if stride != 1 or in_planes != self.expansion * planes:
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(
+ in_planes,
+ self.expansion * planes,
+ kernel_size=1,
+ stride=(stride, 1),
+ bias=False), nn.BatchNorm2d(self.expansion * planes))
+
+ def forward(self, x):
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.bn2(self.conv2(out))
+ out += self.shortcut(x)
+ out = F.relu(out)
+ return out
+
+
+class FCM(nn.Module):
+
+ def __init__(self,
+ block=BasicResBlock,
+ num_blocks=[2, 2],
+ m_channels=32,
+ feat_dim=80):
+ super(FCM, self).__init__()
+ self.in_planes = m_channels
+ self.conv1 = nn.Conv2d(
+ 1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(m_channels)
+
+ self.layer1 = self._make_layer(
+ block, m_channels, num_blocks[0], stride=2)
+ self.layer2 = self._make_layer(
+ block, m_channels, num_blocks[0], stride=2)
+
+ self.conv2 = nn.Conv2d(
+ m_channels,
+ m_channels,
+ kernel_size=3,
+ stride=(2, 1),
+ padding=1,
+ bias=False)
+ self.bn2 = nn.BatchNorm2d(m_channels)
+ self.out_channels = m_channels * (feat_dim // 8)
+
+ def _make_layer(self, block, planes, num_blocks, stride):
+ strides = [stride] + [1] * (num_blocks - 1)
+ layers = []
+ for stride in strides:
+ layers.append(block(self.in_planes, planes, stride))
+ self.in_planes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ x = x.unsqueeze(1)
+ out = F.relu(self.bn1(self.conv1(x)))
+ out = self.layer1(out)
+ out = self.layer2(out)
+ out = F.relu(self.bn2(self.conv2(out)))
+
+ shape = out.shape
+ out = out.reshape(shape[0], shape[1] * shape[2], shape[3])
+ return out
+
+
+class CAMPPlus(nn.Module):
+
+ def __init__(self,
+ feat_dim=80,
+ embedding_size=192,
+ growth_rate=32,
+ bn_size=4,
+ init_channels=128,
+ config_str='batchnorm-relu',
+ memory_efficient=True,
+ output_level='segment'):
+ super(CAMPPlus, self).__init__()
+
+ self.head = FCM(feat_dim=feat_dim)
+ channels = self.head.out_channels
+ self.output_level = output_level
+
+ self.xvector = nn.Sequential(
+ OrderedDict([
+ ('tdnn',
+ TDNNLayer(
+ channels,
+ init_channels,
+ 5,
+ stride=2,
+ dilation=1,
+ padding=-1,
+ config_str=config_str)),
+ ]))
+ channels = init_channels
+ for i, (num_layers, kernel_size, dilation) in enumerate(
+ zip((12, 24, 16), (3, 3, 3), (1, 2, 2))):
+ block = CAMDenseTDNNBlock(
+ num_layers=num_layers,
+ in_channels=channels,
+ out_channels=growth_rate,
+ bn_channels=bn_size * growth_rate,
+ kernel_size=kernel_size,
+ dilation=dilation,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.xvector.add_module('block%d' % (i + 1), block)
+ channels = channels + num_layers * growth_rate
+ self.xvector.add_module(
+ 'transit%d' % (i + 1),
+ TransitLayer(
+ channels, channels // 2, bias=False,
+ config_str=config_str))
+ channels //= 2
+
+ self.xvector.add_module('out_nonlinear',
+ get_nonlinear(config_str, channels))
+
+ if self.output_level == 'segment':
+ self.xvector.add_module('stats', StatsPool())
+ self.xvector.add_module(
+ 'dense',
+ DenseLayer(
+ channels * 2, embedding_size, config_str='batchnorm_'))
+ else:
+ assert self.output_level == 'frame', '`output_level` should be set to \'segment\' or \'frame\'. '
+
+ for m in self.modules():
+ if isinstance(m, (nn.Conv1d, nn.Linear)):
+ nn.init.kaiming_normal_(m.weight.data)
+ if m.bias is not None:
+ nn.init.zeros_(m.bias)
+
+ def forward(self, x):
+ x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
+ x = self.head(x)
+ x = self.xvector(x)
+ if self.output_level == 'frame':
+ x = x.transpose(1, 2)
+ return x
+
+
+def get_nonlinear(config_str, channels):
+ nonlinear = nn.Sequential()
+ for name in config_str.split('-'):
+ if name == 'relu':
+ nonlinear.add_module('relu', nn.ReLU(inplace=True))
+ elif name == 'prelu':
+ nonlinear.add_module('prelu', nn.PReLU(channels))
+ elif name == 'batchnorm':
+ nonlinear.add_module('batchnorm', nn.BatchNorm1d(channels))
+ elif name == 'batchnorm_':
+ nonlinear.add_module('batchnorm',
+ nn.BatchNorm1d(channels, affine=False))
+ else:
+ raise ValueError('Unexpected module ({}).'.format(name))
+ return nonlinear
+
+
+def statistics_pooling(x, dim=-1, keepdim=False, unbiased=True, eps=1e-2):
+ mean = x.mean(dim=dim)
+ std = x.std(dim=dim, unbiased=unbiased)
+ stats = torch.cat([mean, std], dim=-1)
+ if keepdim:
+ stats = stats.unsqueeze(dim=dim)
+ return stats
+
+
+class StatsPool(nn.Module):
+
+ def forward(self, x):
+ return statistics_pooling(x)
+
+
+class TDNNLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(TDNNLayer, self).__init__()
+ if padding < 0:
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.linear = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+
+def extract_feature(audio):
+ features = []
+ for au in audio:
+ feature = Kaldi.fbank(
+ au.unsqueeze(0), num_mel_bins=80)
+ feature = feature - feature.mean(dim=0, keepdim=True)
+ features.append(feature.unsqueeze(0))
+ features = torch.cat(features)
+ return features
+
+
+class CAMLayer(nn.Module):
+
+ def __init__(self,
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ bias,
+ reduction=2):
+ super(CAMLayer, self).__init__()
+ self.linear_local = nn.Conv1d(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+ self.linear1 = nn.Conv1d(bn_channels, bn_channels // reduction, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.linear2 = nn.Conv1d(bn_channels // reduction, out_channels, 1)
+ self.sigmoid = nn.Sigmoid()
+
+ def forward(self, x):
+ y = self.linear_local(x)
+ context = x.mean(-1, keepdim=True) + self.seg_pooling(x)
+ context = self.relu(self.linear1(context))
+ m = self.sigmoid(self.linear2(context))
+ return y * m
+
+ def seg_pooling(self, x, seg_len=100, stype='avg'):
+ if stype == 'avg':
+ seg = F.avg_pool1d(
+ x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ elif stype == 'max':
+ seg = F.max_pool1d(
+ x, kernel_size=seg_len, stride=seg_len, ceil_mode=True)
+ else:
+ raise ValueError('Wrong segment pooling type.')
+ shape = seg.shape
+ seg = seg.unsqueeze(-1).expand(*shape,
+ seg_len).reshape(*shape[:-1], -1)
+ seg = seg[..., :x.shape[-1]]
+ return seg
+
+
+class CAMDenseTDNNLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNLayer, self).__init__()
+ assert kernel_size % 2 == 1, 'Expect equal paddings, but got even kernel size ({})'.format(
+ kernel_size)
+ padding = (kernel_size - 1) // 2 * dilation
+ self.memory_efficient = memory_efficient
+ self.nonlinear1 = get_nonlinear(config_str, in_channels)
+ self.linear1 = nn.Conv1d(in_channels, bn_channels, 1, bias=False)
+ self.nonlinear2 = get_nonlinear(config_str, bn_channels)
+ self.cam_layer = CAMLayer(
+ bn_channels,
+ out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ bias=bias)
+
+ def bn_function(self, x):
+ return self.linear1(self.nonlinear1(x))
+
+ def forward(self, x):
+ if self.training and self.memory_efficient:
+ x = cp.checkpoint(self.bn_function, x)
+ else:
+ x = self.bn_function(x)
+ x = self.cam_layer(self.nonlinear2(x))
+ return x
+
+
+class CAMDenseTDNNBlock(nn.ModuleList):
+
+ def __init__(self,
+ num_layers,
+ in_channels,
+ out_channels,
+ bn_channels,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ bias=False,
+ config_str='batchnorm-relu',
+ memory_efficient=False):
+ super(CAMDenseTDNNBlock, self).__init__()
+ for i in range(num_layers):
+ layer = CAMDenseTDNNLayer(
+ in_channels=in_channels + i * out_channels,
+ out_channels=out_channels,
+ bn_channels=bn_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation,
+ bias=bias,
+ config_str=config_str,
+ memory_efficient=memory_efficient)
+ self.add_module('tdnnd%d' % (i + 1), layer)
+
+ def forward(self, x):
+ for layer in self:
+ x = torch.cat([x, layer(x)], dim=1)
+ return x
+
+
+class TransitLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=True,
+ config_str='batchnorm-relu'):
+ super(TransitLayer, self).__init__()
+ self.nonlinear = get_nonlinear(config_str, in_channels)
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+
+ def forward(self, x):
+ x = self.nonlinear(x)
+ x = self.linear(x)
+ return x
+
+
+class DenseLayer(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ bias=False,
+ config_str='batchnorm-relu'):
+ super(DenseLayer, self).__init__()
+ self.linear = nn.Conv1d(in_channels, out_channels, 1, bias=bias)
+ self.nonlinear = get_nonlinear(config_str, out_channels)
+
+ def forward(self, x):
+ if len(x.shape) == 2:
+ x = self.linear(x.unsqueeze(dim=-1)).squeeze(dim=-1)
+ else:
+ x = self.linear(x)
+ x = self.nonlinear(x)
+ return x
+
+def postprocess(segments: list, vad_segments: list,
+ labels: np.ndarray, embeddings: np.ndarray) -> list:
+ assert len(segments) == len(labels)
+ labels = correct_labels(labels)
+ distribute_res = []
+ for i in range(len(segments)):
+ distribute_res.append([segments[i][0], segments[i][1], labels[i]])
+ # merge the same speakers chronologically
+ distribute_res = merge_seque(distribute_res)
+
+ # accquire speaker center
+ spk_embs = []
+ for i in range(labels.max() + 1):
+ spk_emb = embeddings[labels == i].mean(0)
+ spk_embs.append(spk_emb)
+ spk_embs = np.stack(spk_embs)
+
+ def is_overlapped(t1, t2):
+ if t1 > t2 + 1e-4:
+ return True
+ return False
+
+ # distribute the overlap region
+ for i in range(1, len(distribute_res)):
+ if is_overlapped(distribute_res[i - 1][1], distribute_res[i][0]):
+ p = (distribute_res[i][0] + distribute_res[i - 1][1]) / 2
+ distribute_res[i][0] = p
+ distribute_res[i - 1][1] = p
+
+ # smooth the result
+ distribute_res = smooth(distribute_res)
+
+ return distribute_res
+
+
+def correct_labels(labels):
+ labels_id = 0
+ id2id = {}
+ new_labels = []
+ for i in labels:
+ if i not in id2id:
+ id2id[i] = labels_id
+ labels_id += 1
+ new_labels.append(id2id[i])
+ return np.array(new_labels)
+
+def merge_seque(distribute_res):
+ res = [distribute_res[0]]
+ for i in range(1, len(distribute_res)):
+ if distribute_res[i][2] != res[-1][2] or distribute_res[i][
+ 0] > res[-1][1]:
+ res.append(distribute_res[i])
+ else:
+ res[-1][1] = distribute_res[i][1]
+ return res
+
+def smooth(res, mindur=1):
+ # short segments are assigned to nearest speakers.
+ for i in range(len(res)):
+ res[i][0] = round(res[i][0], 2)
+ res[i][1] = round(res[i][1], 2)
+ if res[i][1] - res[i][0] < mindur:
+ if i == 0:
+ res[i][2] = res[i + 1][2]
+ elif i == len(res) - 1:
+ res[i][2] = res[i - 1][2]
+ elif res[i][0] - res[i - 1][1] <= res[i + 1][0] - res[i][1]:
+ res[i][2] = res[i - 1][2]
+ else:
+ res[i][2] = res[i + 1][2]
+ # merge the speakers
+ res = merge_seque(res)
+
+ return res
+
+
+def distribute_spk(sentence_list, sd_time_list):
+ sd_sentence_list = []
+ for d in sentence_list:
+ sentence_start = d['ts_list'][0][0]
+ sentence_end = d['ts_list'][-1][1]
+ sentence_spk = 0
+ max_overlap = 0
+ for sd_time in sd_time_list:
+ spk_st, spk_ed, spk = sd_time
+ spk_st = spk_st*1000
+ spk_ed = spk_ed*1000
+ overlap = max(
+ min(sentence_end, spk_ed) - max(sentence_start, spk_st), 0)
+ if overlap > max_overlap:
+ max_overlap = overlap
+ sentence_spk = spk
+ d['spk'] = sentence_spk
+ sd_sentence_list.append(d)
+ return sd_sentence_list
\ No newline at end of file
--
Gitblit v1.9.1