From dfcc5d47587d3e793cbfec2e9509c0e9a9e1732c Mon Sep 17 00:00:00 2001
From: Dogvane Huang <dogvane@gmail.com>
Date: 星期二, 02 七月 2024 12:24:13 +0800
Subject: [PATCH] fix c# demo project to new onnx model files (#1689)

---
 runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs |   83 +++++++++++++++++++++++++++++++++--------
 1 files changed, 66 insertions(+), 17 deletions(-)

diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
index f1f491a..c2d7f68 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -9,9 +9,20 @@
 using Microsoft.ML.OnnxRuntime.Tensors;
 using Microsoft.Extensions.Logging;
 using System.Text.RegularExpressions;
+using Newtonsoft.Json.Linq;
 
+// 妯″瀷鏂囦欢鍦板潃锛� https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
 namespace AliParaformerAsr
 {
+    public enum OnnxRumtimeTypes
+    {
+        CPU = 0,
+
+        DML = 1,
+
+        CUDA = 2,
+    }
+
     /// <summary>
     /// offline recognizer package
     /// Copyright (c)  2023 by manyeyes
@@ -24,35 +35,70 @@
         private string _frontend;
         private FrontendConfEntity _frontendConfEntity;
         private string[] _tokens;
-        private int _batchSize = 1;
 
-        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
+        /// <summary>
+        /// 
+        /// </summary>
+        /// <param name="modelFilePath"></param>
+        /// <param name="configFilePath"></param>
+        /// <param name="mvnFilePath"></param>
+        /// <param name="tokensFilePath"></param>
+        /// <param name="rumtimeType">鍙互閫夋嫨gpu锛屼絾鏄洰鍓嶆儏鍐典笅锛屼笉寤鸿浣跨敤锛屽洜涓烘�ц兘鎻愬崌鏈夐檺</param>
+        /// <param name="deviceId">璁惧id锛屽鏄惧崱鏃剁敤浜庢寚瀹氭墽琛岀殑鏄惧崱</param>
+        /// <param name="batchSize"></param>
+        /// <param name="threadsNum"></param>
+        /// <exception cref="ArgumentException"></exception>
+        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
         {
-            Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
+            var options = new SessionOptions();
+            switch(rumtimeType)
+            {
+                case OnnxRumtimeTypes.DML:
+                    options.AppendExecutionProvider_DML(deviceId);
+                    break;
+                case OnnxRumtimeTypes.CUDA:
+                    options.AppendExecutionProvider_CUDA(deviceId);
+                    break;
+                default:
+                    options.AppendExecutionProvider_CPU(deviceId);
+                    break;
+            }
             //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
-            //options.AppendExecutionProvider_DML(0);
-            options.AppendExecutionProvider_CPU(0);
-            //options.AppendExecutionProvider_CUDA(0);
-            options.InterOpNumThreads = threadsNum;
-            _onnxSession = new InferenceSession(modelFilePath, options);
 
-            _tokens = File.ReadAllLines(tokensFilePath);
+            _onnxSession = new InferenceSession(modelFilePath, options);
+            
+            string[] tokenLines;
+            if (tokensFilePath.EndsWith(".txt"))
+            {
+                tokenLines = File.ReadAllLines(tokensFilePath);
+            }
+            else if (tokensFilePath.EndsWith(".json"))
+            {
+                string jsonContent = File.ReadAllText(tokensFilePath);
+                JArray tokenArray = JArray.Parse(jsonContent);
+                tokenLines = tokenArray.Select(t => t.ToString()).ToArray();
+            }
+            else
+            {
+                throw new ArgumentException("Invalid tokens file format. Only .txt and .json are supported.");
+            }
+
+            _tokens = tokenLines;
 
             OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
             _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
             _frontend = offlineYamlEntity.frontend;
             _frontendConfEntity = offlineYamlEntity.frontend_conf;
-            _batchSize = batchSize;
             ILoggerFactory loggerFactory = new LoggerFactory();
             _logger = new Logger<OfflineRecognizer>(loggerFactory);
         }
 
         public List<string> GetResults(List<float[]> samples)
         {
-            this._logger.LogInformation("get features begin");
+            _logger.LogInformation("get features begin");
             List<OfflineInputEntity> offlineInputEntities = ExtractFeats(samples);
-            OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
-            List<string> text_results = this.DecodeMulti(modelOutput.Token_nums);
+            OfflineOutputEntity modelOutput = Forward(offlineInputEntities);
+            List<string> text_results = DecodeMulti(modelOutput.Token_nums);
             return text_results;
         }
 
@@ -156,15 +202,18 @@
                     {
                         break;
                     }
-                    if (_tokens[token] != "</s>" && _tokens[token] != "<s>" && _tokens[token] != "<blank>")
+
+                    string tokenChar = _tokens[token];
+
+                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>")
                     {                        
-                        if (IsChinese(_tokens[token],true))
+                        if (IsChinese(tokenChar, true))
                         {
-                            text_result += _tokens[token];
+                            text_result += tokenChar;
                         }
                         else
                         {
-                            text_result += "鈻�" + _tokens[token]+ "鈻�";
+                            text_result += "鈻�" + tokenChar + "鈻�";
                         }
                     }
                 }

--
Gitblit v1.9.1