chenmengzheAAA
2023-09-14 30c40c643c19f6e2ac8679fa76d09d0f9ceccc65
funasr/models/encoder/sanm_encoder.py
@@ -8,7 +8,6 @@
import torch.nn as nn
import torch.nn.functional as F
from funasr.modules.streaming_utils.chunk_utilis import overlap_chunk
from typeguard import check_argument_types
import numpy as np
from funasr.torch_utils.device_funcs import to_device
from funasr.modules.nets_utils import make_pad_mask
@@ -147,11 +146,14 @@
        interctc_use_conditioning: bool = False,
        kernel_size : int = 11,
        sanm_shfit : int = 0,
        lora_list: List[str] = None,
        lora_rank: int = 8,
        lora_alpha: int = 16,
        lora_dropout: float = 0.1,
        selfattention_layer_type: str = "sanm",
        tf2torch_tensor_name_prefix_torch: str = "encoder",
        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
@@ -231,6 +233,10 @@
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
                lora_list,
                lora_rank,
                lora_alpha,
                lora_dropout,
            )
            encoder_selfattn_layer_args = (
@@ -240,6 +246,10 @@
                attention_dropout_rate,
                kernel_size,
                sanm_shfit,
                lora_list,
                lora_rank,
                lora_alpha,
                lora_dropout,
            )
        self.encoders0 = repeat(
            1,
@@ -601,7 +611,6 @@
            tf2torch_tensor_name_prefix_torch: str = "encoder",
            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size
@@ -1060,7 +1069,6 @@
        sanm_shfit : int = 0,
        selfattention_layer_type: str = "sanm",
    ):
        assert check_argument_types()
        super().__init__()
        self._output_size = output_size