From ddbc8b5eded1fff6084001d160d46b532020ecb7 Mon Sep 17 00:00:00 2001
From: Shi Xian <40013335+R1ckShi@users.noreply.github.com>
Date: 星期一, 15 一月 2024 20:36:20 +0800
Subject: [PATCH] Merge pull request #1247 from alibaba-damo-academy/funasr1.0

---
 funasr/models/scama/chunk_utilis.py |  644 +++++++++++++++++++++++++++++-----------------------------
 1 files changed, 321 insertions(+), 323 deletions(-)

diff --git a/funasr/models/scama/chunk_utilis.py b/funasr/models/scama/chunk_utilis.py
index e90ab62..245d282 100644
--- a/funasr/models/scama/chunk_utilis.py
+++ b/funasr/models/scama/chunk_utilis.py
@@ -1,289 +1,287 @@
-
+import math
 import torch
 import numpy as np
-import math
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-import logging
 import torch.nn.functional as F
-from funasr.models.scama.utils import sequence_mask
 
+from funasr.models.scama.utils import sequence_mask
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
 class overlap_chunk():
-	"""
-	Author: Speech Lab of DAMO Academy, Alibaba Group
-	San-m: Memory equipped self-attention for end-to-end speech recognition
-	https://arxiv.org/abs/2006.01713
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    San-m: Memory equipped self-attention for end-to-end speech recognition
+    https://arxiv.org/abs/2006.01713
 
-	"""
-	def __init__(self,
-		chunk_size: tuple = (16,),
-		stride: tuple = (10,),
-		pad_left: tuple = (0,),
-		encoder_att_look_back_factor: tuple = (1,),
+    """
+    def __init__(self,
+        chunk_size: tuple = (16,),
+        stride: tuple = (10,),
+        pad_left: tuple = (0,),
+        encoder_att_look_back_factor: tuple = (1,),
         shfit_fsmn: int = 0,
         decoder_att_look_back_factor: tuple = (1,),
-	):
+    ):
 
-		pad_left = self.check_chunk_size_args(chunk_size, pad_left)
-		encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
-		decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
-		self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
-			= chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
-		self.shfit_fsmn = shfit_fsmn
-		self.x_add_mask = None
-		self.x_rm_mask = None
-		self.x_len = None
-		self.mask_shfit_chunk = None
-		self.mask_chunk_predictor = None
-		self.mask_att_chunk_encoder = None
-		self.mask_shift_att_chunk_decoder = None
-		self.chunk_outs = None
-		self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
-			= None, None, None, None, None
+        pad_left = self.check_chunk_size_args(chunk_size, pad_left)
+        encoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, encoder_att_look_back_factor)
+        decoder_att_look_back_factor = self.check_chunk_size_args(chunk_size, decoder_att_look_back_factor)
+        self.chunk_size, self.stride, self.pad_left, self.encoder_att_look_back_factor, self.decoder_att_look_back_factor \
+            = chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor
+        self.shfit_fsmn = shfit_fsmn
+        self.x_add_mask = None
+        self.x_rm_mask = None
+        self.x_len = None
+        self.mask_shfit_chunk = None
+        self.mask_chunk_predictor = None
+        self.mask_att_chunk_encoder = None
+        self.mask_shift_att_chunk_decoder = None
+        self.chunk_outs = None
+        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur \
+            = None, None, None, None, None
 
-	def check_chunk_size_args(self, chunk_size, x):
-		if len(x) < len(chunk_size):
-			x = [x[0] for i in chunk_size]
-		return x
+    def check_chunk_size_args(self, chunk_size, x):
+        if len(x) < len(chunk_size):
+            x = [x[0] for i in chunk_size]
+        return x
 
-	def get_chunk_size(self,
-		ind: int = 0
-	):
-		# with torch.no_grad:
-		chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
-			self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
-		self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
-			= chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
-		return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
+    def get_chunk_size(self,
+        ind: int = 0
+    ):
+        # with torch.no_grad:
+        chunk_size, stride, pad_left, encoder_att_look_back_factor, decoder_att_look_back_factor = \
+            self.chunk_size[ind], self.stride[ind], self.pad_left[ind], self.encoder_att_look_back_factor[ind], self.decoder_att_look_back_factor[ind]
+        self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur, self.decoder_att_look_back_factor_cur \
+            = chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size + self.shfit_fsmn, decoder_att_look_back_factor
+        return self.chunk_size_cur, self.stride_cur, self.pad_left_cur, self.encoder_att_look_back_factor_cur, self.chunk_size_pad_shift_cur
 
-	def random_choice(self, training=True, decoding_ind=None):
-		chunk_num = len(self.chunk_size)
-		ind = 0
-		if training and chunk_num > 1:
-			ind = torch.randint(0, chunk_num, ()).cpu().item()
-		if not training and decoding_ind is not None:
-			ind = int(decoding_ind)
+    def random_choice(self, training=True, decoding_ind=None):
+        chunk_num = len(self.chunk_size)
+        ind = 0
+        if training and chunk_num > 1:
+            ind = torch.randint(0, chunk_num, ()).cpu().item()
+        if not training and decoding_ind is not None:
+            ind = int(decoding_ind)
 
-		return ind
+        return ind
 
 
 
 
-	def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
+    def gen_chunk_mask(self, x_len, ind=0, num_units=1, num_units_predictor=1):
 
-		with torch.no_grad():
-			x_len = x_len.cpu().numpy()
-			x_len_max = x_len.max()
+        with torch.no_grad():
+            x_len = x_len.cpu().numpy()
+            x_len_max = x_len.max()
 
-			chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
-			shfit_fsmn = self.shfit_fsmn
-			pad_right = chunk_size - stride - pad_left
+            chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = self.get_chunk_size(ind)
+            shfit_fsmn = self.shfit_fsmn
+            pad_right = chunk_size - stride - pad_left
 
-			chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
-			x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
-			x_len_chunk = x_len_chunk.astype(x_len.dtype)
-			x_len_chunk_max = x_len_chunk.max()
+            chunk_num_batch = np.ceil(x_len/stride).astype(np.int32)
+            x_len_chunk = (chunk_num_batch-1) * chunk_size_pad_shift + shfit_fsmn + pad_left + 0 + x_len - (chunk_num_batch-1) * stride
+            x_len_chunk = x_len_chunk.astype(x_len.dtype)
+            x_len_chunk_max = x_len_chunk.max()
 
-			chunk_num = int(math.ceil(x_len_max/stride))
-			dtype = np.int32
-			max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
-			x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
-			x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
-			mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
-			mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
-			mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
-			mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
-			for chunk_ids in range(chunk_num):
-				# x_mask add
-				fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
-				x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
-				x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
-				x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
-				x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
-				x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
-				x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
-				x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
+            chunk_num = int(math.ceil(x_len_max/stride))
+            dtype = np.int32
+            max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
+            x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
+            x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
+            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
+            mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
+            mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
+            mask_att_chunk_encoder = np.zeros([0, chunk_num*chunk_size_pad_shift], dtype=dtype)
+            for chunk_ids in range(chunk_num):
+                # x_mask add
+                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
+                x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
+                x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
+                x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
+                x_cur_pad = np.concatenate([x_mask_pad_left, x_mask_cur, x_mask_pad_right], axis=1)
+                x_cur_pad = x_cur_pad[:chunk_size, :max_len_for_x_mask_tmp]
+                x_add_mask_fsmn = np.concatenate([fsmn_padding, x_cur_pad], axis=0)
+                x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
 
-				# x_mask rm
-				fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
-				padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
-				padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
-				x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
-				x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
-				x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
-				x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
-				x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
-				x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
-				x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
+                # x_mask rm
+                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn),dtype=dtype)
+                padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left),dtype=dtype)
+                padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
+                x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
+                x_mask_cur_pad_top = np.zeros((chunk_ids*stride, stride), dtype=dtype)
+                x_mask_cur_pad_bottom = np.zeros((max_len_for_x_mask_tmp, stride), dtype=dtype)
+                x_rm_mask_cur = np.concatenate([x_mask_cur_pad_top, x_mask_cur, x_mask_cur_pad_bottom], axis=0)
+                x_rm_mask_cur = x_rm_mask_cur[:max_len_for_x_mask_tmp, :stride]
+                x_rm_mask_cur_fsmn = np.concatenate([fsmn_padding, padding_mask_left, x_rm_mask_cur, padding_mask_right], axis=1)
+                x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
 
-				# fsmn_padding_mask
-				pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
-				ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
-				mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
-				mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
+                # fsmn_padding_mask
+                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
+                ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
+                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
+                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
 
-				# predictor mask
-				zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
-				ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
-				zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
-				ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
-				mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
-				mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
+                # predictor mask
+                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
+                ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
+                zeros_3 = np.zeros([chunk_size - stride - pad_left, num_units_predictor], dtype=dtype)
+                ones_zeros = np.concatenate([ones_2, zeros_3], axis=0)
+                mask_chunk_predictor_cur = np.concatenate([zeros_1, ones_zeros], axis=0)
+                mask_chunk_predictor = np.concatenate([mask_chunk_predictor, mask_chunk_predictor_cur], axis=0)
 
-				# encoder att mask
-				zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
+                # encoder att mask
+                zeros_1_top = np.zeros([shfit_fsmn, chunk_num*chunk_size_pad_shift], dtype=dtype)
 
-				zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
-				zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
+                zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
+                zeros_2 = np.zeros([chunk_size, zeros_2_num*chunk_size_pad_shift], dtype=dtype)
 
-				encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
-				zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
-				ones_2_mid = np.ones([stride, stride], dtype=dtype)
-				zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
-				zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
-				ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
-				ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
-				ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
+                encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
+                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                ones_2_mid = np.ones([stride, stride], dtype=dtype)
+                zeros_2_bottom = np.zeros([chunk_size-stride, stride], dtype=dtype)
+                zeros_2_right = np.zeros([chunk_size, chunk_size-stride], dtype=dtype)
+                ones_2 = np.concatenate([ones_2_mid, zeros_2_bottom], axis=0)
+                ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
+                ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
 
-				zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
-				ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
-				ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
+                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
+                ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
+                ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
 
-				zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
-				zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
+                zeros_remain_num = max(chunk_num - 1 - chunk_ids, 0)
+                zeros_remain = np.zeros([chunk_size, zeros_remain_num*chunk_size_pad_shift], dtype=dtype)
 
-				ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
-				mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
-				mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
+                ones2_bottom = np.concatenate([zeros_2, ones_2, ones_3, zeros_remain], axis=1)
+                mask_att_chunk_encoder_cur = np.concatenate([zeros_1_top, ones2_bottom], axis=0)
+                mask_att_chunk_encoder = np.concatenate([mask_att_chunk_encoder, mask_att_chunk_encoder_cur], axis=0)
 
 
-				# decoder fsmn_shift_att_mask
-				zeros_1 = np.zeros([shfit_fsmn, 1])
-				ones_1 = np.ones([chunk_size, 1])
-				mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
-				mask_shift_att_chunk_decoder = np.concatenate(
-					[mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
+                # decoder fsmn_shift_att_mask
+                zeros_1 = np.zeros([shfit_fsmn, 1])
+                ones_1 = np.ones([chunk_size, 1])
+                mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
+                mask_shift_att_chunk_decoder = np.concatenate(
+                    [mask_shift_att_chunk_decoder, mask_shift_att_chunk_decoder_cur], axis=0)
 
-			self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
-			self.x_len_chunk = x_len_chunk
-			self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
-			self.x_len = x_len
-			self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
-			self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
-			self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
-			self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
-			self.chunk_outs = (self.x_add_mask,
-		        self.x_len_chunk,
-		        self.x_rm_mask,
-		        self.x_len,
-		        self.mask_shfit_chunk,
-		        self.mask_chunk_predictor,
-		        self.mask_att_chunk_encoder,
-		        self.mask_shift_att_chunk_decoder)
+            self.x_add_mask = x_add_mask[:x_len_chunk_max, :x_len_max+pad_left]
+            self.x_len_chunk = x_len_chunk
+            self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
+            self.x_len = x_len
+            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
+            self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
+            self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
+            self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
+            self.chunk_outs = (self.x_add_mask,
+                self.x_len_chunk,
+                self.x_rm_mask,
+                self.x_len,
+                self.mask_shfit_chunk,
+                self.mask_chunk_predictor,
+                self.mask_att_chunk_encoder,
+                self.mask_shift_att_chunk_decoder)
 
-		return self.chunk_outs
+        return self.chunk_outs
 
 
-	def split_chunk(self, x, x_len, chunk_outs):
-		"""
-		:param x: (b, t, d)
-		:param x_length: (b)
-		:param ind: int
-		:return:
-		"""
-		x = x[:, :x_len.max(), :]
-		b, t, d = x.size()
-		x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
-			x.device)
-		x *= x_len_mask[:, :, None]
+    def split_chunk(self, x, x_len, chunk_outs):
+        """
+        :param x: (b, t, d)
+        :param x_length: (b)
+        :param ind: int
+        :return:
+        """
+        x = x[:, :x_len.max(), :]
+        b, t, d = x.size()
+        x_len_mask = (~make_pad_mask(x_len, maxlen=t)).to(
+            x.device)
+        x *= x_len_mask[:, :, None]
 
-		x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
-		x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
-		pad = (0, 0, self.pad_left_cur, 0)
-		x = F.pad(x, pad, "constant", 0.0)
-		b, t, d = x.size()
-		x = torch.transpose(x, 1, 0)
-		x = torch.reshape(x, [t, -1])
-		x_chunk = torch.mm(x_add_mask, x)
-		x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
+        x_add_mask = self.get_x_add_mask(chunk_outs, x.device, dtype=x.dtype)
+        x_len_chunk = self.get_x_len_chunk(chunk_outs, x_len.device, dtype=x_len.dtype)
+        pad = (0, 0, self.pad_left_cur, 0)
+        x = F.pad(x, pad, "constant", 0.0)
+        b, t, d = x.size()
+        x = torch.transpose(x, 1, 0)
+        x = torch.reshape(x, [t, -1])
+        x_chunk = torch.mm(x_add_mask, x)
+        x_chunk = torch.reshape(x_chunk, [-1, b, d]).transpose(1, 0)
 
-		return x_chunk, x_len_chunk
+        return x_chunk, x_len_chunk
 
-	def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
-		x_chunk = x_chunk[:, :x_len_chunk.max(), :]
-		b, t, d = x_chunk.size()
-		x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
-			x_chunk.device)
-		x_chunk *= x_len_chunk_mask[:, :, None]
+    def remove_chunk(self, x_chunk, x_len_chunk, chunk_outs):
+        x_chunk = x_chunk[:, :x_len_chunk.max(), :]
+        b, t, d = x_chunk.size()
+        x_len_chunk_mask = (~make_pad_mask(x_len_chunk, maxlen=t)).to(
+            x_chunk.device)
+        x_chunk *= x_len_chunk_mask[:, :, None]
 
-		x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
-		x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
-		x_chunk = torch.transpose(x_chunk, 1, 0)
-		x_chunk = torch.reshape(x_chunk, [t, -1])
-		x = torch.mm(x_rm_mask, x_chunk)
-		x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
+        x_rm_mask = self.get_x_rm_mask(chunk_outs, x_chunk.device, dtype=x_chunk.dtype)
+        x_len = self.get_x_len(chunk_outs, x_len_chunk.device, dtype=x_len_chunk.dtype)
+        x_chunk = torch.transpose(x_chunk, 1, 0)
+        x_chunk = torch.reshape(x_chunk, [t, -1])
+        x = torch.mm(x_rm_mask, x_chunk)
+        x = torch.reshape(x, [-1, b, d]).transpose(1, 0)
 
-		return x, x_len
+        return x, x_len
 
-	def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_add_mask(self, chunk_outs=None, device='cpu', idx=0, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_len_chunk(self, chunk_outs=None, device='cpu', idx=1, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
-	def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_rm_mask(self, chunk_outs=None, device='cpu', idx=2, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_x_len(self, chunk_outs=None, device='cpu', idx=3, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
-	def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_shfit_chunk(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=4, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_chunk_predictor(self, chunk_outs=None, device='cpu', batch_size=1, num_units=1, idx=5, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, num_units])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_att_chunk_encoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=6, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, :, :, ], [batch_size, 1, 1])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
-	def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
-		with torch.no_grad():
-			x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
-			x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
-			x = torch.from_numpy(x).type(dtype).to(device)
-		return x
+    def get_mask_shift_att_chunk_decoder(self, chunk_outs=None, device='cpu', batch_size=1, idx=7, dtype=torch.float32):
+        with torch.no_grad():
+            x = chunk_outs[idx] if chunk_outs is not None else  self.chunk_outs[idx]
+            x = np.tile(x[None, None, :, 0], [batch_size, 1, 1])
+            x = torch.from_numpy(x).type(dtype).to(device)
+        return x
 
 
 
 def build_scama_mask_for_cross_attention_decoder(
-							  predictor_alignments: torch.Tensor,
+                              predictor_alignments: torch.Tensor,
                               encoder_sequence_length: torch.Tensor,
                               chunk_size: int = 5,
                               encoder_chunk_size: int = 5,
@@ -291,100 +289,100 @@
                               attention_chunk_size: int = 1,
                               attention_chunk_type: str = 'chunk',
                               step=None,
-							  predictor_mask_chunk_hopping: torch.Tensor = None,
-							  decoder_att_look_back_factor: int = 1,
-							  mask_shift_att_chunk_decoder: torch.Tensor = None,
-							  target_length: torch.Tensor = None,
-							  is_training=True,
+                              predictor_mask_chunk_hopping: torch.Tensor = None,
+                              decoder_att_look_back_factor: int = 1,
+                              mask_shift_att_chunk_decoder: torch.Tensor = None,
+                              target_length: torch.Tensor = None,
+                              is_training=True,
                               dtype: torch.dtype = torch.float32):
-	with torch.no_grad():
-		device = predictor_alignments.device
-		batch_size, chunk_num = predictor_alignments.size()
-		maximum_encoder_length = encoder_sequence_length.max().item()
-		int_type = predictor_alignments.dtype
-		if not is_training:
-			target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
-		maximum_target_length = target_length.max()
-		predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
-		predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
-	
-	
-		index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
-		index = torch.cumsum(index, dim=1)
-		index = index[:, :, None].repeat(1, 1, chunk_num)
-	
-		index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
-		index_div_bool_zeros = index_div == 0
-		index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
-	
-		index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
-	
-		index_div_bool_zeros_count *= chunk_size
-		index_div_bool_zeros_count += attention_chunk_center_bias
-		index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
-		index_div_bool_zeros_count_ori = index_div_bool_zeros_count
-	
-		index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
-		max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
-	
-		mask_flip, mask_flip2 = None, None
-		if attention_chunk_size is not None:
-			index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
-			index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
-			index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
-			mask_flip = 1 - index_div_bool_zeros_count_beg_mask
-			attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
-			index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
-	
-			index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
-			index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
-			mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
-	
-		mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
-	
-		if predictor_mask_chunk_hopping is not None:
-				b, k, t = mask.size()
-				predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
-	
-				mask_mask_flip = mask
-				if mask_flip is not None:
-						mask_mask_flip = mask_flip * mask
-	
-				def _fn():
-						mask_sliced = mask[:b, :k, encoder_chunk_size:t]
-						zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
-						mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
-						_, _, tt = predictor_mask_chunk_hopping.size()
-						pad_right_p = max_len_chunk - tt
-						predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
-						masked = mask_sliced * predictor_mask_chunk_hopping_pad
-	
-						mask_true = mask_mask_flip + masked
-						return mask_true
-	
-				mask = _fn() if t > chunk_size else mask_mask_flip
-	
-	
-	
-		if mask_flip2 is not None:
-			mask *= mask_flip2
-	
-		mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
-		mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
-	
-	
-	
-		mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
-		mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
-	
-	
-	
-	
-		if attention_chunk_type == 'full':
-			mask = torch.ones_like(mask).to(device)
-		if mask_shift_att_chunk_decoder is not None:
-			mask = mask * mask_shift_att_chunk_decoder
-		mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
+    with torch.no_grad():
+        device = predictor_alignments.device
+        batch_size, chunk_num = predictor_alignments.size()
+        maximum_encoder_length = encoder_sequence_length.max().item()
+        int_type = predictor_alignments.dtype
+        if not is_training:
+            target_length = predictor_alignments.sum(dim=-1).type(encoder_sequence_length.dtype)
+        maximum_target_length = target_length.max()
+        predictor_alignments_cumsum = torch.cumsum(predictor_alignments, dim=1)
+        predictor_alignments_cumsum = predictor_alignments_cumsum[:, None, :].repeat(1, maximum_target_length, 1)
+    
+    
+        index = torch.ones([batch_size, maximum_target_length], dtype=int_type).to(device)
+        index = torch.cumsum(index, dim=1)
+        index = index[:, :, None].repeat(1, 1, chunk_num)
+    
+        index_div = torch.floor(torch.divide(predictor_alignments_cumsum, index)).type(int_type)
+        index_div_bool_zeros = index_div == 0
+        index_div_bool_zeros_count = torch.sum(index_div_bool_zeros.type(int_type), dim=-1) + 1
+    
+        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count, min=1, max=chunk_num)
+    
+        index_div_bool_zeros_count *= chunk_size
+        index_div_bool_zeros_count += attention_chunk_center_bias
+        index_div_bool_zeros_count = torch.clip(index_div_bool_zeros_count-1, min=0, max=maximum_encoder_length)
+        index_div_bool_zeros_count_ori = index_div_bool_zeros_count
+    
+        index_div_bool_zeros_count = (torch.floor(index_div_bool_zeros_count / encoder_chunk_size)+1)*encoder_chunk_size
+        max_len_chunk = math.ceil(maximum_encoder_length / encoder_chunk_size) * encoder_chunk_size
+    
+        mask_flip, mask_flip2 = None, None
+        if attention_chunk_size is not None:
+            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size
+            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+            mask_flip = 1 - index_div_bool_zeros_count_beg_mask
+            attention_chunk_size2 = attention_chunk_size * (decoder_att_look_back_factor+1)
+            index_div_bool_zeros_count_beg = index_div_bool_zeros_count - attention_chunk_size2
+    
+            index_div_bool_zeros_count_beg = torch.clip(index_div_bool_zeros_count_beg, 0, max_len_chunk)
+            index_div_bool_zeros_count_beg_mask = sequence_mask(index_div_bool_zeros_count_beg, maxlen=max_len_chunk, dtype=int_type, device=device)
+            mask_flip2 = 1 - index_div_bool_zeros_count_beg_mask
+    
+        mask = sequence_mask(index_div_bool_zeros_count, maxlen=max_len_chunk, dtype=dtype, device=device)
+    
+        if predictor_mask_chunk_hopping is not None:
+                b, k, t = mask.size()
+                predictor_mask_chunk_hopping = predictor_mask_chunk_hopping[:, None, :, 0].repeat(1, k, 1)
+    
+                mask_mask_flip = mask
+                if mask_flip is not None:
+                        mask_mask_flip = mask_flip * mask
+    
+                def _fn():
+                        mask_sliced = mask[:b, :k, encoder_chunk_size:t]
+                        zero_pad_right = torch.zeros([b, k, encoder_chunk_size], dtype=mask_sliced.dtype).to(device)
+                        mask_sliced = torch.cat([mask_sliced, zero_pad_right], dim=2)
+                        _, _, tt = predictor_mask_chunk_hopping.size()
+                        pad_right_p = max_len_chunk - tt
+                        predictor_mask_chunk_hopping_pad = torch.nn.functional.pad(predictor_mask_chunk_hopping, [0, pad_right_p], "constant", 0)
+                        masked = mask_sliced * predictor_mask_chunk_hopping_pad
+    
+                        mask_true = mask_mask_flip + masked
+                        return mask_true
+    
+                mask = _fn() if t > chunk_size else mask_mask_flip
+    
+    
+    
+        if mask_flip2 is not None:
+            mask *= mask_flip2
+    
+        mask_target = sequence_mask(target_length, maxlen=maximum_target_length, dtype=mask.dtype, device=device)
+        mask = mask[:, :maximum_target_length, :] * mask_target[:, :, None]
+    
+    
+    
+        mask_len = sequence_mask(encoder_sequence_length, maxlen=maximum_encoder_length, dtype=mask.dtype, device=device)
+        mask = mask[:, :, :maximum_encoder_length] * mask_len[:, None, :]
+    
+    
+    
+    
+        if attention_chunk_type == 'full':
+            mask = torch.ones_like(mask).to(device)
+        if mask_shift_att_chunk_decoder is not None:
+            mask = mask * mask_shift_att_chunk_decoder
+        mask = mask[:, :maximum_target_length, :maximum_encoder_length].type(dtype).to(device)
 
-	return mask
+    return mask
 

--
Gitblit v1.9.1