From cfb2fda87c29db780e595b75f2de1c7710ebadd2 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 30 一月 2023 17:50:36 +0800
Subject: [PATCH] fix bug, ys_pad_masked in sampler of paraformer

---
 funasr/models/e2e_asr_paraformer.py |    6 +++---
 1 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 65c70df..7596896 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -499,11 +499,11 @@
     def sampler(self, encoder_out, encoder_out_lens, ys_pad, ys_pad_lens, pre_acoustic_embeds):
 
         tgt_mask = (~make_pad_mask(ys_pad_lens, maxlen=ys_pad_lens.max())[:, :, None]).to(ys_pad.device)
-        ys_pad = ys_pad * tgt_mask[:, :, 0]
+        ys_pad_masked = ys_pad * tgt_mask[:, :, 0]
         if self.share_embedding:
-            ys_pad_embed = self.decoder.output_layer.weight[ys_pad]
+            ys_pad_embed = self.decoder.output_layer.weight[ys_pad_masked]
         else:
-            ys_pad_embed = self.decoder.embed(ys_pad)
+            ys_pad_embed = self.decoder.embed(ys_pad_masked)
         with torch.no_grad():
             decoder_outs = self.decoder(
                 encoder_out, encoder_out_lens, pre_acoustic_embeds, ys_pad_lens

--
Gitblit v1.9.1