Shi Xian
2024-06-18 6c467e6f0abfc6d20d0621fbbf67b4dbd81776cc
funasr/auto/auto_model.py
@@ -466,25 +466,22 @@
                            result[k] = restored_data[j][k]
                        else:
                            result[k] += restored_data[j][k]
            if not len(result["text"].strip()):
                continue
            return_raw_text = kwargs.get("return_raw_text", False)
            # step.3 compute punc model
            raw_text = None
            if self.punc_model is not None:
                if not len(result["text"].strip()):
                    if return_raw_text:
                        result["raw_text"] = ""
                else:
                    deep_update(self.punc_kwargs, cfg)
                    punc_res = self.inference(
                        result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
                    )
                    raw_text = copy.copy(result["text"])
                    if return_raw_text:
                        result["raw_text"] = raw_text
                    result["text"] = punc_res[0]["text"]
            else:
                raw_text = None
                deep_update(self.punc_kwargs, cfg)
                punc_res = self.inference(
                    result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
                )
                raw_text = copy.copy(result["text"])
                if return_raw_text:
                    result["raw_text"] = raw_text
                result["text"] = punc_res[0]["text"]
            # speaker embedding cluster after resorted
            if self.spk_model is not None and kwargs.get("return_spk_res", True):
                if raw_text is None:
@@ -605,12 +602,6 @@
        )
        with torch.no_grad():
            if type == "onnx":
                export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs)
            else:
                export_dir = export_utils.export_torchscripts(
                    model=model, data_in=data_list, **kwargs
                )
            export_dir = export_utils.export(model=model, data_in=data_list,  **kwargs)
        return export_dir