andyweiqiu
2023-09-21 b8909ab5b94e293605a9e04f42874e0320773a49
itn on iOS is not supported for the time being (#969)

10个文件已修改
56 ■■■■ 已修改文件
funasr/runtime/ios/paraformer_online/paraformer_online.xcodeproj/project.pbxproj 16 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/funasrruntime.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/model.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/offline-stream.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/tpass-stream.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasrruntime.cpp 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/offline-stream.cpp 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/tpass-stream.cpp 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/ios/paraformer_online/paraformer_online.xcodeproj/project.pbxproj
@@ -90,6 +90,8 @@
        1A7F0DBE2A2F221C00A6EEB7 /* AudioCapture.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1A7F0DBB2A2F221C00A6EEB7 /* AudioCapture.mm */; };
        1A7F0DBF2A2F221C00A6EEB7 /* AudioRecorder.m in Sources */ = {isa = PBXBuildFile; fileRef = 1A7F0DBD2A2F221C00A6EEB7 /* AudioRecorder.m */; };
        1A7F0DC32A2F312D00A6EEB7 /* model in Resources */ = {isa = PBXBuildFile; fileRef = 1A7F0DC22A2F312D00A6EEB7 /* model */; };
        1ACBFB692AB99D55002FC7C7 /* seg_dict.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1ACBFB672AB99D55002FC7C7 /* seg_dict.cpp */; };
        1ACBFB6C2AB9A086002FC7C7 /* encode_converter.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 1ACBFB6B2AB9A086002FC7C7 /* encode_converter.cpp */; };
        59C4114F365C8D714BD515FB /* Pods_paraformer_online.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = EA7D0713E60886A787BAA0EA /* Pods_paraformer_online.framework */; };
/* End PBXBuildFile section */
@@ -324,6 +326,10 @@
        1AB8E1EE2AA086F200F4F795 /* model.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = model.h; sourceTree = "<group>"; };
        1AB8E1EF2AA086F200F4F795 /* offline-stream.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "offline-stream.h"; sourceTree = "<group>"; };
        1AB8E1F02AA086F200F4F795 /* vad-model.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = "vad-model.h"; sourceTree = "<group>"; };
        1ACBFB672AB99D55002FC7C7 /* seg_dict.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = seg_dict.cpp; sourceTree = "<group>"; };
        1ACBFB682AB99D55002FC7C7 /* seg_dict.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = seg_dict.h; sourceTree = "<group>"; };
        1ACBFB6A2AB9A086002FC7C7 /* encode_converter.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = encode_converter.h; sourceTree = "<group>"; };
        1ACBFB6B2AB9A086002FC7C7 /* encode_converter.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = encode_converter.cpp; sourceTree = "<group>"; };
        B9ED2A36675364C815C03C96 /* Pods-paraformer_online.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-paraformer_online.debug.xcconfig"; path = "Target Support Files/Pods-paraformer_online/Pods-paraformer_online.debug.xcconfig"; sourceTree = "<group>"; };
        EA7D0713E60886A787BAA0EA /* Pods_paraformer_online.framework */ = {isa = PBXFileReference; explicitFileType = wrapper.framework; includeInIndex = 0; path = Pods_paraformer_online.framework; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
@@ -355,6 +361,8 @@
                1A6C92FB2A84D64E007E36DC /* ct-transformer.cpp */,
                1A6C93032A84D64E007E36DC /* ct-transformer.h */,
                1A6C92F92A84D64E007E36DC /* e2e-vad.h */,
                1ACBFB6B2AB9A086002FC7C7 /* encode_converter.cpp */,
                1ACBFB6A2AB9A086002FC7C7 /* encode_converter.h */,
                1A6C92F72A84D64E007E36DC /* fsmn-vad-online.cpp */,
                1A6C92E92A84D64E007E36DC /* fsmn-vad-online.h */,
                1A6C92E82A84D64E007E36DC /* fsmn-vad.cpp */,
@@ -371,6 +379,8 @@
                1A6C93022A84D64E007E36DC /* punc-model.cpp */,
                1A6C92ED2A84D64E007E36DC /* resample.cpp */,
                1A6C92E32A84D64E007E36DC /* resample.h */,
                1ACBFB672AB99D55002FC7C7 /* seg_dict.cpp */,
                1ACBFB682AB99D55002FC7C7 /* seg_dict.h */,
                1A6C93012A84D64E007E36DC /* tensor.h */,
                1A6C92F02A84D64E007E36DC /* tokenizer.cpp */,
                1A6C92EF2A84D64E007E36DC /* tokenizer.h */,
@@ -917,6 +927,7 @@
                1A6C93F72A84D66E007E36DC /* symbolize.cc in Sources */,
                1A6C93062A84D64E007E36DC /* util.cpp in Sources */,
                1A6C94222A84D66E007E36DC /* nodebuilder.cpp in Sources */,
                1ACBFB692AB99D55002FC7C7 /* seg_dict.cpp in Sources */,
                1A6C94132A84D66E007E36DC /* exp.cpp in Sources */,
                1A6C930A2A84D64E007E36DC /* vocab.cpp in Sources */,
                1A6C94012A84D66E007E36DC /* logging.cc in Sources */,
@@ -931,6 +942,7 @@
                1A6C940F2A84D66E007E36DC /* emitter.cpp in Sources */,
                1A6C93DE2A84D66E007E36DC /* fftsg.c in Sources */,
                1A6C940B2A84D66E007E36DC /* ostream_wrapper.cpp in Sources */,
                1ACBFB6C2AB9A086002FC7C7 /* encode_converter.cpp in Sources */,
                1A6C93E12A84D66E007E36DC /* log.cc in Sources */,
                1A6C94092A84D66E007E36DC /* exceptions.cpp in Sources */,
                1A6C94152A84D66E007E36DC /* node.cpp in Sources */,
@@ -1108,7 +1120,7 @@
                    "@executable_path/Frameworks",
                );
                MARKETING_VERSION = 1.0;
                PRODUCT_BUNDLE_IDENTIFIER = "com.qiuwei.paraformer-online";
                PRODUCT_BUNDLE_IDENTIFIER = "com.qiuwei.paraformer-online1";
                PRODUCT_NAME = "$(TARGET_NAME)";
                SWIFT_EMIT_LOC_STRINGS = YES;
                TARGETED_DEVICE_FAMILY = "1,2";
@@ -1147,7 +1159,7 @@
                    "@executable_path/Frameworks",
                );
                MARKETING_VERSION = 1.0;
                PRODUCT_BUNDLE_IDENTIFIER = "com.qiuwei.paraformer-online";
                PRODUCT_BUNDLE_IDENTIFIER = "com.qiuwei.paraformer-online1";
                PRODUCT_NAME = "$(TARGET_NAME)";
                SWIFT_EMIT_LOC_STRINGS = YES;
                TARGETED_DEVICE_FAMILY = "1,2";
funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -105,7 +105,10 @@
_FUNASRAPI FUNASR_RESULT    FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, 
                                            QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, 
                                            int sampling_rate=16000, bool itn=true);
#if !defined(__APPLE__)
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode=ASR_OFFLINE);
#endif
_FUNASRAPI void                FunOfflineUninit(FUNASR_HANDLE handle);
//2passStream
funasr/runtime/onnxruntime/include/model.h
@@ -17,7 +17,7 @@
    virtual std::string Rescoring() = 0;
    virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){return std::vector<std::vector<float>>();};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
funasr/runtime/onnxruntime/include/offline-stream.h
@@ -7,7 +7,9 @@
#include "model.h"
#include "punc-model.h"
#include "vad-model.h"
#if !defined(__APPLE__)
#include "itn-model.h"
#endif
namespace funasr {
class OfflineStream {
@@ -18,7 +20,9 @@
    std::unique_ptr<VadModel> vad_handle= nullptr;
    std::unique_ptr<Model> asr_handle= nullptr;
    std::unique_ptr<PuncModel> punc_handle= nullptr;
#if !defined(__APPLE__)
    std::unique_ptr<ITNModel> itn_handle = nullptr;
#endif
    bool UseVad(){return use_vad;};
    bool UsePunc(){return use_punc;}; 
    bool UseITN(){return use_itn;};
funasr/runtime/onnxruntime/include/tpass-stream.h
@@ -7,7 +7,9 @@
#include "model.h"
#include "punc-model.h"
#include "vad-model.h"
#if !defined(__APPLE__)
#include "itn-model.h"
#endif
namespace funasr {
class TpassStream {
@@ -18,7 +20,9 @@
    std::unique_ptr<VadModel> vad_handle = nullptr;
    std::unique_ptr<Model> asr_handle = nullptr;
    std::unique_ptr<PuncModel> punc_online_handle = nullptr;
#if !defined(__APPLE__)
    std::unique_ptr<ITNModel> itn_handle = nullptr;
#endif
    bool UseVad(){return use_vad;};
    bool UsePunc(){return use_punc;}; 
    bool UseITN(){return use_itn;};
funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -285,10 +285,12 @@
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
#if !defined(__APPLE__)
        if(offline_stream->UseITN() && itn){
            string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
            p_result->msg = msg_itn;
        }
#endif
        return p_result;
    }
@@ -364,13 +366,16 @@
            string punc_res = (offline_stream->punc_handle)->AddPunc((p_result->msg).c_str());
            p_result->msg = punc_res;
        }
#if !defined(__APPLE__)
        if(offline_stream->UseITN() && itn){
            string msg_itn = offline_stream->itn_handle->Normalize(p_result->msg);
            p_result->msg = msg_itn;
        }
#endif
        return p_result;
    }
#if !defined(__APPLE__)
    _FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords, ASR_TYPE mode)
    {
        if (mode == ASR_OFFLINE){
@@ -394,7 +399,7 @@
        }
        
    }
#endif
    // APIs for 2pass-stream Infer
    _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, 
@@ -450,13 +455,13 @@
                    string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
                    string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
                    p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
                    // ITN
                    if(tpass_stream->UseITN() && itn){
                        string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
                        p_result->tpass_msg = msg_itn;
                    }
#endif
                    ((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
                    p_result->msg += msg;
                }else{
@@ -501,10 +506,12 @@
                msg_punc += "。";
            }
            p_result->tpass_msg = msg_punc;
#if !defined(__APPLE__)
            if(tpass_stream->UseITN() && itn){
                string msg_itn = tpass_stream->itn_handle->Normalize(msg_punc);
                p_result->tpass_msg = msg_itn;
            }
#endif
            if(frame != NULL){
                delete frame;
funasr/runtime/onnxruntime/src/offline-stream.cpp
@@ -84,7 +84,7 @@
            use_punc = true;
        }
    }
#if !defined(__APPLE__)
    // Optional: ITN, here we just support language_type=MandarinEnglish
    if(model_path.find(ITN_DIR) != model_path.end() && model_path.at(ITN_DIR) != ""){
        string itn_tagger_path = PathAppend(model_path.at(ITN_DIR), ITN_TAGGER_NAME);
@@ -100,6 +100,7 @@
            use_itn = true;
        }
    }
#endif
}
OfflineStream *CreateOfflineStream(std::map<std::string, std::string>& model_path, int thread_num)
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -662,7 +662,7 @@
                return "";
            }
            //PrintMat(hw_emb, "input_clas_emb");
            const int64_t hotword_shape[3] = {1, hw_emb.size(), hw_emb[0].size()};
            const int64_t hotword_shape[3] = {1, static_cast<int64_t>(hw_emb.size()), static_cast<int64_t>(hw_emb[0].size())};
            embedding.reserve(hw_emb.size() * hw_emb[0].size());
            for (auto item : hw_emb) {
                embedding.insert(embedding.end(), item.begin(), item.end());
funasr/runtime/onnxruntime/src/precomp.h
@@ -24,6 +24,8 @@
#else
#include "onnxruntime_run_options_config_keys.h"
#include "onnxruntime_cxx_api.h"
#include "itn-model.h"
#include "itn-processor.h"
#endif
#include "kaldi-native-fbank/csrc/feature-fbank.h"
@@ -38,11 +40,9 @@
#include "model.h"
#include "vad-model.h"
#include "punc-model.h"
#include "itn-model.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "ct-transformer-online.h"
#include "itn-processor.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "encode_converter.h"
funasr/runtime/onnxruntime/src/tpass-stream.cpp
@@ -89,7 +89,7 @@
            use_punc = true;
        }
    }
#if !defined(__APPLE__)
    // Optional: ITN, here we just support language_type=MandarinEnglish
    if(model_path.find(ITN_DIR) != model_path.end()){
        string itn_tagger_path = PathAppend(model_path.at(ITN_DIR), ITN_TAGGER_NAME);
@@ -105,6 +105,7 @@
            use_itn = true;
        }
    }
#endif
      
}
@@ -114,4 +115,4 @@
    mm = new TpassStream(model_path, thread_num);
    return mm;
}
} // namespace funasr
} // namespace funasr