Yabin Li
2023-11-07 702ec03ad89d5c62e97eed770a6882d6412f8d58
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
import argparse
import ssl
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
 
tracemalloc.start()
 
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
 
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="0.0.0.0",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=10095,
                    required=False,
                    help="grpc server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="model from modelscope")
parser.add_argument("--asr_model_online",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
                    help="model from modelscope")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="model from modelscope")
parser.add_argument("--punc_model",
                    type=str,
                    default="damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727",
                    help="model from modelscope")
parser.add_argument("--ngpu",
                    type=int,
                    default=1,
                    help="0 for cpu, 1 for gpu")
parser.add_argument("--ncpu",
                    type=int,
                    default=4,
                    help="cpu cores")
parser.add_argument("--certfile",
                    type=str,
                    default="./ssl_key/server.crt",
                    required=False,
                    help="certfile for ssl")
 
parser.add_argument("--keyfile",
                    type=str,
                    default="./ssl_key/server.key",
                    required=False,
                    help="keyfile for ssl")
args = parser.parse_args()
 
 
websocket_users = set()
 
print("model loading")
# asr
inference_pipeline_asr = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision=None)
 
 
# vad
inference_pipeline_vad = pipeline(
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision=None,
    mode='online',
    ngpu=args.ngpu,
    ncpu=args.ncpu,
)
 
if args.punc_model != "":
    inference_pipeline_punc = pipeline(
        task=Tasks.punctuation,
        model=args.punc_model,
        model_revision="v1.0.2",
        ngpu=args.ngpu,
        ncpu=args.ncpu,
    )
else:
    inference_pipeline_punc = None
 
inference_pipeline_asr_online = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model_online,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision='v1.0.7',
    update_model='v1.0.7',
    mode='paraformer_streaming')
 
print("model loaded! only support one client at the same time now!!!!")
 
async def ws_reset(websocket):
    print("ws reset now, total num is ",len(websocket_users))
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": True}
    websocket.param_dict_asr_online["is_final"]=True
    # audio_in=b''.join(np.zeros(int(16000),dtype=np.int16))
    # inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
    # inference_pipeline_asr_online(audio_in=audio_in, param_dict=websocket.param_dict_asr_online)
    await websocket.close()
    
    
async def clear_websocket():
   for websocket in websocket_users:
       await ws_reset(websocket)
   websocket_users.clear()
 
 
       
async def ws_serve(websocket, path):
    frames = []
    frames_asr = []
    frames_asr_online = []
    global websocket_users
    await clear_websocket()
    websocket_users.add(websocket)
    websocket.param_dict_asr = {}
    websocket.param_dict_asr_online = {"cache": dict()}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
    speech_end_i = -1
    websocket.wav_name = "microphone"
    websocket.mode = "2pass"
    print("new user connected", flush=True)
 
    try:
        async for message in websocket:
            if isinstance(message, str):
                messagejson = json.loads(message)
        
                if "is_speaking" in messagejson:
                    websocket.is_speaking = messagejson["is_speaking"]
                    websocket.param_dict_asr_online["is_final"] = not websocket.is_speaking
                if "chunk_interval" in messagejson:
                    websocket.chunk_interval = messagejson["chunk_interval"]
                if "wav_name" in messagejson:
                    websocket.wav_name = messagejson.get("wav_name")
                if "chunk_size" in messagejson:
                    websocket.param_dict_asr_online["chunk_size"] = messagejson["chunk_size"]
                if "encoder_chunk_look_back" in messagejson:
                    websocket.param_dict_asr_online["encoder_chunk_look_back"] = messagejson["encoder_chunk_look_back"]
                if "decoder_chunk_look_back" in messagejson:
                    websocket.param_dict_asr_online["decoder_chunk_look_back"] = messagejson["decoder_chunk_look_back"]
                if "mode" in messagejson:
                    websocket.mode = messagejson["mode"]
            if len(frames_asr_online) > 0 or len(frames_asr) > 0 or not isinstance(message, str):
                if not isinstance(message, str):
                    frames.append(message)
                    duration_ms = len(message)//32
                    websocket.vad_pre_idx += duration_ms
        
                    # asr online
                    frames_asr_online.append(message)
                    websocket.param_dict_asr_online["is_final"] = speech_end_i != -1
                    if len(frames_asr_online) % websocket.chunk_interval == 0 or websocket.param_dict_asr_online["is_final"]:
                        if websocket.mode == "2pass" or websocket.mode == "online":
                            audio_in = b"".join(frames_asr_online)
                            await async_asr_online(websocket, audio_in)
                        frames_asr_online = []
                    if speech_start:
                        frames_asr.append(message)
                    # vad online
                    speech_start_i, speech_end_i = await async_vad(websocket, message)
                    if speech_start_i != -1:
                        speech_start = True
                        beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                        frames_pre = frames[-beg_bias:]
                        frames_asr = []
                        frames_asr.extend(frames_pre)
                # asr punc offline
                if speech_end_i != -1 or not websocket.is_speaking:
                    # print("vad end point")
                    if websocket.mode == "2pass" or websocket.mode == "offline":
                        audio_in = b"".join(frames_asr)
                        await async_asr(websocket, audio_in)
                    frames_asr = []
                    speech_start = False
                    # frames_asr_online = []
                    # websocket.param_dict_asr_online = {"cache": dict()}
                    if not websocket.is_speaking:
                        websocket.vad_pre_idx = 0
                        frames = []
                        websocket.param_dict_vad = {'in_cache': dict()}
                    else:
                        frames = frames[-20:]
 
     
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users,flush=True)
        await ws_reset(websocket)
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")
    except Exception as e:
        print("Exception:", e)
 
 
async def async_vad(websocket, audio_in):
 
    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
 
    speech_start = -1
    speech_end = -1
    
    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
        return speech_start, speech_end
    if segments_result["text"][0][0] != -1:
        speech_start = segments_result["text"][0][0]
    if segments_result["text"][0][1] != -1:
        speech_end = segments_result["text"][0][1]
    return speech_start, speech_end
 
 
async def async_asr(websocket, audio_in):
            if len(audio_in) > 0:
                # print(len(audio_in))
                rec_result = inference_pipeline_asr(audio_in=audio_in,
                                                    param_dict=websocket.param_dict_asr)
                # print(rec_result)
                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
                                                         param_dict=websocket.param_dict_punc)
                    # print("offline", rec_result)
                if 'text' in rec_result:
                    mode = "2pass-offline" if "2pass" in websocket.mode else websocket.mode
                    message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
                    await websocket.send(message)
 
 
async def async_asr_online(websocket, audio_in):
    if len(audio_in) > 0:
        # print(websocket.param_dict_asr_online.get("is_final", False))
        rec_result = inference_pipeline_asr_online(audio_in=audio_in,
                                                   param_dict=websocket.param_dict_asr_online)
        # print(rec_result)
        if websocket.mode == "2pass" and websocket.param_dict_asr_online.get("is_final", False):
            return
            #     websocket.param_dict_asr_online["cache"] = dict()
        if "text" in rec_result:
            if rec_result["text"] != "sil" and rec_result["text"] != "waiting_for_more_voice":
                # print("online", rec_result)
                mode = "2pass-online" if "2pass" in websocket.mode else websocket.mode
                message = json.dumps({"mode": mode, "text": rec_result["text"], "wav_name": websocket.wav_name,"is_final":websocket.is_speaking})
                await websocket.send(message)
 
if len(args.certfile)>0:
    ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
    
    # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
    ssl_cert = args.certfile
    ssl_key = args.keyfile
    
    ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
else:
    start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()