cgisky1980
2023-03-21 b1a5fbd433da55291f0a3d9df3fa1e85e6fcbc66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
#""" from https://github.com/cgisky1980/550W_AI_Assistant """
 
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
import logging
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
import websocket
import pyaudio
import time
import json
import threading
 
 
# ---------WebsocketClient相关  主要处理 on_message on_open  已经做了断线重连处理
class WebsocketClient(object):
    def __init__(self, address, message_callback=None):
        super(WebsocketClient, self).__init__()
        self.address = address
        self.message_callback = None
 
    def on_message(self, ws, message):
        try:
            messages = json.loads(
                (message.encode("raw_unicode_escape")).decode()
            )  # 收到WS消息后的处理
            if messages.get("type") == "ping":
                self.ws.send('{"type":"pong"}')
        except json.JSONDecodeError as e:
            print(f"JSONDecodeError: {e}")
        except KeyError:
            print("KeyError!")
 
    def on_error(self, ws, error):
        print("client error:", error)
 
    def on_close(self, ws):
        print("### client closed ###")
        self.ws.close()
        self.is_running = False
 
    def on_open(self, ws):  # 连上ws后发布登录信息
        self.is_running = True
        self.ws.send(
            '{"type":"login","uid":"asr","pwd":"tts9102093109"}'
        )  # WS链接上后的登陆处理
 
    def close_connect(self):
        self.ws.close()
 
    def send_message(self, message):
        try:
            self.ws.send(message)
        except BaseException as err:
            pass
 
    def run(self):  # WS初始化
        websocket.enableTrace(True)
        self.ws = websocket.WebSocketApp(
            self.address,
            on_message=lambda ws, message: self.on_message(ws, message),
            on_error=lambda ws, error: self.on_error(ws, error),
            on_close=lambda ws: self.on_close(ws),
        )
        websocket.enableTrace(False)  # 要看ws调试信息,请把这行注释掉
        self.ws.on_open = lambda ws: self.on_open(ws)
        self.is_running = False
        # WS断线重连判断
        while True:  
            if not self.is_running:
                self.ws.run_forever()
            time.sleep(3)  # 3秒检测一次
 
 
class WSClient(object):
    def __init__(self, address, call_back):
        super(WSClient, self).__init__()
        self.client = WebsocketClient(address, call_back)
        self.client_thread = None
 
    def run(self):
        self.client_thread = threading.Thread(target=self.run_client)
        self.client_thread.start()
 
    def run_client(self):
        self.client.run()
 
    def send_message(self, message):
        self.client.send_message(message)
 
 
def vad(data):  # VAD推理
    segments_result = vad_pipline(audio_in=data)
    if segments_result["text"] == "[]":
        return False
    else:
        return True
 
 
# 创建一个VAD对象
vad_pipline = pipeline(
    task=Tasks.voice_activity_detection,
    model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    model_revision="v1.2.0",
    output_dir=None,
    batch_size=1,
)
 
param_dict = dict()
param_dict["hotword"] = "小五 小五月"  # 设置热词,用空格隔开
 
 
# 创建一个ASR对象
inference_pipeline2 = pipeline(
    task=Tasks.auto_speech_recognition,
    model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
    param_dict=param_dict,
)
 
# 创建一个PyAudio对象
p = pyaudio.PyAudio()
 
# 定义一些参数
FORMAT = pyaudio.paInt16  # 采样格式
CHANNELS = 1  # 单声道
RATE = 16000  # 采样率
CHUNK = int(RATE / 1000 * 300)  # 每个片段的帧数(300毫秒)
RECORD_NUM = 0 # 录制时长(片段)
 
# 打开输入流
stream = p.open(
    format=FORMAT,
    channels=CHANNELS,
    rate=RATE,
    input=True,
    frames_per_buffer=CHUNK,
)
 
print("开始...")
 
# 创建一个WS连接
ws_client = WSClient("ws://localhost:7272", None)
ws_client.run()
 
frames = []  # 存储所有的帧数据
buffer = []  # 存储缓存中的帧数据(最多两个片段)
silence_count = 0  # 统计连续静音的次数
speech_detected = False  # 标记是否检测到语音
 
# 循环读取输入流中的数据
while True:
    data = stream.read(CHUNK)  # 读取一个片段的数据
    buffer.append(data)  # 将当前数据添加到缓存中
 
    if len(buffer) > 2:
        buffer.pop(0)  # 如果缓存超过两个片段,则删除最早的一个
 
    if speech_detected:
        frames.append(data)
        RECORD_NUM += 1
        # print(str(RECORD_NUM)+ "\r")
 
    if vad(data):  # VAD 判断是否有声音
        if not speech_detected:
            print("开始录音...")
            speech_detected = True  # 标记为检测到语音
            frames = []
            frames.extend(buffer)  # 把之前2个语音数据快加入
        silence_count = 0  # 重置静音次数
 
    else:
        silence_count += 1  # 增加静音次数
        #检测静音次数4次  或者已经录了50个数据块,则录音停止
        if speech_detected and (silence_count > 4 or RECORD_NUM > 50):  
            print("停止录音...")
            audio_in = b"".join(frames)
            rec_result = inference_pipeline2(audio_in=audio_in)  # ws播报数据
            rec_result["type"] = "nlp"  # 添加ws播报数据
            ws_client.send_message(
                json.dumps(rec_result, ensure_ascii=False)
            )  # ws发送到服务端
            print(rec_result)
            frames = []  # 清空所有的帧数据
            buffer = []  # 清空缓存中的帧数据(最多两个片段)
            silence_count = 0  # 统计连续静音的次数清零
            speech_detected = False  # 标记是否检测到语音
            RECORD_NUM = 0
 
print("结束录制...")
 
# 停止并关闭输入流
stream.stop_stream()
stream.close()
 
# 关闭PyAudio对象
p.terminate()