jmwang66
2022-12-09 0b8348376a20a6888d116982e346ada5fa5d15ab
funasr/models/e2e_asr_paraformer.py
@@ -493,7 +493,7 @@
   def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
      tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
      ys_pad *= tgt_mask[:, :, 0]
      ys_pad = ys_pad * tgt_mask[:, :, 0]
      ys_pad_embed = self.decoder.embed(ys_pad)
      with torch.no_grad():
         decoder_outs = self.decoder(