游雁
2023-06-26 a57dc4a93f9815f943733926d5b8bf285f37e211
Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
10个文件已修改
214 ■■■■ 已修改文件
README.md 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_launch.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/iterable_dataset.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/datasets/large_datasets/dataset.py 9 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/funasr-wss-server.cpp 158 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/asr_utils.py 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/prepare_data.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/wav_utils.py 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
README.md
@@ -96,10 +96,12 @@
### runtime
An example with websocket:
For the server:
```shell
python wss_srv_asr.py --port 10095
```
For the client:
```shell
python wss_client_asr.py --host "0.0.0.0" --port 10095 --mode 2pass --chunk_size "5,10,5"
funasr/bin/asr_inference_launch.py
@@ -19,6 +19,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import yaml
from typeguard import check_argument_types
@@ -863,7 +864,10 @@
            raw_inputs = _load_bytes(data_path_and_name_and_type[0])
            raw_inputs = torch.tensor(raw_inputs)
        if data_path_and_name_and_type is not None and data_path_and_name_and_type[2] == "sound":
            raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            try:
                raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
            except:
                raw_inputs = torch.tensor(soundfile.read(data_path_and_name_and_type[0])[0])
        if data_path_and_name_and_type is None and raw_inputs is not None:
            if isinstance(raw_inputs, np.ndarray):
                raw_inputs = torch.tensor(raw_inputs)
funasr/datasets/iterable_dataset.py
@@ -14,6 +14,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
from torch.utils.data.dataset import IterableDataset
from typeguard import check_argument_types
import os.path
@@ -66,8 +67,14 @@
        bytes = f.read()
    return load_bytes(bytes)
def load_wav(input):
    try:
        return torchaudio.load(input)[0].numpy()
    except:
        return np.expand_dims(soundfile.read(input)[0], axis=0)
DATA_TYPES = {
    "sound": lambda x: torchaudio.load(x)[0].numpy(),
    "sound": load_wav,
    "pcm": load_pcm,
    "kaldi_ark": load_kaldi,
    "bytes": load_bytes,
funasr/datasets/large_datasets/dataset.py
@@ -6,6 +6,8 @@
import torch
import torch.distributed as dist
import torchaudio
import numpy as np
import soundfile
from kaldiio import ReadHelper
from torch.utils.data import IterableDataset
@@ -123,7 +125,12 @@
                            sample_dict["key"] = key
                    elif data_type == "sound":
                        key, path = item.strip().split()
                        waveform, sampling_rate = torchaudio.load(path)
                        try:
                            waveform, sampling_rate = torchaudio.load(path)
                        except:
                            waveform, sampling_rate = soundfile.read(path)
                            waveform = np.expand_dims(waveform, axis=0)
                            waveform = torch.tensor(waveform)
                        if self.frontend_conf is not None:
                            if sampling_rate != self.frontend_conf["fs"]:
                                waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate,
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -59,7 +59,7 @@
        if(result){
            string msg = FunASRGetResult(result, 0);
            LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg.c_str();
            LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
            float snippet_time = FunASRGetRetSnippetTime(result);
            n_total_length += snippet_time;
funasr/runtime/websocket/funasr-wss-server.cpp
@@ -25,7 +25,7 @@
    google::InitGoogleLogging(argv[0]);
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-ws-server", ' ', "1.0");
    TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
    TCLAP::ValueArg<std::string> download_model_dir(
        "", "download-model-dir",
        "Download model from Modelscope to download_model_dir",
@@ -105,63 +105,127 @@
    // Download model form Modelscope
    try{
        std::string s_download_model_dir = download_model_dir.getValue();
        // download model from modelscope when the model-dir is model ID or local path
        bool is_download = false;
        if(download_model_dir.isSet() && !s_download_model_dir.empty()){
            is_download = true;
            if (access(s_download_model_dir.c_str(), F_OK) != 0){
                LOG(ERROR) << s_download_model_dir << " do not exists."; 
                exit(-1);
            }
            std::string s_vad_path = model_path[VAD_DIR];
            std::string s_asr_path = model_path[MODEL_DIR];
            std::string s_punc_path = model_path[PUNC_DIR];
            std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
            if(vad_dir.isSet() && !s_vad_path.empty()){
                std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
        }else{
            s_download_model_dir="./";
        }
        std::string s_vad_path = model_path[VAD_DIR];
        std::string s_vad_quant = model_path[VAD_QUANT];
        std::string s_asr_path = model_path[MODEL_DIR];
        std::string s_asr_quant = model_path[QUANTIZE];
        std::string s_punc_path = model_path[PUNC_DIR];
        std::string s_punc_quant = model_path[PUNC_QUANT];
        std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True ";
        if(vad_dir.isSet() && !s_vad_path.empty()){
            std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir;
            if(is_download){
                LOG(INFO) << "Download model: " <<  s_vad_path << " from modelscope: ";
                system(python_cmd_vad.c_str());
                std::string down_vad_path = s_download_model_dir+"/"+s_vad_path;
                std::string down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
                if (access(down_vad_model.c_str(), F_OK) != 0){
                  LOG(ERROR) << down_vad_model << " do not exists.";
                  exit(-1);
                }else{
                  model_path[VAD_DIR]=down_vad_path;
                  LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
                }
            }else{
              LOG(INFO) << "VAD model is not set, use default.";
                LOG(INFO) << "Check local model: " <<  s_vad_path;
                if (access(s_vad_path.c_str(), F_OK) != 0){
                    LOG(ERROR) << s_vad_path << " do not exists.";
                    exit(-1);
                }
            }
            if(model_dir.isSet() && !s_asr_path.empty()){
                std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
            system(python_cmd_vad.c_str());
            std::string down_vad_path;
            std::string down_vad_model;
            if(is_download){
                down_vad_path  = s_download_model_dir+"/"+s_vad_path;
                down_vad_model = s_download_model_dir+"/"+s_vad_path+"/model_quant.onnx";
            }else{
                down_vad_path  = s_vad_path;
                down_vad_model = s_vad_path+"/model_quant.onnx";
                if(s_vad_quant=="false" || s_vad_quant=="False" || s_vad_quant=="FALSE"){
                    down_vad_model = s_vad_path+"/model.onnx";
                }
            }
            if (access(down_vad_model.c_str(), F_OK) != 0){
                LOG(ERROR) << down_vad_model << " do not exists.";
                exit(-1);
            }else{
                model_path[VAD_DIR]=down_vad_path;
                LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
            }
        }else{
            LOG(INFO) << "VAD model is not set, use default.";
        }
        if(model_dir.isSet() && !s_asr_path.empty()){
            std::string python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir;
            if(is_download){
                LOG(INFO) << "Download model: " <<  s_asr_path << " from modelscope: ";
                system(python_cmd_asr.c_str());
                std::string down_asr_path = s_download_model_dir+"/"+s_asr_path;
                std::string down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
                if (access(down_asr_model.c_str(), F_OK) != 0){
                  LOG(ERROR) << down_asr_model << " do not exists.";
                  exit(-1);
                }else{
                  model_path[MODEL_DIR]=down_asr_path;
                  LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
                }
            }else{
              LOG(INFO) << "ASR model is not set, use default.";
                LOG(INFO) << "Check local model: " <<  s_asr_path;
                if (access(s_asr_path.c_str(), F_OK) != 0){
                    LOG(ERROR) << s_asr_path << " do not exists.";
                    exit(-1);
                }
            }
            if(punc_dir.isSet() && !s_punc_path.empty()){
                std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
                LOG(INFO) << "Download model: " << s_punc_path << " from modelscope: ";
                system(python_cmd_punc.c_str());
                std::string down_punc_path = s_download_model_dir+"/"+s_punc_path;
                std::string down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
                if (access(down_punc_model.c_str(), F_OK) != 0){
                  LOG(ERROR) << down_punc_model << " do not exists.";
                  exit(-1);
                }else{
                  model_path[PUNC_DIR]=down_punc_path;
                  LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
                }
            system(python_cmd_asr.c_str());
            std::string down_asr_path;
            std::string down_asr_model;
            if(is_download){
                down_asr_path  = s_download_model_dir+"/"+s_asr_path;
                down_asr_model = s_download_model_dir+"/"+s_asr_path+"/model_quant.onnx";
            }else{
              LOG(INFO) << "PUNC model is not set, use default.";
            }
                down_asr_path  = s_asr_path;
                down_asr_model = s_asr_path+"/model_quant.onnx";
                if(s_asr_quant=="false" || s_asr_quant=="False" || s_asr_quant=="FALSE"){
                    down_asr_model = s_asr_path+"/model.onnx";
                }
            }
            if (access(down_asr_model.c_str(), F_OK) != 0){
              LOG(ERROR) << down_asr_model << " do not exists.";
              exit(-1);
            }else{
              model_path[MODEL_DIR]=down_asr_path;
              LOG(INFO) << "Set " << MODEL_DIR << " : " << model_path[MODEL_DIR];
            }
        }else{
          LOG(INFO) << "ASR model is not set, use default.";
        }
        if(punc_dir.isSet() && !s_punc_path.empty()){
            std::string python_cmd_punc = python_cmd + " --model-name " + s_punc_path + " --export-dir " + s_download_model_dir;
            if(is_download){
                LOG(INFO) << "Download model: " <<  s_punc_path << " from modelscope: ";
            }else{
                LOG(INFO) << "Check local model: " <<  s_punc_path;
                if (access(s_punc_path.c_str(), F_OK) != 0){
                    LOG(ERROR) << s_punc_path << " do not exists.";
                    exit(-1);
                }
            }
            system(python_cmd_punc.c_str());
            std::string down_punc_path;
            std::string down_punc_model;
            if(is_download){
                down_punc_path  = s_download_model_dir+"/"+s_punc_path;
                down_punc_model = s_download_model_dir+"/"+s_punc_path+"/model_quant.onnx";
            }else{
                down_punc_path  = s_punc_path;
                down_punc_model = s_punc_path+"/model_quant.onnx";
                if(s_punc_quant=="false" || s_punc_quant=="False" || s_punc_quant=="FALSE"){
                    down_punc_model = s_punc_path+"/model.onnx";
                }
            }
            if (access(down_punc_model.c_str(), F_OK) != 0){
              LOG(ERROR) << down_punc_model << " do not exists.";
              exit(-1);
            }else{
              model_path[PUNC_DIR]=down_punc_path;
              LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
            }
        }else{
          LOG(INFO) << "PUNC model is not set, use default.";
        }
    } catch (std::exception const& e) {
        LOG(ERROR) << "Error: " << e.what();
@@ -247,4 +311,4 @@
  }
  return 0;
}
}
funasr/utils/asr_utils.py
@@ -5,6 +5,7 @@
from typing import Any, Dict, List, Union
import torchaudio
import soundfile
import numpy as np
import pkg_resources
from modelscope.utils.logger import get_logger
@@ -135,7 +136,10 @@
                if support_audio_type == "pcm":
                    fs = None
                else:
                    audio, fs = torchaudio.load(fname)
                    try:
                        audio, fs = torchaudio.load(fname)
                    except:
                        audio, fs = soundfile.read(fname)
                break
        if audio_type.rfind(".scp") >= 0:
            with open(fname, encoding="utf-8") as f:
funasr/utils/prepare_data.py
@@ -7,6 +7,7 @@
import numpy as np
import torch.distributed as dist
import torchaudio
import soundfile
def filter_wav_text(data_dir, dataset):
@@ -42,7 +43,11 @@
def wav2num_frame(wav_path, frontend_conf):
    waveform, sampling_rate = torchaudio.load(wav_path)
    try:
        waveform, sampling_rate = torchaudio.load(wav_path)
    except:
        waveform, sampling_rate = soundfile.read(wav_path)
        waveform = np.expand_dims(waveform, axis=0)
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
    return n_frames, feature_dim
funasr/utils/wav_utils.py
@@ -11,6 +11,7 @@
import numpy as np
import torch
import torchaudio
import soundfile
import torchaudio.compliance.kaldi as kaldi
@@ -162,7 +163,11 @@
        waveform = torch.from_numpy(waveform.reshape(1, -1))
    else:
        # load pcm from wav, and resample
        waveform, audio_sr = torchaudio.load(wav_file)
        try:
            waveform, audio_sr = torchaudio.load(wav_file)
        except:
            waveform, audio_sr = soundfile.read(wav_file)
            waveform = torch.tensor(np.expand_dims(waveform, axis=0))
        waveform = waveform * (1 << 15)
        waveform = torch_resample(waveform, audio_sr, model_sr)
@@ -181,7 +186,11 @@
def wav2num_frame(wav_path, frontend_conf):
    waveform, sampling_rate = torchaudio.load(wav_path)
    try:
        waveform, audio_sr = torchaudio.load(wav_file)
    except:
        waveform, audio_sr = soundfile.read(wav_file)
        waveform = torch.tensor(np.expand_dims(waveform, axis=0))
    speech_length = (waveform.shape[1] / sampling_rate) * 1000.
    n_frames = (waveform.shape[1] * 1000.0) / (sampling_rate * frontend_conf["frame_shift"] * frontend_conf["lfr_n"])
    feature_dim = frontend_conf["n_mels"] * frontend_conf["lfr_m"]
setup.py
@@ -20,7 +20,7 @@
        "librosa",
        "jamo==0.4.1",  # For kss
        "PyYAML>=5.1.2",
        "soundfile>=0.10.2",
        "soundfile>=0.11.0",
        "h5py>=2.10.0",
        "kaldiio>=2.17.0",
        "torch_complex",