游雁
2023-03-30 795b6e04864d7a8ea1cb8e41a412152651c47eed
export
2个文件已修改
17 ■■■■ 已修改文件
funasr/export/models/encoder/sanm_encoder.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/vad_realtime_transformer.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/encoder/sanm_encoder.py
@@ -151,12 +151,7 @@
    
    def prepare_mask(self, mask, sub_masks):
        mask_3d_btd = mask[:, :, None]
        # sub_masks = subsequent_mask(mask.size(-1)).type(torch.float32)
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - sub_masks[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - sub_masks[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        mask_4d_bhlt = (1 - sub_masks) * -10000.0
        
        return mask_3d_btd, mask_4d_bhlt
    
funasr/export/models/vad_realtime_transformer.py
@@ -63,11 +63,11 @@
        text_lengths = torch.tensor([length], dtype=torch.int32)
        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
        sub_masks = torch.ones(length, length, dtype=torch.float32)
        sub_masks = torch.tril(sub_masks)
        return (text_indexes, text_lengths, vad_mask, sub_masks)
        sub_masks = torch.tril(sub_masks).type(torch.float32)
        return (text_indexes, text_lengths, vad_mask, sub_masks[None, None, :, :])
    def get_input_names(self):
        return ['input', 'text_lengths', 'vad_mask']
        return ['input', 'text_lengths', 'vad_mask', 'sub_masks']
    def get_output_names(self):
        return ['logits']
@@ -81,6 +81,10 @@
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'sub_masks': {
                2: 'feats_length1',
                3: 'feats_length2'
            },
            'logits': {
                1: 'logits_length'
            },