querryton
2024-04-20 01df8f330ccc754223d5e2d688dc0a55d27f2dcc
funasr/utils/timestamp_tools.py
@@ -43,18 +43,18 @@
        alphas, peaks = us_alphas, us_peaks
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    if len(fire_place) != len(char_list) + 1:
        alphas /= (alphas.sum() / (len(char_list) + 1))
        alphas = alphas.unsqueeze(0)
        peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0]
        fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
        fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_frames = peaks.shape[0]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    # fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence
    if fire_place[0] > START_END_THRESHOLD:
@@ -98,7 +98,7 @@
    return res_txt, res
def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed):
def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False):
    punc_list = [',', '。', '?', '、']
    res = []
    if text_postprocessed is None:
@@ -142,15 +142,24 @@
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = timestamp[1] if timestamp is not None else sentence_end
        sentence_text_seg = sentence_text_seg[:-1] if sentence_text_seg[-1] == ' ' else sentence_text_seg
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]
            res.append({
                'text': sentence_text,
                "start": sentence_start,
                "end": sentence_end,
                "timestamp": ts_list
            })
            if return_raw_text:
                res.append({
                    'text': sentence_text,
                    "start": sentence_start,
                    "end": sentence_end,
                    "timestamp": ts_list,
                    'raw_text': sentence_text_seg,
                })
            else:
                res.append({
                    'text': sentence_text,
                    "start": sentence_start,
                    "end": sentence_end,
                    "timestamp": ts_list,
                })
            sentence_text = ''
            sentence_text_seg = ''
            ts_list = []