From 2cdb2d654f2109ef4e648bae6f169143e267e5db Mon Sep 17 00:00:00 2001
From: zhuzizyf <42790740+zhuzizyf@users.noreply.github.com>
Date: 星期六, 11 三月 2023 14:33:14 +0800
Subject: [PATCH] Update dataset.py

---
 funasr/punctuation/target_delay_transformer.py |   35 ++++++++++++++---------------------
 1 files changed, 14 insertions(+), 21 deletions(-)

diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/punctuation/target_delay_transformer.py
index 66025cb..219af26 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/punctuation/target_delay_transformer.py
@@ -8,12 +8,13 @@
 from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.embedding import SinusoidalPositionEncoder
 #from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
-from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
+from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
 #from funasr.modules.mask import subsequent_n_mask
 from funasr.punctuation.abs_model import AbsPunctuation
 
 
 class TargetDelayTransformer(AbsPunctuation):
+
     def __init__(
         self,
         vocab_size: int,
@@ -28,7 +29,7 @@
     ):
         super().__init__()
         if pos_enc == "sinusoidal":
-#            pos_enc_class = PositionalEncoding
+            #            pos_enc_class = PositionalEncoding
             pos_enc_class = SinusoidalPositionEncoder
         elif pos_enc is None:
 
@@ -47,16 +48,16 @@
             num_blocks=layer,
             dropout_rate=dropout_rate,
             input_layer="pe",
-           # pos_enc_class=pos_enc_class,
+            # pos_enc_class=pos_enc_class,
             padding_idx=0,
         )
         self.decoder = nn.Linear(att_unit, punc_size)
+
 
 #    def _target_mask(self, ys_in_pad):
 #        ys_mask = ys_in_pad != 0
 #        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
 #        return ys_mask.unsqueeze(-2) & m
-
 
     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
         """Compute loss value from buffer sequences.
@@ -67,14 +68,15 @@
 
         """
         x = self.embed(input)
-       # mask = self._target_mask(input)
+        # mask = self._target_mask(input)
         h, _, _ = self.encoder(x, text_lengths)
         y = self.decoder(h)
         return y, None
 
-    def score(
-        self, y: torch.Tensor, state: Any, x: torch.Tensor
-    ) -> Tuple[torch.Tensor, Any]:
+    def with_vad(self):
+        return False
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
         """Score new token.
 
         Args:
@@ -89,16 +91,12 @@
 
         """
         y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(
-            self.embed(y), self._target_mask(y), cache=state
-        )
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1).squeeze(0)
         return logp, cache
 
-    def batch_score(
-        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
-    ) -> Tuple[torch.Tensor, List[Any]]:
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
 
         Args:
@@ -120,15 +118,10 @@
             batch_state = None
         else:
             # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [
-                torch.stack([states[b][i] for b in range(n_batch)])
-                for i in range(n_layers)
-            ]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
 
         # batch decoding
-        h, _, states = self.encoder.forward_one_step(
-            self.embed(ys), self._target_mask(ys), cache=batch_state
-        )
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1)
 

--
Gitblit v1.9.1