游雁
2023-05-08 24f2aa224a49bee0de0a59504abce232e5f2683e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import asyncio
import json
import websockets
import time
import logging
import tracemalloc
import numpy as np
 
from parse_args import args
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
from funasr.runtime.python.onnxruntime.funasr_onnx.utils.frontend import load_bytes
 
tracemalloc.start()
 
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
 
 
websocket_users = set()
 
print("model loading")
# asr
inference_pipeline_asr = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model,
    ngpu=args.ngpu,
    ncpu=args.ncpu,
    model_revision=None)
 
 
# vad
inference_pipeline_vad = pipeline(
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision=None,
    output_dir=None,
    batch_size=1,
    mode='online',
    ngpu=args.ngpu,
    ncpu=args.ncpu,
)
 
if args.punc_model != "":
    inference_pipeline_punc = pipeline(
        task=Tasks.punctuation,
        model=args.punc_model,
        model_revision=None,
        ngpu=args.ngpu,
        ncpu=args.ncpu,
    )
else:
    inference_pipeline_punc = None
 
print("model loaded")
 
async def ws_serve(websocket, path):
    frames = []
    frames_asr = []
    global websocket_users
    websocket_users.add(websocket)
    websocket.param_dict_asr = {}
    websocket.param_dict_vad = {'in_cache': dict(), "is_final": False}
    websocket.param_dict_punc = {'cache': list()}
    websocket.vad_pre_idx = 0
    speech_start = False
 
    try:
        async for message in websocket:
            message = json.loads(message)
            is_finished = message["is_finished"]
            if not is_finished:
                audio = bytes(message['audio'], 'ISO-8859-1')
                frames.append(audio)
                duration_ms = len(audio)//32
                websocket.vad_pre_idx += duration_ms
 
                is_speaking = message["is_speaking"]
                websocket.param_dict_vad["is_final"] = not is_speaking
                websocket.wav_name = message.get("wav_name", "demo")
                if speech_start:
                    frames_asr.append(audio)
                speech_start_i, speech_end_i = await async_vad(websocket, audio)
                if speech_start_i:
                    speech_start = True
                    beg_bias = (websocket.vad_pre_idx-speech_start_i)//duration_ms
                    frames_pre = frames[-beg_bias:]
                    frames_asr = []
                    frames_asr.extend(frames_pre)
                if speech_end_i or not is_speaking:
                    audio_in = b"".join(frames_asr)
                    await async_asr(websocket, audio_in)
                    frames_asr = []
                    speech_start = False
                    if not is_speaking:
                        websocket.vad_pre_idx = 0
                        frames = []
                        websocket.param_dict_vad = {'in_cache': dict()}
                    else:
                        frames = frames[-20:]
 
     
    except websockets.ConnectionClosed:
        print("ConnectionClosed...", websocket_users)
        websocket_users.remove(websocket)
    except websockets.InvalidState:
        print("InvalidState...")
    except Exception as e:
        print("Exception:", e)
 
 
async def async_vad(websocket, audio_in):
 
    segments_result = inference_pipeline_vad(audio_in=audio_in, param_dict=websocket.param_dict_vad)
 
    speech_start = False
    speech_end = False
    
    if len(segments_result) == 0 or len(segments_result["text"]) > 1:
        return speech_start, speech_end
    if segments_result["text"][0][0] != -1:
        speech_start = segments_result["text"][0][0]
    if segments_result["text"][0][1] != -1:
        speech_end = True
    return speech_start, speech_end
 
 
async def async_asr(websocket, audio_in):
            if len(audio_in) > 0:
                # print(len(audio_in))
                audio_in = load_bytes(audio_in)
                
                rec_result = inference_pipeline_asr(audio_in=audio_in,
                                                    param_dict=websocket.param_dict_asr)
                # print(rec_result)
                if inference_pipeline_punc is not None and 'text' in rec_result and len(rec_result["text"])>0:
                    rec_result = inference_pipeline_punc(text_in=rec_result['text'],
                                                         param_dict=websocket.param_dict_punc)
                    # print(rec_result)
                message = json.dumps({"mode": "offline", "text": [rec_result["text"]], "wav_name": websocket.wav_name})
                await websocket.send(message)
                
                    
 
 
 
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()