游雁
2023-11-09 adf32376629f6940c84b62167bee6c252e6c2fcc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
"""Initialize funasr package."""
 
import os
from pathlib import Path
import torch
import numpy as np
 
dirname = os.path.dirname(__file__)
version_file = os.path.join(dirname, "version.txt")
with open(version_file, "r") as f:
    __version__ = f.read().strip()
 
 
def prepare_model(
    model: str = None,
    # mode: str = None,
    vad_model: str = None,
    punc_model: str = None,
    model_hub: str = "ms",
    cache_dir: str = None,
    **kwargs,
):
    if not Path(model).exists():
        if model_hub == "ms" or model_hub == "modelscope":
            try:
                from modelscope.hub.snapshot_download import snapshot_download as download_tool
                model = name_maps_ms[model] if model is not None else None
                vad_model = name_maps_ms[vad_model] if vad_model is not None else None
                punc_model = name_maps_ms[punc_model] if punc_model is not None else None
            except:
                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
                      "\npip3 install -U modelscope\n" \
                      "For the users in China, you could install with the command:\n" \
                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
        elif model_hub == "hf" or model_hub == "huggingface":
            download_tool = 0
        else:
            raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
        try:
            model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
            print("model have been downloaded to: {}".format(model))
        except:
            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
                model)
        
        if vad_model is not None and not Path(vad_model).exists():
            vad_model = download_tool(vad_model, cache_dir=cache_dir)
            print("model have been downloaded to: {}".format(vad_model))
        if punc_model is not None and not Path(punc_model).exists():
            punc_model = download_tool(punc_model, cache_dir=cache_dir)
            print("model have been downloaded to: {}".format(punc_model))
        
        # asr
        kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
                       "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
                       "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
                       })
        mode = kwargs.get("mode", None)
        if mode is None:
            import json
            json_file = os.path.join(model, 'configuration.json')
            with open(json_file, 'r') as f:
                config_data = json.load(f)
                if config_data['task'] == "punctuation":
                    mode = config_data['model']['punc_model_config']['mode']
                else:
                    mode = config_data['model']['model_config']['mode']
        if vad_model is not None and "vad" not in mode:
            mode = "paraformer_vad"
        kwargs["mode"] = mode
        # vad
        kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
                       "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
                       "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
                       })
        # punc
        kwargs.update({
            "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
            "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
        })
        
        
        return model, vad_model, punc_model, kwargs
 
name_maps_ms = {
    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
}
 
def infer(task_name: str = "asr",
            model: str = None,
            # mode: str = None,
            vad_model: str = None,
            punc_model: str = None,
            model_hub: str = "ms",
            cache_dir: str = None,
            **kwargs,
          ):
 
    model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
    if task_name == "asr":
        from funasr.bin.asr_inference_launch import inference_launch
 
        inference_pipeline = inference_launch(**kwargs)
    elif task_name == "":
        pipeline = 1
    elif task_name == "":
        pipeline = 2
    elif task_name == "":
        pipeline = 2
    
    def _infer_fn(input, **kwargs):
        data_type = kwargs.get('data_type', 'sound')
        data_path_and_name_and_type = [input, 'speech', data_type]
        raw_inputs = None
        if isinstance(input, torch.Tensor):
            input = input.numpy()
        if isinstance(input, np.ndarray):
            data_path_and_name_and_type = None
            raw_inputs = input
            
 
        
        return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
    
    return _infer_fn
 
if __name__ == '__main__':
    pass