| | |
| | | alphas, peaks = us_alphas, us_peaks |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | if len(fire_place) != len(char_list) + 1: |
| | | alphas /= (alphas.sum() / (len(char_list) + 1)) |
| | | alphas = alphas.unsqueeze(0) |
| | | peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0] |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | num_frames = peaks.shape[0] |
| | | timestamp_list = [] |
| | | new_char_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | # fire_place = torch.where(peaks>=1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |
| | | if fire_place[0] > START_END_THRESHOLD: |
| | |
| | | return res_txt, res |
| | | |
| | | |
| | | def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed): |
| | | def timestamp_sentence(punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False): |
| | | punc_list = [',', '。', '?', '、'] |
| | | res = [] |
| | | if text_postprocessed is None: |
| | |
| | | |
| | | punc_id = int(punc_id) if punc_id is not None else 1 |
| | | sentence_end = timestamp[1] if timestamp is not None else sentence_end |
| | | |
| | | sentence_text_seg = sentence_text_seg[:-1] if sentence_text_seg[-1] == ' ' else sentence_text_seg |
| | | if punc_id > 1: |
| | | sentence_text += punc_list[punc_id - 2] |
| | | res.append({ |
| | | 'text': sentence_text, |
| | | "start": sentence_start, |
| | | "end": sentence_end, |
| | | "timestamp": ts_list |
| | | }) |
| | | if return_raw_text: |
| | | res.append({ |
| | | 'text': sentence_text, |
| | | "start": sentence_start, |
| | | "end": sentence_end, |
| | | "timestamp": ts_list, |
| | | 'raw_text': sentence_text_seg, |
| | | }) |
| | | else: |
| | | res.append({ |
| | | 'text': sentence_text, |
| | | "start": sentence_start, |
| | | "end": sentence_end, |
| | | "timestamp": ts_list, |
| | | }) |
| | | sentence_text = '' |
| | | sentence_text_seg = '' |
| | | ts_list = [] |