From 70645e48072bf193fbf069949f1d2b10fddac8a3 Mon Sep 17 00:00:00 2001
From: pointerhacker <145901472+pointerhacker@users.noreply.github.com>
Date: 星期二, 15 十月 2024 17:50:51 +0800
Subject: [PATCH] 数据并行可能导致的模型训练报错 (#2139)

---
 runtime/onnxruntime/include/model.h |    8 +++++---
 1 files changed, 5 insertions(+), 3 deletions(-)

diff --git a/runtime/onnxruntime/include/model.h b/runtime/onnxruntime/include/model.h
index b8eb491..a49baeb 100644
--- a/runtime/onnxruntime/include/model.h
+++ b/runtime/onnxruntime/include/model.h
@@ -16,14 +16,16 @@
     virtual void StartUtterance() = 0;
     virtual void EndUtterance() = 0;
     virtual void Reset() = 0;
-    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
-    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
-    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
+    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, const std::string &token_file, int thread_num){};
     virtual void InitLm(const std::string &lm_file, const std::string &lm_config, const std::string &lex_file){};
     virtual void InitFstDecoder(){};
     virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr){return "";};
     virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}, void* wfst_decoder=nullptr, int batch_in=1)
       {return std::vector<string>();};
+    virtual std::vector<std::string> Forward(float** din, int* len, bool input_finished, std::string svs_lang="auto", bool svs_itn=false, int batch_in=1)
+      {return std::vector<string>();};
     virtual std::string Rescoring() = 0;
     virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
     virtual void InitSegDict(const std::string &seg_dict_model){};

--
Gitblit v1.9.1