From 1367973f9818d8e15c7bf52ad6ffba4ddb6ac2b2 Mon Sep 17 00:00:00 2001
From: Rin Arakaki <rnarkkx@gmail.com>
Date: 星期二, 24 十二月 2024 17:51:31 +0800
Subject: [PATCH] shfit to shift (#2266)
---
funasr/models/scama/chunk_utilis.py | 40 ++++++++++++++++++++--------------------
1 files changed, 20 insertions(+), 20 deletions(-)
diff --git a/funasr/models/scama/chunk_utilis.py b/funasr/models/scama/chunk_utilis.py
index 2fe3fa4..d9b4aa9 100644
--- a/funasr/models/scama/chunk_utilis.py
+++ b/funasr/models/scama/chunk_utilis.py
@@ -21,7 +21,7 @@
stride: tuple = (10,),
pad_left: tuple = (0,),
encoder_att_look_back_factor: tuple = (1,),
- shfit_fsmn: int = 0,
+ shift_fsmn: int = 0,
decoder_att_look_back_factor: tuple = (1,),
):
@@ -45,11 +45,11 @@
encoder_att_look_back_factor,
decoder_att_look_back_factor,
)
- self.shfit_fsmn = shfit_fsmn
+ self.shift_fsmn = shift_fsmn
self.x_add_mask = None
self.x_rm_mask = None
self.x_len = None
- self.mask_shfit_chunk = None
+ self.mask_shift_chunk = None
self.mask_chunk_predictor = None
self.mask_att_chunk_encoder = None
self.mask_shift_att_chunk_decoder = None
@@ -88,7 +88,7 @@
stride,
pad_left,
encoder_att_look_back_factor,
- chunk_size + self.shfit_fsmn,
+ chunk_size + self.shift_fsmn,
decoder_att_look_back_factor,
)
return (
@@ -118,13 +118,13 @@
chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = (
self.get_chunk_size(ind)
)
- shfit_fsmn = self.shfit_fsmn
+ shift_fsmn = self.shift_fsmn
pad_right = chunk_size - stride - pad_left
chunk_num_batch = np.ceil(x_len / stride).astype(np.int32)
x_len_chunk = (
(chunk_num_batch - 1) * chunk_size_pad_shift
- + shfit_fsmn
+ + shift_fsmn
+ pad_left
+ 0
+ x_len
@@ -138,13 +138,13 @@
max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
- mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
+ mask_shift_chunk = np.zeros([0, num_units], dtype=dtype)
mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
mask_att_chunk_encoder = np.zeros([0, chunk_num * chunk_size_pad_shift], dtype=dtype)
for chunk_ids in range(chunk_num):
# x_mask add
- fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
+ fsmn_padding = np.zeros((shift_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
@@ -154,7 +154,7 @@
x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
# x_mask rm
- fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), dtype=dtype)
+ fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shift_fsmn), dtype=dtype)
padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left), dtype=dtype)
padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
@@ -170,13 +170,13 @@
x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
# fsmn_padding_mask
- pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
+ pad_shift_mask = np.zeros([shift_fsmn, num_units], dtype=dtype)
ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
- mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
- mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
+ mask_shift_chunk_cur = np.concatenate([pad_shift_mask, ones_1], axis=0)
+ mask_shift_chunk = np.concatenate([mask_shift_chunk, mask_shift_chunk_cur], axis=0)
# predictor mask
- zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
+ zeros_1 = np.zeros([shift_fsmn + pad_left, num_units_predictor], dtype=dtype)
ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
zeros_3 = np.zeros(
[chunk_size - stride - pad_left, num_units_predictor], dtype=dtype
@@ -188,13 +188,13 @@
)
# encoder att mask
- zeros_1_top = np.zeros([shfit_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
+ zeros_1_top = np.zeros([shift_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
zeros_2 = np.zeros([chunk_size, zeros_2_num * chunk_size_pad_shift], dtype=dtype)
encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
- zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+ zeros_2_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
ones_2_mid = np.ones([stride, stride], dtype=dtype)
zeros_2_bottom = np.zeros([chunk_size - stride, stride], dtype=dtype)
zeros_2_right = np.zeros([chunk_size, chunk_size - stride], dtype=dtype)
@@ -202,7 +202,7 @@
ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
- zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+ zeros_3_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
@@ -218,7 +218,7 @@
)
# decoder fsmn_shift_att_mask
- zeros_1 = np.zeros([shfit_fsmn, 1])
+ zeros_1 = np.zeros([shift_fsmn, 1])
ones_1 = np.ones([chunk_size, 1])
mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
mask_shift_att_chunk_decoder = np.concatenate(
@@ -229,7 +229,7 @@
self.x_len_chunk = x_len_chunk
self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
self.x_len = x_len
- self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
+ self.mask_shift_chunk = mask_shift_chunk[:x_len_chunk_max, :]
self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
@@ -238,7 +238,7 @@
self.x_len_chunk,
self.x_rm_mask,
self.x_len,
- self.mask_shfit_chunk,
+ self.mask_shift_chunk,
self.mask_chunk_predictor,
self.mask_att_chunk_encoder,
self.mask_shift_att_chunk_decoder,
@@ -309,7 +309,7 @@
x = torch.from_numpy(x).type(dtype).to(device)
return x
- def get_mask_shfit_chunk(
+ def get_mask_shift_chunk(
self, chunk_outs=None, device="cpu", batch_size=1, num_units=1, idx=4, dtype=torch.float32
):
with torch.no_grad():
--
Gitblit v1.9.1