游雁
2023-03-23 8873c2c21a23a67e861fb2ae1672763ac709e7f6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# server.py   注意本例仅处理单个clent发送的语音数据,并未对多client连接进行判断和处理
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
import logging
 
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
 
import asyncio
import websockets
import time
from queue import Queue
import threading
import argparse
 
parser = argparse.ArgumentParser()
parser.add_argument("--host",
                    type=str,
                    default="0.0.0.0",
                    required=False,
                    help="host ip, localhost, 0.0.0.0")
parser.add_argument("--port",
                    type=int,
                    default=10095,
                    required=False,
                    help="grpc server port")
parser.add_argument("--asr_model",
                    type=str,
                    default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                    help="model from modelscope")
parser.add_argument("--vad_model",
                    type=str,
                    default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                    help="model from modelscope")
 
parser.add_argument("--punc_model",
                    type=str,
                    default="",
                    help="model from modelscope")
 
args = parser.parse_args()
 
print("model loading")
voices = Queue()
speek = Queue()
 
# 创建一个VAD对象
vad_pipline = pipeline(
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision="v1.2.0",
    output_dir=None,
    batch_size=1,
)
  
# 创建一个ASR对象
param_dict = dict()
# param_dict["hotword"] = "小五 小五月"  # 设置热词,用空格隔开
inference_pipeline2 = pipeline(
    task=Tasks.auto_speech_recognition,
    model=args.asr_model,
    param_dict=param_dict,
)
print("model loaded")
 
 
 
async def ws_serve(websocket, path):
    global voices
    try:
        async for message in websocket:
            voices.put(message)
            #print("put")
    except websockets.exceptions.ConnectionClosedError as e:
        print('Connection closed with exception:', e)
    except Exception as e:
        print('Exception occurred:', e)
 
start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 
 
def vad(data):  # 推理
    global vad_pipline
    #print(type(data))
    segments_result = vad_pipline(audio_in=data)
    #print(segments_result)
    if len(segments_result) == 0:
        return False
    else:
        return True
 
def asr():  # 推理
    global inference_pipeline2
    global speek
    while True:
        while not speek.empty():
            audio_in = speek.get()
            speek.task_done()
            rec_result = inference_pipeline2(audio_in=audio_in)
            print(rec_result)
            time.sleep(0.1)
        time.sleep(0.1)    
 
 
def main():  # 推理
    frames = []  # 存储所有的帧数据
    buffer = []  # 存储缓存中的帧数据(最多两个片段)
    silence_count = 0  # 统计连续静音的次数
    speech_detected = False  # 标记是否检测到语音
    RECORD_NUM = 0
    global voices 
    global speek
    while True:
        while not voices.empty():
            
            data = voices.get()
            #print("队列排队数",voices.qsize())
            voices.task_done()
            buffer.append(data)
            if len(buffer) > 2:
                buffer.pop(0)  # 如果缓存超过两个片段,则删除最早的一个
            
            if speech_detected:
                frames.append(data)
                RECORD_NUM += 1    
            
            if vad(data):
                if not speech_detected:
                    print("检测到人声...")
                    speech_detected = True  # 标记为检测到语音
                    frames = []
                    frames.extend(buffer)  # 把之前2个语音数据快加入
                silence_count = 0  # 重置静音次数
            else:
                silence_count += 1  # 增加静音次数
 
                if speech_detected and (silence_count > 4 or RECORD_NUM > 50): #这里 50 可根据需求改为合适的数据快数量
                    print("说话结束或者超过设置最长时间...")
                    audio_in = b"".join(frames)
                    #asrt = threading.Thread(target=asr,args=(audio_in,))
                    #asrt.start()
                    speek.put(audio_in)
                    #rec_result = inference_pipeline2(audio_in=audio_in)  # ASR 模型里跑一跑
                    frames = []  # 清空所有的帧数据
                    buffer = []  # 清空缓存中的帧数据(最多两个片段)
                    silence_count = 0  # 统计连续静音的次数清零
                    speech_detected = False  # 标记是否检测到语音
                    RECORD_NUM = 0
            time.sleep(0.01)
        time.sleep(0.01)
            
 
 
s = threading.Thread(target=main)
s.start()
s = threading.Thread(target=asr)
s.start()
 
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()