Yabin Li
2024-06-25 b7060884fa4b8b85f79462644a5c99062d223da0
runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -46,11 +46,11 @@
                    model_dir
                )
        model_file = os.path.join(model_dir, "model.torchscripts")
        model_file = os.path.join(model_dir, "model.torchscript")
        if quantize:
            model_file = os.path.join(model_dir, "model_quant.torchscripts")
            model_file = os.path.join(model_dir, "model_quant.torchscript")
        if not os.path.exists(model_file):
            print(".torchscripts does not exist, begin to export torchscripts")
            print(".torchscripts does not exist, begin to export torchscript")
            try:
                from funasr import AutoModel
            except:
@@ -268,11 +268,11 @@
                )
        if quantize:
            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscripts")
            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscripts")
            model_bb_file = os.path.join(model_dir, "model_bb_quant.torchscript")
            model_eb_file = os.path.join(model_dir, "model_eb_quant.torchscript")
        else:
            model_bb_file = os.path.join(model_dir, "model_bb.torchscripts")
            model_eb_file = os.path.join(model_dir, "model_eb.torchscripts")
            model_bb_file = os.path.join(model_dir, "model_bb.torchscript")
            model_eb_file = os.path.join(model_dir, "model_eb.torchscript")
        if not (os.path.exists(model_eb_file) and os.path.exists(model_bb_file)):
            print(".onnx does not exist, begin to export onnx")
@@ -282,7 +282,7 @@
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="torchscripts", quantize=quantize, **kwargs)
            model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
@@ -337,7 +337,7 @@
                        outputs = self.bb_infer(feats, feats_len, bias_embed)
                        am_scores, valid_token_lens = outputs[0], outputs[1]
                    else:
                        outputs = self.bb_infer_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                        outputs = self.bb_infer(feats.cuda(), feats_len.cuda(), bias_embed.cuda())
                        am_scores, valid_token_lens = outputs[0].cpu(), outputs[1].cpu()
            except:
                # logging.warning(traceback.format_exc())