From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/models/decoder/sanm_decoder.py |   15 ++++++++-------
 1 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/funasr/models/decoder/sanm_decoder.py b/funasr/models/decoder/sanm_decoder.py
index 508eb73..c12e098 100644
--- a/funasr/models/decoder/sanm_decoder.py
+++ b/funasr/models/decoder/sanm_decoder.py
@@ -7,7 +7,6 @@
 
 from funasr.modules.streaming_utils import utils as myutils
 from funasr.models.decoder.transformer_decoder import BaseTransformerDecoder
-from typeguard import check_argument_types
 
 from funasr.modules.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
 from funasr.modules.embedding import PositionalEncoding
@@ -181,7 +180,6 @@
             tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
             embed_tensor_name_prefix_tf: str = None,
     ):
-        assert check_argument_types()
         super().__init__(
             vocab_size=vocab_size,
             encoder_output_size=encoder_output_size,
@@ -835,10 +833,13 @@
         att_layer_num: int = 6,
         kernel_size: int = 21,
         sanm_shfit: int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         tf2torch_tensor_name_prefix_torch: str = "decoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
     ):
-        assert check_argument_types()
         super().__init__(
             vocab_size=vocab_size,
             encoder_output_size=encoder_output_size,
@@ -888,7 +889,7 @@
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate
+                    attention_heads, attention_dim, src_attention_dropout_rate, lora_list, lora_rank, lora_alpha, lora_dropout
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -956,13 +957,13 @@
         """
         tgt = ys_in_pad
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+        
+        memory = hs_pad
+        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         if chunk_mask is not None:
             memory_mask = memory_mask * chunk_mask
             if tgt_mask.size(1) != memory_mask.size(1):
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-
-        memory = hs_pad
-        memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
 
         x = tgt
         x, tgt_mask, memory, memory_mask, _ = self.decoders(

--
Gitblit v1.9.1