From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)

---
 funasr/runtime/python/grpc/grpc_main_client.py |  128 ++++++++++++++++++++++++------------------
 1 files changed, 72 insertions(+), 56 deletions(-)

diff --git a/funasr/runtime/python/grpc/grpc_main_client.py b/funasr/runtime/python/grpc/grpc_main_client.py
index b6491df..92888bd 100644
--- a/funasr/runtime/python/grpc/grpc_main_client.py
+++ b/funasr/runtime/python/grpc/grpc_main_client.py
@@ -1,62 +1,78 @@
-import grpc
-import json
-import time
-import asyncio
-import soundfile as sf
+import logging
 import argparse
+import soundfile as sf
+import time
 
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
+import grpc
+import paraformer_pb2_grpc
+from paraformer_pb2 import Request, WavFormat, DecodeMode
 
-# send the audio data once
-async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
-    with grpc.insecure_channel(grpc_uri) as channel:
-        stub = ASRStub(channel)
-        for line in wav_scp:
-            wav_file = line.split()[1]
-            wav, _ = sf.read(wav_file, dtype='int16')
-            
-            b = time.time()
-            response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
-            resp = response.next()
-            text = ''
-            if 'decoding' == resp.action:
-                resp = response.next()
-                if 'finish' == resp.action:
-                    text = json.loads(resp.sentence)['text']
-            response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
-            res= {'text': text, 'time': time.time() - b}
-            print(res)
+class GrpcClient:
+  def __init__(self, wav_path, uri, mode):
+    self.wav, self.sampling_rate = sf.read(wav_path, dtype='int16')
+    self.wav_format = WavFormat.pcm
+    self.audio_chunk_duration = 1000 # ms
+    self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
+    self.send_interval = 100 # ms
+    self.mode = mode
 
-async def test(args):
-    wav_scp = open(args.wav_scp, "r").readlines()
-    uri = '{}:{}'.format(args.host, args.port)
-    res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+    # connect to grpc server
+    channel = grpc.insecure_channel(uri)
+    self.stub = paraformer_pb2_grpc.ASRStub(channel)
+    
+    # start request
+    for respond in self.stub.Recognize(self.request_iterator()):
+      logging.info("[receive] mode {}, text {}, is final {}".format(
+        DecodeMode.Name(respond.mode), respond.text, respond.is_final))
+
+  def request_iterator(self, mode = DecodeMode.two_pass):
+    is_first_pack = True
+    is_final = False
+    for start in range(0, len(self.wav), self.audio_chunk_size):
+      request = Request()
+      audio_chunk = self.wav[start : start + self.audio_chunk_size]
+
+      if is_first_pack:
+        is_first_pack = False
+        request.sampling_rate = self.sampling_rate
+        request.mode = self.mode
+        request.wav_format = self.wav_format
+        if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
+          request.chunk_size.extend([5, 10, 5])
+
+      if start + self.audio_chunk_size >= len(self.wav):
+        is_final = True
+      request.is_final = is_final
+      request.audio_data = audio_chunk.tobytes()
+      logging.info("[request] audio_data len {}, is final {}".format(
+        len(request.audio_data), request.is_final)) # int16 = 2bytes
+      time.sleep(self.send_interval / 1000)
+      yield request
 
 if __name__ == '__main__':
-    parser = argparse.ArgumentParser()
-    parser.add_argument("--host",
-                        type=str,
-                        default="127.0.0.1",
-                        required=False,
-                        help="grpc server host ip")
-    parser.add_argument("--port",
-                        type=int,
-                        default=10108,
-                        required=False,
-                        help="grpc server port")              
-    parser.add_argument("--user_allowed",
-                        type=str,
-                        default="project1_user1",
-                        help="allowed user for grpc client")
-    parser.add_argument("--sample_rate",
-                        type=int,
-                        default=16000,
-                        help="audio sample_rate from client") 
-    parser.add_argument("--wav_scp",
-                        type=str,
-                        required=True,
-                        help="audio wav scp")                    
-    args = parser.parse_args()
-    
-    asyncio.run(test(args))
+  logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
+  parser = argparse.ArgumentParser()
+  parser.add_argument("--host",
+                      type=str,
+                      default="127.0.0.1",
+                      required=False,
+                      help="grpc server host ip")
+  parser.add_argument("--port",
+                      type=int,
+                      default=10100,
+                      required=False,
+                      help="grpc server port")
+  parser.add_argument("--wav_path",
+                      type=str,
+                      required=True,
+                      help="audio wav path")
+  args = parser.parse_args()
+
+  for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
+    mode_name = DecodeMode.Name(mode)
+    logging.info("[request] start requesting with mode {}".format(mode_name))
+
+    st = time.time()
+    uri = '{}:{}'.format(args.host, args.port)
+    client = GrpcClient(args.wav_path, uri, mode)
+    logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))

--
Gitblit v1.9.1