From 4ace5a95b052d338947fc88809a440ccd55cf6b4 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 16 十一月 2023 16:39:52 +0800
Subject: [PATCH] funasr pages
---
funasr/models/encoder/sanm_encoder.py | 129 +++++++++++++++++++++++++++++++++++++++----
1 files changed, 117 insertions(+), 12 deletions(-)
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a3a353..c15343e 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -6,12 +6,13 @@
import logging
import torch
import torch.nn as nn
+import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
import numpy as np
+from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
@@ -25,9 +26,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -112,12 +114,48 @@
if not self.normalize_before:
x = self.norm2(x)
-
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+
+ def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if self.in_size == self.size:
+ attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+ x = residual + attn
+ else:
+ x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
+
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + self.feed_forward(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ return x, cache
+
class SANMEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -144,11 +182,14 @@
interctc_use_conditioning: bool = False,
kernel_size : int = 11,
sanm_shfit : int = 0,
+ lora_list: List[str] = None,
+ lora_rank: int = 8,
+ lora_alpha: int = 16,
+ lora_dropout: float = 0.1,
selfattention_layer_type: str = "sanm",
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -180,6 +221,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -226,6 +269,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
encoder_selfattn_layer_args = (
@@ -235,6 +282,10 @@
attention_dropout_rate,
kernel_size,
sanm_shfit,
+ lora_list,
+ lora_rank,
+ lora_alpha,
+ lora_dropout,
)
self.encoders0 = repeat(
1,
@@ -347,6 +398,14 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
def forward_chunk(self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
@@ -357,8 +416,11 @@
if self.embed is None:
xs_pad = xs_pad
else:
- xs_pad = self.embed.forward_chunk(xs_pad, cache)
-
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
@@ -549,7 +611,7 @@
class SANMEncoderChunkOpt(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -585,7 +647,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -617,6 +678,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -802,6 +865,49 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
+ if cache["opt"] is None:
+ cache_layer_num = len(self.encoders0) + len(self.encoders)
+ new_cache = [None] * cache_layer_num
+ else:
+ new_cache = cache["opt"]
+
+ for layer_idx, encoder_layer in enumerate(self.encoders0):
+ encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
+
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+ if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1:
+ cache["opt"] = new_cache
+
+ return xs_pad, ilens, None
+
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
@@ -962,7 +1068,7 @@
class SANMVadEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -989,7 +1095,6 @@
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
--
Gitblit v1.9.1