From 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 十二月 2024 17:16:11 +0800
Subject: [PATCH] Revert "shfit to shift (#2266)" (#2336)

---
 funasr/models/paraformer/decoder.py |   18 +++++++++---------
 1 files changed, 9 insertions(+), 9 deletions(-)

diff --git a/funasr/models/paraformer/decoder.py b/funasr/models/paraformer/decoder.py
index fafb8d4..7edd91a 100644
--- a/funasr/models/paraformer/decoder.py
+++ b/funasr/models/paraformer/decoder.py
@@ -248,7 +248,7 @@
         concat_after: bool = False,
         att_layer_num: int = 6,
         kernel_size: int = 21,
-        sanm_shift: int = 0,
+        sanm_shfit: int = 0,
         lora_list: List[str] = None,
         lora_rank: int = 8,
         lora_alpha: int = 16,
@@ -298,14 +298,14 @@
 
         self.att_layer_num = att_layer_num
         self.num_blocks = num_blocks
-        if sanm_shift is None:
-            sanm_shift = (kernel_size - 1) // 2
+        if sanm_shfit is None:
+            sanm_shfit = (kernel_size - 1) // 2
         self.decoders = repeat(
             att_layer_num,
             lambda lnum: DecoderLayerSANM(
                 attention_dim,
                 MultiHeadedAttentionSANMDecoder(
-                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
+                    attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
                     attention_heads,
@@ -330,7 +330,7 @@
                 lambda lnum: DecoderLayerSANM(
                     attention_dim,
                     MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0
+                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
                     ),
                     None,
                     PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -785,20 +785,20 @@
             for _ in range(cache_num)
         ]
         return (tgt, memory, pre_acoustic_embeds, cache)
-
+    
     def is_optimizable(self):
         return True
-
+    
     def get_input_names(self):
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
         return ['tgt', 'memory', 'pre_acoustic_embeds'] \
                + ['cache_%d' % i for i in range(cache_num)]
-
+    
     def get_output_names(self):
         cache_num = len(self.model.decoders) + len(self.model.decoders2)
         return ['y'] \
                + ['out_cache_%d' % i for i in range(cache_num)]
-
+    
     def get_dynamic_axes(self):
         ret = {
             'tgt': {

--
Gitblit v1.9.1