游雁
2024-05-16 48a894c8e3babed74c2f9ab8832fe6cefe3967aa
funasr/models/sense_voice/decoder.py
@@ -337,6 +337,29 @@
        return x
    def init_state(self, x):
        state = {}
        return state
    def final_score(self, state) -> float:
        """Score eos (optional).
        Args:
            state: Scorer state for prefix tokens
        Returns:
            float: final score
        """
        return 0.0
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        return logp.squeeze(0)[-1, :], state
class MultiHeadedAttentionSANMDecoder(nn.Module):
    """Multi-Head Attention layer.
@@ -449,7 +472,7 @@
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
@@ -576,5 +599,6 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state