From a016617c7ec98ab9c7475ff7d3b6150b98d5beeb Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 28 二月 2023 18:36:52 +0800
Subject: [PATCH] Merge pull request #165 from alibaba-damo-academy/dev_cmz

---
 funasr/punctuation/espnet_model.py |   55 ++++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 36 insertions(+), 19 deletions(-)

diff --git a/funasr/punctuation/espnet_model.py b/funasr/punctuation/espnet_model.py
index 65edaad..c513779 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/punctuation/espnet_model.py
@@ -14,15 +14,18 @@
 
 class ESPnetPunctuationModel(AbsESPnetModel):
 
-    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0):
+    def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
         assert check_argument_types()
         super().__init__()
         self.punc_model = punc_model
+        self.punc_weight = torch.Tensor(punc_weight)
         self.sos = 1
         self.eos = 2
 
         # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
         self.ignore_id = ignore_id
+        if self.punc_model.with_vad():
+            print("This is a vad puncuation model.")
 
     def nll(
         self,
@@ -31,6 +34,8 @@
         text_lengths: torch.Tensor,
         punc_lengths: torch.Tensor,
         max_length: Optional[int] = None,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
     ) -> Tuple[torch.Tensor, torch.Tensor]:
         """Compute negative log likelihood(nll)
 
@@ -49,19 +54,16 @@
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-        # 1. Create a sentence pair like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>'
-        # text: (Batch, Length) -> x, y: (Batch, Length + 1)
-        #x = F.pad(text, [1, 0], "constant", self.eos)
-        #t = F.pad(text, [0, 1], "constant", self.ignore_id)
-        #for i, l in enumerate(text_lengths):
-        #    t[i, l] = self.sos
-        #x_lengths = text_lengths + 1
+       
+        if self.punc_model.with_vad():
+            # Should be VadRealtimeTransformer
+            assert vad_indexes is not None
+            y, _ = self.punc_model(text, text_lengths, vad_indexes)
+        else:
+            # Should be TargetDelayTransformer,
+            y, _ = self.punc_model(text, text_lengths)
 
-        # 2. Forward Language model
-        # x: (Batch, Length) -> y: (Batch, Length, NVocab)
-        y, _ = self.punc_model(text, text_lengths)
-
-        # 3. Calc negative log likelihood
+        # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
             _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
@@ -72,7 +74,8 @@
             nll = torch.Tensor([f1_score]).repeat(text_lengths.sum())
             return nll, text_lengths
         else:
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), reduction="none", ignore_index=self.ignore_id)
+            self.punc_weight = self.punc_weight.to(punc.device)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -130,9 +133,16 @@
         assert x_lengths.size(0) == total_num
         return nll, x_lengths
 
-    def forward(self, text: torch.Tensor, punc: torch.Tensor, text_lengths: torch.Tensor,
-                punc_lengths: torch.Tensor) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths)
+    def forward(
+        self,
+        text: torch.Tensor,
+        punc: torch.Tensor,
+        text_lengths: torch.Tensor,
+        punc_lengths: torch.Tensor,
+        vad_indexes: Optional[torch.Tensor] = None,
+        vad_indexes_lengths: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        nll, y_lengths = self.nll(text, punc, text_lengths, punc_lengths, vad_indexes=vad_indexes)
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
@@ -145,5 +155,12 @@
                       text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
         return {}
 
-    def inference(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
-        return self.punc_model(text, text_lengths)
+    def inference(self,
+                  text: torch.Tensor,
+                  text_lengths: torch.Tensor,
+                  vad_indexes: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, None]:
+        if self.punc_model.with_vad():
+            assert vad_indexes is not None
+            return self.punc_model(text, text_lengths, vad_indexes)
+        else:
+            return self.punc_model(text, text_lengths)

--
Gitblit v1.9.1