游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/utils/timestamp_tools.py
@@ -1,14 +1,28 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
import edit_distance
from itertools import zip_longest
def cif_wo_hidden(alphas, threshold):
    batch_size, len_time = alphas.size()
    # loop varss
    integrate = torch.zeros([batch_size], device=alphas.device)
    # intermediate vars along time
    list_fires = []
    for t in range(len_time):
        alpha = alphas[:, t]
        integrate += alpha
        list_fires.append(integrate)
        fire_place = integrate >= threshold
        integrate = torch.where(fire_place,
                                integrate - torch.ones([batch_size], device=alphas.device)*threshold,
                                integrate)
    fires = torch.stack(list_fires, 1)
    return fires
def ts_prediction_lfr6_standard(us_alphas, 
@@ -24,19 +38,24 @@
    MAX_TOKEN_DURATION = 12
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
    if len(us_alphas.shape) == 2:
        _, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
        alphas, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
    else:
        _, peaks = us_alphas, us_peaks
    num_frames = peaks.shape[0]
        alphas, peaks = us_alphas, us_peaks
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    if len(fire_place) != len(char_list) + 1:
        alphas /= (alphas.sum() / (len(char_list) + 1))
        alphas = alphas.unsqueeze(0)
        peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
        fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_frames = peaks.shape[0]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')
@@ -80,6 +99,7 @@
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
    punc_list = [',', '。', '?', '、']
    res = []
    if text_postprocessed is None:
        return res
@@ -124,34 +144,8 @@
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
        if punc_id == 2:
            sentence_text += ','
            res.append({
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "text_seg": sentence_text_seg,
                "ts_list": ts_list
            })
            sentence_text = ''
            sentence_text_seg = ''
            ts_list = []
            sentence_start = sentence_end
        elif punc_id == 3:
            sentence_text += '.'
            res.append({
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "text_seg": sentence_text_seg,
                "ts_list": ts_list
            })
            sentence_text = ''
            sentence_text_seg = ''
            ts_list = []
            sentence_start = sentence_end
        elif punc_id == 4:
            sentence_text += '?'
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]
            res.append({
                'text': sentence_text,
                "start": sentence_start,