zhifu gao
2023-07-25 1dcdd5f8a618ba7ea7eb8d7c9b5f3d0acf5a3d9d
funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
import edit_distance
from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    if num_peak != len(char_list) + 1:
        logging.warning("length mismatch, result might be incorrect.")
        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
    if num_peak > len(char_list) + 1:
        fire_place = fire_place[:len(char_list) - 1]
    elif num_peak < len(char_list) + 1:
        char_list = char_list[:num_peak + 1]
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')
@@ -80,6 +83,7 @@
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
    punc_list = [',', '。', '?', '、']
    res = []
    if text_postprocessed is None:
        return res
@@ -124,34 +128,8 @@
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = time_stamp[1] if time_stamp is not None else sentence_end
        if punc_id == 2:
            sentence_text += ','
            res.append({
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "text_seg": sentence_text_seg,
                "ts_list": ts_list
            })
            sentence_text = ''
            sentence_text_seg = ''
            ts_list = []
            sentence_start = sentence_end
        elif punc_id == 3:
            sentence_text += '.'
            res.append({
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "text_seg": sentence_text_seg,
                "ts_list": ts_list
            })
            sentence_text = ''
            sentence_text_seg = ''
            ts_list = []
            sentence_start = sentence_end
        elif punc_id == 4:
            sentence_text += '?'
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]
            res.append({
                'text': sentence_text,
                "start": sentence_start,