Rin Arakaki
2024-12-24 1367973f9818d8e15c7bf52ad6ffba4ddb6ac2b2
funasr/models/scama/chunk_utilis.py
@@ -21,7 +21,7 @@
        stride: tuple = (10,),
        pad_left: tuple = (0,),
        encoder_att_look_back_factor: tuple = (1,),
        shfit_fsmn: int = 0,
        shift_fsmn: int = 0,
        decoder_att_look_back_factor: tuple = (1,),
    ):
@@ -45,11 +45,11 @@
            encoder_att_look_back_factor,
            decoder_att_look_back_factor,
        )
        self.shfit_fsmn = shfit_fsmn
        self.shift_fsmn = shift_fsmn
        self.x_add_mask = None
        self.x_rm_mask = None
        self.x_len = None
        self.mask_shfit_chunk = None
        self.mask_shift_chunk = None
        self.mask_chunk_predictor = None
        self.mask_att_chunk_encoder = None
        self.mask_shift_att_chunk_decoder = None
@@ -88,7 +88,7 @@
            stride,
            pad_left,
            encoder_att_look_back_factor,
            chunk_size + self.shfit_fsmn,
            chunk_size + self.shift_fsmn,
            decoder_att_look_back_factor,
        )
        return (
@@ -118,13 +118,13 @@
            chunk_size, stride, pad_left, encoder_att_look_back_factor, chunk_size_pad_shift = (
                self.get_chunk_size(ind)
            )
            shfit_fsmn = self.shfit_fsmn
            shift_fsmn = self.shift_fsmn
            pad_right = chunk_size - stride - pad_left
            chunk_num_batch = np.ceil(x_len / stride).astype(np.int32)
            x_len_chunk = (
                (chunk_num_batch - 1) * chunk_size_pad_shift
                + shfit_fsmn
                + shift_fsmn
                + pad_left
                + 0
                + x_len
@@ -138,13 +138,13 @@
            max_len_for_x_mask_tmp = max(chunk_size, x_len_max + pad_left)
            x_add_mask = np.zeros([0, max_len_for_x_mask_tmp], dtype=dtype)
            x_rm_mask = np.zeros([max_len_for_x_mask_tmp, 0], dtype=dtype)
            mask_shfit_chunk = np.zeros([0, num_units], dtype=dtype)
            mask_shift_chunk = np.zeros([0, num_units], dtype=dtype)
            mask_chunk_predictor = np.zeros([0, num_units_predictor], dtype=dtype)
            mask_shift_att_chunk_decoder = np.zeros([0, 1], dtype=dtype)
            mask_att_chunk_encoder = np.zeros([0, chunk_num * chunk_size_pad_shift], dtype=dtype)
            for chunk_ids in range(chunk_num):
                # x_mask add
                fsmn_padding = np.zeros((shfit_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
                fsmn_padding = np.zeros((shift_fsmn, max_len_for_x_mask_tmp), dtype=dtype)
                x_mask_cur = np.diag(np.ones(chunk_size, dtype=np.float32))
                x_mask_pad_left = np.zeros((chunk_size, chunk_ids * stride), dtype=dtype)
                x_mask_pad_right = np.zeros((chunk_size, max_len_for_x_mask_tmp), dtype=dtype)
@@ -154,7 +154,7 @@
                x_add_mask = np.concatenate([x_add_mask, x_add_mask_fsmn], axis=0)
                # x_mask rm
                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shfit_fsmn), dtype=dtype)
                fsmn_padding = np.zeros((max_len_for_x_mask_tmp, shift_fsmn), dtype=dtype)
                padding_mask_left = np.zeros((max_len_for_x_mask_tmp, pad_left), dtype=dtype)
                padding_mask_right = np.zeros((max_len_for_x_mask_tmp, pad_right), dtype=dtype)
                x_mask_cur = np.diag(np.ones(stride, dtype=dtype))
@@ -170,13 +170,13 @@
                x_rm_mask = np.concatenate([x_rm_mask, x_rm_mask_cur_fsmn], axis=1)
                # fsmn_padding_mask
                pad_shfit_mask = np.zeros([shfit_fsmn, num_units], dtype=dtype)
                pad_shift_mask = np.zeros([shift_fsmn, num_units], dtype=dtype)
                ones_1 = np.ones([chunk_size, num_units], dtype=dtype)
                mask_shfit_chunk_cur = np.concatenate([pad_shfit_mask, ones_1], axis=0)
                mask_shfit_chunk = np.concatenate([mask_shfit_chunk, mask_shfit_chunk_cur], axis=0)
                mask_shift_chunk_cur = np.concatenate([pad_shift_mask, ones_1], axis=0)
                mask_shift_chunk = np.concatenate([mask_shift_chunk, mask_shift_chunk_cur], axis=0)
                # predictor mask
                zeros_1 = np.zeros([shfit_fsmn + pad_left, num_units_predictor], dtype=dtype)
                zeros_1 = np.zeros([shift_fsmn + pad_left, num_units_predictor], dtype=dtype)
                ones_2 = np.ones([stride, num_units_predictor], dtype=dtype)
                zeros_3 = np.zeros(
                    [chunk_size - stride - pad_left, num_units_predictor], dtype=dtype
@@ -188,13 +188,13 @@
                )
                # encoder att mask
                zeros_1_top = np.zeros([shfit_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
                zeros_1_top = np.zeros([shift_fsmn, chunk_num * chunk_size_pad_shift], dtype=dtype)
                zeros_2_num = max(chunk_ids - encoder_att_look_back_factor, 0)
                zeros_2 = np.zeros([chunk_size, zeros_2_num * chunk_size_pad_shift], dtype=dtype)
                encoder_att_look_back_num = max(chunk_ids - zeros_2_num, 0)
                zeros_2_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                zeros_2_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
                ones_2_mid = np.ones([stride, stride], dtype=dtype)
                zeros_2_bottom = np.zeros([chunk_size - stride, stride], dtype=dtype)
                zeros_2_right = np.zeros([chunk_size, chunk_size - stride], dtype=dtype)
@@ -202,7 +202,7 @@
                ones_2 = np.concatenate([zeros_2_left, ones_2, zeros_2_right], axis=1)
                ones_2 = np.tile(ones_2, [1, encoder_att_look_back_num])
                zeros_3_left = np.zeros([chunk_size, shfit_fsmn], dtype=dtype)
                zeros_3_left = np.zeros([chunk_size, shift_fsmn], dtype=dtype)
                ones_3_right = np.ones([chunk_size, chunk_size], dtype=dtype)
                ones_3 = np.concatenate([zeros_3_left, ones_3_right], axis=1)
@@ -218,7 +218,7 @@
                )
                # decoder fsmn_shift_att_mask
                zeros_1 = np.zeros([shfit_fsmn, 1])
                zeros_1 = np.zeros([shift_fsmn, 1])
                ones_1 = np.ones([chunk_size, 1])
                mask_shift_att_chunk_decoder_cur = np.concatenate([zeros_1, ones_1], axis=0)
                mask_shift_att_chunk_decoder = np.concatenate(
@@ -229,7 +229,7 @@
            self.x_len_chunk = x_len_chunk
            self.x_rm_mask = x_rm_mask[:x_len_max, :x_len_chunk_max]
            self.x_len = x_len
            self.mask_shfit_chunk = mask_shfit_chunk[:x_len_chunk_max, :]
            self.mask_shift_chunk = mask_shift_chunk[:x_len_chunk_max, :]
            self.mask_chunk_predictor = mask_chunk_predictor[:x_len_chunk_max, :]
            self.mask_att_chunk_encoder = mask_att_chunk_encoder[:x_len_chunk_max, :x_len_chunk_max]
            self.mask_shift_att_chunk_decoder = mask_shift_att_chunk_decoder[:x_len_chunk_max, :]
@@ -238,7 +238,7 @@
                self.x_len_chunk,
                self.x_rm_mask,
                self.x_len,
                self.mask_shfit_chunk,
                self.mask_shift_chunk,
                self.mask_chunk_predictor,
                self.mask_att_chunk_encoder,
                self.mask_shift_att_chunk_decoder,
@@ -309,7 +309,7 @@
            x = torch.from_numpy(x).type(dtype).to(device)
        return x
    def get_mask_shfit_chunk(
    def get_mask_shift_chunk(
        self, chunk_outs=None, device="cpu", batch_size=1, num_units=1, idx=4, dtype=torch.float32
    ):
        with torch.no_grad():