From 6e26ad0e149ae51e3fc8b89b3178684979e6bbd1 Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期四, 09 十一月 2023 11:03:45 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR into main

---
 runtime/python/websocket/funasr_wss_client.py |  110 ++++++++++++++++++++++++++++++++++++++-----------------
 1 files changed, 76 insertions(+), 34 deletions(-)

diff --git a/runtime/python/websocket/funasr_wss_client.py b/runtime/python/websocket/funasr_wss_client.py
index 7c96553..66b3ce0 100644
--- a/runtime/python/websocket/funasr_wss_client.py
+++ b/runtime/python/websocket/funasr_wss_client.py
@@ -27,20 +27,16 @@
                     help="grpc server port")
 parser.add_argument("--chunk_size",
                     type=str,
-                    default="0, 10, 5",
+                    default="5, 10, 5",
                     help="chunk")
-parser.add_argument("--encoder_chunk_look_back",
-                    type=int,
-                    default=4,
-                    help="number of chunks to lookback for encoder self-attention")
-parser.add_argument("--decoder_chunk_look_back",
-                    type=int,
-                    default=1,
-                    help="number of encoder chunks to lookback for decoder cross-attention")
 parser.add_argument("--chunk_interval",
                     type=int,
                     default=10,
                     help="chunk")
+parser.add_argument("--hotword",
+                    type=str,
+                    default="",
+                    help="hotword file path, one hotword perline (e.g.:闃块噷宸村反 20)")
 parser.add_argument("--audio_in",
                     type=str,
                     default=None,
@@ -61,11 +57,14 @@
                     type=str,
                     default=None,
                     help="output_dir")
-
 parser.add_argument("--ssl",
                     type=int,
                     default=1,
                     help="1 for ssl connect, 0 for no ssl")
+parser.add_argument("--use_itn",
+                    type=int,
+                    default=1,
+                    help="1 for using itn, 0 for not itn")
 parser.add_argument("--mode",
                     type=str,
                     default="2pass",
@@ -106,10 +105,29 @@
                     rate=RATE,
                     input=True,
                     frames_per_buffer=CHUNK)
+    # hotwords
+    fst_dict = {}
+    hotword_msg = ""
+    if args.hotword.strip() != "":
+        f_scp = open(args.hotword)
+        hot_lines = f_scp.readlines()
+        for line in hot_lines:
+            words = line.strip().split(" ")
+            if len(words) < 2:
+                print("Please checkout format of hotwords")
+                continue
+            try:
+                fst_dict[" ".join(words[:-1])] = int(words[-1])
+            except ValueError:
+                print("Please checkout format of hotwords")
+        hotword_msg=json.dumps(fst_dict)
 
-    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "encoder_chunk_look_back": args.encoder_chunk_look_back,
-                          "decoder_chunk_look_back": args.decoder_chunk_look_back, "chunk_interval": args.chunk_interval, 
-                          "wav_name": "microphone", "is_speaking": True})
+    use_itn=True
+    if args.use_itn == 0:
+        use_itn=False
+    
+    message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
+                          "wav_name": "microphone", "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
     #voices.put(message)
     await websocket.send(message)
     while True:
@@ -127,6 +145,31 @@
         wavs = f_scp.readlines()
     else:
         wavs = [args.audio_in]
+
+    # hotwords
+    fst_dict = {}
+    hotword_msg = ""
+    if args.hotword.strip() != "":
+        f_scp = open(args.hotword)
+        hot_lines = f_scp.readlines()
+        for line in hot_lines:
+            words = line.strip().split(" ")
+            if len(words) < 2:
+                print("Please checkout format of hotwords")
+                continue
+            try:
+                fst_dict[" ".join(words[:-1])] = int(words[-1])
+            except ValueError:
+                print("Please checkout format of hotwords")
+        hotword_msg=json.dumps(fst_dict)
+        print (hotword_msg)
+
+    sample_rate = 16000
+    wav_format = "pcm"
+    use_itn=True
+    if args.use_itn == 0:
+        use_itn=False
+     
     if chunk_size > 0:
         wavs = wavs[chunk_begin:chunk_begin + chunk_size]
     for wav in wavs:
@@ -143,20 +186,13 @@
             import wave
             with wave.open(wav_path, "rb") as wav_file:
                 params = wav_file.getparams()
+                sample_rate = wav_file.getframerate()
                 frames = wav_file.readframes(wav_file.getnframes())
                 audio_bytes = bytes(frames)
         else:
-            import ffmpeg
-            try:
-                # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
-                # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
-                audio_bytes, _ = (
-                    ffmpeg.input(wav_path, threads=0)
-                    .output("-", format="s16le", acodec="pcm_s16le", ac=1, ar=16000)
-                    .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
-                )
-            except ffmpeg.Error as e:
-                raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
+            wav_format = "others"
+            with open(wav_path, "rb") as f:
+                audio_bytes = f.read()
 
         # stride = int(args.chunk_size/1000*16000*2)
         stride = int(60 * args.chunk_size[1] / args.chunk_interval / 1000 * 16000 * 2)
@@ -164,8 +200,9 @@
         # print(stride)
 
         # send first time
-        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval,
-                              "wav_name": wav_name, "is_speaking": True})
+        message = json.dumps({"mode": args.mode, "chunk_size": args.chunk_size, "chunk_interval": args.chunk_interval, "audio_fs":sample_rate,
+                          "wav_name": wav_name, "wav_format": wav_format, "is_speaking": True, "hotwords":hotword_msg, "itn": use_itn})
+
         #voices.put(message)
         await websocket.send(message)
         is_speaking = True
@@ -213,12 +250,17 @@
         
             meg = await websocket.recv()
             meg = json.loads(meg)
-            # print(meg)
             wav_name = meg.get("wav_name", "demo")
             text = meg["text"]
+            timestamp=""
+            if "timestamp" in meg:
+                timestamp = meg["timestamp"]
 
             if ibest_writer is not None:
-                text_write_line = "{}\t{}\n".format(wav_name, text)
+                if timestamp !="":
+                    text_write_line = "{}\t{}\t{}\n".format(wav_name, text, timestamp)
+                else:
+                    text_write_line = "{}\t{}\n".format(wav_name, text)
                 ibest_writer.write(text_write_line)
                 
             if meg["mode"] == "online":
@@ -227,15 +269,15 @@
                 os.system('clear')
                 print("\rpid" + str(id) + ": " + text_print)
             elif meg["mode"] == "offline":
-                text_print += "{}".format(text)
+                if timestamp !="":
+                    text_print += "{} timestamp: {}".format(text, timestamp)
+                else:
+                    text_print += "{}".format(text)
+
                 # text_print = text_print[-args.words_max_print:]
                 # os.system('clear')
                 print("\rpid" + str(id) + ": " + wav_name + ": " + text_print)
-                if ("is_final" in meg and meg["is_final"]==False):
-                    offline_msg_done = True
-                
-                if not "is_final" in meg:
-                    offline_msg_done = True
+                offline_msg_done = True
             else:
                 if meg["mode"] == "2pass-online":
                     text_print_2pass_online += "{}".format(text)

--
Gitblit v1.9.1