From dfcc5d47587d3e793cbfec2e9509c0e9a9e1732c Mon Sep 17 00:00:00 2001
From: Dogvane Huang <dogvane@gmail.com>
Date: 星期二, 02 七月 2024 12:24:13 +0800
Subject: [PATCH] fix c# demo project to new onnx model files (#1689)

---
 runtime/csharp/AliParaformerAsr/AliParaformerAsr/Utils/YamlHelper.cs                       |   13 +
 runtime/csharp/AliFsmnVad/AliFsmnVadSharp/AliFsmnVad.cs                                    |    6 
 runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj         |    1 
 runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs                      |   83 +++++++++--
 runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/AliParaformerAsr.Examples.csproj |    2 
 runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/Program.cs                              |  140 +++++++++++++------
 runtime/csharp/AliParaformerAsr/AliParaformerAsr/AliParaformerAsr.csproj                   |    4 
 runtime/csharp/AliFsmnVad/AliFsmnVadSharp/Model/VadYamlEntity.cs                           |    2 
 runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/Program.cs                       |  140 +++++++++++++------
 9 files changed, 273 insertions(+), 118 deletions(-)

diff --git a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj
index b494bb5..cdb4122 100644
--- a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj
+++ b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/AliFsmnVadSharp.Examples.csproj
@@ -8,6 +8,7 @@
   </PropertyGroup>
 
   <ItemGroup>
+    <PackageReference Include="CommandLineParser" Version="2.9.1" />
     <PackageReference Include="NAudio" Version="2.1.0" />
   </ItemGroup>
 
diff --git a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/Program.cs b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/Program.cs
index dd3bf78..c3fc52f 100644
--- a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/Program.cs
+++ b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp.Examples/Program.cs
@@ -1,61 +1,107 @@
 锘縰sing AliFsmnVadSharp;
 using AliFsmnVadSharp.Model;
+using CommandLine;
 using NAudio.Wave;
 
 internal static class Program
 {
-	[STAThread]
-	private static void Main()
-	{
-		string applicationBase = AppDomain.CurrentDomain.BaseDirectory;
-		string modelFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/model.onnx";
-		string configFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.yaml";
-		string mvnFilePath = applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/vad.mvn";
-		int batchSize = 2;
-		TimeSpan start_time0 = new TimeSpan(DateTime.Now.Ticks);
-		AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize);
-		TimeSpan end_time0 = new TimeSpan(DateTime.Now.Ticks);
-		double elapsed_milliseconds0 = end_time0.TotalMilliseconds - start_time0.TotalMilliseconds;
-		Console.WriteLine("load model and init config elapsed_milliseconds:{0}", elapsed_milliseconds0.ToString());
-		List<float[]> samples = new List<float[]>();
-		TimeSpan total_duration = new TimeSpan(0L);
-		for (int i = 0; i < 2; i++)
-		{
-			string wavFilePath = string.Format(applicationBase + "./speech_fsmn_vad_zh-cn-16k-common-pytorch/example/{0}.wav", i.ToString());//vad_example
-			if (!File.Exists(wavFilePath))
-			{
-				continue;
-			}
-			AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath);
-			byte[] datas = new byte[_audioFileReader.Length];
-			_audioFileReader.Read(datas, 0, datas.Length);
-			TimeSpan duration = _audioFileReader.TotalTime;
-			float[] wavdata = new float[datas.Length / 4];
-			Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length);
-			float[] sample = wavdata.Select((float x) => x * 32768f).ToArray();
-			samples.Add(wavdata);
-			total_duration += duration;			
-		}
-		TimeSpan start_time = new TimeSpan(DateTime.Now.Ticks);
-		//SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples);
-		SegmentEntity[] segments_duration = aliFsmnVad.GetSegmentsByStep(samples);
-		TimeSpan end_time = new TimeSpan(DateTime.Now.Ticks);
-		Console.WriteLine("vad infer result:");
-		foreach (SegmentEntity segment in segments_duration)
-		{
-			Console.Write("[");
-			foreach (var x in segment.Segment) 
-			{
-				Console.Write("[" + string.Join(",", x.ToArray()) + "]");
-			}
-			Console.Write("]\r\n");
-		}
+    public class ProgramParams
+    {
+        [Option('i', "input", Required = true, HelpText = "Input wav file/folder path.")]
+        public string WavFilePath { get; set; }
 
-		double elapsed_milliseconds = end_time.TotalMilliseconds - start_time.TotalMilliseconds;
+        [Option('m', "model", Default = "speech_fsmn_vad_zh-cn-16k-common-onnx", HelpText = "Model path.")]
+        public string Model { get; set; }
+    }
+
+    [STAThread]
+	private static void Main(string[] args)
+	{
+        var argParams = Parser.Default.ParseArguments<ProgramParams>(args).Value;
+
+        string modelPath = argParams.Model;
+        if (!Directory.Exists(argParams.Model))
+        {
+            modelPath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, modelPath);
+            if (!Directory.Exists(modelPath))
+            {
+                throw new DirectoryNotFoundException($"Model not found: {argParams.Model}");
+            }
+        }
+
+        string modelFilePath = Path.Combine(modelPath, "model_quant.onnx");
+        string configFilePath = Path.Combine(modelPath, "config.yaml");
+        string mvnFilePath = Path.Combine(modelPath, "am.mvn");
+
+        int batchSize = 1;
+        AliFsmnVad aliFsmnVad = new AliFsmnVad(modelFilePath, configFilePath, mvnFilePath, batchSize);
+
+        List<string> wavFiles = new List<string>();
+
+        if (File.Exists(argParams.WavFilePath))
+        {
+            wavFiles.Add(argParams.WavFilePath);
+        }
+        else if (Directory.Exists(argParams.WavFilePath))
+        {
+            foreach (var wavFilePath in Directory.GetFiles(argParams.WavFilePath, "*.wav"))
+            {
+                wavFiles.Add(wavFilePath);
+            }
+        }
+        else
+        {
+            throw new Exception($"Invalid wav input path. {argParams.WavFilePath}");
+        }
+
+        var start_time = DateTime.Now;
+
+        TimeSpan total_duration = new TimeSpan(0L);
+        for (int i = 0; i < wavFiles.Count; i += batchSize)
+        {
+            List<float[]> samples = new List<float[]>();
+            
+            foreach(var wavFile in wavFiles.Skip(i).Take(batchSize))
+            {
+                (var sample, var duration) = LoadWavFile(wavFile);
+                samples.Add(sample);
+                total_duration += duration;
+            }
+
+            SegmentEntity[] segments_duration = aliFsmnVad.GetSegments(samples);
+            Console.WriteLine("vad infer result:");
+            foreach (SegmentEntity segment in segments_duration)
+            {
+                Console.Write("[");
+                foreach (var x in segment.Segment)
+                {
+                    Console.Write("[" + string.Join(",", x.ToArray()) + "]");
+                }
+                Console.Write("]\r\n");
+            }
+        }
+
+        var end_time = DateTime.Now;
+
+		double elapsed_milliseconds = (end_time - start_time).TotalMilliseconds;
+
 		double rtf = elapsed_milliseconds / total_duration.TotalMilliseconds;
 		Console.WriteLine("elapsed_milliseconds:{0}", elapsed_milliseconds.ToString());
 		Console.WriteLine("total_duration:{0}", total_duration.TotalMilliseconds.ToString());
 		Console.WriteLine("rtf:{1}", "0".ToString(), rtf.ToString());
 		Console.WriteLine("------------------------");
 	}
+
+    private static (float[] sample, TimeSpan duration) LoadWavFile(string wavFilePath)
+    {
+        AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath);
+        byte[] datas = new byte[_audioFileReader.Length];
+        _audioFileReader.Read(datas, 0, datas.Length);
+        var duration = _audioFileReader.TotalTime;
+        float[] wavdata = new float[datas.Length / 4];
+        Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length);
+        var sample = wavdata.Select((float x) => x * 32768f).ToArray();
+
+        return (sample, duration);
+    }
 }
\ No newline at end of file
diff --git a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/AliFsmnVad.cs b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/AliFsmnVad.cs
index 672eac2..aad8bf5 100644
--- a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/AliFsmnVad.cs
+++ b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/AliFsmnVad.cs
@@ -4,6 +4,8 @@
 using Microsoft.ML.OnnxRuntime;
 using Microsoft.ML.OnnxRuntime.Tensors;
 
+// 妯″瀷鏂囦欢涓嬭浇鍦板潃锛歨ttps://modelscope.cn/models/iic/speech_fsmn_vad_zh-cn-16k-common-onnx/
+
 namespace AliFsmnVadSharp
 {
     public class AliFsmnVad : IDisposable
@@ -28,9 +30,9 @@
             VadYamlEntity vadYamlEntity = YamlHelper.ReadYaml<VadYamlEntity>(configFilePath);
             _wavFrontend = new WavFrontend(mvnFilePath, vadYamlEntity.frontend_conf);
             _frontend = vadYamlEntity.frontend;
-            _vad_post_conf = vadYamlEntity.vad_post_conf;
+            _vad_post_conf = vadYamlEntity.model_conf;
             _batchSize = batchSize;
-            _max_end_sil = _max_end_sil != int.MinValue ? _max_end_sil : vadYamlEntity.vad_post_conf.max_end_silence_time;
+            _max_end_sil = _max_end_sil != int.MinValue ? _max_end_sil : vadYamlEntity.model_conf.max_end_silence_time;
             _encoderConfEntity = vadYamlEntity.encoder_conf;
 
             ILoggerFactory loggerFactory = new LoggerFactory();
diff --git a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/Model/VadYamlEntity.cs b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/Model/VadYamlEntity.cs
index 65e77ed..665b415 100644
--- a/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/Model/VadYamlEntity.cs
+++ b/runtime/csharp/AliFsmnVad/AliFsmnVadSharp/Model/VadYamlEntity.cs
@@ -22,6 +22,6 @@
         public string encoder { get => _encoder; set => _encoder = value; }
         public FrontendConfEntity frontend_conf { get => _frontend_conf; set => _frontend_conf = value; }
         public EncoderConfEntity encoder_conf { get => _encoder_conf; set => _encoder_conf = value; }
-        public VadPostConfEntity vad_post_conf { get => _vad_post_conf; set => _vad_post_conf = value; }
+        public VadPostConfEntity model_conf { get => _vad_post_conf; set => _vad_post_conf = value; }
     }
 }
diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/AliParaformerAsr.Examples.csproj b/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/AliParaformerAsr.Examples.csproj
index 7d76968..094753a 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/AliParaformerAsr.Examples.csproj
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/AliParaformerAsr.Examples.csproj
@@ -5,6 +5,7 @@
     <TargetFramework>net6.0</TargetFramework>
     <ImplicitUsings>enable</ImplicitUsings>
     <Nullable>enable</Nullable>
+    <Platforms>AnyCPU;x64</Platforms>
   </PropertyGroup>
 
   <ItemGroup>
@@ -12,6 +13,7 @@
   </ItemGroup>
 
   <ItemGroup>
+    <PackageReference Include="CommandLineParser" Version="2.9.1" />
     <PackageReference Include="NAudio" Version="2.1.0" />
   </ItemGroup>
 
diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/Program.cs b/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/Program.cs
index 26b11f9..03bbb16 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/Program.cs
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr.Examples/Program.cs
@@ -1,66 +1,116 @@
 锘縰sing AliParaformerAsr;
+using CommandLine;
 using NAudio.Wave;
+
 internal static class Program
 {
-	[STAThread]
-	private static void Main()
-	{
-        string applicationBase = AppDomain.CurrentDomain.BaseDirectory;
-        string modelName = "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx";
-        string modelFilePath = applicationBase + "./"+ modelName + "/model_quant.onnx";
-        string configFilePath = applicationBase + "./" + modelName + "/asr.yaml";
-        string mvnFilePath = applicationBase + "./" + modelName + "/am.mvn";
-        string tokensFilePath = applicationBase + "./" + modelName + "/tokens.txt";
-        AliParaformerAsr.OfflineRecognizer offlineRecognizer = new OfflineRecognizer(modelFilePath, configFilePath, mvnFilePath, tokensFilePath);
-        List<float[]>? samples = null;
-        TimeSpan total_duration = new TimeSpan(0L);
-        if (samples == null)
+
+    public class ProgramParams
+    {
+        [Option('i', "input", Required = true, HelpText = "Input wav file/folder path.")]
+        public string WavFilePath { get; set; }
+
+        [Option('m', "model", Default = "speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx", HelpText = "Model path.")]
+        public string Model { get; set; }
+    }
+
+    [STAThread]
+    private static void Main(string[] args)
+    {
+        var argParams = Parser.Default.ParseArguments<ProgramParams>(args).Value;
+
+        string modelPath = argParams.Model;
+        if (!Directory.Exists(argParams.Model))
         {
-            samples = new List<float[]>();
-            for (int i = 0; i < 5; i++)
+            modelPath = Path.Combine(AppDomain.CurrentDomain.BaseDirectory, modelPath);
+            if (!Directory.Exists(modelPath))
             {
-                string wavFilePath = string.Format(applicationBase + "./" + modelName + "/example/{0}.wav", i.ToString());
-                if (!File.Exists(wavFilePath))
-                {
-                    break;
-                }
-                AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath);
-                byte[] datas = new byte[_audioFileReader.Length];
-                _audioFileReader.Read(datas, 0, datas.Length);
-                TimeSpan duration = _audioFileReader.TotalTime;
-                float[] wavdata = new float[datas.Length / 4];
-                Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length);
-                float[] sample = wavdata.Select((float x) => x * 32768f).ToArray();
-                samples.Add(sample);
-                total_duration += duration;
+                throw new DirectoryNotFoundException($"Model not found: {argParams.Model}");
             }
         }
-        TimeSpan start_time = new TimeSpan(DateTime.Now.Ticks);
-        //1.Non batch method
-        foreach (var sample in samples)
+
+        string modelFilePath = Path.Combine(modelPath, "model_quant.onnx");
+        string configFilePath = Path.Combine(modelPath, "asr.yaml");
+        string mvnFilePath = Path.Combine(modelPath, "am.mvn");
+        string tokensFilePath = Path.Combine(modelPath, "tokens.json");
+
+        
+        var offlineRecognizer = new OfflineRecognizer(modelFilePath, configFilePath, mvnFilePath, tokensFilePath, OnnxRumtimeTypes.CPU);
+
+        List<float[]> samples = new List<float[]>();
+        TimeSpan total_duration = new TimeSpan(0L);
+
+        if (File.Exists(argParams.WavFilePath))
         {
-            List<float[]> temp_samples = new List<float[]>();
-            temp_samples.Add(sample);
+            (var sample, var duration) = LoadWavFile(argParams.WavFilePath);
+
+            samples.Add(sample);
+            total_duration += duration;
+        }
+        else if (Directory.Exists(argParams.WavFilePath)) 
+        {
+            var findWavCount = 0;
+
+            foreach (var wavFilePath in Directory.EnumerateFiles(argParams.WavFilePath, "*.wav"))
+            {
+                (var sample, var duration) = LoadWavFile(wavFilePath);
+
+                samples.Add(sample);
+                total_duration += duration;
+                findWavCount++;
+            }
+
+            Console.WriteLine($"Total WAV files found: {findWavCount} duration锛歿total_duration}");
+        }
+        else
+        {
+            throw new Exception($"Invalid wav input path. {argParams.WavFilePath}");
+        }
+
+        var start_time = DateTime.Now;
+
+        int batchSize = 1; // 杈撳叆鍙傛暟鏀寔鎵瑰鐞嗭紝浣嗘槸瀹為檯鏁堟灉鎻愬崌鏈夐檺锛屾劅瑙夎繕鏄礋浼樺寲锛岀瓑GPU鐗堟湰浼樺寲鍚庡啀璇�
+        for (int i = 0; i < samples.Count; i += batchSize)
+        {
+            List<float[]> temp_samples = samples.Skip(i).Take(batchSize).ToList();
+
             List<string> results = offlineRecognizer.GetResults(temp_samples);
+
             foreach (string result in results)
             {
                 Console.WriteLine(result);
                 Console.WriteLine("");
             }
         }
-        //2.batch method
-        //List<string> results_batch = offlineRecognizer.GetResults(samples);
-        //foreach (string result in results_batch)
-        //{
-        //    Console.WriteLine(result);
-        //    Console.WriteLine("");
-        //}
-        TimeSpan end_time = new TimeSpan(DateTime.Now.Ticks);
-        double elapsed_milliseconds = end_time.TotalMilliseconds - start_time.TotalMilliseconds;
+
+
+        var end_time = DateTime.Now;
+
+        double elapsed_milliseconds = (end_time - start_time).TotalMilliseconds;
         double rtf = elapsed_milliseconds / total_duration.TotalMilliseconds;
+
         Console.WriteLine("elapsed_milliseconds:{0}", elapsed_milliseconds.ToString());
         Console.WriteLine("total_duration:{0}", total_duration.TotalMilliseconds.ToString());
-        Console.WriteLine("rtf:{1}", "0".ToString(), rtf.ToString());
+
+        // 瀹炴椂鍥犲瓙鏄鐞嗘椂闂翠笌闊抽鏃堕暱鐨勬瘮鍊笺��
+        // 渚嬪锛屽鏋滀竴涓� 10 绉掔殑闊抽鐗囨闇�瑕� 5 绉掓潵澶勭悊锛岄偅涔堝疄鏃跺洜瀛愬氨鏄� 0.5銆�
+        // 濡傛灉澶勭悊鏃堕棿鍜岄煶棰戞椂闀跨浉绛夛紝閭d箞瀹炴椂鍥犲瓙灏辨槸 1锛岃繖鎰忓懗鐫�绯荤粺浠ュ疄鏃堕�熷害杩涜澶勭悊銆� 
+        // 鏁板�艰秺灏忥紝琛ㄧず澶勭悊閫熷害瓒婂揩銆�
+        // from chatgpt 瑙i噴
+        Console.WriteLine("Real-Time Factor :{0}", rtf.ToString());
         Console.WriteLine("end!");
     }
-}
\ No newline at end of file
+
+    private static (float[] sample, TimeSpan duration) LoadWavFile(string wavFilePath)
+    {
+        AudioFileReader _audioFileReader = new AudioFileReader(wavFilePath);
+        byte[] datas = new byte[_audioFileReader.Length];
+        _audioFileReader.Read(datas, 0, datas.Length);
+        var duration = _audioFileReader.TotalTime;
+        float[] wavdata = new float[datas.Length / 4];
+        Buffer.BlockCopy(datas, 0, wavdata, 0, datas.Length);
+        var sample = wavdata.Select((float x) => x * 32768f).ToArray();
+
+        return (sample, duration);
+    }
+}
diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/AliParaformerAsr.csproj b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/AliParaformerAsr.csproj
index f451b13..b3fe558 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/AliParaformerAsr.csproj
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/AliParaformerAsr.csproj
@@ -9,7 +9,9 @@
   <ItemGroup>
     <PackageReference Include="KaldiNativeFbankSharp" Version="1.0.8" />
     <PackageReference Include="Microsoft.Extensions.Logging" Version="7.0.0" />
-    <PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.15.1" />
+    <PackageReference Include="Microsoft.ML.OnnxRuntime" Version="1.17.3" />
+    <PackageReference Include="Microsoft.ML.OnnxRuntime.Managed" Version="1.17.3" />
+    <PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
     <PackageReference Include="YamlDotNet" Version="13.1.1" />
   </ItemGroup>
 
diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
index f1f491a..c2d7f68 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -9,9 +9,20 @@
 using Microsoft.ML.OnnxRuntime.Tensors;
 using Microsoft.Extensions.Logging;
 using System.Text.RegularExpressions;
+using Newtonsoft.Json.Linq;
 
+// 妯″瀷鏂囦欢鍦板潃锛� https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
 namespace AliParaformerAsr
 {
+    public enum OnnxRumtimeTypes
+    {
+        CPU = 0,
+
+        DML = 1,
+
+        CUDA = 2,
+    }
+
     /// <summary>
     /// offline recognizer package
     /// Copyright (c)  2023 by manyeyes
@@ -24,35 +35,70 @@
         private string _frontend;
         private FrontendConfEntity _frontendConfEntity;
         private string[] _tokens;
-        private int _batchSize = 1;
 
-        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
+        /// <summary>
+        /// 
+        /// </summary>
+        /// <param name="modelFilePath"></param>
+        /// <param name="configFilePath"></param>
+        /// <param name="mvnFilePath"></param>
+        /// <param name="tokensFilePath"></param>
+        /// <param name="rumtimeType">鍙互閫夋嫨gpu锛屼絾鏄洰鍓嶆儏鍐典笅锛屼笉寤鸿浣跨敤锛屽洜涓烘�ц兘鎻愬崌鏈夐檺</param>
+        /// <param name="deviceId">璁惧id锛屽鏄惧崱鏃剁敤浜庢寚瀹氭墽琛岀殑鏄惧崱</param>
+        /// <param name="batchSize"></param>
+        /// <param name="threadsNum"></param>
+        /// <exception cref="ArgumentException"></exception>
+        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
         {
-            Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
+            var options = new SessionOptions();
+            switch(rumtimeType)
+            {
+                case OnnxRumtimeTypes.DML:
+                    options.AppendExecutionProvider_DML(deviceId);
+                    break;
+                case OnnxRumtimeTypes.CUDA:
+                    options.AppendExecutionProvider_CUDA(deviceId);
+                    break;
+                default:
+                    options.AppendExecutionProvider_CPU(deviceId);
+                    break;
+            }
             //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
-            //options.AppendExecutionProvider_DML(0);
-            options.AppendExecutionProvider_CPU(0);
-            //options.AppendExecutionProvider_CUDA(0);
-            options.InterOpNumThreads = threadsNum;
-            _onnxSession = new InferenceSession(modelFilePath, options);
 
-            _tokens = File.ReadAllLines(tokensFilePath);
+            _onnxSession = new InferenceSession(modelFilePath, options);
+            
+            string[] tokenLines;
+            if (tokensFilePath.EndsWith(".txt"))
+            {
+                tokenLines = File.ReadAllLines(tokensFilePath);
+            }
+            else if (tokensFilePath.EndsWith(".json"))
+            {
+                string jsonContent = File.ReadAllText(tokensFilePath);
+                JArray tokenArray = JArray.Parse(jsonContent);
+                tokenLines = tokenArray.Select(t => t.ToString()).ToArray();
+            }
+            else
+            {
+                throw new ArgumentException("Invalid tokens file format. Only .txt and .json are supported.");
+            }
+
+            _tokens = tokenLines;
 
             OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
             _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
             _frontend = offlineYamlEntity.frontend;
             _frontendConfEntity = offlineYamlEntity.frontend_conf;
-            _batchSize = batchSize;
             ILoggerFactory loggerFactory = new LoggerFactory();
             _logger = new Logger<OfflineRecognizer>(loggerFactory);
         }
 
         public List<string> GetResults(List<float[]> samples)
         {
-            this._logger.LogInformation("get features begin");
+            _logger.LogInformation("get features begin");
             List<OfflineInputEntity> offlineInputEntities = ExtractFeats(samples);
-            OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
-            List<string> text_results = this.DecodeMulti(modelOutput.Token_nums);
+            OfflineOutputEntity modelOutput = Forward(offlineInputEntities);
+            List<string> text_results = DecodeMulti(modelOutput.Token_nums);
             return text_results;
         }
 
@@ -156,15 +202,18 @@
                     {
                         break;
                     }
-                    if (_tokens[token] != "</s>" && _tokens[token] != "<s>" && _tokens[token] != "<blank>")
+
+                    string tokenChar = _tokens[token];
+
+                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>")
                     {                        
-                        if (IsChinese(_tokens[token],true))
+                        if (IsChinese(tokenChar, true))
                         {
-                            text_result += _tokens[token];
+                            text_result += tokenChar;
                         }
                         else
                         {
-                            text_result += "鈻�" + _tokens[token]+ "鈻�";
+                            text_result += "鈻�" + tokenChar + "鈻�";
                         }
                     }
                 }
diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/Utils/YamlHelper.cs b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/Utils/YamlHelper.cs
index 15c13e3..225d4c8 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/Utils/YamlHelper.cs
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/Utils/YamlHelper.cs
@@ -14,16 +14,19 @@
     /// YamlHelper
     /// Copyright (c)  2023 by manyeyes
     /// </summary>
-    internal class YamlHelper
+    internal class YamlHelper 
     {
-        public static T ReadYaml<T>(string yamlFilePath)
+        public static T ReadYaml<T>(string yamlFilePath) where T:new()
         {
             if (!File.Exists(yamlFilePath))
             {
-#pragma warning disable CS8603 // 鍙兘杩斿洖 null 寮曠敤銆�
-                return default(T);
-#pragma warning restore CS8603 // 鍙兘杩斿洖 null 寮曠敤銆�
+                // 濡傛灉鍏佽杩斿洖榛樿瀵硅薄锛屽垯鏂板缓涓�涓粯璁ゅ璞★紝鍚﹀垯搴旇鏄姏鍑哄紓甯�
+                // If allowing to return a default object, create a new default object; otherwise, throw an exception
+
+                return new T();
+                // throw new Exception($"not find yaml config file: {yamlFilePath}");
             }
+
             StreamReader yamlReader = File.OpenText(yamlFilePath);
             Deserializer yamlDeserializer = new Deserializer();
             T info = yamlDeserializer.Deserialize<T>(yamlReader);

--
Gitblit v1.9.1