cgisky1980
2023-03-23 61089908a0e8e7a8b6a05493b35a1c42854dd8d9
websocket runtime
2个文件已添加
216 ■■■■■ 已修改文件
funasr/runtime/websocket/ASR_client.py 73 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/ASR_server.py 143 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/ASR_client.py
New file
@@ -0,0 +1,73 @@
import pyaudio
import websocket #区别服务端这里是 websocket-client库
import time
import websockets
import asyncio
from queue import Queue
import threading
voices = Queue()
async def hello():
    global ws # 定义一个全局变量ws,用于保存websocket连接对象
    uri = "ws://localhost:8899"
    ws = await websockets.connect(uri, subprotocols=["binary"]) # 创建一个长连接
    ws.max_size = 1024 * 1024 * 20
    print("connected ws server")
async def send(data):
    global ws # 引用全局变量ws
    try:
        await ws.send(data) # 通过ws对象发送数据
    except Exception as e:
        print('Exception occurred:', e)
asyncio.get_event_loop().run_until_complete(hello()) # 启动协程
# 其他函数可以通过调用send(data)来发送数据,例如:
async def test():
    #print("2")
    global voices
    FORMAT = pyaudio.paInt16
    CHANNELS = 1
    RATE = 16000
    CHUNK = int(RATE / 1000 * 300)
    p = pyaudio.PyAudio()
    stream = p.open(format=FORMAT,
                    channels=CHANNELS,
                    rate=RATE,
                    input=True,
                    frames_per_buffer=CHUNK)
    while True:
        data = stream.read(CHUNK)
        voices.put(data)
        #print(voices.qsize())
        await asyncio.sleep(0.01)
async def ws_send():
    global voices
    print("started to sending data!")
    while True:
        while not voices.empty():
            data = voices.get()
            voices.task_done()
            await send(data)
            await asyncio.sleep(0.01)
        await asyncio.sleep(0.01)
async def main():
    task = asyncio.create_task(test()) # 创建一个后台任务
    task2 = asyncio.create_task(ws_send()) # 创建一个后台任务
    await asyncio.gather(task, task2)
asyncio.run(main())
funasr/runtime/websocket/ASR_server.py
New file
@@ -0,0 +1,143 @@
# server.py   注意本例仅处理单个clent发送的语音数据,并未对多client连接进行判断和处理
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
from modelscope.utils.logger import get_logger
import logging
logger = get_logger(log_level=logging.CRITICAL)
logger.setLevel(logging.CRITICAL)
import asyncio
import websockets  #区别客户端这里是 websockets库
import time
from queue import Queue
import  threading
print("model loading")
voices = Queue()
speek = Queue()
# 创建一个VAD对象
vad_pipline = pipeline(
    task=Tasks.voice_activity_detection,
    model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
    model_revision="v1.2.0",
    output_dir=None,
    batch_size=1,
)
# 创建一个ASR对象
param_dict = dict()
param_dict["hotword"] = "小五 小五月"  # 设置热词,用空格隔开
inference_pipeline2 = pipeline(
    task=Tasks.auto_speech_recognition,
    model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
    param_dict=param_dict,
)
print("model loaded")
async def echo(websocket, path):
    global voices
    try:
        async for message in websocket:
            voices.put(message)
            #print("put")
    except websockets.exceptions.ConnectionClosedError as e:
        print('Connection closed with exception:', e)
    except Exception as e:
        print('Exception occurred:', e)
start_server = websockets.serve(echo, "localhost", 8899, subprotocols=["binary"],ping_interval=None)
def vad(data):  # 推理
    global vad_pipline
    #print(type(data))
    segments_result = vad_pipline(audio_in=data)
    #print(segments_result)
    if len(segments_result) == 0:
        return False
    else:
        return True
def asr():  # 推理
    global inference_pipeline2
    global speek
    while True:
        while not speek.empty():
            audio_in = speek.get()
            speek.task_done()
            rec_result = inference_pipeline2(audio_in=audio_in)
            print(rec_result)
            time.sleep(0.1)
        time.sleep(0.1)
def main():  # 推理
    frames = []  # 存储所有的帧数据
    buffer = []  # 存储缓存中的帧数据(最多两个片段)
    silence_count = 0  # 统计连续静音的次数
    speech_detected = False  # 标记是否检测到语音
    RECORD_NUM = 0
    global voices
    global speek
    while True:
        while not voices.empty():
            data = voices.get()
            #print("队列排队数",voices.qsize())
            voices.task_done()
            buffer.append(data)
            if len(buffer) > 2:
                buffer.pop(0)  # 如果缓存超过两个片段,则删除最早的一个
            if speech_detected:
                frames.append(data)
                RECORD_NUM += 1
            if  vad(data):
                if not speech_detected:
                    print("检测到人声...")
                    speech_detected = True  # 标记为检测到语音
                    frames = []
                    frames.extend(buffer)  # 把之前2个语音数据快加入
                silence_count = 0  # 重置静音次数
            else:
                silence_count += 1  # 增加静音次数
                if speech_detected and (silence_count > 4 or RECORD_NUM > 50): #这里 50 可根据需求改为合适的数据快数量
                    print("说话结束或者超过设置最长时间...")
                    audio_in = b"".join(frames)
                    #asrt = threading.Thread(target=asr,args=(audio_in,))
                    #asrt.start()
                    speek.put(audio_in)
                    #rec_result = inference_pipeline2(audio_in=audio_in)  # ASR 模型里跑一跑
                    frames = []  # 清空所有的帧数据
                    buffer = []  # 清空缓存中的帧数据(最多两个片段)
                    silence_count = 0  # 统计连续静音的次数清零
                    speech_detected = False  # 标记是否检测到语音
                    RECORD_NUM = 0
            time.sleep(0.01)
        time.sleep(0.01)
s = threading.Thread(target=main)
s.start()
s = threading.Thread(target=asr)
s.start()
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()