From 00d0df3a1018c63ec8c5d13e611f53c564c0a7e2 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 06 五月 2024 22:17:25 +0800
Subject: [PATCH] Dev gzf decoding (#1695)
---
funasr/models/sense_voice/decoder.py | 70 +++++++++++++++++++++++++++++++++-
1 files changed, 67 insertions(+), 3 deletions(-)
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index f5b8825..19d9c16 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -15,6 +15,7 @@
import torch
import torch.nn.functional as F
from torch import Tensor, nn
+from funasr.models.transformer.utils.mask import subsequent_mask
class LayerNorm(nn.LayerNorm):
@@ -336,6 +337,29 @@
return x
+ def init_state(self, x):
+ state = {}
+
+ return state
+
+ def final_score(self, state) -> float:
+ """Score eos (optional).
+
+ Args:
+ state: Scorer state for prefix tokens
+
+ Returns:
+ float: final score
+
+ """
+ return 0.0
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+ return logp.squeeze(0)[-1, :], state
+
class MultiHeadedAttentionSANMDecoder(nn.Module):
"""Multi-Head Attention layer.
@@ -443,9 +467,19 @@
kv_cache: Optional[dict] = None,
**kwargs,
):
+ cache = kwargs.get("cache", {})
+ layer = kwargs.get("layer", 0)
is_pad_mask = kwargs.get("is_pad_mask", False)
is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
- x = x + self.attn(self.attn_ln(x), mask=None, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+
+ fsmn_cache = cache[layer]["fsmn_cache"] if len(cache) > 0 else None
+ # if fsmn_cache is not None:
+ # x = x[:, -1:]
+ att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
+ # if len(cache)>1:
+ # cache[layer]["fsmn_cache"] = fsmn_cache
+ # x = x[:, -1:]
+ x = x + att_res
if self.cross_attn:
x = (
x
@@ -510,10 +544,9 @@
ys_in_lens = kwargs.get("ys_in_lens", None)
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
tgt, memory = x, xa
tgt[tgt == -1] = 0
- tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
+ tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
# tgt = self.dropout(tgt)
x = tgt.to(memory.dtype)
@@ -531,9 +564,40 @@
memory_mask=memory_mask,
is_pad_mask=False,
is_pad_memory_mask=True,
+ cache=kwargs.get("cache", None),
+ layer=layer,
)
x = self.ln(x)
x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
return x
+
+ def init_state(self, x):
+ state = {}
+ for layer, block in enumerate(self.blocks):
+ state[layer] = {
+ "fsmn_cache": None,
+ "memory_key": None,
+ "memory_value": None,
+ }
+
+ return state
+
+ def final_score(self, state) -> float:
+ """Score eos (optional).
+
+ Args:
+ state: Scorer state for prefix tokens
+
+ Returns:
+ float: final score
+
+ """
+ return 0.0
+
+ def score(self, ys, state, x):
+ """Score."""
+ ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
+ logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+ return logp.squeeze(0)[-1, :], state
--
Gitblit v1.9.1