zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -1,28 +1,17 @@
// See https://github.com/manyeyes for more information
// Copyright (c)  2023 by manyeyes
using System.Linq;
using System.Text;
using System.Threading.Tasks;
// Copyright (c)  2024 by manyeyes
using AliParaformerAsr.Model;
using AliParaformerAsr.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;
using Newtonsoft.Json.Linq;
using System.Text.RegularExpressions;
// 模型文件地址: https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
// 模型文件地址: https://www.modelscope.cn/models/manyeyes/sensevoice-small-onnx
namespace AliParaformerAsr
{
    public enum OnnxRumtimeTypes
    {
        CPU = 0,
        DML = 1,
        CUDA = 2,
    }
    /// <summary>
    /// offline recognizer package
    /// Copyright (c)  2023 by manyeyes
@@ -35,6 +24,8 @@
        private string _frontend;
        private FrontendConfEntity _frontendConfEntity;
        private string[] _tokens;
        private IOfflineProj? _offlineProj;
        private OfflineModel _offlineModel;
        /// <summary>
        /// 
@@ -48,24 +39,9 @@
        /// <param name="batchSize"></param>
        /// <param name="threadsNum"></param>
        /// <exception cref="ArgumentException"></exception>
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, int threadsNum = 1, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
        {
            var options = new SessionOptions();
            switch(rumtimeType)
            {
                case OnnxRumtimeTypes.DML:
                    options.AppendExecutionProvider_DML(deviceId);
                    break;
                case OnnxRumtimeTypes.CUDA:
                    options.AppendExecutionProvider_CUDA(deviceId);
                    break;
                default:
                    options.AppendExecutionProvider_CPU(deviceId);
                    break;
            }
            //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
            _onnxSession = new InferenceSession(modelFilePath, options);
            _offlineModel = new OfflineModel(modelFilePath, threadsNum);
            
            string[] tokenLines;
            if (tokensFilePath.EndsWith(".txt"))
@@ -86,6 +62,18 @@
            _tokens = tokenLines;
            OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
            switch (offlineYamlEntity.model.ToLower())
            {
                case "paraformer":
                    _offlineProj = new OfflineProjOfParaformer(_offlineModel);
                    break;
                case "sensevoicesmall":
                    _offlineProj = new OfflineProjOfSenseVoiceSmall(_offlineModel);
                    break;
                default:
                    _offlineProj = null;
                    break;
            }
            _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
            _frontend = offlineYamlEntity.frontend;
            _frontendConfEntity = offlineYamlEntity.frontend_conf;
@@ -120,73 +108,40 @@
        private OfflineOutputEntity Forward(List<OfflineInputEntity> modelInputs)
        {
            int BatchSize = modelInputs.Count;
            float[] padSequence = PadSequence(modelInputs);
            var inputMeta = _onnxSession.InputMetadata;
            var container = new List<NamedOnnxValue>();
            foreach (var name in inputMeta.Keys)
            {
                if (name == "speech")
                {
                    int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2]
                    var tensor = new DenseTensor<float>(padSequence, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
                }
                if (name == "speech_lengths")
                {
                    int[] dim = new int[] { BatchSize };
                    int[] speech_lengths = new int[BatchSize];
                    for (int i = 0; i < BatchSize; i++)
                    {
                        speech_lengths[i] = padSequence.Length / 560 / BatchSize;
                    }
                    var tensor = new DenseTensor<int>(speech_lengths, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<int>(name, tensor));
                }
            }
            IReadOnlyCollection<string> outputNames = new List<string>();
            outputNames.Append("logits");
            outputNames.Append("token_num");
            IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = null;
            OfflineOutputEntity offlineOutputEntity = new OfflineOutputEntity();
            try
            {
                results = _onnxSession.Run(container);
                ModelOutputEntity modelOutputEntity = _offlineProj.ModelProj(modelInputs);
                if (modelOutputEntity != null)
                {
                    offlineOutputEntity.Token_nums_length = modelOutputEntity.model_out_lens.AsEnumerable<int>().ToArray();
                    Tensor<float> logits_tensor = modelOutputEntity.model_out;
                    List<int[]> token_nums = new List<int[]> { };
                    for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
                    {
                        int[] item = new int[logits_tensor.Dimensions[1]];
                        for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
                        {
                            int token_num = 0;
                            for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
                            {
                                token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
                            }
                            item[j] = (int)token_num;
                        }
                        token_nums.Add(item);
                    }
                    offlineOutputEntity.Token_nums = token_nums;
                }
            }
            catch (Exception ex)
            {
                //
            }
            OfflineOutputEntity modelOutput = new OfflineOutputEntity();
            if (results != null)
            {
                var resultsArray = results.ToArray();
                modelOutput.Logits = resultsArray[0].AsEnumerable<float>().ToArray();
                modelOutput.Token_nums_length = resultsArray[1].AsEnumerable<int>().ToArray();
                Tensor<float> logits_tensor = resultsArray[0].AsTensor<float>();
                Tensor<Int64> token_nums_tensor = resultsArray[1].AsTensor<Int64>();
                List<int[]> token_nums = new List<int[]> { };
                for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
                {
                    int[] item = new int[logits_tensor.Dimensions[1]];
                    for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
                    {
                        int token_num = 0;
                        for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
                        {
                            token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
                        }
                        item[j] = (int)token_num;
                    }
                    token_nums.Add(item);
                }
                modelOutput.Token_nums = token_nums;
            }
            return modelOutput;
            return offlineOutputEntity;
        }
        private List<string> DecodeMulti(List<int[]> token_nums)
@@ -203,9 +158,9 @@
                        break;
                    }
                    string tokenChar = _tokens[token];
                    string tokenChar = _tokens[token].Split("\t")[0];
                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>")
                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>" && tokenChar != "<unk>")
                    {                        
                        if (IsChinese(tokenChar, true))
                        {
@@ -244,48 +199,6 @@
            else
                return false;
        }
        private float[] PadSequence(List<OfflineInputEntity> modelInputs)
        {
            int max_speech_length = modelInputs.Max(x => x.SpeechLength);
            int speech_length = max_speech_length * modelInputs.Count;
            float[] speech = new float[speech_length];
            float[,] xxx = new float[modelInputs.Count, max_speech_length];
            for (int i = 0; i < modelInputs.Count; i++)
            {
                if (max_speech_length == modelInputs[i].SpeechLength)
                {
                    for (int j = 0; j < xxx.GetLength(1); j++)
                    {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                        xxx[i, j] = modelInputs[i].Speech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                    }
                    continue;
                }
                float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
                float[]? curr_speech = modelInputs[i].Speech;
                float[] padspeech = new float[max_speech_length];
                padspeech = _wavFrontend.ApplyCmvn(padspeech);
                Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
                for (int j = 0; j < padspeech.Length; j++)
                {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                    xxx[i, j] = padspeech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                }
            }
            int s = 0;
            for (int i = 0; i < xxx.GetLength(0); i++)
            {
                for (int j = 0; j < xxx.GetLength(1); j++)
                {
                    speech[s] = xxx[i, j];
                    s++;
                }
            }
            return speech;
        }
    }
}