游雁
2024-06-24 1596f6f414f6f41da66506debb1dff19fffeb3ec
funasr/models/sense_voice/decoder.py
@@ -146,7 +146,9 @@
                qk = qk + mask[:n_ctx, :n_ctx]
            else:
                mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
                min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                min_value = -float(
                    "inf"
                )  # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
                qk = qk.masked_fill(mask, min_value)
        qk = qk.float()
@@ -358,6 +360,7 @@
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state
@@ -472,7 +475,7 @@
        is_pad_mask = kwargs.get("is_pad_mask", False)
        is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
        fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 or cache is None else None
        fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
        # if fsmn_cache is not None:
        #     x = x[:, -1:]
        att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
@@ -599,5 +602,6 @@
    def score(self, ys, state, x):
        """Score."""
        ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
        logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
        logp = torch.log_softmax(logp, dim=-1)
        return logp.squeeze(0)[-1, :], state