From a5a06217c17dc69aca53d1e247b7a6671d35373c Mon Sep 17 00:00:00 2001
From: manyeyes <32889020+manyeyes@users.noreply.github.com>
Date: 星期二, 23 七月 2024 14:41:00 +0800
Subject: [PATCH] Add support for SenseVoiceSmall onnx model in c# lib (#1946)

---
 runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs |  187 ++++++++++++----------------------------------
 1 files changed, 50 insertions(+), 137 deletions(-)

diff --git a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
index c2d7f68..3011d13 100644
--- a/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
+++ b/runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -1,28 +1,17 @@
 锘�// See https://github.com/manyeyes for more information
-// Copyright (c)  2023 by manyeyes
-using System.Linq;
-using System.Text;
-using System.Threading.Tasks;
+// Copyright (c)  2024 by manyeyes
 using AliParaformerAsr.Model;
 using AliParaformerAsr.Utils;
+using Microsoft.Extensions.Logging;
 using Microsoft.ML.OnnxRuntime;
 using Microsoft.ML.OnnxRuntime.Tensors;
-using Microsoft.Extensions.Logging;
-using System.Text.RegularExpressions;
 using Newtonsoft.Json.Linq;
+using System.Text.RegularExpressions;
 
 // 妯″瀷鏂囦欢鍦板潃锛� https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
+// 妯″瀷鏂囦欢鍦板潃锛� https://www.modelscope.cn/models/manyeyes/sensevoice-small-onnx
 namespace AliParaformerAsr
 {
-    public enum OnnxRumtimeTypes
-    {
-        CPU = 0,
-
-        DML = 1,
-
-        CUDA = 2,
-    }
-
     /// <summary>
     /// offline recognizer package
     /// Copyright (c)  2023 by manyeyes
@@ -35,6 +24,8 @@
         private string _frontend;
         private FrontendConfEntity _frontendConfEntity;
         private string[] _tokens;
+        private IOfflineProj? _offlineProj;
+        private OfflineModel _offlineModel;
 
         /// <summary>
         /// 
@@ -48,24 +39,9 @@
         /// <param name="batchSize"></param>
         /// <param name="threadsNum"></param>
         /// <exception cref="ArgumentException"></exception>
-        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
+        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, int threadsNum = 1, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
         {
-            var options = new SessionOptions();
-            switch(rumtimeType)
-            {
-                case OnnxRumtimeTypes.DML:
-                    options.AppendExecutionProvider_DML(deviceId);
-                    break;
-                case OnnxRumtimeTypes.CUDA:
-                    options.AppendExecutionProvider_CUDA(deviceId);
-                    break;
-                default:
-                    options.AppendExecutionProvider_CPU(deviceId);
-                    break;
-            }
-            //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
-
-            _onnxSession = new InferenceSession(modelFilePath, options);
+            _offlineModel = new OfflineModel(modelFilePath, threadsNum);
             
             string[] tokenLines;
             if (tokensFilePath.EndsWith(".txt"))
@@ -86,6 +62,18 @@
             _tokens = tokenLines;
 
             OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
+            switch (offlineYamlEntity.model.ToLower())
+            {
+                case "paraformer":
+                    _offlineProj = new OfflineProjOfParaformer(_offlineModel);
+                    break;
+                case "sensevoicesmall":
+                    _offlineProj = new OfflineProjOfSenseVoiceSmall(_offlineModel);
+                    break;
+                default:
+                    _offlineProj = null;
+                    break;
+            }
             _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
             _frontend = offlineYamlEntity.frontend;
             _frontendConfEntity = offlineYamlEntity.frontend_conf;
@@ -120,73 +108,40 @@
 
         private OfflineOutputEntity Forward(List<OfflineInputEntity> modelInputs)
         {
-            int BatchSize = modelInputs.Count;
-            float[] padSequence = PadSequence(modelInputs);
-            var inputMeta = _onnxSession.InputMetadata;
-            var container = new List<NamedOnnxValue>(); 
-            foreach (var name in inputMeta.Keys)
-            {
-                if (name == "speech")
-                {
-                    int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2]
-                    var tensor = new DenseTensor<float>(padSequence, dim, false);
-                    container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
-                }
-                if (name == "speech_lengths")
-                {
-                    int[] dim = new int[] { BatchSize };
-                    int[] speech_lengths = new int[BatchSize];
-                    for (int i = 0; i < BatchSize; i++)
-                    {
-                        speech_lengths[i] = padSequence.Length / 560 / BatchSize;
-                    }
-                    var tensor = new DenseTensor<int>(speech_lengths, dim, false);
-                    container.Add(NamedOnnxValue.CreateFromTensor<int>(name, tensor));
-                }
-            }
-
-
-            IReadOnlyCollection<string> outputNames = new List<string>();
-            outputNames.Append("logits");
-            outputNames.Append("token_num");
-            IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = null;
+            OfflineOutputEntity offlineOutputEntity = new OfflineOutputEntity();            
             try
             {
-                results = _onnxSession.Run(container);
+                ModelOutputEntity modelOutputEntity = _offlineProj.ModelProj(modelInputs);
+                if (modelOutputEntity != null)
+                {
+                    offlineOutputEntity.Token_nums_length = modelOutputEntity.model_out_lens.AsEnumerable<int>().ToArray();
+                    Tensor<float> logits_tensor = modelOutputEntity.model_out;
+                    List<int[]> token_nums = new List<int[]> { };
+
+                    for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
+                    {
+                        int[] item = new int[logits_tensor.Dimensions[1]];
+                        for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
+                        {
+                            int token_num = 0;
+                            for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
+                            {
+                                token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
+                            }
+                            item[j] = (int)token_num;
+                        }
+                        token_nums.Add(item);
+                    }
+                    offlineOutputEntity.Token_nums = token_nums;
+                }
             }
             catch (Exception ex)
             {
                 //
             }
-            OfflineOutputEntity modelOutput = new OfflineOutputEntity();
-            if (results != null)
-            {
-                var resultsArray = results.ToArray();
-                modelOutput.Logits = resultsArray[0].AsEnumerable<float>().ToArray();
-                modelOutput.Token_nums_length = resultsArray[1].AsEnumerable<int>().ToArray();
-
-                Tensor<float> logits_tensor = resultsArray[0].AsTensor<float>();
-                Tensor<Int64> token_nums_tensor = resultsArray[1].AsTensor<Int64>();
-
-                List<int[]> token_nums = new List<int[]> { };
-
-                for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
-                {
-                    int[] item = new int[logits_tensor.Dimensions[1]];
-                    for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
-                    {                        
-                        int token_num = 0;
-                        for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
-                        {
-                            token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
-                        }
-                        item[j] = (int)token_num;                        
-                    }
-                    token_nums.Add(item);
-                }                
-                modelOutput.Token_nums = token_nums;
-            }
-            return modelOutput;
+            
+            
+            return offlineOutputEntity;
         }
 
         private List<string> DecodeMulti(List<int[]> token_nums)
@@ -203,9 +158,9 @@
                         break;
                     }
 
-                    string tokenChar = _tokens[token];
+                    string tokenChar = _tokens[token].Split("\t")[0];
 
-                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>")
+                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>" && tokenChar != "<unk>")
                     {                        
                         if (IsChinese(tokenChar, true))
                         {
@@ -244,48 +199,6 @@
             else
                 return false;
         }
-
-        private float[] PadSequence(List<OfflineInputEntity> modelInputs)
-        {
-            int max_speech_length = modelInputs.Max(x => x.SpeechLength);
-            int speech_length = max_speech_length * modelInputs.Count;
-            float[] speech = new float[speech_length];
-            float[,] xxx = new float[modelInputs.Count, max_speech_length];
-            for (int i = 0; i < modelInputs.Count; i++)
-            {
-                if (max_speech_length == modelInputs[i].SpeechLength)
-                {
-                    for (int j = 0; j < xxx.GetLength(1); j++)
-                    {
-#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
-                        xxx[i, j] = modelInputs[i].Speech[j];
-#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
-                    }
-                    continue;
-                }
-                float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
-                float[]? curr_speech = modelInputs[i].Speech;
-                float[] padspeech = new float[max_speech_length];
-                padspeech = _wavFrontend.ApplyCmvn(padspeech);
-                Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
-                for (int j = 0; j < padspeech.Length; j++)
-                {
-#pragma warning disable CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
-                    xxx[i, j] = padspeech[j];
-#pragma warning restore CS8602 // 瑙e紩鐢ㄥ彲鑳藉嚭鐜扮┖寮曠敤銆�
-                }
-
-            }
-            int s = 0;
-            for (int i = 0; i < xxx.GetLength(0); i++)
-            {
-                for (int j = 0; j < xxx.GetLength(1); j++)
-                {
-                    speech[s] = xxx[i, j];
-                    s++;
-                }
-            }
-            return speech;
-        }
+        
     }
 }
\ No newline at end of file

--
Gitblit v1.9.1