zhifu gao
2024-03-11 0d9384c8c0161259192cc3d676ca0d60e0d18e5c
funasr/auto/auto_model.py
@@ -155,9 +155,8 @@
            device = "cpu"
            kwargs["batch_size"] = 1
        kwargs["device"] = device
        if kwargs.get("ncpu", None):
            torch.set_num_threads(kwargs.get("ncpu"))
        torch.set_num_threads(kwargs.get("ncpu", 4))
        
        # build tokenizer
        tokenizer = kwargs.get("tokenizer", None)
@@ -476,11 +475,13 @@
               calib_num: int = 100,
               opset_version: int = 14,
               **cfg):
        os.environ['EXPORTING_MODEL'] = 'TRUE'
        device = cfg.get("device", "cpu")
        model = self.model.to(device=device)
        kwargs = self.kwargs
        deep_update(kwargs, cfg)
        kwargs["device"] = device
        del kwargs["model"]
        model = self.model
        model.eval()
        batch_size = 1
@@ -493,11 +494,19 @@
                export_dir = export_utils.export_onnx(
                                        model=model,
                                        data_in=data_list,
                                        quantize=quantize,
                                        fallback_num=fallback_num,
                                        calib_num=calib_num,
                                        opset_version=opset_version,
                                        **kwargs)
            else:
                export_dir = export_utils.export_torchscripts(
                                        model=model,
                                        data_in=data_list,
                                        quantize=quantize,
                                        fallback_num=fallback_num,
                                        calib_num=calib_num,
                                        opset_version=opset_version,
                                        **kwargs)
        return export_dir