| | |
| | | from itertools import zip_longest |
| | | |
| | | |
| | | def cif_wo_hidden(alphas, threshold): |
| | | batch_size, len_time = alphas.size() |
| | | # loop varss |
| | | integrate = torch.zeros([batch_size], device=alphas.device) |
| | | # intermediate vars along time |
| | | list_fires = [] |
| | | for t in range(len_time): |
| | | alpha = alphas[:, t] |
| | | integrate += alpha |
| | | list_fires.append(integrate) |
| | | fire_place = integrate >= threshold |
| | | integrate = torch.where(fire_place, |
| | | integrate - torch.ones([batch_size], device=alphas.device), |
| | | integrate) |
| | | fires = torch.stack(list_fires, 1) |
| | | return fires |
| | | |
| | | |
| | | def ts_prediction_lfr6_standard(us_alphas, |
| | | us_peaks, |
| | | char_list, |
| | |
| | | MAX_TOKEN_DURATION = 12 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | | if len(us_alphas.shape) == 2: |
| | | _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only |
| | | alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only |
| | | else: |
| | | _, peaks = us_alphas, us_peaks |
| | | num_frames = peaks.shape[0] |
| | | alphas, peaks = us_alphas, us_peaks |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | if len(fire_place) != len(char_list) + 1: |
| | | alphas /= (alphas.sum() / (len(char_list) + 1)) |
| | | alphas = alphas.unsqueeze(0) |
| | | peaks = cif_wo_hidden(alphas, threshold=1.0-1e-4)[0] |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | num_frames = peaks.shape[0] |
| | | timestamp_list = [] |
| | | new_char_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | num_peak = len(fire_place) |
| | | if num_peak != len(char_list) + 1: |
| | | logging.warning("length mismatch, result might be incorrect.") |
| | | logging.warning("num_peaks: {}, num_chars+1: {}, which is supposed to be same.".format(num_peak, len(char_list)+1)) |
| | | if num_peak > len(char_list) + 1: |
| | | fire_place = fire_place[:len(char_list) - 1] |
| | | elif num_peak < len(char_list) + 1: |
| | | char_list = char_list[:num_peak + 1] |
| | | # assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |
| | | if fire_place[0] > START_END_THRESHOLD: |