lzr265946
2023-02-03 1d97d628f2f19674fa50495e984db8185604ca8e
funasr/punctuation/target_delay_transformer.py
@@ -14,6 +14,7 @@
class TargetDelayTransformer(AbsPunctuation):
    def __init__(
        self,
        vocab_size: int,
@@ -28,7 +29,7 @@
    ):
        super().__init__()
        if pos_enc == "sinusoidal":
#            pos_enc_class = PositionalEncoding
            #            pos_enc_class = PositionalEncoding
            pos_enc_class = SinusoidalPositionEncoder
        elif pos_enc is None:
@@ -47,16 +48,16 @@
            num_blocks=layer,
            dropout_rate=dropout_rate,
            input_layer="pe",
           # pos_enc_class=pos_enc_class,
            # pos_enc_class=pos_enc_class,
            padding_idx=0,
        )
        self.decoder = nn.Linear(att_unit, punc_size)
#    def _target_mask(self, ys_in_pad):
#        ys_mask = ys_in_pad != 0
#        m = subsequent_n_mask(ys_mask.size(-1), 5, device=ys_mask.device).unsqueeze(0)
#        return ys_mask.unsqueeze(-2) & m
    def forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]:
        """Compute loss value from buffer sequences.
@@ -67,14 +68,12 @@
        """
        x = self.embed(input)
       # mask = self._target_mask(input)
        # mask = self._target_mask(input)
        h, _, _ = self.encoder(x, text_lengths)
        y = self.decoder(h)
        return y, None
    def score(
        self, y: torch.Tensor, state: Any, x: torch.Tensor
    ) -> Tuple[torch.Tensor, Any]:
    def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]:
        """Score new token.
        Args:
@@ -89,16 +88,12 @@
        """
        y = y.unsqueeze(0)
        h, _, cache = self.encoder.forward_one_step(
            self.embed(y), self._target_mask(y), cache=state
        )
        h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1).squeeze(0)
        return logp, cache
    def batch_score(
        self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor
    ) -> Tuple[torch.Tensor, List[Any]]:
    def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]:
        """Score new token batch.
        Args:
@@ -120,15 +115,10 @@
            batch_state = None
        else:
            # transpose state of [batch, layer] into [layer, batch]
            batch_state = [
                torch.stack([states[b][i] for b in range(n_batch)])
                for i in range(n_layers)
            ]
            batch_state = [torch.stack([states[b][i] for b in range(n_batch)]) for i in range(n_layers)]
        # batch decoding
        h, _, states = self.encoder.forward_one_step(
            self.embed(ys), self._target_mask(ys), cache=batch_state
        )
        h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state)
        h = self.decoder(h[:, -1])
        logp = h.log_softmax(dim=-1)