From 1d97d628f2f19674fa50495e984db8185604ca8e Mon Sep 17 00:00:00 2001
From: lzr265946 <lzr265946@alibaba-inc.com>
Date: 星期五, 03 二月 2023 14:11:22 +0800
Subject: [PATCH] Merge branch 'main' into dev

---
 funasr/punctuation/target_delay_transformer.py |   30 ++++++++++--------------------
 1 files changed, 10 insertions(+), 20 deletions(-)

diff --git a/funasr/punctuation/target_delay_transformer.py b/funasr/punctuation/target_delay_transformer.py
index 66025cb..10cc5a8 100644
--- a/funasr/punctuation/target_delay_transformer.py
+++ b/funasr/punctuation/target_delay_transformer.py
@@ -14,6 +14,7 @@
 
 
 class TargetDelayTransformer(AbsPunctuation):
+
     def __init__(
         self,
         vocab_size: int,
@@ -28,7 +29,7 @@
     ):
         super().__init__()
         if pos_enc == "sinusoidal":
-#            pos_enc_class = PositionalEncoding
+            #            pos_enc_class = PositionalEncoding
             pos_enc_class = SinusoidalPositionEncoder
         elif pos_enc is None:
 
@@ -47,16 +48,16 @@
             num_blocks=layer,
             dropout_rate=dropout_rate,
             input_layer="pe",
-           # pos_enc_class=pos_enc_class,
+            # pos_enc_class=pos_enc_class,
             padding_idx=0,
         )
         self.decoder = nn.Linear(att_unit, punc_size)
+
 
 #    def _target_mask(self, ys_in_pad):
 #        ys_mask = ys_in_pad != 0
 #        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
 #        return ys_mask.unsqueeze(-2) & m
-
 
     def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
         """Compute loss value from buffer sequences.
@@ -67,14 +68,12 @@
 
         """
         x = self.embed(input)
-       # mask = self._target_mask(input)
+        # mask = self._target_mask(input)
         h, _, _ = self.encoder(x, text_lengths)
         y = self.decoder(h)
         return y, None
 
-    def score(
-        self, y: torch.Tensor, state: Any, x: torch.Tensor
-    ) -> Tuple[torch.Tensor, Any]:
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
         """Score new token.
 
         Args:
@@ -89,16 +88,12 @@
 
         """
         y = y.unsqueeze(0)
-        h, _, cache = self.encoder.forward_one_step(
-            self.embed(y), self._target_mask(y), cache=state
-        )
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1).squeeze(0)
         return logp, cache
 
-    def batch_score(
-        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
-    ) -> Tuple[torch.Tensor, List[Any]]:
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
         """Score new token batch.
 
         Args:
@@ -120,15 +115,10 @@
             batch_state = None
         else:
             # transpose state of [batch, layer] into [layer, batch]
-            batch_state = [
-                torch.stack([states[b][i] for b in range(n_batch)])
-                for i in range(n_layers)
-            ]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
 
         # batch decoding
-        h, _, states = self.encoder.forward_one_step(
-            self.embed(ys), self._target_mask(ys), cache=batch_state
-        )
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
         h = self.decoder(h[:, -1])
         logp = h.log_softmax(dim=-1)
 

--
Gitblit v1.9.1