Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
| | |
| | |
|
| | | var isfilemode=false; // if it is in file mode
|
| | | var file_data_array; // array to save file data
|
| | | var isconnected=0; // for file rec, 0 is not begin, 1 is connected, -1 is error
|
| | | var totalsend=0;
|
| | |
|
| | | var totalsend=0;
|
| | |
|
| | | upfile.onclick=function()
|
| | | {
|
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | | |
| | | }
|
| | | upfile.onchange = function () {
|
| | | var len = this.files.length;
|
| | | for(let i = 0; i < len; i++) {
|
| | |
| | | var audioblob= fileAudio.result;
|
| | | file_data_array=audioblob;
|
| | | console.log(audioblob);
|
| | | btnConnect.disabled = false;
|
| | | |
| | | info_div.innerHTML='请点击连接进行识别';
|
| | |
|
| | | }
|
| | |
| | | sendBuf=sampleBuf.slice(0,chunk_size);
|
| | | totalsend=totalsend+sampleBuf.length;
|
| | | sampleBuf=sampleBuf.slice(chunk_size,sampleBuf.length);
|
| | | wsconnecter.wsSend(sendBuf,false);
|
| | | wsconnecter.wsSend(sendBuf);
|
| | |
|
| | |
|
| | | }
|
| | |
| | |
|
| | |
|
| | | }
|
| | | function start_file_offline()
|
| | | { |
| | | console.log("start_file_offline",isconnected); |
| | | if(isconnected==-1)
|
| | | {
|
| | | return;
|
| | | }
|
| | | if(isconnected==0){
|
| | | |
| | | setTimeout(start_file_offline, 1000);
|
| | | return;
|
| | | }
|
| | | start_file_send();
|
| | |
|
| | | |
| | |
|
| | | |
| | | }
|
| | |
|
| | | function on_recoder_mode_change()
|
| | | {
|
| | |
| | | document.getElementById("mic_mode_div").style.display = 'block';
|
| | | document.getElementById("rec_mode_div").style.display = 'none';
|
| | |
|
| | | btnConnect.disabled=false;
|
| | | |
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | | isfilemode=false;
|
| | | }
|
| | | else
|
| | | {
|
| | | document.getElementById("mic_mode_div").style.display = 'none';
|
| | | document.getElementById("rec_mode_div").style.display = 'block';
|
| | | btnConnect.disabled = true;
|
| | | |
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=true;
|
| | | isfilemode=true;
|
| | | info_div.innerHTML='请点击选择文件';
|
| | |
|
| | |
| | | wsconnecter.wsStop();
|
| | |
|
| | | info_div.innerHTML="请点击连接";
|
| | | isconnected=0;
|
| | | |
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | |
| | |
|
| | | // 连接状态响应
|
| | | function getConnState( connState ) {
|
| | | if ( connState === 0 ) {
|
| | | if ( connState === 0 ) { //on open
|
| | |
|
| | |
|
| | | info_div.innerHTML='连接成功!请点击开始';
|
| | | if (isfilemode==true){
|
| | | info_div.innerHTML='请耐心等待,大文件等待时间更长';
|
| | | start_file_send();
|
| | | }
|
| | | else
|
| | | {
|
| | | btnStart.disabled = false;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=true;
|
| | | }
|
| | | } else if ( connState === 1 ) {
|
| | | //stop();
|
| | |
| | |
|
| | | alert("连接地址"+document.getElementById('wssip').value+"失败,请检查asr地址和端口,并确保h5服务和asr服务在同一个域内。或换个浏览器试试。");
|
| | | btnStart.disabled = true;
|
| | | isconnected=0;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | | |
| | |
|
| | | info_div.innerHTML='请点击连接';
|
| | | }
|
| | |
| | | rec.open( function(){
|
| | | rec.start();
|
| | | console.log("开始");
|
| | | btnStart.disabled = true;
|
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = false;
|
| | | btnConnect.disabled=true;
|
| | | });
|
| | |
|
| | | }
|
| | |
| | | // 清除显示
|
| | | clear();
|
| | | //控件状态更新
|
| | | console.log("isfilemode"+isfilemode+","+isconnected);
|
| | | info_div.innerHTML="正在连接asr服务器,请等待...";
|
| | | console.log("isfilemode"+isfilemode);
|
| | | |
| | | //启动连接
|
| | | var ret=wsconnecter.wsStart();
|
| | | // 1 is ok, 0 is error
|
| | | if(ret==1){
|
| | | info_div.innerHTML="正在连接asr服务器,请等待...";
|
| | | isRec = true;
|
| | | btnStart.disabled = false;
|
| | | btnStop.disabled = false;
|
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=true;
|
| | | if (isfilemode)
|
| | | {
|
| | | console.log("start file now");
|
| | | start_file_offline();
|
| | |
|
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled = true;
|
| | | }
|
| | | return 1;
|
| | | }
|
| | | return 0;
|
| | | else
|
| | | {
|
| | | info_div.innerHTML="请点击开始";
|
| | | btnStart.disabled = true;
|
| | | btnStop.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | | |
| | | return 0;
|
| | | }
|
| | | }
|
| | |
|
| | |
|
| | |
| | | };
|
| | | console.log(request);
|
| | | if(sampleBuf.length>0){
|
| | | wsconnecter.wsSend(sampleBuf,false);
|
| | | wsconnecter.wsSend(sampleBuf);
|
| | | console.log("sampleBuf.length"+sampleBuf.length);
|
| | | sampleBuf=new Int16Array();
|
| | | }
|
| | | wsconnecter.wsSend( JSON.stringify(request) ,false);
|
| | | wsconnecter.wsSend( JSON.stringify(request) );
|
| | |
|
| | |
|
| | |
|
| | |
|
| | |
|
| | | //isconnected=0;
|
| | | |
| | | // 控件状态更新
|
| | |
|
| | | isRec = false;
|
| | |
| | | if(isfilemode==false){
|
| | | btnStop.disabled = true;
|
| | | btnStart.disabled = true;
|
| | | btnConnect.disabled=false;
|
| | | btnConnect.disabled=true;
|
| | | //wait 3s for asr result
|
| | | setTimeout(function(){
|
| | | console.log("call stop ws!");
|
| | | wsconnecter.wsStop();
|
| | | isconnected=0;
|
| | | btnConnect.disabled=false;
|
| | | info_div.innerHTML="请点击连接";}, 3000 );
|
| | | |
| | | |
| | |
|
| | | rec.stop(function(blob,duration){
|
| | |
|
| | |
| | | while(sampleBuf.length>=chunk_size){
|
| | | sendBuf=sampleBuf.slice(0,chunk_size);
|
| | | sampleBuf=sampleBuf.slice(chunk_size,sampleBuf.length);
|
| | | wsconnecter.wsSend(sendBuf,false);
|
| | | wsconnecter.wsSend(sendBuf);
|
| | |
|
| | |
|
| | |
|
| | |
| | | speechSokt.onopen = function(e){onOpen(e);}; // 定义响应函数
|
| | | speechSokt.onclose = function(e){
|
| | | console.log("onclose ws!");
|
| | | speechSokt.close();
|
| | | //speechSokt.close();
|
| | | onClose(e);
|
| | | };
|
| | | speechSokt.onmessage = function(e){onMessage(e);};
|
| | |
| | | }
|
| | | };
|
| | |
|
| | | this.wsSend = function ( oneData,stop ) {
|
| | | this.wsSend = function ( oneData ) {
|
| | |
|
| | | if(speechSokt == undefined) return;
|
| | | if ( speechSokt.readyState === 1 ) { // 0:CONNECTING, 1:OPEN, 2:CLOSING, 3:CLOSED
|
| | |
|
| | | speechSokt.send( oneData );
|
| | | if(stop){
|
| | | setTimeout(speechSokt.close(), 3000 );
|
| | |
|
| | | }
|
| | |
|
| | | }
|
| | | };
|
| | |
| | | speechSokt.send( JSON.stringify(request) );
|
| | | console.log("连接成功");
|
| | | stateHandle(0);
|
| | | isconnected=1;
|
| | | |
| | | }
|
| | |
|
| | | function onClose( e ) {
|
| | |
| | | }
|
| | |
|
| | | function onError( e ) {
|
| | | isconnected=-1;
|
| | | |
| | | info_div.innerHTML="连接"+e;
|
| | | console.log(e);
|
| | | stateHandle(2);
|
| | |
| | | add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp") |
| | | target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr) |
| | | |
| | | add_executable(funasr-onnx-online-punc "funasr-onnx-online-punc.cpp") |
| | | target_link_libraries(funasr-onnx-online-punc PUBLIC funasr) |
| | | |
| | | add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp") |
| | | target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr) |
| | |
| | | long taking_micros = 0; |
| | | for(auto& txt_str : txt_list){ |
| | | gettimeofday(&start, NULL); |
| | | string result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL); |
| | | FUNASR_RESULT result=CTTransformerInfer(punc_hanlde, txt_str.c_str(), RASR_NONE, NULL); |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | LOG(INFO)<<"Results: "<<result; |
| | | string msg = FunASRGetResult(result, 0); |
| | | LOG(INFO)<<"Results: "<<msg; |
| | | CTTransformerFreeResult(result); |
| | | } |
| | | |
| | | LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | |
| | | #ifndef _WIN32 |
| | | #include <sys/time.h> |
| | | #else |
| | | #include <win_func.h> |
| | | #endif |
| | | |
| | | #include <iostream> |
| | | #include <fstream> |
| | | #include <sstream> |
| | | #include <map> |
| | | #include <glog/logging.h> |
| | | #include "funasrruntime.h" |
| | | #include "tclap/CmdLine.h" |
| | | #include "com-define.h" |
| | | |
| | | using namespace std; |
| | | |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path) |
| | | { |
| | | if (value_arg.isSet()){ |
| | | model_path.insert({key, value_arg.getValue()}); |
| | | LOG(INFO)<< key << " : " << value_arg.getValue(); |
| | | } |
| | | } |
| | | |
| | | void splitString(vector<string>& strings, const string& org_string, const string& seq) { |
| | | string::size_type p1 = 0; |
| | | string::size_type p2 = org_string.find(seq); |
| | | |
| | | while (p2 != string::npos) { |
| | | if (p2 == p1) { |
| | | ++p1; |
| | | p2 = org_string.find(seq, p1); |
| | | continue; |
| | | } |
| | | strings.push_back(org_string.substr(p1, p2 - p1)); |
| | | p1 = p2 + seq.size(); |
| | | p2 = org_string.find(seq, p1); |
| | | } |
| | | |
| | | if (p1 != org_string.size()) { |
| | | strings.push_back(org_string.substr(p1)); |
| | | } |
| | | } |
| | | |
| | | int main(int argc, char *argv[]) |
| | | { |
| | | google::InitGoogleLogging(argv[0]); |
| | | FLAGS_logtostderr = true; |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string"); |
| | | |
| | | cmd.add(model_dir); |
| | | cmd.add(quantize); |
| | | cmd.add(txt_path); |
| | | cmd.parse(argc, argv); |
| | | |
| | | std::map<std::string, std::string> model_path; |
| | | GetValue(model_dir, MODEL_DIR, model_path); |
| | | GetValue(quantize, QUANTIZE, model_path); |
| | | GetValue(txt_path, TXT_PATH, model_path); |
| | | |
| | | struct timeval start, end; |
| | | gettimeofday(&start, NULL); |
| | | int thread_num = 1; |
| | | FUNASR_HANDLE punc_hanlde=CTTransformerInit(model_path, thread_num, PUNC_ONLINE); |
| | | |
| | | if (!punc_hanlde) |
| | | { |
| | | LOG(ERROR) << "FunASR init failed"; |
| | | exit(-1); |
| | | } |
| | | |
| | | gettimeofday(&end, NULL); |
| | | long seconds = (end.tv_sec - start.tv_sec); |
| | | long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; |
| | | |
| | | // read txt_path |
| | | vector<string> txt_list; |
| | | |
| | | if(model_path.find(TXT_PATH)!=model_path.end()){ |
| | | ifstream in(model_path.at(TXT_PATH)); |
| | | if (!in.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << model_path.at(TXT_PATH) ; |
| | | return 0; |
| | | } |
| | | string line; |
| | | while(getline(in, line)) |
| | | { |
| | | txt_list.emplace_back(line); |
| | | } |
| | | in.close(); |
| | | } |
| | | |
| | | long taking_micros = 0; |
| | | for(auto& txt_str : txt_list){ |
| | | vector<string> vad_strs; |
| | | splitString(vad_strs, txt_str, "|"); |
| | | string str_out; |
| | | FUNASR_RESULT result = nullptr; |
| | | gettimeofday(&start, NULL); |
| | | for(auto& vad_str:vad_strs){ |
| | | result=CTTransformerInfer(punc_hanlde, vad_str.c_str(), RASR_NONE, NULL, PUNC_ONLINE, result); |
| | | if(result){ |
| | | string msg = CTTransformerGetResult(result, 0); |
| | | str_out += msg; |
| | | LOG(INFO)<<"Online result: "<<msg; |
| | | } |
| | | } |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | LOG(INFO)<<"Results: "<<str_out; |
| | | CTTransformerFreeResult(result); |
| | | } |
| | | |
| | | LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s"; |
| | | CTTransformerUninit(punc_hanlde); |
| | | return 0; |
| | | } |
| | | |
| | |
| | | |
| | | #define CANDIDATE_NUM 6 |
| | | #define UNKNOW_INDEX 0 |
| | | #define NOTPUNC "_" |
| | | #define NOTPUNC_INDEX 1 |
| | | #define COMMA_INDEX 2 |
| | | #define PERIOD_INDEX 3 |
| | |
| | | FUNASR_MODEL_PARAFORMER = 3, |
| | | }FUNASR_MODEL_TYPE; |
| | | |
| | | typedef enum { |
| | | PUNC_OFFLINE=0, |
| | | PUNC_ONLINE=1, |
| | | }PUNC_TYPE; |
| | | |
| | | typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step. |
| | | |
| | | // ASR |
| | |
| | | _FUNASRAPI const float FsmnVadGetRetSnippetTime(FUNASR_RESULT result); |
| | | |
| | | // PUNC |
| | | _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num); |
| | | _FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback); |
| | | _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE); |
| | | _FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type=PUNC_OFFLINE, FUNASR_RESULT pre_result=nullptr); |
| | | _FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index); |
| | | _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result); |
| | | _FUNASRAPI void CTTransformerUninit(FUNASR_HANDLE handle); |
| | | |
| | | //OfflineStream |
| | |
| | | #include <string> |
| | | #include <map> |
| | | #include <vector> |
| | | #include "funasrruntime.h" |
| | | |
| | | namespace funasr { |
| | | class PuncModel { |
| | | public: |
| | | virtual ~PuncModel(){}; |
| | | virtual void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num)=0; |
| | | virtual std::vector<int> Infer(std::vector<int32_t> input_data)=0; |
| | | virtual std::string AddPunc(const char* sz_input)=0; |
| | | virtual std::string AddPunc(const char* sz_input){return "";}; |
| | | virtual std::string AddPunc(const char* sz_input, std::vector<std::string>& arr_cache){return "";}; |
| | | }; |
| | | |
| | | PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num); |
| | | PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type=PUNC_OFFLINE); |
| | | } // namespace funasr |
| | | #endif |
| | |
| | | float snippet_time; |
| | | }FUNASR_VAD_RESULT; |
| | | |
| | | typedef struct |
| | | { |
| | | string msg; |
| | | vector<string> arr_cache; |
| | | }FUNASR_PUNC_RESULT; |
| | | |
| | | #ifdef _WIN32 |
| | | #include <codecvt> |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | |
| | | #include "precomp.h" |
| | | |
| | | namespace funasr { |
| | | CTTransformerOnline::CTTransformerOnline() |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, ""),session_options{} |
| | | { |
| | | } |
| | | |
| | | void CTTransformerOnline::InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num){ |
| | | session_options.SetIntraOpNumThreads(thread_num); |
| | | session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| | | session_options.DisableCpuMemArena(); |
| | | |
| | | try{ |
| | | m_session = std::make_unique<Ort::Session>(env_, punc_model.c_str(), session_options); |
| | | LOG(INFO) << "Successfully load model from " << punc_model; |
| | | } |
| | | catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load punc onnx model: " << e.what(); |
| | | exit(0); |
| | | } |
| | | // read inputnames outputnames |
| | | string strName; |
| | | GetInputName(m_session.get(), strName); |
| | | m_strInputNames.push_back(strName.c_str()); |
| | | GetInputName(m_session.get(), strName, 1); |
| | | m_strInputNames.push_back(strName); |
| | | GetInputName(m_session.get(), strName, 2); |
| | | m_strInputNames.push_back(strName); |
| | | GetInputName(m_session.get(), strName, 3); |
| | | m_strInputNames.push_back(strName); |
| | | |
| | | GetOutputName(m_session.get(), strName); |
| | | m_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : m_strInputNames) |
| | | m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : m_strOutputNames) |
| | | m_szOutputNames.push_back(item.c_str()); |
| | | |
| | | m_tokenizer.OpenYaml(punc_config.c_str()); |
| | | } |
| | | |
| | | CTTransformerOnline::~CTTransformerOnline() |
| | | { |
| | | } |
| | | |
| | | string CTTransformerOnline::AddPunc(const char* sz_input, vector<string> &arr_cache) |
| | | { |
| | | string strResult; |
| | | vector<string> strOut; |
| | | vector<int> InputData; |
| | | string strText; //full_text |
| | | strText = accumulate(arr_cache.begin(), arr_cache.end(), strText); |
| | | strText += sz_input; // full_text = precache + text |
| | | m_tokenizer.Tokenize(strText.c_str(), strOut, InputData); |
| | | |
| | | int nTotalBatch = ceil((float)InputData.size() / TOKEN_LEN); |
| | | int nCurBatch = -1; |
| | | int nSentEnd = -1, nLastCommaIndex = -1; |
| | | vector<int32_t> RemainIDs; // |
| | | vector<string> RemainStr; // |
| | | vector<int> new_mini_sentence_punc; // sentence_punc_list = [] |
| | | vector<string> sentenceOut; // sentenceOut |
| | | vector<string> sentence_punc_list,sentence_words_list,sentence_punc_list_out; // sentence_words_list = [] |
| | | |
| | | int nSkipNum = 0; |
| | | int nDiff = 0; |
| | | for (size_t i = 0; i < InputData.size(); i += TOKEN_LEN) |
| | | { |
| | | nDiff = (i + TOKEN_LEN) < InputData.size() ? (0) : (i + TOKEN_LEN - InputData.size()); |
| | | vector<int32_t> InputIDs(InputData.begin() + i, InputData.begin() + i + TOKEN_LEN - nDiff); |
| | | vector<string> InputStr(strOut.begin() + i, strOut.begin() + i + TOKEN_LEN - nDiff); |
| | | InputIDs.insert(InputIDs.begin(), RemainIDs.begin(), RemainIDs.end()); // RemainIDs+InputIDs; |
| | | InputStr.insert(InputStr.begin(), RemainStr.begin(), RemainStr.end()); // RemainStr+InputStr; |
| | | |
| | | auto Punction = Infer(InputIDs, arr_cache.size()); |
| | | nCurBatch = i / TOKEN_LEN; |
| | | if (nCurBatch < nTotalBatch - 1) // not the last minisetence |
| | | { |
| | | nSentEnd = -1; |
| | | nLastCommaIndex = -1; |
| | | for (int nIndex = Punction.size() - 2; nIndex > 0; nIndex--) |
| | | { |
| | | if (m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(PERIOD_INDEX) || m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(QUESTION_INDEX)) |
| | | { |
| | | nSentEnd = nIndex; |
| | | break; |
| | | } |
| | | if (nLastCommaIndex < 0 && m_tokenizer.Id2Punc(Punction[nIndex]) == m_tokenizer.Id2Punc(COMMA_INDEX)) |
| | | { |
| | | nLastCommaIndex = nIndex; |
| | | } |
| | | } |
| | | if (nSentEnd < 0 && InputStr.size() > CACHE_POP_TRIGGER_LIMIT && nLastCommaIndex > 0) |
| | | { |
| | | nSentEnd = nLastCommaIndex; |
| | | Punction[nSentEnd] = PERIOD_INDEX; |
| | | } |
| | | RemainStr.assign(InputStr.begin() + nSentEnd + 1, InputStr.end()); |
| | | RemainIDs.assign(InputIDs.begin() + nSentEnd + 1, InputIDs.end()); |
| | | InputStr.assign(InputStr.begin(), InputStr.begin() + nSentEnd + 1); // minit_sentence |
| | | Punction.assign(Punction.begin(), Punction.begin() + nSentEnd + 1); |
| | | } |
| | | |
| | | for (auto& item : Punction) |
| | | { |
| | | sentence_punc_list.push_back(m_tokenizer.Id2Punc(item)); |
| | | } |
| | | |
| | | sentence_words_list.insert(sentence_words_list.end(), InputStr.begin(), InputStr.end()); |
| | | |
| | | new_mini_sentence_punc.insert(new_mini_sentence_punc.end(), Punction.begin(), Punction.end()); |
| | | } |
| | | vector<string> WordWithPunc; |
| | | for (int i = 0; i < sentence_words_list.size(); i++) // for i in range(0, len(sentence_words_list)): |
| | | { |
| | | if (i > 0 && !(sentence_words_list[i][0] & 0x80) && (i + 1) < sentence_words_list.size() && !(sentence_words_list[i + 1][0] & 0x80)) |
| | | { |
| | | sentence_words_list[i] = sentence_words_list[i] + " "; |
| | | } |
| | | if (nSkipNum < arr_cache.size()) // if skip_num < len(cache): |
| | | nSkipNum++; |
| | | else |
| | | WordWithPunc.push_back(sentence_words_list[i]); |
| | | |
| | | if (nSkipNum >= arr_cache.size()) |
| | | { |
| | | sentence_punc_list_out.push_back(sentence_punc_list[i]); |
| | | if (sentence_punc_list[i] != NOTPUNC) |
| | | { |
| | | WordWithPunc.push_back(sentence_punc_list[i]); |
| | | } |
| | | } |
| | | } |
| | | |
| | | sentenceOut.insert(sentenceOut.end(), WordWithPunc.begin(), WordWithPunc.end()); // |
| | | nSentEnd = -1; |
| | | for (int i = sentence_punc_list.size() - 2; i > 0; i--) |
| | | { |
| | | if (new_mini_sentence_punc[i] == PERIOD_INDEX || new_mini_sentence_punc[i] == QUESTION_INDEX) |
| | | { |
| | | nSentEnd = i; |
| | | break; |
| | | } |
| | | } |
| | | arr_cache.assign(sentence_words_list.begin() + nSentEnd + 1, sentence_words_list.end()); |
| | | |
| | | if (sentenceOut.size() > 0 && m_tokenizer.IsPunc(sentenceOut[sentenceOut.size() - 1])) |
| | | { |
| | | sentenceOut.assign(sentenceOut.begin(), sentenceOut.end() - 1); |
| | | sentence_punc_list_out[sentence_punc_list_out.size() - 1] = m_tokenizer.Id2Punc(NOTPUNC_INDEX); |
| | | } |
| | | return accumulate(sentenceOut.begin(), sentenceOut.end(), string("")); |
| | | } |
| | | |
| | | vector<int> CTTransformerOnline::Infer(vector<int32_t> input_data, int nCacheSize) |
| | | { |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | vector<int> punction; |
| | | std::array<int64_t, 2> input_shape_{ 1, (int64_t)input_data.size()}; |
| | | Ort::Value onnx_input = Ort::Value::CreateTensor( |
| | | m_memoryInfo, |
| | | input_data.data(), |
| | | input_data.size() * sizeof(int32_t), |
| | | input_shape_.data(), |
| | | input_shape_.size(), ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); |
| | | |
| | | std::array<int32_t,1> text_lengths{ (int32_t)input_data.size() }; |
| | | std::array<int64_t,1> text_lengths_dim{ 1 }; |
| | | Ort::Value onnx_text_lengths = Ort::Value::CreateTensor<int32_t>( |
| | | m_memoryInfo, |
| | | text_lengths.data(), |
| | | text_lengths.size(), |
| | | text_lengths_dim.data(), |
| | | text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32); |
| | | |
| | | //vad_mask |
| | | vector<float> arVadMask,arSubMask; |
| | | int nTextLength = input_data.size(); |
| | | |
| | | VadMask(nTextLength, nCacheSize, arVadMask); |
| | | Triangle(nTextLength, arSubMask); |
| | | std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength }; |
| | | Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>( |
| | | m_memoryInfo, |
| | | arVadMask.data(), |
| | | arVadMask.size(), // * sizeof(float), |
| | | VadMask_Dim.data(), |
| | | VadMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); |
| | | //sub_masks |
| | | |
| | | std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength }; |
| | | Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>( |
| | | m_memoryInfo, |
| | | arSubMask.data(), |
| | | arSubMask.size() , |
| | | SubMask_Dim.data(), |
| | | SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); |
| | | |
| | | std::vector<Ort::Value> input_onnx; |
| | | input_onnx.emplace_back(std::move(onnx_input)); |
| | | input_onnx.emplace_back(std::move(onnx_text_lengths)); |
| | | input_onnx.emplace_back(std::move(onnx_vad_mask)); |
| | | input_onnx.emplace_back(std::move(onnx_sub_mask)); |
| | | |
| | | try { |
| | | auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), m_szInputNames.size(), m_szOutputNames.data(), m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float * floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | | |
| | | for (int i = 0; i < outputCount; i += CANDIDATE_NUM) |
| | | { |
| | | int index = Argmax(floatData + i, floatData + i + CANDIDATE_NUM-1); |
| | | punction.push_back(index); |
| | | } |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR) << "Error when run punc onnx forword: " << (e.what()); |
| | | exit(0); |
| | | } |
| | | return punction; |
| | | } |
| | | |
| | | void CTTransformerOnline::VadMask(int nSize, int vad_pos, vector<float>& Result) |
| | | { |
| | | Result.resize(0); |
| | | Result.assign(nSize * nSize, 1); |
| | | if (vad_pos <= 0 || vad_pos >= nSize) |
| | | { |
| | | return; |
| | | } |
| | | for (int i = 0; i < vad_pos-1; i++) |
| | | { |
| | | for (int j = vad_pos; j < nSize; j++) |
| | | { |
| | | Result[i * nSize + j] = 0.0f; |
| | | } |
| | | } |
| | | } |
| | | |
| | | void CTTransformerOnline::Triangle(int text_length, vector<float>& Result) |
| | | { |
| | | Result.resize(0); |
| | | Result.assign(text_length * text_length,1); // generate a zeros: text_length x text_length |
| | | |
| | | for (int i = 0; i < text_length; i++) // rows |
| | | { |
| | | for (int j = i+1; j<text_length; j++) //cols |
| | | { |
| | | Result[i * text_length + j] = 0.0f; |
| | | } |
| | | |
| | | } |
| | | //Transport(Result, text_length, text_length); |
| | | } |
| | | |
| | | void CTTransformerOnline::Transport(vector<float>& In,int nRows, int nCols) |
| | | { |
| | | vector<float> Out; |
| | | Out.resize(nRows * nCols); |
| | | int i = 0; |
| | | for (int j = 0; j < nCols; j++) { |
| | | for (; i < nRows * nCols; i++) { |
| | | Out[i] = In[j + nCols * (i % nRows)]; |
| | | if ((i + 1) % nRows == 0) { |
| | | i++; |
| | | break; |
| | | } |
| | | } |
| | | } |
| | | In = Out; |
| | | } |
| | | |
| | | } // namespace funasr |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | |
| | | #pragma once |
| | | |
| | | namespace funasr { |
| | | class CTTransformerOnline : public PuncModel { |
| | | /** |
| | | * Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | * https://arxiv.org/pdf/2003.01309.pdf |
| | | */ |
| | | |
| | | private: |
| | | |
| | | CTokenizer m_tokenizer; |
| | | vector<string> m_strInputNames, m_strOutputNames; |
| | | vector<const char*> m_szInputNames; |
| | | vector<const char*> m_szOutputNames; |
| | | |
| | | std::shared_ptr<Ort::Session> m_session; |
| | | Ort::Env env_; |
| | | Ort::SessionOptions session_options; |
| | | public: |
| | | |
| | | CTTransformerOnline(); |
| | | void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num); |
| | | ~CTTransformerOnline(); |
| | | vector<int> Infer(vector<int32_t> input_data, int nCacheSize); |
| | | string AddPunc(const char* sz_input, vector<string> &arr_cache); |
| | | void Transport(vector<float>& In, int nRows, int nCols); |
| | | void VadMask(int size, int vad_pos,vector<float>& Result); |
| | | void Triangle(int text_length, vector<float>& Result); |
| | | }; |
| | | } // namespace funasr |
| | |
| | | return mm; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num) |
| | | _FUNASRAPI FUNASR_HANDLE CTTransformerInit(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type) |
| | | { |
| | | funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num); |
| | | funasr::PuncModel* mm = funasr::CreatePuncModel(model_path, thread_num, type); |
| | | return mm; |
| | | } |
| | | |
| | |
| | | } |
| | | |
| | | // APIs for PUNC Infer |
| | | _FUNASRAPI const std::string CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback) |
| | | _FUNASRAPI FUNASR_RESULT CTTransformerInfer(FUNASR_HANDLE handle, const char* sz_sentence, FUNASR_MODE mode, QM_CALLBACK fn_callback, PUNC_TYPE type, FUNASR_RESULT pre_result) |
| | | { |
| | | funasr::PuncModel* punc_obj = (funasr::PuncModel*)handle; |
| | | if (!punc_obj) |
| | | return nullptr; |
| | | |
| | | FUNASR_RESULT p_result = nullptr; |
| | | if (type==PUNC_OFFLINE){ |
| | | p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT; |
| | | ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence); |
| | | }else if(type==PUNC_ONLINE){ |
| | | if (!pre_result) |
| | | p_result = (FUNASR_RESULT)new funasr::FUNASR_PUNC_RESULT; |
| | | else |
| | | p_result = pre_result; |
| | | ((funasr::FUNASR_PUNC_RESULT*)p_result)->msg = punc_obj->AddPunc(sz_sentence, ((funasr::FUNASR_PUNC_RESULT*)p_result)->arr_cache); |
| | | }else{ |
| | | LOG(ERROR) << "Wrong PUNC_TYPE"; |
| | | exit(-1); |
| | | } |
| | | |
| | | string punc_res = punc_obj->AddPunc(sz_sentence); |
| | | return punc_res; |
| | | return p_result; |
| | | } |
| | | |
| | | // APIs for Offline-stream Infer |
| | |
| | | return p_result->msg.c_str(); |
| | | } |
| | | |
| | | _FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index) |
| | | { |
| | | funasr::FUNASR_PUNC_RESULT * p_result = (funasr::FUNASR_PUNC_RESULT*)result; |
| | | if(!p_result) |
| | | return nullptr; |
| | | |
| | | return p_result->msg.c_str(); |
| | | } |
| | | |
| | | _FUNASRAPI vector<std::vector<int>>* FsmnVadGetResult(FUNASR_RESULT result,int n_index) |
| | | { |
| | | funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result; |
| | |
| | | } |
| | | } |
| | | |
| | | _FUNASRAPI void CTTransformerFreeResult(FUNASR_RESULT result) |
| | | { |
| | | if (result) |
| | | { |
| | | delete (funasr::FUNASR_PUNC_RESULT*)result; |
| | | } |
| | | } |
| | | |
| | | _FUNASRAPI void FsmnVadFreeResult(FUNASR_RESULT result) |
| | | { |
| | | funasr::FUNASR_VAD_RESULT * p_result = (funasr::FUNASR_VAD_RESULT*)result; |
| | |
| | | #include "offline-stream.h" |
| | | #include "tokenizer.h" |
| | | #include "ct-transformer.h" |
| | | #include "ct-transformer-online.h" |
| | | #include "e2e-vad.h" |
| | | #include "fsmn-vad.h" |
| | | #include "fsmn-vad-online.h" |
| | |
| | | #include "precomp.h" |
| | | |
| | | namespace funasr { |
| | | PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num) |
| | | PuncModel *CreatePuncModel(std::map<std::string, std::string>& model_path, int thread_num, PUNC_TYPE type) |
| | | { |
| | | PuncModel *mm; |
| | | mm = new CTTransformer(); |
| | | |
| | | if (type==PUNC_OFFLINE){ |
| | | mm = new CTTransformer(); |
| | | }else if(type==PUNC_ONLINE){ |
| | | mm = new CTTransformerOnline(); |
| | | }else{ |
| | | LOG(ERROR) << "Wrong PUNC TYPE"; |
| | | exit(-1); |
| | | } |
| | | string punc_model_path; |
| | | string punc_config_path; |
| | | |
| | |
| | | return result; |
| | | } |
| | | |
| | | bool CTokenizer::IsPunc(string& Punc) |
| | | { |
| | | if (m_punc2id.find(Punc) != m_punc2id.end()) |
| | | return true; |
| | | else |
| | | return false; |
| | | } |
| | | |
| | | vector<string> CTokenizer::SplitChineseString(const string & str_info) |
| | | { |
| | | vector<string> list; |
| | |
| | | vector<string> SplitChineseString(const string& str_info); |
| | | void StrSplit(const string& str, const char split, vector<string>& res); |
| | | void Tokenize(const char* str_info, vector<string>& str_out, vector<int>& id_out); |
| | | |
| | | bool IsPunc(string& Punc); |
| | | }; |
| | | |
| | | } // namespace funasr |
| | |
| | | std::string s_asr_quant = model_path[QUANTIZE]; |
| | | std::string s_punc_path = model_path[PUNC_DIR]; |
| | | std::string s_punc_quant = model_path[PUNC_QUANT]; |
| | | std::string python_cmd = "python -m funasr.export.export_model --type onnx --quantize True "; |
| | | std::string python_cmd = "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True "; |
| | | if(vad_dir.isSet() && !s_vad_path.empty()){ |
| | | std::string python_cmd_vad = python_cmd + " --model-name " + s_vad_path + " --export-dir " + s_download_model_dir; |
| | | if(is_download){ |
| | |
| | |
|
| | | ```shell
|
| | | # pip3 install torch torchaudio
|
| | | pip install -U modelscope funasr
|
| | | pip3 install -U modelscope funasr
|
| | | # For the users in China, you could install with the command:
|
| | | # pip install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
| | | # pip3 install -U modelscope funasr -f https://modelscope.oss-cn-beijing.aliyuncs.com/releases/repo.html -i https://mirror.sjtu.edu.cn/pypi/web/simple
|
| | | ```
|
| | |
|
| | | ### Export [onnx model](https://github.com/alibaba-damo-academy/FunASR/tree/main/funasr/export)
|
| | |
|
| | | ```shell
|
| | | python -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
| | | python3 -m funasr.export.export_model --model-name damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch --export-dir ./export --type onnx --quantize True
|
| | | ```
|
| | |
|
| | | ## Building for Linux/Unix
|
| | |
| | | required openssl lib
|
| | |
|
| | | ```shell
|
| | | #install openssl lib for ubuntu |
| | | apt-get install libssl-dev
|
| | | #install openssl lib for centos
|
| | | yum install openssl-devel
|
| | | apt-get install libssl-dev #ubuntu |
| | | # yum install openssl-devel #centos
|
| | |
|
| | |
|
| | | git clone https://github.com/alibaba-damo-academy/FunASR.git && cd funasr/runtime/websocket
|
| | | git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/funasr/runtime/websocket
|
| | | mkdir build && cd build
|
| | | cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
|
| | | make
|