游雁
2024-05-17 d3ff05837bbc14749d09f44947633b87e8f2db0e
funasr/models/sense_voice/decoder.py
@@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
from funasr.models.transformer.utils.mask import subsequent_mask
class LayerNorm(nn.LayerNorm):
@@ -145,7 +146,9 @@
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                min_value = -float(
                    "inf"
                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()
@@ -336,6 +339,29 @@
        return x
    def init_state(self, x):
        state = {}
        return state
    def final_score(self, state) -> float:
        """Score eos (optional).
        Args:
            state: Scorer state for prefix tokens
        Returns:
            float: final score
        """
        return 0.0
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        return logp.squeeze(0)[-1, :], state
class MultiHeadedAttentionSANMDecoder(nn.Module):
    """Multi-Head Attention layer.
@@ -365,7 +391,7 @@
        self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
        self.kernel_size = kernel_size
    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None):
    def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs):
        """
        :param x: (#batch, time1, size).
        :param mask: Mask tensor (#batch, 1, time)
@@ -443,9 +469,19 @@
        kv_cache: Optional[dict] = None,
        **kwargs,
    ):
        cache = kwargs.get("cache", {})
        layer = kwargs.get("layer", 0)
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
        # if len(cache)>1:
        #     cache[layer]["fsmn_cache"] = fsmn_cache
        #     x = x[:, -1:]
        x = x + att_res
        if self.cross_attn:
            x = (
                x
@@ -510,10 +546,9 @@
        ys_in_lens = kwargs.get("ys_in_lens", None)
        offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
        tgt, memory = x, xa
        tgt[tgt == -1] = 0
        tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
        tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
        # tgt = self.dropout(tgt)
        x = tgt.to(memory.dtype)
@@ -531,9 +566,41 @@
                memory_mask=memory_mask,
                is_pad_mask=False,
                is_pad_memory_mask=True,
                cache=kwargs.get("cache", None),
                layer=layer,
            )
        x = self.ln(x)
        x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
        return x
    def init_state(self, x):
        state = {}
        for layer, block in enumerate(self.blocks):
            state[layer] = {
                "fsmn_cache": None,
                "memory_key": None,
                "memory_value": None,
            }
        return state
    def final_score(self, state) -> float:
        """Score eos (optional).
        Args:
            state: Scorer state for prefix tokens
        Returns:
            float: final score
        """
        return 0.0
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state