lzr265946
2022-12-29 bc3ad9c0f631ccb245b31b6cf6e4d757de6e6712
Merge branch 'dev' of https://github.com/alibaba/FunASR into dev
1个文件已修改
6 ■■■■ 已修改文件
funasr/bin/asr_inference_uniasr.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_uniasr.py
@@ -215,14 +215,14 @@
        lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
        # lengths: (1,)
        lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
        speech_raw = speech.clone().to(self.device)
        if self.frontend is not None:
            feats, feats_len = self.frontend.forward(speech, lengths)
            feats = to_device(feats, device=self.device)
            feats_len = feats_len.int()
        else:
            feats = speech_raw
            feats = speech
            feats_len = lengths
        feats_raw = feats.clone().to(self.device)
        batch = {"speech": feats, "speech_lengths": feats_len}
        # a. To device
@@ -235,7 +235,7 @@
        if self.decoding_mode == "model1":
            predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len)
        else:
            enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind)
            enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind)
            predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len)
        scama_mask = predictor_outs[4]