majic31
2024-12-24 23e7ddebccd3b05cf7ef89809bcfe565ad6dfa1f
runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -1,15 +1,15 @@
// See https://github.com/manyeyes for more information
// Copyright (c)  2023 by manyeyes
using System.Linq;
using System.Text;
using System.Threading.Tasks;
// Copyright (c)  2024 by manyeyes
using AliParaformerAsr.Model;
using AliParaformerAsr.Utils;
using Microsoft.Extensions.Logging;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Extensions.Logging;
using Newtonsoft.Json.Linq;
using System.Text.RegularExpressions;
// 模型文件地址: https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
// 模型文件地址: https://www.modelscope.cn/models/manyeyes/sensevoice-small-onnx
namespace AliParaformerAsr
{
    /// <summary>
@@ -24,35 +24,69 @@
        private string _frontend;
        private FrontendConfEntity _frontendConfEntity;
        private string[] _tokens;
        private int _batchSize = 1;
        private IOfflineProj? _offlineProj;
        private OfflineModel _offlineModel;
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
        /// <summary>
        ///
        /// </summary>
        /// <param name="modelFilePath"></param>
        /// <param name="configFilePath"></param>
        /// <param name="mvnFilePath"></param>
        /// <param name="tokensFilePath"></param>
        /// <param name="rumtimeType">可以选择gpu,但是目前情况下,不建议使用,因为性能提升有限</param>
        /// <param name="deviceId">设备id,多显卡时用于指定执行的显卡</param>
        /// <param name="batchSize"></param>
        /// <param name="threadsNum"></param>
        /// <exception cref="ArgumentException"></exception>
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, int threadsNum = 1, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
        {
            Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
            //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
            //options.AppendExecutionProvider_DML(0);
            options.AppendExecutionProvider_CPU(0);
            //options.AppendExecutionProvider_CUDA(0);
            options.InterOpNumThreads = threadsNum;
            _onnxSession = new InferenceSession(modelFilePath, options);
            _offlineModel = new OfflineModel(modelFilePath, threadsNum);
            string[] tokenLines;
            if (tokensFilePath.EndsWith(".txt"))
            {
                tokenLines = File.ReadAllLines(tokensFilePath);
            }
            else if (tokensFilePath.EndsWith(".json"))
            {
                string jsonContent = File.ReadAllText(tokensFilePath);
                JArray tokenArray = JArray.Parse(jsonContent);
                tokenLines = tokenArray.Select(t => t.ToString()).ToArray();
            }
            else
            {
                throw new ArgumentException("Invalid tokens file format. Only .txt and .json are supported.");
            }
            _tokens = File.ReadAllLines(tokensFilePath);
            _tokens = tokenLines;
            OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
            switch (offlineYamlEntity.model.ToLower())
            {
                case "paraformer":
                    _offlineProj = new OfflineProjOfParaformer(_offlineModel);
                    break;
                case "sensevoicesmall":
                    _offlineProj = new OfflineProjOfSenseVoiceSmall(_offlineModel);
                    break;
                default:
                    _offlineProj = null;
                    break;
            }
            _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
            _frontend = offlineYamlEntity.frontend;
            _frontendConfEntity = offlineYamlEntity.frontend_conf;
            _batchSize = batchSize;
            ILoggerFactory loggerFactory = new LoggerFactory();
            _logger = new Logger<OfflineRecognizer>(loggerFactory);
        }
        public List<string> GetResults(List<float[]> samples)
        {
            this._logger.LogInformation("get features begin");
            _logger.LogInformation("get features begin");
            List<OfflineInputEntity> offlineInputEntities = ExtractFeats(samples);
            OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
            List<string> text_results = this.DecodeMulti(modelOutput.Token_nums);
            OfflineOutputEntity modelOutput = Forward(offlineInputEntities);
            List<string> text_results = DecodeMulti(modelOutput.Token_nums);
            return text_results;
        }
@@ -74,73 +108,40 @@
        private OfflineOutputEntity Forward(List<OfflineInputEntity> modelInputs)
        {
            int BatchSize = modelInputs.Count;
            float[] padSequence = PadSequence(modelInputs);
            var inputMeta = _onnxSession.InputMetadata;
            var container = new List<NamedOnnxValue>();
            foreach (var name in inputMeta.Keys)
            {
                if (name == "speech")
                {
                    int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2]
                    var tensor = new DenseTensor<float>(padSequence, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
                }
                if (name == "speech_lengths")
                {
                    int[] dim = new int[] { BatchSize };
                    int[] speech_lengths = new int[BatchSize];
                    for (int i = 0; i < BatchSize; i++)
                    {
                        speech_lengths[i] = padSequence.Length / 560 / BatchSize;
                    }
                    var tensor = new DenseTensor<int>(speech_lengths, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<int>(name, tensor));
                }
            }
            IReadOnlyCollection<string> outputNames = new List<string>();
            outputNames.Append("logits");
            outputNames.Append("token_num");
            IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = null;
            OfflineOutputEntity offlineOutputEntity = new OfflineOutputEntity();
            try
            {
                results = _onnxSession.Run(container);
                ModelOutputEntity modelOutputEntity = _offlineProj.ModelProj(modelInputs);
                if (modelOutputEntity != null)
                {
                    offlineOutputEntity.Token_nums_length = modelOutputEntity.model_out_lens.AsEnumerable<int>().ToArray();
                    Tensor<float> logits_tensor = modelOutputEntity.model_out;
                    List<int[]> token_nums = new List<int[]> { };
                    for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
                    {
                        int[] item = new int[logits_tensor.Dimensions[1]];
                        for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
                        {
                            int token_num = 0;
                            for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
                            {
                                token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
                            }
                            item[j] = (int)token_num;
                        }
                        token_nums.Add(item);
                    }
                    offlineOutputEntity.Token_nums = token_nums;
                }
            }
            catch (Exception ex)
            {
                //
            }
            OfflineOutputEntity modelOutput = new OfflineOutputEntity();
            if (results != null)
            {
                var resultsArray = results.ToArray();
                modelOutput.Logits = resultsArray[0].AsEnumerable<float>().ToArray();
                modelOutput.Token_nums_length = resultsArray[1].AsEnumerable<int>().ToArray();
                Tensor<float> logits_tensor = resultsArray[0].AsTensor<float>();
                Tensor<Int64> token_nums_tensor = resultsArray[1].AsTensor<Int64>();
                List<int[]> token_nums = new List<int[]> { };
                for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
                {
                    int[] item = new int[logits_tensor.Dimensions[1]];
                    for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
                    {
                        int token_num = 0;
                        for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
                        {
                            token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
                        }
                        item[j] = (int)token_num;
                    }
                    token_nums.Add(item);
                }
                modelOutput.Token_nums = token_nums;
            }
            return modelOutput;
            return offlineOutputEntity;
        }
        private List<string> DecodeMulti(List<int[]> token_nums)
@@ -156,15 +157,18 @@
                    {
                        break;
                    }
                    if (_tokens[token] != "</s>" && _tokens[token] != "<s>" && _tokens[token] != "<blank>")
                    string tokenChar = _tokens[token].Split("\t")[0];
                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>" && tokenChar != "<unk>")
                    {                        
                        if (IsChinese(_tokens[token],true))
                        if (IsChinese(tokenChar, true))
                        {
                            text_result += _tokens[token];
                            text_result += tokenChar;
                        }
                        else
                        {
                            text_result += "▁" + _tokens[token]+ "▁";
                            text_result += "▁" + tokenChar + "▁";
                        }
                    }
                }
@@ -195,48 +199,6 @@
            else
                return false;
        }
        private float[] PadSequence(List<OfflineInputEntity> modelInputs)
        {
            int max_speech_length = modelInputs.Max(x => x.SpeechLength);
            int speech_length = max_speech_length * modelInputs.Count;
            float[] speech = new float[speech_length];
            float[,] xxx = new float[modelInputs.Count, max_speech_length];
            for (int i = 0; i < modelInputs.Count; i++)
            {
                if (max_speech_length == modelInputs[i].SpeechLength)
                {
                    for (int j = 0; j < xxx.GetLength(1); j++)
                    {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                        xxx[i, j] = modelInputs[i].Speech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                    }
                    continue;
                }
                float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
                float[]? curr_speech = modelInputs[i].Speech;
                float[] padspeech = new float[max_speech_length];
                padspeech = _wavFrontend.ApplyCmvn(padspeech);
                Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
                for (int j = 0; j < padspeech.Length; j++)
                {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                    xxx[i, j] = padspeech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                }
            }
            int s = 0;
            for (int i = 0; i < xxx.GetLength(0); i++)
            {
                for (int j = 0; j < xxx.GetLength(1); j++)
                {
                    speech[s] = xxx[i, j];
                    s++;
                }
            }
            return speech;
        }
    }
}