1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
| #!/usr/bin/env python3
| # -*- encoding: utf-8 -*-
| # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
| # MIT License (https://opensource.org/licenses/MIT)
|
| import os
|
| import logging
| import torch
| import numpy as np
| from funasr.utils.download_and_prepare_model import prepare_model
|
| from funasr.utils.types import str2bool
|
| def infer(task_name: str = "asr",
| model: str = None,
| # mode: str = None,
| vad_model: str = None,
| disable_vad: bool = False,
| punc_model: str = None,
| disable_punc: bool = False,
| model_hub: str = "ms",
| cache_dir: str = None,
| **kwargs,
| ):
|
| # set logging messages
| logging.basicConfig(
| level=logging.ERROR,
| )
|
| model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
| if task_name == "asr":
| from funasr.bin.asr_inference_launch import inference_launch
|
| inference_pipeline = inference_launch(**kwargs)
| elif task_name == "":
| pipeline = 1
| elif task_name == "":
| pipeline = 2
| elif task_name == "":
| pipeline = 2
|
| def _infer_fn(input, **kwargs):
| data_type = kwargs.get('data_type', 'sound')
| data_path_and_name_and_type = [input, 'speech', data_type]
| raw_inputs = None
| if isinstance(input, torch.Tensor):
| input = input.numpy()
| if isinstance(input, np.ndarray):
| data_path_and_name_and_type = None
| raw_inputs = input
|
| return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
|
| return _infer_fn
|
|
| def main(cmd=None):
| # print(get_commandline_args(), file=sys.stderr)
| from funasr.bin.argument import get_parser
|
| parser = get_parser()
| parser.add_argument('input', help='input file to transcribe')
| parser.add_argument(
| "--task_name",
| type=str,
| default="asr",
| help="The decoding mode",
| )
| parser.add_argument(
| "-m",
| "--model",
| type=str,
| default="paraformer-zh",
| help="The asr mode name",
| )
| parser.add_argument(
| "-v",
| "--vad_model",
| type=str,
| default="fsmn-vad",
| help="vad model name",
| )
| parser.add_argument(
| "-dv",
| "--disable_vad",
| type=str2bool,
| default=False,
| help="",
| )
| parser.add_argument(
| "-p",
| "--punc_model",
| type=str,
| default="ct-punc",
| help="",
| )
| parser.add_argument(
| "-dp",
| "--disable_punc",
| type=str2bool,
| default=False,
| help="",
| )
| parser.add_argument(
| "--batch_size_token",
| type=int,
| default=5000,
| help="",
| )
| parser.add_argument(
| "--batch_size_token_threshold_s",
| type=int,
| default=35,
| help="",
| )
| parser.add_argument(
| "--max_single_segment_time",
| type=int,
| default=5000,
| help="",
| )
| args = parser.parse_args(cmd)
| kwargs = vars(args)
|
| # set logging messages
| logging.basicConfig(
| level=logging.ERROR,
| )
| logging.info("Decoding args: {}".format(kwargs))
|
| # kwargs["ncpu"] = 2 #os.cpu_count()
| kwargs.pop("data_path_and_name_and_type")
| print("args: {}".format(kwargs))
| p = infer(**kwargs)
|
| res = p(**kwargs)
| print(res)
|
|