From 0a4a1d5257dace9561d95b38a9386539908dcd5e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 23 四月 2024 12:48:52 +0800
Subject: [PATCH] Dev gzf exp (#1645)

---
 funasr/models/sense_voice/decoder.py |  273 +++++++++++++++++++++++++++++++++++++++++++++++++++++-
 1 files changed, 267 insertions(+), 6 deletions(-)

diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index bae2832..9087ea1 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -5,6 +5,32 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+import base64
+import gzip
+from dataclasses import dataclass
+from typing import Dict, Iterable, Optional
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class LayerNorm(nn.LayerNorm):
+    def forward(self, x: Tensor) -> Tensor:
+        return super().forward(x.float()).type(x.dtype)
+
+
+class Linear(nn.Linear):
+    def forward(self, x: Tensor) -> Tensor:
+        return F.linear(
+            x,
+            self.weight.to(x.dtype),
+            None if self.bias is None else self.bias.to(x.dtype),
+        )
+
+
 
 def sense_voice_decode_forward(
 	self,
@@ -38,10 +64,10 @@
 	
 	offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
 	tgt, memory = x, xa
-	tgt[tgt==-1] = 0
+	tgt[tgt == -1] = 0
 	tgt = (
 		self.token_embedding(tgt)
-		+ self.positional_embedding[offset : offset + tgt.size(1)]
+		+ self.positional_embedding[offset: offset + tgt.size(1)]
 	)
 	# tgt = self.dropout(tgt)
 	
@@ -54,13 +80,248 @@
 	
 	for layer, block in enumerate(self.blocks):
 		x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
-
-
+	
 	x = self.ln(x)
 	x = (
 		x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
 	).float()
 	
-	
 	return x
-	
\ No newline at end of file
+
+
+class MultiHeadAttention(nn.Module):
+	def __init__(self, n_state: int, n_head: int):
+		super().__init__()
+		self.n_head = n_head
+		self.query = Linear(n_state, n_state)
+		self.key = Linear(n_state, n_state, bias=False)
+		self.value = Linear(n_state, n_state)
+		self.out = Linear(n_state, n_state)
+	
+	def forward(
+		self,
+		x: Tensor,
+		xa: Optional[Tensor] = None,
+		mask: Optional[Tensor] = None,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		
+		q = self.query(x)
+		
+		if kv_cache is None or xa is None or self.key not in kv_cache:
+			# hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
+			# otherwise, perform key/value projections for self- or cross-attention as usual.
+			k = self.key(x if xa is None else xa)
+			v = self.value(x if xa is None else xa)
+		else:
+			# for cross-attention, calculate keys and values once and reuse in subsequent calls.
+			k = kv_cache[self.key]
+			v = kv_cache[self.value]
+		
+		wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
+		return self.out(wv), qk
+	
+	def qkv_attention(
+		self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, **kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		n_batch, n_ctx, n_state = q.shape
+		scale = (n_state // self.n_head) ** -0.25
+		q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
+		k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
+		v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
+		
+		qk = q @ k
+		if mask is not None:
+			if not is_pad_mask:
+				qk = qk + mask[:n_ctx, :n_ctx]
+			else:
+				mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+				min_value = float(
+					np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min
+				)
+				qk = qk.masked_fill(mask, min_value)
+		
+		qk = qk.float()
+		
+		w = F.softmax(qk, dim=-1).to(q.dtype)
+		if mask is not None and is_pad_mask:
+			w = w.masked_fill(mask, 0.0)
+		return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
+
+
+from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
+from omegaconf import OmegaConf
+
+
+class ResidualAttentionBlockRWKV(nn.Module):
+	def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs):
+		super().__init__()
+		
+		rwkv_cfg = kwargs.get("rwkv_cfg", {})
+		args = OmegaConf.create(rwkv_cfg)
+		self.attn = RWKVLayer(args=args, layer_id=layer_id)
+		if args.get("datatype", "bf16") == "bf16":
+			self.attn.to(torch.bfloat16)
+			
+		self.ln0 = None
+		if layer_id == 0 and not args.get("ln0", True):
+			self.ln0 = LayerNorm(args.n_embd)
+			if args.get("init_rwkv", True):
+				print("init_rwkv")
+				layer_id = 0
+				scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+				nn.init.constant_(self.ln0.weight, scale)
+		self.layer_id = layer_id
+		self.args = args
+		
+		self.ln1 = None
+		if not args.get("ln1", True):
+			self.ln1 = LayerNorm(args.n_embd)
+			# init
+			if args.get("init_rwkv", True):
+				print("init_rwkv")
+				scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
+				nn.init.constant_(self.ln1.weight, scale)
+		
+		self.cross_attn = (
+			MultiHeadAttention(n_state, n_head) if cross_attention else None
+		)
+		self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
+		
+		n_mlp = n_state * 4
+		self.mlp = nn.Sequential(
+			Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
+		)
+		self.mlp_ln = LayerNorm(n_state)
+	
+	def forward(
+		self,
+		x: Tensor,
+		xa: Optional[Tensor] = None,
+		mask: Optional[Tensor] = None,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		is_pad_mask = kwargs.get("is_pad_mask", False)
+		is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
+		
+		if self.layer_id == 0 and self.ln0 is not None:
+			x = self.ln0(x)
+		
+		if self.ln1 is None:
+			x = x + self.attn(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+		else:
+			x = x + self.attn(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
+	
+		if self.cross_attn:
+			x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask)[0]
+		x = x + self.mlp(self.mlp_ln(x))
+		
+		return x
+
+@tables.register("decoder_classes", "SenseVoiceDecoder")
+class SenseVoiceDecoder(nn.Module):
+	def __init__(
+		self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs
+	):
+		super().__init__()
+		
+		self.token_embedding = nn.Embedding(n_vocab, n_state)
+		self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
+
+
+		self.blocks = nn.ModuleList(
+			[
+				ResidualAttentionBlockRWKV(n_state, n_head, cross_attention=True, layer_id=i, **kwargs)
+				for i in range(n_layer)
+			]
+		)
+		self.ln = LayerNorm(n_state)
+		
+		mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
+		self.register_buffer("mask", mask, persistent=False)
+		
+		self.use_padmask = kwargs.get("use_padmask", True)
+	# def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
+	# 	"""
+	# 	x : torch.LongTensor, shape = (batch_size, <= n_ctx)
+	# 		the text tokens
+	# 	xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
+	# 		the encoded audio features to be attended on
+	# 	"""
+	# 	offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+	# 	x = (
+	# 		self.token_embedding(x)
+	# 		+ self.positional_embedding[offset: offset + x.shape[-1]]
+	# 	)
+	# 	x = x.to(xa.dtype)
+	#
+	# 	for block in self.blocks:
+	# 		x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
+	#
+	# 	x = self.ln(x)
+	# 	logits = (
+	# 		x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+	# 	).float()
+	#
+	# 	return logits
+
+	
+	def forward(
+		self,
+		x: torch.Tensor,
+		xa: torch.Tensor,
+		kv_cache: Optional[dict] = None,
+		**kwargs,
+	):
+		"""Forward decoder.
+	
+		Args:
+			hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
+			hlens: (batch)
+			ys_in_pad:
+				input token ids, int64 (batch, maxlen_out)
+				if input_layer == "embed"
+				input tensor (batch, maxlen_out, #mels) in the other cases
+			ys_in_lens: (batch)
+		Returns:
+			(tuple): tuple containing:
+	
+			x: decoded token score before softmax (batch, maxlen_out, token)
+				if use_output_layer is True,
+			olens: (batch, )
+		"""
+		# import pdb;pdb.set_trace()
+		use_padmask = self.use_padmask
+		hlens = kwargs.get("hlens", None)
+		
+		ys_in_lens = kwargs.get("ys_in_lens", None)
+		
+		offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
+		tgt, memory = x, xa
+		tgt[tgt == -1] = 0
+		tgt = (
+			self.token_embedding(tgt)
+			+ self.positional_embedding[offset: offset + tgt.size(1)]
+		)
+		# tgt = self.dropout(tgt)
+		
+		x = tgt.to(memory.dtype)
+		
+		if use_padmask and hlens is not None:
+			memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
+		else:
+			memory_mask = None
+		
+		for layer, block in enumerate(self.blocks):
+			x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
+		
+		x = self.ln(x)
+		x = (
+			x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
+		).float()
+		
+		return x

--
Gitblit v1.9.1