游雁
2023-03-28 30433b58d68b93f85e61f304af32e6c7c8ef1f11
export
2个文件已添加
2 文件已重命名
26 ■■■■■ 已修改文件
funasr/export/test/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_onnx_vad.py 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/test_torchscripts.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/test/__init__.py
funasr/export/test/test_onnx.py
funasr/export/test/test_onnx_vad.py
New file
@@ -0,0 +1,26 @@
import onnxruntime
import numpy as np
if __name__ == '__main__':
    onnx_path = "/mnt/workspace/export/damo/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx"
    sess = onnxruntime.InferenceSession(onnx_path)
    input_name = [nd.name for nd in sess.get_inputs()]
    output_name = [nd.name for nd in sess.get_outputs()]
    def _get_feed_dict(feats_length):
        return {'speech': np.random.rand(1, feats_length, 400).astype(np.float32),
                'in_cache0': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache1': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache2': np.random.rand(1, 128, 19, 1).astype(np.float32),
                'in_cache3': np.random.rand(1, 128, 19, 1).astype(np.float32),
                }
    def _run(feed_dict):
        output = sess.run(output_name, input_feed=feed_dict)
        for name, value in zip(output_name, output):
            print('{}: {}'.format(name, value.shape))
    _run(_get_feed_dict(100))
    _run(_get_feed_dict(200))
funasr/export/test/test_torchscripts.py