From 3549c0106e5a35ef2ddffdfd7381e613ed5310bd Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 14 三月 2024 15:11:22 +0800
Subject: [PATCH] update com define
---
funasr/models/paraformer/decoder.py | 537 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
1 files changed, 518 insertions(+), 19 deletions(-)
diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index f59ce4d..7c370ba 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -1,25 +1,26 @@
-from typing import List
-from typing import Tuple
-import logging
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
import torch
-import torch.nn as nn
-import numpy as np
+from typing import List, Tuple
+from funasr.register import tables
from funasr.models.scama import utils as myutils
-from funasr.models.transformer.decoder import BaseTransformerDecoder
-
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-from funasr.models.transformer.layer_norm import LayerNorm
-from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.decoder import DecoderLayer
-from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.decoder import BaseTransformerDecoder
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.utils.register import register_class, registry_tables
+from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
+from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
-class DecoderLayerSANM(nn.Module):
+
+class DecoderLayerSANM(torch.nn.Module):
"""Single decoder layer module.
Args:
@@ -62,12 +63,12 @@
self.norm2 = LayerNorm(size)
if src_attn is not None:
self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
+ self.dropout = torch.nn.Dropout(dropout_rate)
self.normalize_before = normalize_before
self.concat_after = concat_after
if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
+ self.concat_linear1 = torch.nn.Linear(size + size, size)
+ self.concat_linear2 = torch.nn.Linear(size + size, size)
self.reserve_attn=False
self.attn_mat = []
@@ -115,6 +116,22 @@
# x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
return x, tgt_mask, memory, memory_mask, cache
+
+ def get_attn_mat(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ residual = x
+ x = self.norm3(x)
+ x_src_attn, attn_mat = self.src_attn(x, memory, memory_mask, ret_attn=True)
+ return attn_mat
def forward_one_step(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
"""Compute decoded features.
@@ -200,7 +217,7 @@
return x, memory, fsmn_cache, opt_cache
-@register_class("decoder_classes", "ParaformerSANMDecoder")
+@tables.register("decoder_classes", "ParaformerSANMDecoder")
class ParaformerSANMDecoder(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -395,6 +412,46 @@
ys.unsqueeze(0), ys_mask, x.unsqueeze(0), cache=state
)
return logp.squeeze(0), state
+
+ def forward_asf2(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.model.decoders[1].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
+
+ def forward_asf6(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[0](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[1](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[2](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[3](tgt, tgt_mask, memory, memory_mask)
+ tgt, tgt_mask, memory, memory_mask, _ = self.decoders[4](tgt, tgt_mask, memory, memory_mask)
+ attn_mat = self.decoders[5].get_attn_mat(tgt, tgt_mask, memory, memory_mask)
+ return attn_mat
def forward_chunk(
self,
@@ -524,9 +581,335 @@
return y, new_cache
+class DecoderLayerSANMExport(torch.nn.Module):
-@register_class("decoder_classes", "ParaformerDecoderSAN")
-class ParaformerDecoderSAN(BaseTransformerDecoder):
+ def __init__(
+ self,
+ model
+ ):
+ super().__init__()
+ self.self_attn = model.self_attn
+ self.src_attn = model.src_attn
+ self.feed_forward = model.feed_forward
+ self.norm1 = model.norm1
+ self.norm2 = model.norm2 if hasattr(model, 'norm2') else None
+ self.norm3 = model.norm3 if hasattr(model, 'norm3') else None
+ self.size = model.size
+
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask=None, cache=None):
+
+ residual = tgt
+ tgt = self.norm1(tgt)
+ tgt = self.feed_forward(tgt)
+
+ x = tgt
+ if self.self_attn is not None:
+ tgt = self.norm2(tgt)
+ x, cache = self.self_attn(tgt, tgt_mask, cache=cache)
+ x = residual + x
+
+ if self.src_attn is not None:
+ residual = x
+ x = self.norm3(x)
+ x = residual + self.src_attn(x, memory, memory_mask)
+
+
+ return x, tgt_mask, memory, memory_mask, cache
+
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderExport")
+class ParaformerSANMDecoderExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True,
+ **kwargs
+ ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ if self.model.decoders2 is not None:
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
+ x, tgt_mask, memory, memory_mask
+ )
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
+@tables.register("decoder_classes", "ParaformerSANMDecoderOnlineExport")
+class ParaformerSANMDecoderOnlineExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True, **kwargs):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoderExport
+ from funasr.models.sanm.attention import MultiHeadedAttentionCrossAttExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAttExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANMExport(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoderExport(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANMExport(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ self.model.decoders3[i] = DecoderLayerSANMExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ *args,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ out_caches = list()
+ for i, decoder in enumerate(self.model.decoders):
+ in_cache = args[i]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ if self.model.decoders2 is not None:
+ for i, decoder in enumerate(self.model.decoders2):
+ in_cache = args[i + len(self.model.decoders)]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, out_caches
+
+ def get_dummy_inputs(self, enc_size):
+ enc = torch.randn(2, 100, enc_size).type(torch.float32)
+ enc_len = torch.tensor([30, 100], dtype=torch.int32)
+ acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
+ acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ cache = [
+ torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size - 1),
+ dtype=torch.float32)
+ for _ in range(cache_num)
+ ]
+ return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
+ + ['in_cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ['logits', 'sample_ids'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'enc': {
+ 0: 'batch_size',
+ 1: 'enc_length'
+ },
+ 'acoustic_embeds': {
+ 0: 'batch_size',
+ 1: 'token_length'
+ },
+ 'enc_len': {
+ 0: 'batch_size',
+ },
+ 'acoustic_embeds_len': {
+ 0: 'batch_size',
+ },
+
+ }
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ ret.update({
+ 'in_cache_%d' % d: {
+ 0: 'batch_size',
+ }
+ for d in range(cache_num)
+ })
+ ret.update({
+ 'out_cache_%d' % d: {
+ 0: 'batch_size',
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
+
+@tables.register("decoder_classes", "ParaformerSANDecoder")
+class ParaformerSANDecoder(BaseTransformerDecoder):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
@@ -639,3 +1022,119 @@
else:
return x, olens
+@tables.register("decoder_classes", "ParaformerDecoderSANExport")
+class ParaformerDecoderSANExport(torch.nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True, ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+
+
+ from funasr.utils.torch_function import sequence_mask
+
+ self.model = model
+
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+
+ from funasr.models.transformer.decoder import DecoderLayerExport
+ from funasr.models.transformer.attention import MultiHeadedAttentionExport
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.src_attn, MultiHeadedAttention):
+ d.src_attn = MultiHeadedAttentionExport(d.src_attn)
+ self.model.decoders[i] = DecoderLayerExport(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ x, tgt_mask, memory, memory_mask = self.model.decoders(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, ys_in_lens
+
+ def get_dummy_inputs(self, enc_size):
+ tgt = torch.LongTensor([0]).unsqueeze(0)
+ memory = torch.randn(1, 100, enc_size)
+ pre_acoustic_embeds = torch.randn(1, 1, enc_size)
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ cache = [
+ torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
+ for _ in range(cache_num)
+ ]
+ return (tgt, memory, pre_acoustic_embeds, cache)
+
+ def is_optimizable(self):
+ return True
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['tgt', 'memory', 'pre_acoustic_embeds'] \
+ + ['cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ return ['y'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'tgt': {
+ 0: 'tgt_batch',
+ 1: 'tgt_length'
+ },
+ 'memory': {
+ 0: 'memory_batch',
+ 1: 'memory_length'
+ },
+ 'pre_acoustic_embeds': {
+ 0: 'acoustic_embeds_batch',
+ 1: 'acoustic_embeds_length',
+ }
+ }
+ cache_num = len(self.model.decoders) + len(self.model.decoders2)
+ ret.update({
+ 'cache_%d' % d: {
+ 0: 'cache_%d_batch' % d,
+ 2: 'cache_%d_length' % d
+ }
+ for d in range(cache_num)
+ })
+ return ret
+
\ No newline at end of file
--
Gitblit v1.9.1