Yuekai Zhang
2023-03-06 80e6c258cf89b5f11f4e52a4cc5a9cf2e95aa7be
funasr/punctuation/target_delay_transformer.py
@@ -8,12 +8,13 @@
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.embedding import SinusoidalPositionEncoder
#from funasr.models.encoder.transformer_encoder import TransformerEncoder as Encoder
from funasr.models.encoder.sanm_encoder import SANMEncoder as Encoder
from funasr.punctuation.sanm_encoder import SANMEncoder as Encoder
#from funasr.modules.mask import subsequent_n_mask
from funasr.punctuation.abs_model import AbsPunctuation
class TargetDelayTransformer(AbsPunctuation):
    def __init__(
        self,
        vocab_size: int,
@@ -28,7 +29,7 @@
    ):
        super().__init__()
        if pos_enc == "sinusoidal":
#            pos_enc_class = PositionalEncoding
            #            pos_enc_class = PositionalEncoding
            pos_enc_class = SinusoidalPositionEncoder
        elif pos_enc is None:
@@ -47,16 +48,16 @@
            num_blocks=layer,
            dropout_rate=dropout_rate,
            input_layer="pe",
           # pos_enc_class=pos_enc_class,
            # pos_enc_class=pos_enc_class,
            padding_idx=0,
        )
        self.decoder = nn.Linear(att_unit, punc_size)
#    def _target_mask(self, ys_in_pad):
#        ys_mask = ys_in_pad != 0
#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
#        return ys_mask.unsqueeze(-2) & m
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
@@ -67,14 +68,15 @@
        """
        x = self.embed(input)
       # mask = self._target_mask(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
    def score(
        self, y: torch.Tensor, state: Any, x: torch.Tensor
    ) -> Tuple[torch.Tensor, Any]:
    def with_vad(self):
        return False
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.
        Args:
@@ -89,16 +91,12 @@
        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
@@ -120,15 +118,10 @@
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]
            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
        # batch decoding
        h, _, states = self.encoder.forward_one_step(
            self.embed(ys), self._target_mask(ys), cache=batch_state
        )
        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1)