lzr265946
2023-02-03 1d97d628f2f19674fa50495e984db8185604ca8e
funasr/punctuation/target_delay_transformer.py
@@ -14,6 +14,7 @@
class TargetDelayTransformer(AbsPunctuation):
    def __init__(
        self,
        vocab_size: int,
@@ -52,11 +53,11 @@
        )
        self.decoder = nn.Linear(att_unit, punc_size)
#    def _target_mask(self, ys_in_pad):
#        ys_mask = ys_in_pad != 0
#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
#        return ys_mask.unsqueeze(-2) & m
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
@@ -72,9 +73,7 @@
        y = self.decoder(h)
        return y, None
    def score(
        self, y: torch.Tensor, state: Any, x: torch.Tensor
    ) -> Tuple[torch.Tensor, Any]:
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.
        Args:
@@ -89,16 +88,12 @@
        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
@@ -120,15 +115,10 @@
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]
            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
        # batch decoding
        h, _, states = self.encoder.forward_one_step(
            self.embed(ys), self._target_mask(ys), cache=batch_state
        )
        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1)