雾聪
2023-08-10 ffb05b9ae7eccc47416e9e7fae9dea54d400a245
funasr/export/export_model.py
@@ -55,18 +55,21 @@
        # export encoder1
        self.export_config["model_name"] = "model"
        model = get_model(
        models = get_model(
            model,
            self.export_config,
        )
        model.eval()
        # self._export_onnx(model, verbose, export_dir)
        if self.onnx:
            self._export_onnx(model, verbose, export_dir)
        else:
            self._export_torchscripts(model, verbose, export_dir)
        print("output dir: {}".format(export_dir))
        if not isinstance(models, tuple):
            models = (models,)
        for i, model in enumerate(models):
            model.eval()
            if self.onnx:
                self._export_onnx(model, verbose, export_dir)
            else:
                self._export_torchscripts(model, verbose, export_dir)
            print("output dir: {}".format(export_dir))
    def _torch_quantize(self, model):
@@ -192,6 +195,7 @@
                config, model_file, cmvn_file, 'cpu'
            )
            self.frontend = model.frontend
            self.export_config["feats_dim"] = 560
        elif mode.startswith('offline'):
            from funasr.tasks.vad import VADTask
            config = os.path.join(model_dir, 'vad.yaml')