From 4137f5cf26e7c4b40853959cd2574edfde03aa60 Mon Sep 17 00:00:00 2001
From: 志浩 <neo.dzh@alibaba-inc.com>
Date: 星期五, 07 四月 2023 21:03:34 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR into dev_dzh

---
 funasr/models/target_delay_transformer.py |  129 +++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 129 insertions(+), 0 deletions(-)

diff --git a/funasr/models/target_delay_transformer.py b/funasr/models/target_delay_transformer.py
new file mode 100644
index 0000000..84a2e6c
--- /dev/null
+++ b/funasr/models/target_delay_transformer.py
@@ -0,0 +1,129 @@
+from typing import Any
+from typing import List
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from funasr.modules.embedding import SinusoidalPositionEncoder
+#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
+#from funasr.modules.mask import subsequent_n_mask
+from funasr.train.abs_model import AbsPunctuation
+
+
+class TargetDelayTransformer(AbsPunctuation):
+
+    def __init__(
+        self,
+        vocab_size: int,
+        punc_size: int,
+        pos_enc: str = None,
+        embed_unit: int = 128,
+        att_unit: int = 256,
+        head: int = 2,
+        unit: int = 1024,
+        layer: int = 4,
+        dropout_rate: float = 0.5,
+    ):
+        super().__init__()
+        if pos_enc == "sinusoidal":
+            #            pos_enc_class = PositionalEncoding
+            pos_enc_class = SinusoidalPositionEncoder
+        elif pos_enc is None:
+
+            def pos_enc_class(*args, **kwargs):
+                return nn.Sequential()  # indentity
+
+        else:
+            raise ValueError(f"unknown pos-enc option: {pos_enc}")
+
+        self.embed = nn.Embedding(vocab_size, embed_unit)
+        self.encoder = Encoder(
+            input_size=embed_unit,
+            output_size=att_unit,
+            attention_heads=head,
+            linear_units=unit,
+            num_blocks=layer,
+            dropout_rate=dropout_rate,
+            input_layer="pe",
+            # pos_enc_class=pos_enc_class,
+            padding_idx=0,
+        )
+        self.decoder = nn.Linear(att_unit, punc_size)
+
+
+#    def _target_mask(self, ys_in_pad):
+#        ys_mask = ys_in_pad != 0
+#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
+#        return ys_mask.unsqueeze(-2) & m
+
+    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
+        """Compute loss value from buffer sequences.
+
+        Args:
+            input (torch.Tensor): Input ids. (batch, len)
+            hidden (torch.Tensor): Target ids. (batch, len)
+
+        """
+        x = self.embed(input)
+        # mask = self._target_mask(input)
+        h, _, _ = self.encoder(x, text_lengths)
+        y = self.decoder(h)
+        return y, None
+
+    def with_vad(self):
+        return False
+
+    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
+        """Score new token.
+
+        Args:
+            y (torch.Tensor): 1D torch.int64 prefix tokens.
+            state: Scorer state for prefix tokens
+            x (torch.Tensor): encoder feature that generates ys.
+
+        Returns:
+            tuple[torch.Tensor, Any]: Tuple of
+                torch.float32 scores for next token (vocab_size)
+                and next state for ys
+
+        """
+        y = y.unsqueeze(0)
+        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1).squeeze(0)
+        return logp, cache
+
+    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
+        """Score new token batch.
+
+        Args:
+            ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
+            states (List[Any]): Scorer states for prefix tokens.
+            xs (torch.Tensor):
+                The encoder feature that generates ys (n_batch, xlen, n_feat).
+
+        Returns:
+            tuple[torch.Tensor, List[Any]]: Tuple of
+                batchfied scores for next token with shape of `(n_batch, vocab_size)`
+                and next state list for ys.
+
+        """
+        # merge states
+        n_batch = len(ys)
+        n_layers = len(self.encoder.encoders)
+        if states[0] is None:
+            batch_state = None
+        else:
+            # transpose state of [batch, layer] into [layer, batch]
+            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
+
+        # batch decoding
+        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
+        h = self.decoder(h[:, -1])
+        logp = h.log_softmax(dim=-1)
+
+        # transpose state of [layer, batch] into [batch, layer]
+        state_list = [[states[i][b] for i in range(n_layers)] for b in range(n_batch)]
+        return logp, state_list

--
Gitblit v1.9.1