From 1367973f9818d8e15c7bf52ad6ffba4ddb6ac2b2 Mon Sep 17 00:00:00 2001
From: Rin Arakaki <rnarkkx@gmail.com>
Date: 星期二, 24 十二月 2024 17:51:31 +0800
Subject: [PATCH] shfit to shift (#2266)

---
 funasr/models/scama/chunk_utilis.py |   40 ++++++++++++++++++++--------------------
 1 files changed, 20 insertions(+), 20 deletions(-)

diff --git a/funasr/models/scama/chunk_utilis.py b/funasr/models/scama/chunk_utilis.py
index 2fe3fa4..d9b4aa9 100644
--- a/funasr/models/scama/chunk_utilis.py
+++ b/funasr/models/scama/chunk_utilis.py
@@ -21,7 +21,7 @@
         stride: tuple = (10,),
         pad_left: tuple = (0,),
         encoder_att_look_back_factor: tuple = (1,),
-        shfit_fsmn: int = 0,
+        shift_fsmn: int = 0,
         decoder_att_look_back_factor: tuple = (1,),
     ):
 
@@ -45,11 +45,11 @@
             encoder_att_look_back_factor,
             decoder_att_look_back_factor,
         )
-        self.shfit_fsmn = shfit_fsmn
+        self.shift_fsmn = shift_fsmn
         self.x_add_mask = None
         self.x_rm_mask = None
         self.x_len = None
-        self.mask_shfit_chunk = None
+        self.mask_shift_chunk = None
         self.mask_chunk_predictor = None
         self.mask_att_chunk_encoder = None
         self.mask_shift_att_chunk_decoder = None
@@ -88,7 +88,7 @@
             stride,
             pad_left,
             encoder_att_look_back_factor,
-            chunk_size + self.shfit_fsmn,
+            chunk_size + self.shift_fsmn,
             decoder_att_look_back_factor,
         )
         return (
@@ -118,13 +118,13 @@
             chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = (
                 self.get_chunk_size(ind)
             )
-            shfit_fsmn = self.shfit_fsmn
+            shift_fsmn = self.shift_fsmn
             pad_right = chunk_size - stride - pad_left
 
             chunk_num_batch = np.ceil(x_len / stride).astype(np.int32)
             x_len_chunk = (
                 (chunk_num_batch - 1) * chunk_size_pad_shift
-                + shfit_fsmn
+                + shift_fsmn
                 + pad_left
                 + 0
                 + x_len
@@ -138,13 +138,13 @@
             max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
             x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
             x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
-            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
+            mask_shift_chunk = np.zeros([0, num_units], dtype=dtype)
             mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
             mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
             mask_att_chunk_encoder = np.zeros([0, chunk_num * chunk_size_pad_shift], dtype=dtype)
             for chunk_ids in range(chunk_num):
                 # x_mask add
-                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
+                fsmn_padding = np.zeros((shift_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
                 x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
                 x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
                 x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
@@ -154,7 +154,7 @@
                 x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
 
                 # x_mask rm
-                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), dtype=dtype)
+                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shift_fsmn), dtype=dtype)
                 padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left), dtype=dtype)
                 padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
                 x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
@@ -170,13 +170,13 @@
                 x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
 
                 # fsmn_padding_mask
-                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
+                pad_shift_mask = np.zeros([shift_fsmn, num_units], dtype=dtype)
                 ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
-                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
-                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
+                mask_shift_chunk_cur = np.concatenate([pad_shift_mask, ones_1], axis=0)
+                mask_shift_chunk = np.concatenate([mask_shift_chunk, mask_shift_chunk_cur], axis=0)
 
                 # predictor mask
-                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
+                zeros_1 = np.zeros([shift_fsmn + pad_left, num_units_predictor], dtype=dtype)
                 ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
                 zeros_3 = np.zeros(
                     [chunk_size - stride - pad_left, num_units_predictor], dtype=dtype
@@ -188,13 +188,13 @@
                 )
 
                 # encoder att mask
-                zeros_1_top = np.zeros([shfit_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
+                zeros_1_top = np.zeros([shift_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
 
                 zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
                 zeros_2 = np.zeros([chunk_size, zeros_2_num * chunk_size_pad_shift], dtype=dtype)
 
                 encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
-                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                zeros_2_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
                 ones_2_mid = np.ones([stride, stride], dtype=dtype)
                 zeros_2_bottom = np.zeros([chunk_size - stride, stride], dtype=dtype)
                 zeros_2_right = np.zeros([chunk_size, chunk_size - stride], dtype=dtype)
@@ -202,7 +202,7 @@
                 ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
                 ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
 
-                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                zeros_3_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
                 ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
                 ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
 
@@ -218,7 +218,7 @@
                 )
 
                 # decoder fsmn_shift_att_mask
-                zeros_1 = np.zeros([shfit_fsmn, 1])
+                zeros_1 = np.zeros([shift_fsmn, 1])
                 ones_1 = np.ones([chunk_size, 1])
                 mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
                 mask_shift_att_chunk_decoder = np.concatenate(
@@ -229,7 +229,7 @@
             self.x_len_chunk = x_len_chunk
             self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
             self.x_len = x_len
-            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
+            self.mask_shift_chunk = mask_shift_chunk[:x_len_chunk_max, :]
             self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
             self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
             self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
@@ -238,7 +238,7 @@
                 self.x_len_chunk,
                 self.x_rm_mask,
                 self.x_len,
-                self.mask_shfit_chunk,
+                self.mask_shift_chunk,
                 self.mask_chunk_predictor,
                 self.mask_att_chunk_encoder,
                 self.mask_shift_att_chunk_decoder,
@@ -309,7 +309,7 @@
             x = torch.from_numpy(x).type(dtype).to(device)
         return x
 
-    def get_mask_shfit_chunk(
+    def get_mask_shift_chunk(
         self, chunk_outs=None, device="cpu", batch_size=1, num_units=1, idx=4, dtype=torch.float32
     ):
         with torch.no_grad():

--
Gitblit v1.9.1