shixian.shi
2023-03-13 f5aa97f7bff53169a11a1e20ef1ff965438d1bc1
funasr/utils/timestamp_tools.py
@@ -6,10 +6,9 @@
def ts_prediction_lfr6_standard(us_alphas, 
                       us_cif_peak,
                       us_peaks,
                       char_list, 
                       vad_offset=0.0, 
                       end_time=None,
                       force_time_shift=-1.5
                       ):
    if not len(char_list):
@@ -18,17 +17,17 @@
    MAX_TOKEN_DURATION = 12
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
    if len(us_alphas.shape) == 2:
        alphas, cif_peak = us_alphas[0], us_cif_peak[0]  # support inference batch_size=1 only
        _, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
    else:
        alphas, cif_peak = us_alphas, us_cif_peak
    num_frames = cif_peak.shape[0]
        _, peaks = us_alphas, us_peaks
    num_frames = peaks.shape[0]
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence