From bc723ea200144bd6fa8a5dff4b9a780feda144fc Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 29 六月 2023 18:55:01 +0800
Subject: [PATCH] dcos
---
funasr/models/encoder/sanm_encoder.py | 89 +++++++++++++++++++++++++++++++++++++++-----
1 files changed, 78 insertions(+), 11 deletions(-)
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a3a353..45163df 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -6,12 +6,13 @@
import logging
import torch
import torch.nn as nn
+import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
import numpy as np
+from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask
-from funasr.modules.embedding import SinusoidalPositionEncoder
+from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.modules.layer_norm import LayerNorm
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
@@ -25,9 +26,10 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
+from funasr.modules.mask import subsequent_mask, vad_mask
+
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -117,7 +119,7 @@
class SANMEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
San-m: Memory equipped self-attention for end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -148,7 +150,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -180,6 +181,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -347,6 +350,14 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
def forward_chunk(self,
xs_pad: torch.Tensor,
ilens: torch.Tensor,
@@ -357,8 +368,11 @@
if self.embed is None:
xs_pad = xs_pad
else:
- xs_pad = self.embed.forward_chunk(xs_pad, cache)
-
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
encoder_outs = self.encoders0(xs_pad, None, None, None, None)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
@@ -549,7 +563,7 @@
class SANMEncoderChunkOpt(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01713
@@ -585,7 +599,6 @@
tf2torch_tensor_name_prefix_torch: str = "encoder",
tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
@@ -617,6 +630,8 @@
self.embed = torch.nn.Linear(input_size, output_size)
elif input_layer == "pe":
self.embed = SinusoidalPositionEncoder()
+ elif input_layer == "pe_online":
+ self.embed = StreamSinusoidalPositionEncoder()
else:
raise ValueError("unknown input_layer: " + input_layer)
self.normalize_before = normalize_before
@@ -802,6 +817,59 @@
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+ def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ cache["feats"] = to_device(cache["feats"], device=feats.device)
+ overlap_feats = torch.cat((cache["feats"], feats), dim=1)
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ return overlap_feats
+
+ def forward_chunk(self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ ctc: CTC = None,
+ ):
+ xs_pad *= self.output_size() ** 0.5
+ if self.embed is None:
+ xs_pad = xs_pad
+ else:
+ xs_pad = self.embed(xs_pad, cache)
+ if cache["tail_chunk"]:
+ xs_pad = to_device(cache["feats"], device=xs_pad.device)
+ else:
+ xs_pad = self._add_overlap_chunk(xs_pad, cache)
+ encoder_outs = self.encoders0(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ intermediate_outs = []
+ if len(self.interctc_layer_idx) == 0:
+ encoder_outs = self.encoders(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ else:
+ for layer_idx, encoder_layer in enumerate(self.encoders):
+ encoder_outs = encoder_layer(xs_pad, None, None, None, None)
+ xs_pad, masks = encoder_outs[0], encoder_outs[1]
+ if layer_idx + 1 in self.interctc_layer_idx:
+ encoder_out = xs_pad
+
+ # intermediate outputs are also normalized
+ if self.normalize_before:
+ encoder_out = self.after_norm(encoder_out)
+
+ intermediate_outs.append((layer_idx + 1, encoder_out))
+
+ if self.interctc_use_conditioning:
+ ctc_out = ctc.softmax(encoder_out)
+ xs_pad = xs_pad + self.conditioning_layer(ctc_out)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ if len(intermediate_outs) > 0:
+ return (xs_pad, intermediate_outs), None, None
+ return xs_pad, ilens, None
+
def gen_tf2torch_map_dict(self):
tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
@@ -962,7 +1030,7 @@
class SANMVadEncoder(AbsEncoder):
"""
- author: Speech Lab, Alibaba Group, China
+ Author: Speech Lab of DAMO Academy, Alibaba Group
"""
@@ -989,7 +1057,6 @@
sanm_shfit : int = 0,
selfattention_layer_type: str = "sanm",
):
- assert check_argument_types()
super().__init__()
self._output_size = output_size
--
Gitblit v1.9.1