shixian.shi
2024-01-11 70379713923e9938dcfcb3791b17f7b469233432
update asr with speaker
6个文件已修改
1个文件已添加
608 ■■■■■ 已修改文件
examples/industrial_data_pretraining/bicif_paraformer/demo.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/inference.py 134 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/campplus/cluster_backend.py 191 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/campplus/model.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/campplus/utils.py 40 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/ct_transformer/model.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 209 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/bicif_paraformer/demo.py
@@ -11,7 +11,23 @@
                    vad_model_revision="v2.0.0",
                    punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                    punc_model_revision="v2.0.0",
                    spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
                  )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_vad_punc_example.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''try asr with speaker label with
model = AutoModel(model="damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    model_revision="v2.0.0",
                    vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    vad_model_revision="v2.0.0",
                    punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                    punc_model_revision="v2.0.0",
                    spk_model="/Users/shixian/code/modelscope_models/speech_campplus_sv_zh-cn_16k-common",
                    spk_mode='punc_segment',
                  )
res = model(input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_speaker_demo.wav", batch_size_s=300, batch_size_threshold_s=60)
print(res)
'''
funasr/bin/inference.py
@@ -1,26 +1,26 @@
import os.path
import torch
import numpy as np
import hydra
import json
from omegaconf import DictConfig, OmegaConf, ListConfig
import logging
from funasr.download.download_from_hub import download_model
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.utils.load_utils import load_bytes
from funasr.train_utils.device_funcs import to_device
from tqdm import tqdm
from funasr.train_utils.load_pretrained_model import load_pretrained_model
import time
import torch
import hydra
import random
import string
from funasr.register import tables
import logging
import os.path
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
from funasr.models.campplus.cluster_backend import ClusterBackend
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
@@ -126,13 +126,27 @@
            punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
            punc_model, punc_kwargs = self.build_model(**punc_kwargs)
            
        # if spk_model is not None, build spk model else None
        spk_model = kwargs.get("spk_model", None)
        spk_kwargs = kwargs.get("spk_model_revision", None)
        if spk_model is not None:
            spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
            spk_model, spk_kwargs = self.build_model(**spk_kwargs)
            self.cb_model = ClusterBackend()
            spk_mode = kwargs.get("spk_mode", 'punc_segment')
            if spk_mode not in ["default", "vad_segment", "punc_segment"]:
                logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
            self.spk_mode = spk_mode
            logging.warning("Many to print when using speaker model...")
        self.kwargs = kwargs
        self.model = model
        self.vad_model = vad_model
        self.vad_kwargs = vad_kwargs
        self.punc_model = punc_model
        self.punc_kwargs = punc_kwargs
        self.spk_model = spk_model
        self.spk_kwargs = spk_kwargs
        
    def build_model(self, **kwargs):
@@ -198,7 +212,6 @@
            return self.generate_with_vad(input, input_len=input_len, **cfg)
        
    def generate(self, input, input_len=None, model=None, kwargs=None, key=None, **cfg):
        # import pdb; pdb.set_trace()
        kwargs = self.kwargs if kwargs is None else kwargs
        kwargs.update(cfg)
        model = self.model if model is None else model
@@ -260,6 +273,7 @@
        kwargs.update(cfg)
        beg_vad = time.time()
        res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
        vad_res = res
        end_vad = time.time()
        print(f"time cost vad: {end_vad - beg_vad:0.3f}")
@@ -314,10 +328,20 @@
                batch_size_ms_cum = 0
                end_idx = j + 1
                speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
                beg_idx = end_idx
                results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
                if self.spk_model is not None:
                    all_segments = []
                    # compose vad segments: [[start_time_sec, end_time_sec, speech], [...]]
                    for _b in range(len(speech_j)):
                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0, \
                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0, \
                                        speech_j[_b]]]
                        segments = sv_chunk(vad_segments)
                        all_segments.extend(segments)
                        speech_b = [i[2] for i in segments]
                        spk_res = self.generate(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
                        results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
                beg_idx = end_idx
                if len(results) < 1:
                    continue
                results_sorted.extend(results)
@@ -336,39 +360,63 @@
                restored_data[index] = results_sorted[j]
            result = {}
            
            # results combine for texts, timestamps, speaker embeddings and others
            # TODO: rewrite for clean code
            for j in range(n):
                for k, v in restored_data[j].items():
                    if not k.startswith("timestamp"):
                    if k.startswith("timestamp"):
                        if k not in result:
                            result[k] = restored_data[j][k]
                        else:
                            result[k] += restored_data[j][k]
                    else:
                        result[k] = []
                        for t in restored_data[j][k]:
                            t[0] += vadsegments[j][0]
                            t[1] += vadsegments[j][0]
                        result[k].extend(restored_data[j][k])
                    elif k == 'spk_embedding':
                        if k not in result:
                            result[k] = restored_data[j][k]
                        else:
                            result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
                    elif k == 'text':
                        if k not in result:
                            result[k] = restored_data[j][k]
                        else:
                            result[k] += " " + restored_data[j][k]
                    else:
                        if k not in result:
                            result[k] = restored_data[j][k]
                        else:
                        result[k] += restored_data[j][k]
            # step.3 compute punc model
            if self.punc_model is not None:
                self.punc_kwargs.update(cfg)
                punc_res = self.generate(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                result["text_with_punc"] = punc_res[0]["text"]
            # speaker embedding cluster after resorted
            if self.spk_model is not None:
                all_segments = sorted(all_segments, key=lambda x: x[0])
                spk_embedding = result['spk_embedding']
                labels = self.cb_model(spk_embedding)
                del result['spk_embedding']
                sv_output = postprocess(all_segments, None, labels, spk_embedding)
                if self.spk_mode == 'vad_segment':
                    sentence_list = []
                    for res, vadsegment in zip(restored_data, vadsegments):
                        sentence_list.append({"start": vadsegment[0],\
                                                "end": vadsegment[1],
                                                "sentence": res['text'],
                                                "timestamp": res['timestamp']})
                else: # punc_segment
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['text'])
                distribute_spk(sentence_list, sv_output)
                result['sentence_info'] = sentence_list
                        
            result["key"] = key
            results_ret_list.append(result)
            pbar_total.update(1)
        # step.3 compute punc model
        model = self.punc_model
        kwargs = self.punc_kwargs
        kwargs.update(cfg)
        for i, result in enumerate(results_ret_list):
            beg_punc = time.time()
            res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg)
            end_punc = time.time()
            print(f"time punc: {end_punc - beg_punc:0.3f}")
            # sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"])
            # results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"]
            # results_ret_list[i]["sentences"] = sentences
            results_ret_list[i]["text_with_punc"] = res[i]["text"]
        
        pbar_total.update(1)
        end_total = time.time()
funasr/models/campplus/cluster_backend.py
New file
@@ -0,0 +1,191 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Any, Dict, Union
import hdbscan
import numpy as np
import scipy
import sklearn
import umap
from sklearn.cluster._kmeans import k_means
from torch import nn
class SpectralCluster:
    r"""A spectral clustering mehtod using unnormalized Laplacian of affinity matrix.
    This implementation is adapted from https://github.com/speechbrain/speechbrain.
    """
    def __init__(self, min_num_spks=1, max_num_spks=15, pval=0.022):
        self.min_num_spks = min_num_spks
        self.max_num_spks = max_num_spks
        self.pval = pval
    def __call__(self, X, oracle_num=None):
        # Similarity matrix computation
        sim_mat = self.get_sim_mat(X)
        # Refining similarity matrix with pval
        prunned_sim_mat = self.p_pruning(sim_mat)
        # Symmetrization
        sym_prund_sim_mat = 0.5 * (prunned_sim_mat + prunned_sim_mat.T)
        # Laplacian calculation
        laplacian = self.get_laplacian(sym_prund_sim_mat)
        # Get Spectral Embeddings
        emb, num_of_spk = self.get_spec_embs(laplacian, oracle_num)
        # Perform clustering
        labels = self.cluster_embs(emb, num_of_spk)
        return labels
    def get_sim_mat(self, X):
        # Cosine similarities
        M = sklearn.metrics.pairwise.cosine_similarity(X, X)
        return M
    def p_pruning(self, A):
        if A.shape[0] * self.pval < 6:
            pval = 6. / A.shape[0]
        else:
            pval = self.pval
        n_elems = int((1 - pval) * A.shape[0])
        # For each row in a affinity matrix
        for i in range(A.shape[0]):
            low_indexes = np.argsort(A[i, :])
            low_indexes = low_indexes[0:n_elems]
            # Replace smaller similarity values by 0s
            A[i, low_indexes] = 0
        return A
    def get_laplacian(self, M):
        M[np.diag_indices(M.shape[0])] = 0
        D = np.sum(np.abs(M), axis=1)
        D = np.diag(D)
        L = D - M
        return L
    def get_spec_embs(self, L, k_oracle=None):
        lambdas, eig_vecs = scipy.linalg.eigh(L)
        if k_oracle is not None:
            num_of_spk = k_oracle
        else:
            lambda_gap_list = self.getEigenGaps(
                lambdas[self.min_num_spks - 1:self.max_num_spks + 1])
            num_of_spk = np.argmax(lambda_gap_list) + self.min_num_spks
        emb = eig_vecs[:, :num_of_spk]
        return emb, num_of_spk
    def cluster_embs(self, emb, k):
        _, labels, _ = k_means(emb, k)
        return labels
    def getEigenGaps(self, eig_vals):
        eig_vals_gap_list = []
        for i in range(len(eig_vals) - 1):
            gap = float(eig_vals[i + 1]) - float(eig_vals[i])
            eig_vals_gap_list.append(gap)
        return eig_vals_gap_list
class UmapHdbscan:
    r"""
    Reference:
    - Siqi Zheng, Hongbin Suo. Reformulating Speaker Diarization as Community Detection With
      Emphasis On Topological Structure. ICASSP2022
    """
    def __init__(self,
                 n_neighbors=20,
                 n_components=60,
                 min_samples=10,
                 min_cluster_size=10,
                 metric='cosine'):
        self.n_neighbors = n_neighbors
        self.n_components = n_components
        self.min_samples = min_samples
        self.min_cluster_size = min_cluster_size
        self.metric = metric
    def __call__(self, X):
        umap_X = umap.UMAP(
            n_neighbors=self.n_neighbors,
            min_dist=0.0,
            n_components=min(self.n_components, X.shape[0] - 2),
            metric=self.metric,
        ).fit_transform(X)
        labels = hdbscan.HDBSCAN(
            min_samples=self.min_samples,
            min_cluster_size=self.min_cluster_size,
            allow_single_cluster=True).fit_predict(umap_X)
        return labels
class ClusterBackend(nn.Module):
    r"""Perfom clustering for input embeddings and output the labels.
    Args:
        model_dir: A model dir.
        model_config: The model config.
    """
    def __init__(self):
        super().__init__()
        self.model_config = {'merge_thr':0.78}
        # self.other_config = kwargs
        self.spectral_cluster = SpectralCluster()
        self.umap_hdbscan_cluster = UmapHdbscan()
    def forward(self, X, **params):
        # clustering and return the labels
        k = params['oracle_num'] if 'oracle_num' in params else None
        assert len(
            X.shape
        ) == 2, 'modelscope error: the shape of input should be [N, C]'
        if X.shape[0] < 20:
            return np.zeros(X.shape[0], dtype='int')
        if X.shape[0] < 2048 or k is not None:
            labels = self.spectral_cluster(X, k)
        else:
            labels = self.umap_hdbscan_cluster(X)
        if k is None and 'merge_thr' in self.model_config:
            labels = self.merge_by_cos(labels, X,
                                       self.model_config['merge_thr'])
        return labels
    def merge_by_cos(self, labels, embs, cos_thr):
        # merge the similar speakers by cosine similarity
        assert cos_thr > 0 and cos_thr <= 1
        while True:
            spk_num = labels.max() + 1
            if spk_num == 1:
                break
            spk_center = []
            for i in range(spk_num):
                spk_emb = embs[labels == i].mean(0)
                spk_center.append(spk_emb)
            assert len(spk_center) > 0
            spk_center = np.stack(spk_center, axis=0)
            norm_spk_center = spk_center / np.linalg.norm(
                spk_center, axis=1, keepdims=True)
            affinity = np.matmul(norm_spk_center, norm_spk_center.T)
            affinity = np.triu(affinity, 1)
            spks = np.unravel_index(np.argmax(affinity), affinity.shape)
            if affinity[spks] < cos_thr:
                break
            for i in range(len(labels)):
                if labels[i] == spks[1]:
                    labels[i] = spks[0]
                elif labels[i] > spks[1]:
                    labels[i] -= 1
        return labels
funasr/models/campplus/model.py
@@ -109,13 +109,9 @@
        audio_sample_list = load_audio_text_image_video(data_in, fs=16000, audio_fs=kwargs.get("fs", 16000), data_type="sound")
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_feature(audio_sample_list)
        speech, speech_lengths, speech_times = extract_feature(audio_sample_list)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = np.array(speech_lengths).sum().item() / 16000.0
        # import pdb; pdb.set_trace()
        results = []
        embeddings = self.forward(speech)
        for embedding in embeddings:
            results.append({"spk_embedding":embedding})
        meta_data["batch_data_time"] = np.array(speech_times).sum().item() / 16000.0
        results = [{"spk_embedding": self.forward(speech)}]
        return results, meta_data
funasr/models/campplus/utils.py
@@ -2,23 +2,19 @@
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import io
from typing import Union
import librosa as sf
import numpy as np
import torch
import torch.nn.functional as F
import torchaudio.compliance.kaldi as Kaldi
from torch import nn
import contextlib
import os
import torch
import requests
import tempfile
from abc import ABCMeta, abstractmethod
import contextlib
import numpy as np
import librosa as sf
from typing import Union
from pathlib import Path
from typing import Generator, Union
import requests
from abc import ABCMeta, abstractmethod
import torchaudio.compliance.kaldi as Kaldi
from funasr.models.transformer.utils.nets_utils import pad_list
def check_audio_list(audio: list):
@@ -105,15 +101,19 @@
def extract_feature(audio):
    features = []
    feature_times = []
    feature_lengths = []
    for au in audio:
        feature = Kaldi.fbank(
            au.unsqueeze(0), num_mel_bins=80)
        feature = feature - feature.mean(dim=0, keepdim=True)
        features.append(feature.unsqueeze(0))
        feature_lengths.append(au.shape[0])
    features = torch.cat(features)
    return features, feature_lengths
        features.append(feature)
        feature_times.append(au.shape[0])
        feature_lengths.append(feature.shape[0])
    # padding for batch inference
    features_padded = pad_list(features, pad_value=0)
    # features = torch.cat(features)
    return features_padded, feature_lengths, feature_times
def postprocess(segments: list, vad_segments: list,
@@ -195,8 +195,8 @@
def distribute_spk(sentence_list, sd_time_list):
    sd_sentence_list = []
    for d in sentence_list:
        sentence_start = d['ts_list'][0][0]
        sentence_end = d['ts_list'][-1][1]
        sentence_start = d['start']
        sentence_end = d['end']
        sentence_spk = 0
        max_overlap = 0
        for sd_time in sd_time_list:
@@ -211,8 +211,6 @@
        d['spk'] = sentence_spk
        sd_sentence_list.append(d)
    return sd_sentence_list
class Storage(metaclass=ABCMeta):
funasr/models/ct_transformer/model.py
@@ -239,6 +239,7 @@
        cache_pop_trigger_limit = 200
        results = []
        meta_data = {}
        punc_array = None
        for mini_sentence_i in range(len(mini_sentences)):
            mini_sentence = mini_sentences[mini_sentence_i]
            mini_sentence_id = mini_sentences_id[mini_sentence_i]
@@ -320,8 +321,13 @@
                elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1:
                    new_mini_sentence_out = new_mini_sentence + "."
                    new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id]
            # keep a punctuations array for punc segment
            if punc_array is None:
                punc_array = punctuations
            else:
                punc_array = torch.cat([punc_array, punctuations], dim=0)
        result_i = {"key": key[0], "text": new_mini_sentence_out}
        result_i = {"key": key[0], "text": new_mini_sentence_out, "punc_array": punc_array}
        results.append(result_i)
    
        return results, meta_data
funasr/utils/timestamp_tools.py
@@ -98,14 +98,14 @@
    return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
    punc_list = [',', '。', '?', '、']
    res = []
    if text_postprocessed is None:
        return res
    if time_stamp_postprocessed is None:
    if timestamp_postprocessed is None:
        return res
    if len(time_stamp_postprocessed) == 0:
    if len(timestamp_postprocessed) == 0:
        return res
    if len(text_postprocessed) == 0:
        return res
@@ -113,23 +113,22 @@
    if punc_id_list is None or len(punc_id_list) == 0:
        res.append({
            'text': text_postprocessed.split(),
            "start": time_stamp_postprocessed[0][0],
            "end": time_stamp_postprocessed[-1][1],
            'text_seg': text_postprocessed.split(),
            "ts_list": time_stamp_postprocessed,
            "start": timestamp_postprocessed[0][0],
            "end": timestamp_postprocessed[-1][1],
            "timestamp": timestamp_postprocessed,
        })
        return res
    if len(punc_id_list) != len(time_stamp_postprocessed):
        print("  warning length mistach!!!!!!")
    if len(punc_id_list) != len(timestamp_postprocessed):
        logging.warning("length mismatch between punc and timestamp")
    sentence_text = ""
    sentence_text_seg = ""
    ts_list = []
    sentence_start = time_stamp_postprocessed[0][0]
    sentence_end = time_stamp_postprocessed[0][1]
    sentence_start = timestamp_postprocessed[0][0]
    sentence_end = timestamp_postprocessed[0][1]
    texts = text_postprocessed.split()
    punc_stamp_text_list = list(zip_longest(punc_id_list, time_stamp_postprocessed, texts, fillvalue=None))
    punc_stamp_text_list = list(zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None))
    for punc_stamp_text in punc_stamp_text_list:
        punc_id, time_stamp, text = punc_stamp_text
        punc_id, timestamp, text = punc_stamp_text
        # sentence_text += text if text is not None else ''
        if text is not None:
            if 'a' <= text[0] <= 'z' or 'A' <= text[0] <= 'Z':
@@ -139,10 +138,10 @@
            else:
                sentence_text += text
            sentence_text_seg += text + ' '
        ts_list.append(time_stamp)
        ts_list.append(timestamp)
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
        sentence_end = timestamp[1] if timestamp is not None else sentence_end
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]
@@ -150,8 +149,7 @@
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "text_seg": sentence_text_seg,
                "ts_list": ts_list
                "timestamp": ts_list
            })
            sentence_text = ''
            sentence_text_seg = ''
@@ -160,181 +158,4 @@
    return res
# class AverageShiftCalculator():
#     def __init__(self):
#         logging.warning("Calculating average shift.")
#     def __call__(self, file1, file2):
#         uttid_list1, ts_dict1 = self.read_timestamps(file1)
#         uttid_list2, ts_dict2 = self.read_timestamps(file2)
#         uttid_intersection = self._intersection(uttid_list1, uttid_list2)
#         res = self.as_cal(uttid_intersection, ts_dict1, ts_dict2)
#         logging.warning("Average shift of {} and {}: {}.".format(file1, file2, str(res)[:8]))
#         logging.warning("Following timestamp pair differs most: {}, detail:{}".format(self.max_shift, self.max_shift_uttid))
#
#     def _intersection(self, list1, list2):
#         set1 = set(list1)
#         set2 = set(list2)
#         if set1 == set2:
#             logging.warning("Uttid same checked.")
#             return set1
#         itsc = list(set1 & set2)
#         logging.warning("Uttid differs: file1 {}, file2 {}, lines same {}.".format(len(list1), len(list2), len(itsc)))
#         return itsc
#
#     def read_timestamps(self, file):
#         # read timestamps file in standard format
#         uttid_list = []
#         ts_dict = {}
#         with codecs.open(file, 'r') as fin:
#             for line in fin.readlines():
#                 text = ''
#                 ts_list = []
#                 line = line.rstrip()
#                 uttid = line.split()[0]
#                 uttid_list.append(uttid)
#                 body = " ".join(line.split()[1:])
#                 for pd in body.split(';'):
#                     if not len(pd): continue
#                     # pdb.set_trace()
#                     char, start, end = pd.lstrip(" ").split(' ')
#                     text += char + ','
#                     ts_list.append((float(start), float(end)))
#                 # ts_lists.append(ts_list)
#                 ts_dict[uttid] = (text[:-1], ts_list)
#         logging.warning("File {} read done.".format(file))
#         return uttid_list, ts_dict
#
#     def _shift(self, filtered_timestamp_list1, filtered_timestamp_list2):
#         shift_time = 0
#         for fts1, fts2 in zip(filtered_timestamp_list1, filtered_timestamp_list2):
#             shift_time += abs(fts1[0] - fts2[0]) + abs(fts1[1] - fts2[1])
#         num_tokens = len(filtered_timestamp_list1)
#         return shift_time, num_tokens
#
#     # def as_cal(self, uttid_list, ts_dict1, ts_dict2):
#     #     # calculate average shift between timestamp1 and timestamp2
#     #     # when characters differ, use edit distance alignment
#     #     # and calculate the error between the same characters
#     #     self._accumlated_shift = 0
#     #     self._accumlated_tokens = 0
#     #     self.max_shift = 0
#     #     self.max_shift_uttid = None
#     #     for uttid in uttid_list:
#     #         (t1, ts1) = ts_dict1[uttid]
#     #         (t2, ts2) = ts_dict2[uttid]
#     #         _align, _align2, _align3 = [], [], []
#     #         fts1, fts2 = [], []
#     #         _t1, _t2 = [], []
#     #         sm = edit_distance.SequenceMatcher(t1.split(','), t2.split(','))
#     #         s = sm.get_opcodes()
#     #         for j in range(len(s)):
#     #             if s[j][0] == "replace" or s[j][0] == "insert":
#     #                 _align.append(0)
#     #             if s[j][0] == "replace" or s[j][0] == "delete":
#     #                 _align3.append(0)
#     #             elif s[j][0] == "equal":
#     #                 _align.append(1)
#     #                 _align3.append(1)
#     #             else:
#     #                 continue
#     #         # use s to index t2
#     #         for a, ts , t in zip(_align, ts2, t2.split(',')):
#     #             if a:
#     #                 fts2.append(ts)
#     #                 _t2.append(t)
#     #         sm2 = edit_distance.SequenceMatcher(t2.split(','), t1.split(','))
#     #         s = sm2.get_opcodes()
#     #         for j in range(len(s)):
#     #             if s[j][0] == "replace" or s[j][0] == "insert":
#     #                 _align2.append(0)
#     #             elif s[j][0] == "equal":
#     #                 _align2.append(1)
#     #             else:
#     #                 continue
#     #         # use s2 tp index t1
#     #         for a, ts, t in zip(_align3, ts1, t1.split(',')):
#     #             if a:
#     #                 fts1.append(ts)
#     #                 _t1.append(t)
#     #         if len(fts1) == len(fts2):
#     #             shift_time, num_tokens = self._shift(fts1, fts2)
#     #             self._accumlated_shift += shift_time
#     #             self._accumlated_tokens += num_tokens
#     #             if shift_time/num_tokens > self.max_shift:
#     #                 self.max_shift = shift_time/num_tokens
#     #                 self.max_shift_uttid = uttid
#     #         else:
#     #             logging.warning("length mismatch")
#     #     return self._accumlated_shift / self._accumlated_tokens
def convert_external_alphas(alphas_file, text_file, output_file):
    from funasr.models.paraformer.cif_predictor import cif_wo_hidden
    with open(alphas_file, 'r') as f1, open(text_file, 'r') as f2, open(output_file, 'w') as f3:
        for line1, line2 in zip(f1.readlines(), f2.readlines()):
            line1 = line1.rstrip()
            line2 = line2.rstrip()
            assert line1.split()[0] == line2.split()[0]
            uttid = line1.split()[0]
            alphas = [float(i) for i in line1.split()[1:]]
            new_alphas = np.array(remove_chunk_padding(alphas))
            new_alphas[-1] += 1e-4
            text = line2.split()[1:]
            if len(text) + 1 != int(new_alphas.sum()):
                # force resize
                new_alphas *= (len(text) + 1) / int(new_alphas.sum())
            peaks = cif_wo_hidden(torch.Tensor(new_alphas).unsqueeze(0), 1.0-1e-4)
            if " " in text:
                text = text.split()
            else:
                text = [i for i in text]
            res_str, _ = ts_prediction_lfr6_standard(new_alphas, peaks[0], text,
                                                     force_time_shift=-7.0,
                                                     sil_in_str=False)
            f3.write("{} {}\n".format(uttid, res_str))
def remove_chunk_padding(alphas):
    # remove the padding part in alphas if using chunk paraformer for GPU
    START_ZERO = 45
    MID_ZERO = 75
    REAL_FRAMES = 360  # for chunk based encoder 10-120-10 and fsmn padding 5
    alphas = alphas[START_ZERO:]  # remove the padding at beginning
    new_alphas = []
    while True:
        new_alphas = new_alphas + alphas[:REAL_FRAMES]
        alphas = alphas[REAL_FRAMES+MID_ZERO:]
        if len(alphas) < REAL_FRAMES: break
    return new_alphas
SUPPORTED_MODES = ['cal_aas', 'read_ext_alphas']
def main(args):
    # if args.mode == 'cal_aas':
    #     asc = AverageShiftCalculator()
    #     asc(args.input, args.input2)
    if args.mode == 'read_ext_alphas':
        convert_external_alphas(args.input, args.input2, args.output)
    else:
        logging.error("Mode {} not in SUPPORTED_MODES: {}.".format(args.mode, SUPPORTED_MODES))
if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='timestamp tools')
    parser.add_argument('--mode',
                        default=None,
                        type=str,
                        choices=SUPPORTED_MODES,
                        help='timestamp related toolbox')
    parser.add_argument('--input', default=None, type=str, help='input file path')
    parser.add_argument('--output', default=None, type=str, help='output file name')
    parser.add_argument('--input2', default=None, type=str, help='input2 file path')
    parser.add_argument('--kaldi-ts-type',
                        default='v2',
                        type=str,
                        choices=['v0', 'v1', 'v2'],
                        help='kaldi timestamp to write')
    args = parser.parse_args()
    main(args)