游雁
2023-04-28 ea6903101b0f2da45770312b2cea2c78673a70fa
punc onnx
1个文件已修改
4 ■■■■ 已修改文件
funasr/export/models/CT_Transformer.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/CT_Transformer.py
@@ -53,7 +53,7 @@
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length))
        text_indexes = torch.randint(0, self.embed.num_embeddings, (2, length)).type(torch.int32)
        text_lengths = torch.tensor([length-20, length], dtype=torch.int32)
        return (text_indexes, text_lengths)
@@ -130,7 +130,7 @@
    def get_dummy_inputs(self):
        length = 120
        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length))
        text_indexes = torch.randint(0, self.embed.num_embeddings, (1, length)).type(torch.int32)
        text_lengths = torch.tensor([length], dtype=torch.int32)
        vad_mask = torch.ones(length, length, dtype=torch.float32)[None, None, :, :]
        sub_masks = torch.ones(length, length, dtype=torch.float32)