From a57dc4a93f9815f943733926d5b8bf285f37e211 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 26 六月 2023 21:46:26 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/datasets/large_datasets/dataset.py                  |    9 ++
 funasr/runtime/websocket/funasr-wss-server.cpp             |  158 +++++++++++++++++++++++++++-----------
 funasr/utils/prepare_data.py                               |    7 +
 funasr/utils/wav_utils.py                                  |   13 ++
 funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp |    2 
 setup.py                                                   |    2 
 funasr/datasets/iterable_dataset.py                        |    9 ++
 funasr/bin/asr_inference_launch.py                         |    6 +
 funasr/utils/asr_utils.py                                  |    6 +
 README.md                                                  |    2 
 10 files changed, 158 insertions(+), 56 deletions(-)

diff --git a/README.md b/README.md
index 8368b3b..4338992 100644
--- a/README.md
+++ b/README.md
@@ -96,10 +96,12 @@
 ### runtime
 
 An example with websocket:
+
 For the server:
 ```shell
 python wss_srv_asr.py --port 10095
 ```
+
 For the client:
 ```shell
 python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 656a965..ce1f984 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import yaml
 from typeguard import check_argument_types
 
@@ -863,7 +864,10 @@
             raw_inputs = _load_bytes(data_path_and_name_and_type[0])
             raw_inputs = torch.tensor(raw_inputs)
         if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
-            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            try:
+                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
+            except:
+                raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
         if data_path_and_name_and_type is None and raw_inputs is not None:
             if isinstance(raw_inputs, np.ndarray):
                 raw_inputs = torch.tensor(raw_inputs)
diff --git a/funasr/datasets/iterable_dataset.py b/funasr/datasets/iterable_dataset.py
index 4b2fb1a..fa0f0c7 100644
--- a/funasr/datasets/iterable_dataset.py
+++ b/funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 from torch.utils.data.dataset import IterableDataset
 from typeguard import check_argument_types
 import os.path
@@ -66,8 +67,14 @@
         bytes = f.read()
     return load_bytes(bytes)
 
+def load_wav(input):
+    try:
+        return torchaudio.load(input)[0].numpy()
+    except:
+        return np.expand_dims(soundfile.read(input)[0], axis=0)
+
 DATA_TYPES = {
-    "sound": lambda x: torchaudio.load(x)[0].numpy(),
+    "sound": load_wav,
     "pcm": load_pcm,
     "kaldi_ark": load_kaldi,
     "bytes": load_bytes,
diff --git a/funasr/datasets/large_datasets/dataset.py b/funasr/datasets/large_datasets/dataset.py
index 68b63e1..844dde7 100644
--- a/funasr/datasets/large_datasets/dataset.py
+++ b/funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
 import torch
 import torch.distributed as dist
 import torchaudio
+import numpy as np
+import soundfile
 from kaldiio import ReadHelper
 from torch.utils.data import IterableDataset
 
@@ -123,7 +125,12 @@
                             sample_dict["key"] = key
                     elif data_type == "sound":
                         key, path = item.strip().split()
-                        waveform, sampling_rate = torchaudio.load(path)
+                        try:
+                            waveform, sampling_rate = torchaudio.load(path)
+                        except:
+                            waveform, sampling_rate = soundfile.read(path)
+                            waveform = np.expand_dims(waveform, axis=0)
+                            waveform = torch.tensor(waveform)
                         if self.frontend_conf is not None:
                             if sampling_rate != self.frontend_conf["fs"]:
                                 waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index a4ee7f7..ee05d75 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -59,7 +59,7 @@
 
         if(result){
             string msg = FunASRGetResult(result, 0);
-            LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
+            LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
 
             float snippet_time = FunASRGetRetSnippetTime(result);
             n_total_length += snippet_time;
diff --git a/funasr/runtime/websocket/funasr-wss-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp
index 874f306..e479888 100644
--- a/funasr/runtime/websocket/funasr-wss-server.cpp
+++ b/funasr/runtime/websocket/funasr-wss-server.cpp
@@ -25,7 +25,7 @@
     google::InitGoogleLogging(argv[0]);
     FLAGS_logtostderr = true;
 
-    TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
+    TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
     TCLAP::ValueArg<std::string> download_model_dir(
         "", "download-model-dir",
         "Download model from Modelscope to download_model_dir",
@@ -105,63 +105,127 @@
     // Download model form Modelscope
     try{
         std::string s_download_model_dir = download_model_dir.getValue();
+        // download model from modelscope when the model-dir is model ID or local path
+        bool is_download = false;
         if(download_model_dir.isSet() && !s_download_model_dir.empty()){
+            is_download = true;
             if (access(s_download_model_dir.c_str(), F_OK) != 0){
                 LOG(ERROR) << s_download_model_dir << " do not exists."; 
                 exit(-1);
             }
-            std::string s_vad_path = model_path[VAD_DIR];
-            std::string s_asr_path = model_path[MODEL_DIR];
-            std::string s_punc_path = model_path[PUNC_DIR];
-            std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
-            if(vad_dir.isSet() && !s_vad_path.empty()){
-                std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
+        }else{
+            s_download_model_dir="./";
+        }
+        std::string s_vad_path = model_path[VAD_DIR];
+        std::string s_vad_quant = model_path[VAD_QUANT];
+        std::string s_asr_path = model_path[MODEL_DIR];
+        std::string s_asr_quant = model_path[QUANTIZE];
+        std::string s_punc_path = model_path[PUNC_DIR];
+        std::string s_punc_quant = model_path[PUNC_QUANT];
+        std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
+        if(vad_dir.isSet() && !s_vad_path.empty()){
+            std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
+            if(is_download){
                 LOG(INFO) << "Download model: " <<  s_vad_path << " from modelscope: ";
-                system(python_cmd_vad.c_str());
-                std::string down_vad_path = s_download_model_dir+"/"+s_vad_path;
-                std::string down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
-                if (access(down_vad_model.c_str(), F_OK) != 0){
-                  LOG(ERROR) << down_vad_model << " do not exists."; 
-                  exit(-1);
-                }else{
-                  model_path[VAD_DIR]=down_vad_path;
-                  LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
-                }
             }else{
-              LOG(INFO) << "VAD model is not set, use default.";
+                LOG(INFO) << "Check local model: " <<  s_vad_path;
+                if (access(s_vad_path.c_str(), F_OK) != 0){
+                    LOG(ERROR) << s_vad_path << " do not exists."; 
+                    exit(-1);
+                }                
             }
-            if(model_dir.isSet() && !s_asr_path.empty()){
-                std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
+            system(python_cmd_vad.c_str());
+            std::string down_vad_path;
+            std::string down_vad_model;            
+            if(is_download){
+                down_vad_path  = s_download_model_dir+"/"+s_vad_path;
+                down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
+            }else{
+                down_vad_path  = s_vad_path;
+                down_vad_model = s_vad_path+"/model_quant.onnx";
+                if(s_vad_quant=="false" || s_vad_quant=="False" || s_vad_quant=="FALSE"){
+                    down_vad_model = s_vad_path+"/model.onnx";
+                }
+            }
+            if (access(down_vad_model.c_str(), F_OK) != 0){
+                LOG(ERROR) << down_vad_model << " do not exists."; 
+                exit(-1);
+            }else{
+                model_path[VAD_DIR]=down_vad_path;
+                LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
+            }
+        }else{
+            LOG(INFO) << "VAD model is not set, use default.";
+        }
+
+        if(model_dir.isSet() && !s_asr_path.empty()){
+            std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
+            if(is_download){
                 LOG(INFO) << "Download model: " <<  s_asr_path << " from modelscope: ";
-                system(python_cmd_asr.c_str());
-                std::string down_asr_path = s_download_model_dir+"/"+s_asr_path;
-                std::string down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
-                if (access(down_asr_model.c_str(), F_OK) != 0){
-                  LOG(ERROR) << down_asr_model << " do not exists."; 
-                  exit(-1);
-                }else{
-                  model_path[MODEL_DIR]=down_asr_path;
-                  LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
-                }
             }else{
-              LOG(INFO) << "ASR model is not set, use default.";
+                LOG(INFO) << "Check local model: " <<  s_asr_path;
+                if (access(s_asr_path.c_str(), F_OK) != 0){
+                    LOG(ERROR) << s_asr_path << " do not exists."; 
+                    exit(-1);
+                }                
             }
-            if(punc_dir.isSet() && !s_punc_path.empty()){
-                std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
-                LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
-                system(python_cmd_punc.c_str());
-                std::string down_punc_path = s_download_model_dir+"/"+s_punc_path;
-                std::string down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
-                if (access(down_punc_model.c_str(), F_OK) != 0){
-                  LOG(ERROR) << down_punc_model << " do not exists."; 
-                  exit(-1);
-                }else{
-                  model_path[PUNC_DIR]=down_punc_path;
-                  LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
-                }
+            system(python_cmd_asr.c_str());
+            std::string down_asr_path;
+            std::string down_asr_model;     
+            if(is_download){
+                down_asr_path  = s_download_model_dir+"/"+s_asr_path;
+                down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
             }else{
-              LOG(INFO) << "PUNC model is not set, use default.";
-            }    
+                down_asr_path  = s_asr_path;
+                down_asr_model = s_asr_path+"/model_quant.onnx";
+                if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
+                    down_asr_model = s_asr_path+"/model.onnx";
+                }
+            }
+            if (access(down_asr_model.c_str(), F_OK) != 0){
+              LOG(ERROR) << down_asr_model << " do not exists."; 
+              exit(-1);
+            }else{
+              model_path[MODEL_DIR]=down_asr_path;
+              LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
+            }
+        }else{
+          LOG(INFO) << "ASR model is not set, use default.";
+        }
+
+        if(punc_dir.isSet() && !s_punc_path.empty()){
+            std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
+            if(is_download){
+                LOG(INFO) << "Download model: " <<  s_punc_path << " from modelscope: ";
+            }else{
+                LOG(INFO) << "Check local model: " <<  s_punc_path;
+                if (access(s_punc_path.c_str(), F_OK) != 0){
+                    LOG(ERROR) << s_punc_path << " do not exists."; 
+                    exit(-1);
+                }                
+            }
+            system(python_cmd_punc.c_str());
+            std::string down_punc_path;
+            std::string down_punc_model;            
+            if(is_download){
+                down_punc_path  = s_download_model_dir+"/"+s_punc_path;
+                down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
+            }else{
+                down_punc_path  = s_punc_path;
+                down_punc_model = s_punc_path+"/model_quant.onnx";
+                if(s_punc_quant=="false" || s_punc_quant=="False" || s_punc_quant=="FALSE"){
+                    down_punc_model = s_punc_path+"/model.onnx";
+                }
+            }
+            if (access(down_punc_model.c_str(), F_OK) != 0){
+              LOG(ERROR) << down_punc_model << " do not exists."; 
+              exit(-1);
+            }else{
+              model_path[PUNC_DIR]=down_punc_path;
+              LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
+            }
+        }else{
+          LOG(INFO) << "PUNC model is not set, use default.";
         }
     } catch (std::exception const& e) {
         LOG(ERROR) << "Error: " << e.what();
@@ -247,4 +311,4 @@
   }
 
   return 0;
-}
\ No newline at end of file
+}
diff --git a/funasr/utils/asr_utils.py b/funasr/utils/asr_utils.py
index 4067b04..5aa40ec 100644
--- a/funasr/utils/asr_utils.py
+++ b/funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
 from typing import Any, Dict, List, Union
 
 import torchaudio
+import soundfile
 import numpy as np
 import pkg_resources
 from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
                 if support_audio_type == "pcm":
                     fs = None
                 else:
-                    audio, fs = torchaudio.load(fname)
+                    try:
+                        audio, fs = torchaudio.load(fname)
+                    except:
+                        audio, fs = soundfile.read(fname)
                 break
         if audio_type.rfind(".scp") >= 0:
             with open(fname, encoding="utf-8") as f:
diff --git a/funasr/utils/prepare_data.py b/funasr/utils/prepare_data.py
index 7602740..0e773bb 100644
--- a/funasr/utils/prepare_data.py
+++ b/funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
 import numpy as np
 import torch.distributed as dist
 import torchaudio
+import soundfile
 
 
 def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, sampling_rate = torchaudio.load(wav_path)
+    except:
+        waveform, sampling_rate = soundfile.read(wav_path)
+        waveform = np.expand_dims(waveform, axis=0)
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
     return n_frames, feature_dim
diff --git a/funasr/utils/wav_utils.py b/funasr/utils/wav_utils.py
index ebb80d2..a6e394f 100644
--- a/funasr/utils/wav_utils.py
+++ b/funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
 import numpy as np
 import torch
 import torchaudio
+import soundfile
 import torchaudio.compliance.kaldi as kaldi
 
 
@@ -162,7 +163,11 @@
         waveform = torch.from_numpy(waveform.reshape(1, -1))
     else:
         # load pcm from wav, and resample
-        waveform, audio_sr = torchaudio.load(wav_file)
+        try:
+            waveform, audio_sr = torchaudio.load(wav_file)
+        except:
+            waveform, audio_sr = soundfile.read(wav_file)
+            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
         waveform = waveform * (1 << 15)
         waveform = torch_resample(waveform, audio_sr, model_sr)
 
@@ -181,7 +186,11 @@
 
 
 def wav2num_frame(wav_path, frontend_conf):
-    waveform, sampling_rate = torchaudio.load(wav_path)
+    try:
+        waveform, audio_sr = torchaudio.load(wav_file)
+    except:
+        waveform, audio_sr = soundfile.read(wav_file)
+        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
     speech_length = (waveform.shape[1] / sampling_rate) * 1000.
     n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
     feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
diff --git a/setup.py b/setup.py
index 5b49d06..f13a2c2 100644
--- a/setup.py
+++ b/setup.py
@@ -20,7 +20,7 @@
         "librosa",
         "jamo==0.4.1",  # For kss
         "PyYAML>=5.1.2",
-        "soundfile>=0.10.2",
+        "soundfile>=0.11.0",
         "h5py>=2.10.0",
         "kaldiio>=2.17.0",
         "torch_complex",

--
Gitblit v1.9.1