维石
2024-06-03 509d09f50d4ed73cec79e6a73de8bf30ab798e8c
runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -15,9 +15,16 @@
logging = get_logger()
import torch
import json
class Paraformer:
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(
        self,
        model_dir: Union[str, Path] = None,
@@ -25,20 +32,43 @@
        device_id: Union[str, int] = "-1",
        plot_timestamp_to: str = "",
        quantize: bool = False,
        intra_op_num_threads: int = 1,
        intra_op_num_threads: int = 4,
        cache_dir: str = None,
        **kwargs,
    ):
        if not Path(model_dir).exists():
            raise FileNotFoundError(f"{model_dir} does not exist.")
            try:
                from modelscope.hub.snapshot_download import snapshot_download
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" "\npip3 install -U modelscope\n" "For the users in China, you could install with the command:\n" "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                    model_dir
                )
        model_file = os.path.join(model_dir, "model.torchscripts")
        if quantize:
            model_file = os.path.join(model_dir, "model_quant.torchscripts")
        if not os.path.exists(model_file):
            print(".onnx is not exist, begin to export onnx")
            try:
                from funasr import AutoModel
            except:
                raise "You are exporting onnx, please install funasr and try it again. To install funasr, you could:\n" "\npip3 install -U funasr\n" "For the users in China, you could install with the command:\n" "\npip3 install -U funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple"
            model = AutoModel(model=model_dir)
            model_dir = model.export(type="torchscript", quantize=quantize, **kwargs)
        config_file = os.path.join(model_dir, "config.yaml")
        cmvn_file = os.path.join(model_dir, "am.mvn")
        config = read_yaml(config_file)
        token_list = os.path.join(model_dir, "tokens.json")
        with open(token_list, "r", encoding="utf-8") as f:
            token_list = json.load(f)
        self.converter = TokenIDConverter(config["token_list"])
        self.converter = TokenIDConverter(token_list)
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
        self.ort_infer = torch.jit.load(model_file)
@@ -49,6 +79,10 @@
            self.pred_bias = config["model_conf"]["predictor_bias"]
        else:
            self.pred_bias = 0
        if "lang" in config:
            self.language = config["lang"]
        else:
            self.language = None
    def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List:
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
@@ -202,4 +236,4 @@
        token = self.converter.ids2tokens(token_int)
        token = token[: valid_token_num - self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
        return token