From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/models/encoder/sanm_encoder.py |   16 ++++++++++++----
 1 files changed, 12 insertions(+), 4 deletions(-)

diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 46eabd1..9e27d4a 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -8,7 +8,6 @@
 import torch.nn as nn
 import torch.nn.functional as F
 from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
-from typeguard import check_argument_types
 import numpy as np
 from funasr.torch_utils.device_funcs import to_device
 from funasr.modules.nets_utils import make_pad_mask
@@ -147,11 +146,14 @@
         interctc_use_conditioning: bool = False,
         kernel_size : int = 11,
         sanm_shfit : int = 0,
+        lora_list: List[str] = None,
+        lora_rank: int = 8,
+        lora_alpha: int = 16,
+        lora_dropout: float = 0.1,
         selfattention_layer_type: str = "sanm",
         tf2torch_tensor_name_prefix_torch: str = "encoder",
         tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 
@@ -231,6 +233,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
 
             encoder_selfattn_layer_args = (
@@ -240,6 +246,10 @@
                 attention_dropout_rate,
                 kernel_size,
                 sanm_shfit,
+                lora_list,
+                lora_rank,
+                lora_alpha,
+                lora_dropout,
             )
         self.encoders0 = repeat(
             1,
@@ -601,7 +611,6 @@
             tf2torch_tensor_name_prefix_torch: str = "encoder",
             tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 
@@ -1060,7 +1069,6 @@
         sanm_shfit : int = 0,
         selfattention_layer_type: str = "sanm",
     ):
-        assert check_argument_types()
         super().__init__()
         self._output_size = output_size
 

--
Gitblit v1.9.1