jmwang66
2022-12-26 682204f0bb1335eb9ba3a2f0eb5605bdf42e8505
funasr/models/e2e_asr_paraformer.py
@@ -330,9 +330,10 @@
   def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
      decoder_out, _ = self.decoder(
      decoder_outs = self.decoder(
         encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens
      )
      decoder_out = decoder_outs[0]
      decoder_out = torch.log_softmax(decoder_out, dim=-1)
      return decoder_out, ys_pad_lens
@@ -492,7 +493,7 @@
   def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
      tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
      ys_pad *= tgt_mask[:, :, 0]
      ys_pad = ys_pad * tgt_mask[:, :, 0]
      ys_pad_embed = self.decoder.embed(ys_pad)
      with torch.no_grad():
         decoder_outs = self.decoder(
@@ -553,7 +554,6 @@
      postencoder: Optional[AbsPostEncoder],
      decoder: AbsDecoder,
      ctc: CTC,
      joint_network: Optional[torch.nn.Module],
      ctc_weight: float = 0.5,
      interctc_weight: float = 0.0,
      ignore_id: int = -1,
@@ -590,7 +590,6 @@
      postencoder=postencoder,
      decoder=decoder,
      ctc=ctc,
      joint_network=joint_network,
      ctc_weight=ctc_weight,
      interctc_weight=interctc_weight,
      ignore_id=ignore_id,