From 87d5f69b819df11969263cf99f7cc2f80bea30da Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 五月 2024 13:43:53 +0800
Subject: [PATCH] decoding key

---
 funasr/models/sense_voice/model.py |   11 ++++++++---
 1 files changed, 8 insertions(+), 3 deletions(-)

diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index d5e4130..bcaaca3 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -378,14 +378,19 @@
         stats = {}
 
         # 1. Forward decoder
+        # ys_pad: [sos, task, lid, text, eos]
         decoder_out = self.model.decoder(
             x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
         )
 
         # 2. Compute attention loss
-        mask = torch.ones_like(ys_pad) * (-1)
-        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
-        ys_pad_mask[ys_pad_mask == 0] = -1
+        mask = torch.ones_like(ys_pad) * (-1)  # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
+        ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
+            torch.int64
+        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
+        ys_pad_mask[ys_pad_mask == 0] = -1  # [-1, -1, lid, text, eos]
+        # decoder_out: [sos, task, lid, text]
+        # ys_pad_mask: [-1, lid, text, eos]
         loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
 
         with torch.no_grad():

--
Gitblit v1.9.1