From a57dc4a93f9815f943733926d5b8bf285f37e211 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 26 六月 2023 21:46:26 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/datasets/large_datasets/dataset.py | 9 ++
funasr/runtime/websocket/funasr-wss-server.cpp | 158 +++++++++++++++++++++++++++-----------
funasr/utils/prepare_data.py | 7 +
funasr/utils/wav_utils.py | 13 ++
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 2
setup.py | 2
funasr/datasets/iterable_dataset.py | 9 ++
funasr/bin/asr_inference_launch.py | 6 +
funasr/utils/asr_utils.py | 6 +
README.md | 2
10 files changed, 158 insertions(+), 56 deletions(-)
diff --git a/README.md b/README.md
index 8368b3b..4338992 100644
--- a/README.md
+++ b/README.md
@@ -96,10 +96,12 @@
### runtime
An example with websocket:
+
For the server:
```shell
python wss_srv_asr.py --port 10095
```
+
For the client:
```shell
python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 656a965..ce1f984 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import yaml
from typeguard import check_argument_types
@@ -863,7 +864,10 @@
raw_inputs = _load_bytes(data_path_and_name_and_type[0])
raw_inputs = torch.tensor(raw_inputs)
if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
- raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ try:
+ raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+ except:
+ raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
if data_path_and_name_and_type is None and raw_inputs is not None:
if isinstance(raw_inputs, np.ndarray):
raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 4b2fb1a..fa0f0c7 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,14 @@
bytes = f.read()
return load_bytes(bytes)
+def load_wav(input):
+ try:
+ return torchaudio.load(input)[0].numpy()
+ except:
+ return np.expand_dims(soundfile.read(input)[0], axis=0)
+
DATA_TYPES = {
- "sound": lambda x: torchaudio.load(x)[0].numpy(),
+ "sound": load_wav,
"pcm": load_pcm,
"kaldi_ark": load_kaldi,
"bytes": load_bytes,
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 68b63e1..844dde7 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
+import numpy as np
+import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,12 @@
sample_dict["key"] = key
elif data_type == "sound":
key, path = item.strip().split()
- waveform, sampling_rate = torchaudio.load(path)
+ try:
+ waveform, sampling_rate = torchaudio.load(path)
+ except:
+ waveform, sampling_rate = soundfile.read(path)
+ waveform = np.expand_dims(waveform, axis=0)
+ waveform = torch.tensor(waveform)
if self.frontend_conf is not None:
if sampling_rate != self.frontend_conf["fs"]:
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index a4ee7f7..ee05d75 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -59,7 +59,7 @@
if(result){
string msg = FunASRGetResult(result, 0);
- LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
+ LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
float snippet_time = FunASRGetRetSnippetTime(result);
n_total_length += snippet_time;
diff --git a/funasr/runtime/websocket/funasr-wss-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp
index 874f306..e479888 100644
--- a/funasr/runtime/websocket/funasr-wss-server.cpp
+++ b/funasr/runtime/websocket/funasr-wss-server.cpp
@@ -25,7 +25,7 @@
google::InitGoogleLogging(argv[0]);
FLAGS_logtostderr = true;
- TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
+ TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
TCLAP::ValueArg<std::string> download_model_dir(
"", "download-model-dir",
"Download model from Modelscope to download_model_dir",
@@ -105,63 +105,127 @@
// Download model form Modelscope
try{
std::string s_download_model_dir = download_model_dir.getValue();
+ // download model from modelscope when the model-dir is model ID or local path
+ bool is_download = false;
if(download_model_dir.isSet() && !s_download_model_dir.empty()){
+ is_download = true;
if (access(s_download_model_dir.c_str(), F_OK) != 0){
LOG(ERROR) << s_download_model_dir << " do not exists.";
exit(-1);
}
- std::string s_vad_path = model_path[VAD_DIR];
- std::string s_asr_path = model_path[MODEL_DIR];
- std::string s_punc_path = model_path[PUNC_DIR];
- std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
- if(vad_dir.isSet() && !s_vad_path.empty()){
- std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
+ }else{
+ s_download_model_dir="./";
+ }
+ std::string s_vad_path = model_path[VAD_DIR];
+ std::string s_vad_quant = model_path[VAD_QUANT];
+ std::string s_asr_path = model_path[MODEL_DIR];
+ std::string s_asr_quant = model_path[QUANTIZE];
+ std::string s_punc_path = model_path[PUNC_DIR];
+ std::string s_punc_quant = model_path[PUNC_QUANT];
+ std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
+ if(vad_dir.isSet() && !s_vad_path.empty()){
+ std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
+ if(is_download){
LOG(INFO) << "Download model: " << s_vad_path << " from modelscope: ";
- system(python_cmd_vad.c_str());
- std::string down_vad_path = s_download_model_dir+"/"+s_vad_path;
- std::string down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
- if (access(down_vad_model.c_str(), F_OK) != 0){
- LOG(ERROR) << down_vad_model << " do not exists.";
- exit(-1);
- }else{
- model_path[VAD_DIR]=down_vad_path;
- LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
- }
}else{
- LOG(INFO) << "VAD model is not set, use default.";
+ LOG(INFO) << "Check local model: " << s_vad_path;
+ if (access(s_vad_path.c_str(), F_OK) != 0){
+ LOG(ERROR) << s_vad_path << " do not exists.";
+ exit(-1);
+ }
}
- if(model_dir.isSet() && !s_asr_path.empty()){
- std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
+ system(python_cmd_vad.c_str());
+ std::string down_vad_path;
+ std::string down_vad_model;
+ if(is_download){
+ down_vad_path = s_download_model_dir+"/"+s_vad_path;
+ down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
+ }else{
+ down_vad_path = s_vad_path;
+ down_vad_model = s_vad_path+"/model_quant.onnx";
+ if(s_vad_quant=="false" || s_vad_quant=="False" || s_vad_quant=="FALSE"){
+ down_vad_model = s_vad_path+"/model.onnx";
+ }
+ }
+ if (access(down_vad_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_vad_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[VAD_DIR]=down_vad_path;
+ LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
+ }
+ }else{
+ LOG(INFO) << "VAD model is not set, use default.";
+ }
+
+ if(model_dir.isSet() && !s_asr_path.empty()){
+ std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
+ if(is_download){
LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: ";
- system(python_cmd_asr.c_str());
- std::string down_asr_path = s_download_model_dir+"/"+s_asr_path;
- std::string down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
- if (access(down_asr_model.c_str(), F_OK) != 0){
- LOG(ERROR) << down_asr_model << " do not exists.";
- exit(-1);
- }else{
- model_path[MODEL_DIR]=down_asr_path;
- LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
- }
}else{
- LOG(INFO) << "ASR model is not set, use default.";
+ LOG(INFO) << "Check local model: " << s_asr_path;
+ if (access(s_asr_path.c_str(), F_OK) != 0){
+ LOG(ERROR) << s_asr_path << " do not exists.";
+ exit(-1);
+ }
}
- if(punc_dir.isSet() && !s_punc_path.empty()){
- std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
- LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
- system(python_cmd_punc.c_str());
- std::string down_punc_path = s_download_model_dir+"/"+s_punc_path;
- std::string down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
- if (access(down_punc_model.c_str(), F_OK) != 0){
- LOG(ERROR) << down_punc_model << " do not exists.";
- exit(-1);
- }else{
- model_path[PUNC_DIR]=down_punc_path;
- LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
- }
+ system(python_cmd_asr.c_str());
+ std::string down_asr_path;
+ std::string down_asr_model;
+ if(is_download){
+ down_asr_path = s_download_model_dir+"/"+s_asr_path;
+ down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
}else{
- LOG(INFO) << "PUNC model is not set, use default.";
- }
+ down_asr_path = s_asr_path;
+ down_asr_model = s_asr_path+"/model_quant.onnx";
+ if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
+ down_asr_model = s_asr_path+"/model.onnx";
+ }
+ }
+ if (access(down_asr_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_asr_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[MODEL_DIR]=down_asr_path;
+ LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
+ }
+ }else{
+ LOG(INFO) << "ASR model is not set, use default.";
+ }
+
+ if(punc_dir.isSet() && !s_punc_path.empty()){
+ std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
+ if(is_download){
+ LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
+ }else{
+ LOG(INFO) << "Check local model: " << s_punc_path;
+ if (access(s_punc_path.c_str(), F_OK) != 0){
+ LOG(ERROR) << s_punc_path << " do not exists.";
+ exit(-1);
+ }
+ }
+ system(python_cmd_punc.c_str());
+ std::string down_punc_path;
+ std::string down_punc_model;
+ if(is_download){
+ down_punc_path = s_download_model_dir+"/"+s_punc_path;
+ down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
+ }else{
+ down_punc_path = s_punc_path;
+ down_punc_model = s_punc_path+"/model_quant.onnx";
+ if(s_punc_quant=="false" || s_punc_quant=="False" || s_punc_quant=="FALSE"){
+ down_punc_model = s_punc_path+"/model.onnx";
+ }
+ }
+ if (access(down_punc_model.c_str(), F_OK) != 0){
+ LOG(ERROR) << down_punc_model << " do not exists.";
+ exit(-1);
+ }else{
+ model_path[PUNC_DIR]=down_punc_path;
+ LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
+ }
+ }else{
+ LOG(INFO) << "PUNC model is not set, use default.";
}
} catch (std::exception const& e) {
LOG(ERROR) << "Error: " << e.what();
@@ -247,4 +311,4 @@
}
return 0;
-}
\ No newline at end of file
+}
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 4067b04..5aa40ec 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
+import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
if support_audio_type == "pcm":
fs = None
else:
- audio, fs = torchaudio.load(fname)
+ try:
+ audio, fs = torchaudio.load(fname)
+ except:
+ audio, fs = soundfile.read(fname)
break
if audio_type.rfind(".scp") >= 0:
with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 7602740..0e773bb 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
+import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, sampling_rate = torchaudio.load(wav_path)
+ except:
+ waveform, sampling_rate = soundfile.read(wav_path)
+ waveform = np.expand_dims(waveform, axis=0)
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
return n_frames, feature_dim
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index ebb80d2..a6e394f 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
+import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,11 @@
waveform = torch.from_numpy(waveform.reshape(1, -1))
else:
# load pcm from wav, and resample
- waveform, audio_sr = torchaudio.load(wav_file)
+ try:
+ waveform, audio_sr = torchaudio.load(wav_file)
+ except:
+ waveform, audio_sr = soundfile.read(wav_file)
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
waveform = waveform * (1 << 15)
waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +186,11 @@
def wav2num_frame(wav_path, frontend_conf):
- waveform, sampling_rate = torchaudio.load(wav_path)
+ try:
+ waveform, audio_sr = torchaudio.load(wav_file)
+ except:
+ waveform, audio_sr = soundfile.read(wav_file)
+ waveform = torch.tensor(np.expand_dims(waveform, axis=0))
speech_length = (waveform.shape[1] / sampling_rate) * 1000.
n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
diff --git a/setup.py b/setup.py
index 5b49d06..f13a2c2 100644
--- a/setup.py
+++ b/setup.py
@@ -20,7 +20,7 @@
"librosa",
"jamo==0.4.1", # For kss
"PyYAML>=5.1.2",
- "soundfile>=0.10.2",
+ "soundfile>=0.11.0",
"h5py>=2.10.0",
"kaldiio>=2.17.0",
"torch_complex",
--
Gitblit v1.9.1