From b8dd1e310b631cba7f8124ff30783ea60026b724 Mon Sep 17 00:00:00 2001
From: zhaomingwork <61895407+zhaomingwork@users.noreply.github.com>
Date: 星期三, 24 五月 2023 15:43:48 +0800
Subject: [PATCH] add ssl for ws_clinet and 2pass offline srv (#546)
---
funasr/runtime/python/websocket/ws_server_offline.py | 14 +++++-
funasr/runtime/python/websocket/README.md | 19 ++++++---
funasr/runtime/python/websocket/ws_client.py | 19 ++++++++-
funasr/runtime/python/websocket/ws_server_2pass.py | 13 +++++-
4 files changed, 51 insertions(+), 14 deletions(-)
diff --git a/funasr/runtime/python/websocket/README.md b/funasr/runtime/python/websocket/README.md
index 5d1639d..f7c3cc7 100644
--- a/funasr/runtime/python/websocket/README.md
+++ b/funasr/runtime/python/websocket/README.md
@@ -29,11 +29,13 @@
--asr_model [asr model_name] \
--punc_model [punc model_name] \
--ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl]
```
##### Usage examples
```shell
-python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch"
+python ws_server_offline.py --port 10095 --asr_model "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" --certfile ./server.crt --keyfile ./server.key
```
#### ASR streaming server
@@ -43,11 +45,13 @@
--port [port id] \
--asr_model_online [asr model_name] \
--ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl]
```
##### Usage examples
```shell
-python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+python ws_server_online.py --port 10095 --asr_model_online "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online" --certfile ./server.crt --keyfile ./server.key
```
#### ASR offline/online 2pass server
@@ -59,7 +63,9 @@
--asr_model_online [asr model_name] \
--punc_model [punc model_name] \
--ngpu [0 or 1] \
---ncpu [1 or 4]
+--ncpu [1 or 4] \
+--certfile [path of certfile for ssl] \
+--keyfile [path of keyfile for ssl]
```
##### Usage examples
```shell
@@ -86,7 +92,8 @@
--words_max_print [max number of words to print] \
--audio_in [if set, loadding from wav.scp, else recording from mircrophone] \
--output_dir [if set, write the results to output_dir] \
---send_without_sleep [only set for offline]
+--send_without_sleep [only set for offline] \
+--ssl [1 for wss connect, 0 for ws, default is 1]
```
#### Usage examples
##### ASR offline client
diff --git a/funasr/runtime/python/websocket/ws_client.py b/funasr/runtime/python/websocket/ws_client.py
index de5a1d8..f7dfcaf 100644
--- a/funasr/runtime/python/websocket/ws_client.py
+++ b/funasr/runtime/python/websocket/ws_client.py
@@ -1,7 +1,7 @@
# -*- encoding: utf-8 -*-
import os
import time
-import websockets
+import websockets,ssl
import asyncio
# import threading
import argparse
@@ -53,6 +53,11 @@
type=str,
default=None,
help="output_dir")
+
+parser.add_argument("--ssl",
+ type=int,
+ default=1,
+ help="1 for ssl connect, 0 for no ssl")
args = parser.parse_args()
args.chunk_size = [int(x) for x in args.chunk_size.split(",")]
@@ -221,8 +226,16 @@
async def ws_client(id,chunk_begin,chunk_size):
global websocket
- uri = "ws://{}:{}".format(args.host, args.port)
- async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None):
+ if args.ssl==1:
+ ssl_context = ssl.SSLContext()
+ ssl_context.check_hostname = False
+ ssl_context.verify_mode = ssl.CERT_NONE
+ uri = "wss://{}:{}".format(args.host, args.port)
+ else:
+ uri = "ws://{}:{}".format(args.host, args.port)
+ ssl_context=None
+ print("connect to",uri)
+ async for websocket in websockets.connect(uri, subprotocols=["binary"], ping_interval=None,ssl=ssl_context):
if args.audio_in is not None:
task = asyncio.create_task(record_from_scp(chunk_begin,chunk_size))
else:
diff --git a/funasr/runtime/python/websocket/ws_server_2pass.py b/funasr/runtime/python/websocket/ws_server_2pass.py
index e5cab9c..df13ad9 100644
--- a/funasr/runtime/python/websocket/ws_server_2pass.py
+++ b/funasr/runtime/python/websocket/ws_server_2pass.py
@@ -5,7 +5,7 @@
import logging
import tracemalloc
import numpy as np
-
+import ssl
from parse_args import args
from modelscope.pipelines import pipeline
from modelscope.utils.constant import Tasks
@@ -191,7 +191,16 @@
message = json.dumps({"mode": "2pass-online", "text": rec_result["text"], "wav_name": websocket.wav_name})
await websocket.send(message)
+if len(args.certfile)>0:
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+
+ # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+ ssl_cert = args.certfile
+ ssl_key = args.keyfile
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+ ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
\ No newline at end of file
diff --git a/funasr/runtime/python/websocket/ws_server_offline.py b/funasr/runtime/python/websocket/ws_server_offline.py
index 1fcc246..1ea1ff7 100644
--- a/funasr/runtime/python/websocket/ws_server_offline.py
+++ b/funasr/runtime/python/websocket/ws_server_offline.py
@@ -5,6 +5,7 @@
import logging
import tracemalloc
import numpy as np
+import ssl
from parse_args import args
from modelscope.pipelines import pipeline
@@ -147,9 +148,16 @@
await websocket.send(message)
-
+if len(args.certfile)>0:
+ ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
+
+ # Generate with Lets Encrypt, copied to this location, chown to current user and 400 permissions
+ ssl_cert = args.certfile
+ ssl_key = args.keyfile
-
-start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
+ ssl_context.load_cert_chain(ssl_cert, keyfile=ssl_key)
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None,ssl=ssl_context)
+else:
+ start_server = websockets.serve(ws_serve, args.host, args.port, subprotocols=["binary"], ping_interval=None)
asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()
\ No newline at end of file
--
Gitblit v1.9.1