| | |
| | | stride: tuple = (10,), |
| | | pad_left: tuple = (0,), |
| | | encoder_att_look_back_factor: tuple = (1,), |
| | | shfit_fsmn: int = 0, |
| | | shift_fsmn: int = 0, |
| | | decoder_att_look_back_factor: tuple = (1,), |
| | | ): |
| | | |
| | |
| | | encoder_att_look_back_factor, |
| | | decoder_att_look_back_factor, |
| | | ) |
| | | self.shfit_fsmn = shfit_fsmn |
| | | self.shift_fsmn = shift_fsmn |
| | | self.x_add_mask = None |
| | | self.x_rm_mask = None |
| | | self.x_len = None |
| | | self.mask_shfit_chunk = None |
| | | self.mask_shift_chunk = None |
| | | self.mask_chunk_predictor = None |
| | | self.mask_att_chunk_encoder = None |
| | | self.mask_shift_att_chunk_decoder = None |
| | |
| | | stride, |
| | | pad_left, |
| | | encoder_att_look_back_factor, |
| | | chunk_size + self.shfit_fsmn, |
| | | chunk_size + self.shift_fsmn, |
| | | decoder_att_look_back_factor, |
| | | ) |
| | | return ( |
| | |
| | | chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = ( |
| | | self.get_chunk_size(ind) |
| | | ) |
| | | shfit_fsmn = self.shfit_fsmn |
| | | shift_fsmn = self.shift_fsmn |
| | | pad_right = chunk_size - stride - pad_left |
| | | |
| | | chunk_num_batch = np.ceil(x_len / stride).astype(np.int32) |
| | | x_len_chunk = ( |
| | | (chunk_num_batch - 1) * chunk_size_pad_shift |
| | | + shfit_fsmn |
| | | + shift_fsmn |
| | | + pad_left |
| | | + 0 |
| | | + x_len |
| | |
| | | max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left) |
| | | x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype) |
| | | x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype) |
| | | mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype) |
| | | mask_shift_chunk = np.zeros([0, num_units], dtype=dtype) |
| | | mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype) |
| | | mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype) |
| | | mask_att_chunk_encoder = np.zeros([0, chunk_num * chunk_size_pad_shift], dtype=dtype) |
| | | for chunk_ids in range(chunk_num): |
| | | # x_mask add |
| | | fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype) |
| | | fsmn_padding = np.zeros((shift_fsmn, max_len_for_x_mask_tmp), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32)) |
| | | x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype) |
| | | x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype) |
| | |
| | | x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0) |
| | | |
| | | # x_mask rm |
| | | fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), dtype=dtype) |
| | | fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shift_fsmn), dtype=dtype) |
| | | padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left), dtype=dtype) |
| | | padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype) |
| | | x_mask_cur = np.diag(np.ones(stride, dtype=dtype)) |
| | |
| | | x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1) |
| | | |
| | | # fsmn_padding_mask |
| | | pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype) |
| | | pad_shift_mask = np.zeros([shift_fsmn, num_units], dtype=dtype) |
| | | ones_1 = np.ones([chunk_size, num_units], dtype=dtype) |
| | | mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0) |
| | | mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0) |
| | | mask_shift_chunk_cur = np.concatenate([pad_shift_mask, ones_1], axis=0) |
| | | mask_shift_chunk = np.concatenate([mask_shift_chunk, mask_shift_chunk_cur], axis=0) |
| | | |
| | | # predictor mask |
| | | zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype) |
| | | zeros_1 = np.zeros([shift_fsmn + pad_left, num_units_predictor], dtype=dtype) |
| | | ones_2 = np.ones([stride, num_units_predictor], dtype=dtype) |
| | | zeros_3 = np.zeros( |
| | | [chunk_size - stride - pad_left, num_units_predictor], dtype=dtype |
| | |
| | | ) |
| | | |
| | | # encoder att mask |
| | | zeros_1_top = np.zeros([shfit_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype) |
| | | zeros_1_top = np.zeros([shift_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype) |
| | | |
| | | zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0) |
| | | zeros_2 = np.zeros([chunk_size, zeros_2_num * chunk_size_pad_shift], dtype=dtype) |
| | | |
| | | encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0) |
| | | zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | zeros_2_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype) |
| | | ones_2_mid = np.ones([stride, stride], dtype=dtype) |
| | | zeros_2_bottom = np.zeros([chunk_size - stride, stride], dtype=dtype) |
| | | zeros_2_right = np.zeros([chunk_size, chunk_size - stride], dtype=dtype) |
| | |
| | | ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1) |
| | | ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num]) |
| | | |
| | | zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype) |
| | | zeros_3_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype) |
| | | ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype) |
| | | ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1) |
| | | |
| | |
| | | ) |
| | | |
| | | # decoder fsmn_shift_att_mask |
| | | zeros_1 = np.zeros([shfit_fsmn, 1]) |
| | | zeros_1 = np.zeros([shift_fsmn, 1]) |
| | | ones_1 = np.ones([chunk_size, 1]) |
| | | mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0) |
| | | mask_shift_att_chunk_decoder = np.concatenate( |
| | |
| | | self.x_len_chunk = x_len_chunk |
| | | self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max] |
| | | self.x_len = x_len |
| | | self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :] |
| | | self.mask_shift_chunk = mask_shift_chunk[:x_len_chunk_max, :] |
| | | self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :] |
| | | self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max] |
| | | self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :] |
| | |
| | | self.x_len_chunk, |
| | | self.x_rm_mask, |
| | | self.x_len, |
| | | self.mask_shfit_chunk, |
| | | self.mask_shift_chunk, |
| | | self.mask_chunk_predictor, |
| | | self.mask_att_chunk_encoder, |
| | | self.mask_shift_att_chunk_decoder, |
| | |
| | | x = torch.from_numpy(x).type(dtype).to(device) |
| | | return x |
| | | |
| | | def get_mask_shfit_chunk( |
| | | def get_mask_shift_chunk( |
| | | self, chunk_outs=None, device="cpu", batch_size=1, num_units=1, idx=4, dtype=torch.float32 |
| | | ): |
| | | with torch.no_grad(): |