cdevelop
2023-11-15 eff2570faf3dae7908db87edf4ef1a6ea88e5b33
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
 
#pragma once 
 
namespace funasr {
class CTTransformer : public PuncModel {
/**
 * Author: Speech Lab of DAMO Academy, Alibaba Group
 * CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection
 * https://arxiv.org/pdf/2003.01309.pdf
*/
 
private:
 
    CTokenizer m_tokenizer;
    vector<string> m_strInputNames, m_strOutputNames;
    vector<const char*> m_szInputNames;
    vector<const char*> m_szOutputNames;
 
    std::shared_ptr<Ort::Session> m_session;
    Ort::Env env_;
    Ort::SessionOptions session_options;
public:
 
    CTTransformer();
    void InitPunc(const std::string &punc_model, const std::string &punc_config, int thread_num);
    ~CTTransformer();
    vector<int>  Infer(vector<int32_t> input_data);
    string AddPunc(const char* sz_input, std::string language="zh-cn");
};
} // namespace funasr