Merge branch 'dev' of https://github.com/alibaba/FunASR into dev
| | |
| | | lfr_factor = max(1, (speech.size()[-1] // 80) - 1) |
| | | # lengths: (1,) |
| | | lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) |
| | | speech_raw = speech.clone().to(self.device) |
| | | if self.frontend is not None: |
| | | feats, feats_len = self.frontend.forward(speech, lengths) |
| | | feats = to_device(feats, device=self.device) |
| | | feats_len = feats_len.int() |
| | | else: |
| | | feats = speech_raw |
| | | feats = speech |
| | | feats_len = lengths |
| | | feats_raw = feats.clone().to(self.device) |
| | | batch = {"speech": feats, "speech_lengths": feats_len} |
| | | |
| | | # a. To device |
| | |
| | | if self.decoding_mode == "model1": |
| | | predictor_outs = self.asr_model.calc_predictor_mask(enc, enc_len) |
| | | else: |
| | | enc, enc_len = self.asr_model.encode2(enc, enc_len, feats, feats_len, ind=self.decoding_ind) |
| | | enc, enc_len = self.asr_model.encode2(enc, enc_len, feats_raw, feats_len, ind=self.decoding_ind) |
| | | predictor_outs = self.asr_model.calc_predictor_mask2(enc, enc_len) |
| | | |
| | | scama_mask = predictor_outs[4] |