shixian.shi
2023-03-13 f5aa97f7bff53169a11a1e20ef1ff965438d1bc1
update params name
5个文件已修改
27 ■■■■ 已修改文件
funasr/bin/asr_inference_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer_vad_punc.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_paraformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_tp.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 11 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_paraformer.py
@@ -245,7 +245,7 @@
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
@@ -292,7 +292,7 @@
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], 
                                                            us_cif_peak[i],
                                                            us_peaks[i],
                                                            copy.copy(token), 
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -256,7 +256,7 @@
            decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        if isinstance(self.asr_model, BiCifParaformer):
            _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
            _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
                                                                                   pre_token_length)  # test no bias cif2
        results = []
@@ -303,7 +303,7 @@
                if isinstance(self.asr_model, BiCifParaformer):
                    _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], 
                                                            us_cif_peak[i],
                                                            us_peaks[i],
                                                            copy.copy(token), 
                                                            vad_offset=begin_time)
                    results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
funasr/models/e2e_asr_paraformer.py
@@ -926,10 +926,10 @@
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def forward(
            self,
funasr/models/e2e_tp.py
@@ -150,10 +150,10 @@
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def collect_feats(
            self,
funasr/utils/timestamp_tools.py
@@ -6,10 +6,9 @@
def ts_prediction_lfr6_standard(us_alphas, 
                       us_cif_peak,
                       us_peaks,
                       char_list, 
                       vad_offset=0.0, 
                       end_time=None,
                       force_time_shift=-1.5
                       ):
    if not len(char_list):
@@ -18,17 +17,17 @@
    MAX_TOKEN_DURATION = 12
    TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
    if len(us_alphas.shape) == 2:
        alphas, cif_peak = us_alphas[0], us_cif_peak[0]  # support inference batch_size=1 only
        _, peaks = us_alphas[0], us_peaks[0]  # support inference batch_size=1 only
    else:
        alphas, cif_peak = us_alphas, us_cif_peak
    num_frames = cif_peak.shape[0]
        _, peaks = us_alphas, us_peaks
    num_frames = peaks.shape[0]
    if char_list[-1] == '</s>':
        char_list = char_list[:-1]
    timestamp_list = []
    new_char_list = []
    # for bicif model trained with large data, cif2 actually fires when a character starts
    # so treat the frames between two peaks as the duration of the former token
    fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift  # total offset
    num_peak = len(fire_place)
    assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
    # begin silence