From b8dd1e310b631cba7f8124ff30783ea60026b724 Mon Sep 17 00:00:00 2001
From: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Date: 星期三, 24 五月 2023 15:43:48 +0800
Subject: [PATCH] add ssl for ws_clinet and 2pass offline srv (#546)

---
 funasr/runtime/python/websocket/ws_server_offline.py |   14 +++++-
 funasr/runtime/python/websocket/README.md            |   19 ++++++---
 funasr/runtime/python/websocket/ws_client.py         |   19 ++++++++-
 funasr/runtime/python/websocket/ws_server_2pass.py   |   13 +++++-
 4 files changed, 51 insertions(+), 14 deletions(-)

diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md
index 5d1639d..f7c3cc7 100644
--- a/funasr/runtime/python/websocket/README.md
+++ b/funasr/runtime/python/websocket/README.md
@@ -29,11 +29,13 @@
 --asr_model [asr model_name] \
 --punc_model [punc model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
-python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" --certfile ./server.crt --keyfile ./server.key
 ```
 
 #### ASR streaming server
@@ -43,11 +45,13 @@
 --port [port id] \
 --asr_model_online [asr model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
-python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" --certfile ./server.crt --keyfile ./server.key
 ```
 
 #### ASR offline/online 2pass server
@@ -59,7 +63,9 @@
 --asr_model_online [asr model_name] \
 --punc_model [punc model_name] \
 --ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl] 
 ```
 ##### Usage examples
 ```shell
@@ -86,7 +92,8 @@
 --words_max_print [max number of words to print] \
 --audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
 --output_dir [if set, write the results to output_dir] \
---send_without_sleep [only set for offline]
+--send_without_sleep [only set for offline] \
+--ssl [1 for wss connect, 0 for ws, default is 1]
 ```
 #### Usage examples
 ##### ASR offline client
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index de5a1d8..f7dfcaf 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -1,7 +1,7 @@
 # -*- encoding: utf-8 -*-
 import os
 import time
-import websockets
+import websockets,ssl
 import asyncio
 # import threading
 import argparse
@@ -53,6 +53,11 @@
                     type=str,
                     default=None,
                     help="output_dir")
+                    
+parser.add_argument("--ssl",
+                    type=int,
+                    default=1,
+                    help="1 for ssl connect, 0 for no ssl")
 
 args = parser.parse_args()
 args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
@@ -221,8 +226,16 @@
 
 async def ws_client(id,chunk_begin,chunk_size):
     global websocket
-    uri = "ws://{}:{}".format(args.host, args.port)
-    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
+    if  args.ssl==1:
+       ssl_context = ssl.SSLContext()
+       ssl_context.check_hostname = False
+       ssl_context.verify_mode = ssl.CERT_NONE
+       uri = "wss://{}:{}".format(args.host, args.port)
+    else:
+       uri = "ws://{}:{}".format(args.host, args.port)
+       ssl_context=None
+    print("connect to",uri)
+    async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
         if args.audio_in is not None:
             task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
         else:
diff --git a/funasr/runtime/python/websocket/ws_server_2pass.py b/funasr/runtime/python/websocket/ws_server_2pass.py
index e5cab9c..df13ad9 100644
--- a/funasr/runtime/python/websocket/ws_server_2pass.py
+++ b/funasr/runtime/python/websocket/ws_server_2pass.py
@@ -5,7 +5,7 @@
 import logging
 import tracemalloc
 import numpy as np
-
+import ssl
 from parse_args import args
 from modelscope.pipelines import pipeline
 from modelscope.utils.constant import Tasks
@@ -191,7 +191,16 @@
                 message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
                 await websocket.send(message)
 
+if len(args.certfile)>0:
+	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+	
+	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+	ssl_cert = args.certfile
+	ssl_key = args.keyfile
 
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 asyncio.get_event_loop().run_until_complete(start_server)
 asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_offline.py b/funasr/runtime/python/websocket/ws_server_offline.py
index 1fcc246..1ea1ff7 100644
--- a/funasr/runtime/python/websocket/ws_server_offline.py
+++ b/funasr/runtime/python/websocket/ws_server_offline.py
@@ -5,6 +5,7 @@
 import logging
 import tracemalloc
 import numpy as np
+import ssl
 
 from parse_args import args
 from modelscope.pipelines import pipeline
@@ -147,9 +148,16 @@
                 await websocket.send(message)
                 
                 
- 
+if len(args.certfile)>0:
+	ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+	
+	# Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+	ssl_cert = args.certfile
+	ssl_key = args.keyfile
 
-
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+	ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+	start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
 asyncio.get_event_loop().run_until_complete(start_server)
 asyncio.get_event_loop().run_forever()
\ No newline at end of file

--
Gitblit v1.9.1