From d0cd484fdc21c06b8bc892bb2ab1c2a25fb1da8a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 31 三月 2023 15:05:37 +0800
Subject: [PATCH] export

---
 funasr/train/abs_model.py |   56 ++++++++++++++++++++++++++++++++++++++++++--------------
 1 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/funasr/punctuation/espnet_model.py b/funasr/train/abs_model.py
similarity index 85%
rename from funasr/punctuation/espnet_model.py
rename to funasr/train/abs_model.py
index 7266b38..8bfba45 100644
--- a/funasr/punctuation/espnet_model.py
+++ b/funasr/train/abs_model.py
@@ -1,3 +1,9 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Tuple
+
+import torch
+
 from typing import Dict
 from typing import Optional
 from typing import Tuple
@@ -7,13 +13,34 @@
 from typeguard import check_argument_types
 
 from funasr.modules.nets_utils import make_pad_mask
-from funasr.punctuation.abs_model import AbsPunctuation
 from funasr.torch_utils.device_funcs import force_gatherable
 from funasr.train.abs_espnet_model import AbsESPnetModel
 
+from funasr.modules.scorers.scorer_interface import BatchScorerInterface
 
-class ESPnetPunctuationModel(AbsESPnetModel):
 
+class AbsPunctuation(torch.nn.Module, BatchScorerInterface, ABC):
+    """The abstract class
+
+    To share the loss calculation way among different models,
+    We uses delegate pattern here:
+    The instance of this class should be passed to "LanguageModel"
+
+    This "model" is one of mediator objects for "Task" class.
+
+    """
+
+    @abstractmethod
+    def forward(self, input: torch.Tensor, hidden: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+        raise NotImplementedError
+
+    @abstractmethod
+    def with_vad(self) -> bool:
+        raise NotImplementedError
+
+
+class PunctuationModel(AbsESPnetModel):
+    
     def __init__(self, punc_model: AbsPunctuation, vocab_size: int, ignore_id: int = 0, punc_weight: list = None):
         assert check_argument_types()
         super().__init__()
@@ -21,12 +48,12 @@
         self.punc_weight = torch.Tensor(punc_weight)
         self.sos = 1
         self.eos = 2
-
+        
         # ignore_id may be assumed as 0, shared with CTC-blank symbol for ASR.
         self.ignore_id = ignore_id
-        #if self.punc_model.with_vad():
+        # if self.punc_model.with_vad():
         #    print("This is a vad puncuation model.")
-
+    
     def nll(
         self,
         text: torch.Tensor,
@@ -54,7 +81,7 @@
         else:
             text = text[:, :max_length]
             punc = punc[:, :max_length]
-       
+        
         if self.punc_model.with_vad():
             # Should be VadRealtimeTransformer
             assert vad_indexes is not None
@@ -62,7 +89,7 @@
         else:
             # Should be TargetDelayTransformer,
             y, _ = self.punc_model(text, text_lengths)
-
+        
         # Calc negative log likelihood
         # nll: (BxL,)
         if self.training == False:
@@ -75,7 +102,8 @@
             return nll, text_lengths
         else:
             self.punc_weight = self.punc_weight.to(punc.device)
-            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none", ignore_index=self.ignore_id)
+            nll = F.cross_entropy(y.view(-1, y.shape[-1]), punc.view(-1), self.punc_weight, reduction="none",
+                                  ignore_index=self.ignore_id)
         # nll: (BxL,) -> (BxL,)
         if max_length is None:
             nll.masked_fill_(make_pad_mask(text_lengths).to(nll.device).view(-1), 0.0)
@@ -87,7 +115,7 @@
         # nll: (BxL,) -> (B, L)
         nll = nll.view(batch_size, -1)
         return nll, text_lengths
-
+    
     def batchify_nll(self,
                      text: torch.Tensor,
                      punc: torch.Tensor,
@@ -113,7 +141,7 @@
             nlls = []
             x_lengths = []
             max_length = text_lengths.max()
-
+            
             start_idx = 0
             while True:
                 end_idx = min(start_idx + batch_size, total_num)
@@ -132,7 +160,7 @@
         assert nll.size(0) == total_num
         assert x_lengths.size(0) == total_num
         return nll, x_lengths
-
+    
     def forward(
         self,
         text: torch.Tensor,
@@ -146,15 +174,15 @@
         ntokens = y_lengths.sum()
         loss = nll.sum() / ntokens
         stats = dict(loss=loss.detach())
-
+        
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         loss, stats, weight = force_gatherable((loss, stats, ntokens), loss.device)
         return loss, stats, weight
-
+    
     def collect_feats(self, text: torch.Tensor, punc: torch.Tensor,
                       text_lengths: torch.Tensor) -> Dict[str, torch.Tensor]:
         return {}
-
+    
     def inference(self,
                   text: torch.Tensor,
                   text_lengths: torch.Tensor,

--
Gitblit v1.9.1