| | |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk |
| | | from typeguard import check_argument_types |
| | | import numpy as np |
| | | from funasr.torch_utils.device_funcs import to_device |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.attention import MultiHeadedAttention, MultiHeadedAttentionSANM, MultiHeadedAttentionSANMwithMask |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder |
| | | from funasr.modules.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | |
| | | from funasr.modules.subsampling import Conv2dSubsampling8 |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | from funasr.models.ctc import CTC |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.mask import subsequent_mask, vad_mask |
| | | |
| | | class EncoderLayerSANM(nn.Module): |
| | | def __init__( |
| | |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | |
| | | return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder |
| | | |
| | | def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (torch.Tensor): Input tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | |
| | | """ |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | if self.in_size == self.size: |
| | | attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
| | | x = residual + attn |
| | | else: |
| | | x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back) |
| | | |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + self.feed_forward(x) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | return x, cache |
| | | |
| | | |
| | | class SANMEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | San-m: Memory equipped self-attention for end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | interctc_use_conditioning: bool = False, |
| | | kernel_size : int = 11, |
| | | sanm_shfit : int = 0, |
| | | lora_list: List[str] = None, |
| | | lora_rank: int = 8, |
| | | lora_alpha: int = 16, |
| | | lora_dropout: float = 0.1, |
| | | selfattention_layer_type: str = "sanm", |
| | | tf2torch_tensor_name_prefix_torch: str = "encoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | elif input_layer == "pe_online": |
| | | self.embed = StreamSinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | lora_list, |
| | | lora_rank, |
| | | lora_alpha, |
| | | lora_dropout, |
| | | ) |
| | | |
| | | encoder_selfattn_layer_args = ( |
| | |
| | | attention_dropout_rate, |
| | | kernel_size, |
| | | sanm_shfit, |
| | | lora_list, |
| | | lora_rank, |
| | | lora_alpha, |
| | | lora_dropout, |
| | | ) |
| | | self.encoders0 = repeat( |
| | | 1, |
| | |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): |
| | | if len(cache) == 0: |
| | | return feats |
| | | cache["feats"] = to_device(cache["feats"], device=feats.device) |
| | | overlap_feats = torch.cat((cache["feats"], feats), dim=1) |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | return overlap_feats |
| | | |
| | | def forward_chunk(self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | else: |
| | | xs_pad = self.embed.forward_chunk(xs_pad, cache) |
| | | |
| | | xs_pad = self.embed(xs_pad, cache) |
| | | if cache["tail_chunk"]: |
| | | xs_pad = to_device(cache["feats"], device=xs_pad.device) |
| | | else: |
| | | xs_pad = self._add_overlap_chunk(xs_pad, cache) |
| | | encoder_outs = self.encoders0(xs_pad, None, None, None, None) |
| | | xs_pad, masks = encoder_outs[0], encoder_outs[1] |
| | | intermediate_outs = [] |
| | |
| | | |
| | | class SANMEncoderChunkOpt(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition |
| | | https://arxiv.org/abs/2006.01713 |
| | | |
| | |
| | | tf2torch_tensor_name_prefix_torch: str = "encoder", |
| | | tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | |
| | | self.embed = torch.nn.Linear(input_size, output_size) |
| | | elif input_layer == "pe": |
| | | self.embed = SinusoidalPositionEncoder() |
| | | elif input_layer == "pe_online": |
| | | self.embed = StreamSinusoidalPositionEncoder() |
| | | else: |
| | | raise ValueError("unknown input_layer: " + input_layer) |
| | | self.normalize_before = normalize_before |
| | |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | def _add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}): |
| | | if len(cache) == 0: |
| | | return feats |
| | | cache["feats"] = to_device(cache["feats"], device=feats.device) |
| | | overlap_feats = torch.cat((cache["feats"], feats), dim=1) |
| | | cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :] |
| | | return overlap_feats |
| | | |
| | | def forward_chunk(self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | cache: dict = None, |
| | | ): |
| | | xs_pad *= self.output_size() ** 0.5 |
| | | if self.embed is None: |
| | | xs_pad = xs_pad |
| | | else: |
| | | xs_pad = self.embed(xs_pad, cache) |
| | | if cache["tail_chunk"]: |
| | | xs_pad = to_device(cache["feats"], device=xs_pad.device) |
| | | else: |
| | | xs_pad = self._add_overlap_chunk(xs_pad, cache) |
| | | if cache["opt"] is None: |
| | | cache_layer_num = len(self.encoders0) + len(self.encoders) |
| | | new_cache = [None] * cache_layer_num |
| | | else: |
| | | new_cache = cache["opt"] |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders0): |
| | | encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]) |
| | | xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1] |
| | | |
| | | for layer_idx, encoder_layer in enumerate(self.encoders): |
| | | encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"]) |
| | | xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1] |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | if cache["encoder_chunk_look_back"] > 0 or cache["encoder_chunk_look_back"] == -1: |
| | | cache["opt"] = new_cache |
| | | |
| | | return xs_pad, ilens, None |
| | | |
| | | def gen_tf2torch_map_dict(self): |
| | | tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch |
| | | tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf |
| | |
| | | |
| | | class SANMVadEncoder(AbsEncoder): |
| | | """ |
| | | author: Speech Lab, Alibaba Group, China |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | |
| | | """ |
| | | |
| | |
| | | sanm_shfit : int = 0, |
| | | selfattention_layer_type: str = "sanm", |
| | | ): |
| | | assert check_argument_types() |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |