From 7375292887752a995cef6f77f3d837db67c981a4 Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 16 六月 2023 10:09:39 +0800
Subject: [PATCH] Merge branch 'main' into dev_wjm_infer
---
docs/images/dingding.jpg | 0
funasr/runtime/websocket/CMakeLists.txt | 8 +-
funasr/runtime/websocket/funasr-wss-server.cpp | 39 +++++----
funasr/runtime/websocket/funasr-wss-client.cpp | 40 ++++++---
egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py | 2
funasr/runtime/websocket/websocket-server.cpp | 28 +++---
funasr/runtime/python/websocket/wss_client_asr.py | 23 ++++-
funasr/runtime/websocket/readme.md | 27 +++---
tests/test_sv_inference_pipeline.py | 7 -
funasr/runtime/onnxruntime/src/offline-stream.cpp | 26 +++++-
tests/test_asr_inference_pipeline.py | 6 +
tests/test_asr_vad_punc_inference_pipeline.py | 1
12 files changed, 123 insertions(+), 84 deletions(-)
diff --git a/docs/images/dingding.jpg b/docs/images/dingding.jpg
index 6ac3ab8..9c9166c 100644
--- a/docs/images/dingding.jpg
+++ b/docs/images/dingding.jpg
Binary files differ
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
index dc867b0..aa0db93 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch/infer.py
@@ -17,7 +17,7 @@
diar_model_config="sond.yaml",
model='damo/speech_diarization_sond-zh-cn-alimeeting-16k-n16k4-pytorch',
sv_model="damo/speech_xvector_sv-zh-cn-cnceleb-16k-spk3465-pytorch",
- sv_model_revision="master",
+ sv_model_revision="v1.2.2",
)
# use audio_list as the input, where the first one is the record to be detected
diff --git a/funasr/runtime/onnxruntime/src/offline-stream.cpp b/funasr/runtime/onnxruntime/src/offline-stream.cpp
index 8170129..d96cf27 100644
--- a/funasr/runtime/onnxruntime/src/offline-stream.cpp
+++ b/funasr/runtime/onnxruntime/src/offline-stream.cpp
@@ -1,11 +1,11 @@
#include "precomp.h"
+#include <unistd.h>
namespace funasr {
OfflineStream::OfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
{
// VAD model
if(model_path.find(VAD_DIR) != model_path.end()){
- use_vad = true;
string vad_model_path;
string vad_cmvn_path;
string vad_config_path;
@@ -16,8 +16,16 @@
}
vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
- vad_handle = make_unique<FsmnVad>();
- vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ if (access(vad_model_path.c_str(), F_OK) != 0 ||
+ access(vad_cmvn_path.c_str(), F_OK) != 0 ||
+ access(vad_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "VAD model file is not exist, skip load vad model.";
+ }else{
+ vad_handle = make_unique<FsmnVad>();
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ use_vad = true;
+ }
}
// AM model
@@ -39,7 +47,6 @@
// PUNC model
if(model_path.find(PUNC_DIR) != model_path.end()){
- use_punc = true;
string punc_model_path;
string punc_config_path;
@@ -49,8 +56,15 @@
}
punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
- punc_handle = make_unique<CTTransformer>();
- punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ if (access(punc_model_path.c_str(), F_OK) != 0 ||
+ access(punc_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
+ }else{
+ punc_handle = make_unique<CTTransformer>();
+ punc_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ use_punc = true;
+ }
}
}
diff --git a/funasr/runtime/python/websocket/wss_client_asr.py b/funasr/runtime/python/websocket/wss_client_asr.py
index bd9e89f..2ea8a16 100644
--- a/funasr/runtime/python/websocket/wss_client_asr.py
+++ b/funasr/runtime/python/websocket/wss_client_asr.py
@@ -71,6 +71,8 @@
from queue import Queue
voices = Queue()
+offline_msg_done=False
+
ibest_writer = None
if args.output_dir is not None:
writer = DatadirWriter(args.output_dir)
@@ -158,13 +160,20 @@
message = json.dumps({"is_speaking": is_speaking})
#voices.put(message)
await websocket.send(message)
- # print("data_chunk: ", len(data_chunk))
- # print(voices.qsize())
+
sleep_duration = 0.001 if args.send_without_sleep else 60 * args.chunk_size[1] / args.chunk_interval / 1000
await asyncio.sleep(sleep_duration)
+ # when all data sent, we need to close websocket
while not voices.empty():
await asyncio.sleep(1)
await asyncio.sleep(3)
+ # offline model need to wait for message recved
+
+ if args.mode=="offline":
+ global offline_msg_done
+ while not offline_msg_done:
+ await asyncio.sleep(1)
+
await websocket.close()
@@ -173,7 +182,7 @@
async def message(id):
- global websocket,voices
+ global websocket,voices,offline_msg_done
text_print = ""
text_print_2pass_online = ""
text_print_2pass_offline = ""
@@ -183,7 +192,6 @@
meg = await websocket.recv()
meg = json.loads(meg)
wav_name = meg.get("wav_name", "demo")
- # print(wav_name)
text = meg["text"]
if ibest_writer is not None:
ibest_writer["text"][wav_name] = text
@@ -198,6 +206,7 @@
text_print = text_print[-args.words_max_print:]
os.system('clear')
print("\rpid" + str(id) + ": " + text_print)
+ offline_msg_done=True
else:
if meg["mode"] == "2pass-online":
text_print_2pass_online += "{}".format(text)
@@ -233,8 +242,10 @@
if args.audio_in is None:
chunk_begin=0
chunk_size=1
- global websocket,voices
+ global websocket,voices,offline_msg_done
+
for i in range(chunk_begin,chunk_begin+chunk_size):
+ offline_msg_done=False
voices = Queue()
if args.ssl == 1:
ssl_context = ssl.SSLContext()
@@ -251,7 +262,7 @@
else:
task = asyncio.create_task(record_microphone())
#task2 = asyncio.create_task(ws_send())
- task3 = asyncio.create_task(message(id))
+ task3 = asyncio.create_task(message(str(id)+"_"+str(i))) #processid+fileid
await asyncio.gather(task, task3)
exit(0)
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index c1715d8..513e48d 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -56,8 +56,8 @@
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
-add_executable(funasr-ws-server "funasr-ws-server.cpp" "websocket-server.cpp")
-add_executable(funasr-ws-client "funasr-ws-client.cpp")
+add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
+add_executable(funasr-wss-client "funasr-wss-client.cpp")
-target_link_libraries(funasr-ws-client PUBLIC funasr ssl crypto)
-target_link_libraries(funasr-ws-server PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
diff --git a/funasr/runtime/websocket/funasr-ws-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
similarity index 92%
rename from funasr/runtime/websocket/funasr-ws-client.cpp
rename to funasr/runtime/websocket/funasr-wss-client.cpp
index 23c68cc..4a3c751 100644
--- a/funasr/runtime/websocket/funasr-ws-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -5,7 +5,14 @@
/* 2022-2023 by zhaomingwork */
// client for websocket, support multiple threads
-// Usage: websocketclient server_ip port wav_path threads_num
+// ./funasr-ws-client --server-ip <string>
+// --port <string>
+// --wav-path <string>
+// [--thread-num <int>]
+// [--is-ssl <int>] [--]
+// [--version] [-h]
+// example:
+// ./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
#define ASIO_STANDALONE 1
#include <websocketpp/client.hpp>
@@ -55,7 +62,7 @@
asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
} catch (std::exception& e) {
- std::cout << e.what() << std::endl;
+ LOG(ERROR) << e.what();
}
return ctx;
}
@@ -99,7 +106,16 @@
const std::string& payload = msg->get_payload();
switch (msg->get_opcode()) {
case websocketpp::frame::opcode::text:
- std::cout << "on_message = " << payload << std::endl;
+ total_num=total_num+1;
+ LOG(INFO)<<total_num<<",on_message = " << payload;
+ if((total_num+1)==wav_index)
+ {
+ websocketpp::lib::error_code ec;
+ m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
+ if (ec){
+ LOG(ERROR)<< "Error closing connection " << ec.message();
+ }
+ }
}
}
@@ -132,12 +148,8 @@
}
send_wav_data(wav_list[i], wav_ids[i]);
}
- WaitABit();
- m_client.close(m_hdl,websocketpp::close::status::going_away, "", ec);
- if (ec) {
- std::cout << "> Error closing connection " << ec.message() << std::endl;
- }
- //send_wav_data();
+ WaitABit();
+
asio_thread.join();
}
@@ -206,7 +218,7 @@
}
}
if (wait) {
- std::cout << "wait.." << m_open << std::endl;
+ LOG(INFO) << "wait.." << m_open;
WaitABit();
continue;
}
@@ -236,7 +248,7 @@
// send data to server
m_client.send(m_hdl, iArray, len * sizeof(short),
websocketpp::frame::opcode::binary, ec);
- std::cout << "sended data len=" << len * sizeof(short) << std::endl;
+ LOG(INFO) << "sended data len=" << len * sizeof(short);
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
@@ -247,14 +259,13 @@
"Send Error: " + ec.message());
break;
}
-
- WaitABit();
+ // WaitABit();
}
nlohmann::json jsonresult;
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
ec);
- WaitABit();
+ // WaitABit();
}
websocketpp::client<T> m_client;
@@ -263,6 +274,7 @@
websocketpp::lib::mutex m_lock;
bool m_open;
bool m_done;
+ int total_num=0;
};
int main(int argc, char* argv[]) {
diff --git a/funasr/runtime/websocket/funasr-ws-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp
similarity index 80%
rename from funasr/runtime/websocket/funasr-ws-server.cpp
rename to funasr/runtime/websocket/funasr-wss-server.cpp
index 67a0f2d..5f2af5c 100644
--- a/funasr/runtime/websocket/funasr-ws-server.cpp
+++ b/funasr/runtime/websocket/funasr-wss-server.cpp
@@ -5,7 +5,7 @@
/* 2022-2023 by zhaomingwork */
// io server
-// Usage:websocketmain [--model_thread_num <int>] [--decoder_thread_num <int>]
+// Usage:funasr-ws-server [--model_thread_num <int>] [--decoder_thread_num <int>]
// [--io_thread_num <int>] [--port <int>] [--listen_ip
// <string>] [--punc-quant <string>] [--punc-dir <string>]
// [--vad-quant <string>] [--vad-dir <string>] [--quantize
@@ -15,44 +15,43 @@
using namespace std;
void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
std::map<std::string, std::string>& model_path) {
- if (value_arg.isSet()) {
model_path.insert({key, value_arg.getValue()});
LOG(INFO) << key << " : " << value_arg.getValue();
- }
}
int main(int argc, char* argv[]) {
try {
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
- TCLAP::CmdLine cmd("websocketmain", ' ', "1.0");
+ TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
TCLAP::ValueArg<std::string> model_dir(
"", MODEL_DIR,
- "the asr model path, which contains model.onnx, config.yaml, am.mvn",
- true, "", "string");
+ "default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn",
+ false, "/workspace/models/asr", "string");
TCLAP::ValueArg<std::string> quantize(
"", QUANTIZE,
- "false (Default), load the model of model.onnx in model_dir. If set "
+ "true (Default), load the model of model.onnx in model_dir. If set "
"true, load the model of model_quant.onnx in model_dir",
- false, "false", "string");
+ false, "true", "string");
TCLAP::ValueArg<std::string> vad_dir(
"", VAD_DIR,
- "the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
- false, "", "string");
+ "default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn",
+ false, "/workspace/models/vad", "string");
TCLAP::ValueArg<std::string> vad_quant(
"", VAD_QUANT,
- "false (Default), load the model of model.onnx in vad_dir. If set "
+ "true (Default), load the model of model.onnx in vad_dir. If set "
"true, load the model of model_quant.onnx in vad_dir",
- false, "false", "string");
+ false, "true", "string");
TCLAP::ValueArg<std::string> punc_dir(
"", PUNC_DIR,
- "the punc model path, which contains model.onnx, punc.yaml", false, "",
+ "default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml",
+ false, "/workspace/models/punc",
"string");
TCLAP::ValueArg<std::string> punc_quant(
"", PUNC_QUANT,
- "false (Default), load the model of model.onnx in punc_dir. If set "
+ "true (Default), load the model of model.onnx in punc_dir. If set "
"true, load the model of model_quant.onnx in punc_dir",
- false, "false", "string");
+ false, "true", "string");
TCLAP::ValueArg<std::string> listen_ip("", "listen_ip", "listen_ip", false,
"0.0.0.0", "string");
@@ -64,10 +63,12 @@
TCLAP::ValueArg<int> model_thread_num("", "model_thread_num",
"model_thread_num", false, 1, "int");
- TCLAP::ValueArg<std::string> certfile("", "certfile", "certfile", false, "",
- "string");
- TCLAP::ValueArg<std::string> keyfile("", "keyfile", "keyfile", false, "",
- "string");
+ TCLAP::ValueArg<std::string> certfile("", "certfile",
+ "default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",
+ false, "../../../ssl_key/server.crt", "string");
+ TCLAP::ValueArg<std::string> keyfile("", "keyfile",
+ "default: ../../../ssl_key/server.key, path of keyfile for WSS connection",
+ false, "../../../ssl_key/server.key", "string");
cmd.add(certfile);
cmd.add(keyfile);
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index 4a1a9d4..0cebe64 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -51,7 +51,7 @@
```shell
cd bin
- ./funasr-ws-server [--model_thread_num <int>] [--decoder_thread_num <int>]
+./funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num <int>]
[--io_thread_num <int>] [--port <int>] [--listen_ip
<string>] [--punc-quant <string>] [--punc-dir <string>]
[--vad-quant <string>] [--vad-dir <string>] [--quantize
@@ -59,19 +59,19 @@
[--certfile <string>] [--] [--version] [-h]
Where:
--model-dir <string>
- (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
+ default: /workspace/models/asr, the asr model path, which contains model.onnx, config.yaml, am.mvn
--quantize <string>
- false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+ true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
--vad-dir <string>
- the vad model path, which contains model.onnx, vad.yaml, vad.mvn
+ default: /workspace/models/vad, the vad model path, which contains model.onnx, vad.yaml, vad.mvn
--vad-quant <string>
- false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+ true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
--punc-dir <string>
- the punc model path, which contains model.onnx, punc.yaml
+ default: /workspace/models/punc, the punc model path, which contains model.onnx, punc.yaml
--punc-quant <string>
- false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
+ true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
--decoder_thread_num <int>
number of threads for decoder, default:8
@@ -80,21 +80,18 @@
--port <int>
listen port, default:8889
--certfile <string>
- path of certficate for WSS connection. if it is empty, it will be in WS mode.
+ default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.
--keyfile <string>
- path of keyfile for WSS connection
+ default: ../../../ssl_key/server.key, path of keyfile for WSS connection
- Required: --model-dir <string>
- If use vad, please add: --vad-dir <string>
- If use punc, please add: --punc-dir <string>
example:
- funasr-ws-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+./funasr-wss-server --model-dir /FunASR/funasr/runtime/onnxruntime/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
```
## Run websocket client test
```shell
-./funasr-ws-client --server-ip <string>
+./funasr-wss-client --server-ip <string>
--port <string>
--wav-path <string>
[--thread-num <int>]
@@ -119,7 +116,7 @@
is-ssl is 1 means use wss connection, or use ws connection
example:
-./funasr-ws-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 0
+./funasr-wss-client --server-ip 127.0.0.1 --port 8889 --wav-path test.wav --thread-num 1 --is-ssl 1
result json, example like:
{"mode":"offline","text":"娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�","wav_name":"wav2"}
diff --git a/funasr/runtime/websocket/websocket-server.cpp b/funasr/runtime/websocket/websocket-server.cpp
index 91b6c9d..a311c23 100644
--- a/funasr/runtime/websocket/websocket-server.cpp
+++ b/funasr/runtime/websocket/websocket-server.cpp
@@ -22,12 +22,11 @@
std::string& s_keyfile) {
namespace asio = websocketpp::lib::asio;
- std::cout << "on_tls_init called with hdl: " << hdl.lock().get() << std::endl;
- std::cout << "using TLS mode: "
+ LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
+ LOG(INFO) << "using TLS mode: "
<< (mode == MOZILLA_MODERN ? "Mozilla Modern"
- : "Mozilla Intermediate")
- << std::endl;
-
+ : "Mozilla Intermediate");
+
context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
asio::ssl::context::sslv23);
@@ -49,7 +48,7 @@
ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
} catch (std::exception& e) {
- std::cout << "Exception: " << e.what() << std::endl;
+ LOG(INFO) << "Exception: " << e.what();
}
return ctx;
}
@@ -86,8 +85,7 @@
ec);
}
- std::cout << "buffer.size=" << buffer.size()
- << ",result json=" << jsonresult.dump() << std::endl;
+ LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
if (!isonline) {
// close the client if it is not online asr
// server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
@@ -110,14 +108,14 @@
data_msg->samples = std::make_shared<std::vector<char>>();
data_msg->msg = nlohmann::json::parse("{}");
data_map.emplace(hdl, data_msg);
- std::cout << "on_open, active connections: " << data_map.size() << std::endl;
+ LOG(INFO) << "on_open, active connections: " << data_map.size();
}
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
scoped_lock guard(m_lock);
data_map.erase(hdl); // remove data vector when connection is closed
- std::cout << "on_close, active connections: " << data_map.size() << std::endl;
+ LOG(INFO) << "on_close, active connections: " << data_map.size();
}
// remove closed connection
@@ -143,7 +141,7 @@
}
for (auto hdl : to_remove) {
data_map.erase(hdl);
- std::cout << "remove one connection " << std::endl;
+ LOG(INFO)<< "remove one connection ";
}
}
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
@@ -161,7 +159,7 @@
lock.unlock();
if (sample_data_p == nullptr) {
- std::cout << "error when fetch sample data vector" << std::endl;
+ LOG(INFO) << "error when fetch sample data vector";
return;
}
@@ -176,7 +174,7 @@
if (jsonresult["is_speaking"] == false ||
jsonresult["is_finished"] == true) {
- std::cout << "client done" << std::endl;
+ LOG(INFO) << "client done";
if (isonline) {
// do_close(ws);
@@ -225,9 +223,9 @@
// init model with api
asr_hanlde = FunOfflineInit(model_path, thread_num);
- std::cout << "model ready" << std::endl;
+ LOG(INFO) << "model successfully inited";
} catch (const std::exception& e) {
- std::cout << e.what() << std::endl;
+ LOG(INFO) << e.what();
}
}
diff --git a/tests/test_asr_inference_pipeline.py b/tests/test_asr_inference_pipeline.py
index 9098ea6..2b21acf 100644
--- a/tests/test_asr_inference_pipeline.py
+++ b/tests/test_asr_inference_pipeline.py
@@ -87,6 +87,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_hotword.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "鍥藉姟闄㈠彂灞曠爺绌朵腑蹇冨競鍦虹粡娴庣爺绌舵墍鍓墍闀块倱閮佹澗璁や负"
def test_paraformer_large_aishell1(self):
inference_pipeline = pipeline(
@@ -95,6 +96,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_aishell2(self):
inference_pipeline = pipeline(
@@ -103,6 +105,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_common(self):
inference_pipeline = pipeline(
@@ -111,6 +114,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨�"
def test_paraformer_large_online_common(self):
inference_pipeline = pipeline(
@@ -119,6 +123,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶� 瀹舵潵 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
def test_paraformer_online_common(self):
inference_pipeline = pipeline(
@@ -127,6 +132,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋 澶у鏉� 浣撻獙杈� 鎽╅櫌鎺� 鍑虹殑 璇煶璇� 鍒ā 鍨�"
def test_paraformer_tiny_commandword(self):
inference_pipeline = pipeline(
diff --git a/tests/test_asr_vad_punc_inference_pipeline.py b/tests/test_asr_vad_punc_inference_pipeline.py
index 628b256..f86f23d 100644
--- a/tests/test_asr_vad_punc_inference_pipeline.py
+++ b/tests/test_asr_vad_punc_inference_pipeline.py
@@ -26,6 +26,7 @@
rec_result = inference_pipeline(
audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
logger.info("asr_vad_punc inference result: {0}".format(rec_result))
+ assert rec_result["text"] == "娆㈣繋澶у鏉ヤ綋楠岃揪鎽╅櫌鎺ㄥ嚭鐨勮闊宠瘑鍒ā鍨嬨��"
if __name__ == '__main__':
diff --git a/tests/test_sv_inference_pipeline.py b/tests/test_sv_inference_pipeline.py
index 09139b9..c4e427e 100644
--- a/tests/test_sv_inference_pipeline.py
+++ b/tests/test_sv_inference_pipeline.py
@@ -24,16 +24,15 @@
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_same.wav'))
- assert abs(rec_result["scores"][0] - 0.85) < 0.1 and abs(rec_result["scores"][1] - 0.14) < 0.1
+ assert abs(rec_result["scores"][0]-0.85) < 0.1 and abs(rec_result["scores"][1]-0.14) < 0.1
logger.info(f"Similarity {rec_result['scores']}")
-
+
# different speaker
rec_result = inference_sv_pipline(audio_in=(
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_enroll.wav',
'https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/sv_example_different.wav'))
- assert abs(rec_result["scores"][0] - 0.0) < 0.1 and abs(rec_result["scores"][1] - 1.0) < 0.1
+ assert abs(rec_result["scores"][0]-0.0) < 0.1 and abs(rec_result["scores"][1]-1.0) < 0.1
logger.info(f"Similarity {rec_result['scores']}")
-
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
--
Gitblit v1.9.1