shixian.shi
2023-07-20 1054daf44a1250ae6be5416e79a1f113d306e635
remove assert in ts_prediction_lfr6_standard
1个文件已修改
17 ■■■■■ 已修改文件
funasr/utils/timestamp_tools.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py
@@ -1,14 +1,10 @@
from itertools import zip_longest
import torch
import copy
import codecs
import logging
import edit_distance
import argparse
import pdb
import numpy as np
from typing import Any, List, Tuple, Union
import edit_distance
from itertools import zip_longest
def ts_prediction_lfr6_standard(us_alphas, 
@@ -36,7 +32,14 @@
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    if num_peak != len(char_list) + 1:
        logging.warning("length mismatch, result might be incorrect.")
        logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1))
    if num_peak > len(char_list) + 1:
        fire_place = fire_place[:len(char_list) - 1]
    elif num_peak < len(char_list) + 1:
        char_list = char_list[:num_peak + 1]
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
        # char_list.insert(0, '<sil>')