mengzhe.cmz
2023-05-17 ce7914034dd8496409af3b6b368218be1c71d3a1
increase vad realtime punc
3个文件已修改
9 ■■■■■ 已修改文件
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx_punc_vadrealtime.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/demo.py
@@ -9,7 +9,7 @@
inference_pipeline = pipeline(
    task=Tasks.punctuation,
    model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727',
    output_dir="./tmp/"
    model_revision = 'v1.0.2'
)
##################text二进制数据#####################
funasr/export/test/test_onnx_punc_vadrealtime.py
@@ -12,7 +12,7 @@
        return {'inputs': np.ones((1, text_length), dtype=np.int64),
                'text_lengths': np.array([text_length,], dtype=np.int32),
                'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
                'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
                'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32),
                }
    def _run(feed_dict):
funasr/runtime/python/onnxruntime/funasr_onnx/punc_bin.py
@@ -186,11 +186,12 @@
            mini_sentence = cache_sent + mini_sentence
            mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32')
            text_length = len(mini_sentence_id)
            vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32)
            data = {
                "input": mini_sentence_id[None,:],
                "text_lengths": np.array([text_length], dtype='int32'),
                "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32),
                "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32)
                "vad_mask": vad_mask
                "sub_masks": vad_mask
            }
            try:
                outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"])