From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)
---
funasr/runtime/python/grpc/grpc_main_client.py | 128 ++++++++++++++++++++++++------------------
1 files changed, 72 insertions(+), 56 deletions(-)
diff --git a/funasr/runtime/python/grpc/grpc_main_client.py b/funasr/runtime/python/grpc/grpc_main_client.py
index b6491df..92888bd 100644
--- a/funasr/runtime/python/grpc/grpc_main_client.py
+++ b/funasr/runtime/python/grpc/grpc_main_client.py
@@ -1,62 +1,78 @@
-import grpc
-import json
-import time
-import asyncio
-import soundfile as sf
+import logging
import argparse
+import soundfile as sf
+import time
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
+import grpc
+import paraformer_pb2_grpc
+from paraformer_pb2 import Request, WavFormat, DecodeMode
-# send the audio data once
-async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
- with grpc.insecure_channel(grpc_uri) as channel:
- stub = ASRStub(channel)
- for line in wav_scp:
- wav_file = line.split()[1]
- wav, _ = sf.read(wav_file, dtype='int16')
-
- b = time.time()
- response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
- resp = response.next()
- text = ''
- if 'decoding' == resp.action:
- resp = response.next()
- if 'finish' == resp.action:
- text = json.loads(resp.sentence)['text']
- response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
- res= {'text': text, 'time': time.time() - b}
- print(res)
+class GrpcClient:
+ def __init__(self, wav_path, uri, mode):
+ self.wav, self.sampling_rate = sf.read(wav_path, dtype='int16')
+ self.wav_format = WavFormat.pcm
+ self.audio_chunk_duration = 1000 # ms
+ self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
+ self.send_interval = 100 # ms
+ self.mode = mode
-async def test(args):
- wav_scp = open(args.wav_scp, "r").readlines()
- uri = '{}:{}'.format(args.host, args.port)
- res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+ # connect to grpc server
+ channel = grpc.insecure_channel(uri)
+ self.stub = paraformer_pb2_grpc.ASRStub(channel)
+
+ # start request
+ for respond in self.stub.Recognize(self.request_iterator()):
+ logging.info("[receive] mode {}, text {}, is final {}".format(
+ DecodeMode.Name(respond.mode), respond.text, respond.is_final))
+
+ def request_iterator(self, mode = DecodeMode.two_pass):
+ is_first_pack = True
+ is_final = False
+ for start in range(0, len(self.wav), self.audio_chunk_size):
+ request = Request()
+ audio_chunk = self.wav[start : start + self.audio_chunk_size]
+
+ if is_first_pack:
+ is_first_pack = False
+ request.sampling_rate = self.sampling_rate
+ request.mode = self.mode
+ request.wav_format = self.wav_format
+ if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
+ request.chunk_size.extend([5, 10, 5])
+
+ if start + self.audio_chunk_size >= len(self.wav):
+ is_final = True
+ request.is_final = is_final
+ request.audio_data = audio_chunk.tobytes()
+ logging.info("[request] audio_data len {}, is final {}".format(
+ len(request.audio_data), request.is_final)) # int16 = 2bytes
+ time.sleep(self.send_interval / 1000)
+ yield request
if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--host",
- type=str,
- default="127.0.0.1",
- required=False,
- help="grpc server host ip")
- parser.add_argument("--port",
- type=int,
- default=10108,
- required=False,
- help="grpc server port")
- parser.add_argument("--user_allowed",
- type=str,
- default="project1_user1",
- help="allowed user for grpc client")
- parser.add_argument("--sample_rate",
- type=int,
- default=16000,
- help="audio sample_rate from client")
- parser.add_argument("--wav_scp",
- type=str,
- required=True,
- help="audio wav scp")
- args = parser.parse_args()
-
- asyncio.run(test(args))
+ logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host",
+ type=str,
+ default="127.0.0.1",
+ required=False,
+ help="grpc server host ip")
+ parser.add_argument("--port",
+ type=int,
+ default=10100,
+ required=False,
+ help="grpc server port")
+ parser.add_argument("--wav_path",
+ type=str,
+ required=True,
+ help="audio wav path")
+ args = parser.parse_args()
+
+ for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
+ mode_name = DecodeMode.Name(mode)
+ logging.info("[request] start requesting with mode {}".format(mode_name))
+
+ st = time.time()
+ uri = '{}:{}'.format(args.host, args.port)
+ client = GrpcClient(args.wav_path, uri, mode)
+ logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))
--
Gitblit v1.9.1