| | |
| | | |
| | | |
| | | def ts_prediction_lfr6_standard(us_alphas, |
| | | us_cif_peak, |
| | | us_peaks, |
| | | char_list, |
| | | vad_offset=0.0, |
| | | end_time=None, |
| | | force_time_shift=-1.5 |
| | | ): |
| | | if not len(char_list): |
| | |
| | | MAX_TOKEN_DURATION = 12 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | | if len(us_alphas.shape) == 2: |
| | | alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only |
| | | _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only |
| | | else: |
| | | alphas, cif_peak = us_alphas, us_cif_peak |
| | | num_frames = cif_peak.shape[0] |
| | | _, peaks = us_alphas, us_peaks |
| | | num_frames = peaks.shape[0] |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | timestamp_list = [] |
| | | new_char_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | num_peak = len(fire_place) |
| | | assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |