// See https://github.com/manyeyes for more information // Copyright (c) 2023 by manyeyes using System.Linq; using System.Text; using System.Threading.Tasks; using AliParaformerAsr.Model; using AliParaformerAsr.Utils; using Microsoft.ML.OnnxRuntime; using Microsoft.ML.OnnxRuntime.Tensors; using Microsoft.Extensions.Logging; using System.Text.RegularExpressions; using Newtonsoft.Json.Linq; // 模型文件地址: https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx namespace AliParaformerAsr { public enum OnnxRumtimeTypes { CPU = 0, DML = 1, CUDA = 2, } /// /// offline recognizer package /// Copyright (c) 2023 by manyeyes /// public class OfflineRecognizer { private InferenceSession _onnxSession; private readonly ILogger _logger; private WavFrontend _wavFrontend; private string _frontend; private FrontendConfEntity _frontendConfEntity; private string[] _tokens; /// /// /// /// /// /// /// /// 可以选择gpu,但是目前情况下,不建议使用,因为性能提升有限 /// 设备id,多显卡时用于指定执行的显卡 /// /// /// public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0) { var options = new SessionOptions(); switch(rumtimeType) { case OnnxRumtimeTypes.DML: options.AppendExecutionProvider_DML(deviceId); break; case OnnxRumtimeTypes.CUDA: options.AppendExecutionProvider_CUDA(deviceId); break; default: options.AppendExecutionProvider_CPU(deviceId); break; } //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO; _onnxSession = new InferenceSession(modelFilePath, options); string[] tokenLines; if (tokensFilePath.EndsWith(".txt")) { tokenLines = File.ReadAllLines(tokensFilePath); } else if (tokensFilePath.EndsWith(".json")) { string jsonContent = File.ReadAllText(tokensFilePath); JArray tokenArray = JArray.Parse(jsonContent); tokenLines = tokenArray.Select(t => t.ToString()).ToArray(); } else { throw new ArgumentException("Invalid tokens file format. Only .txt and .json are supported."); } _tokens = tokenLines; OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml(configFilePath); _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf); _frontend = offlineYamlEntity.frontend; _frontendConfEntity = offlineYamlEntity.frontend_conf; ILoggerFactory loggerFactory = new LoggerFactory(); _logger = new Logger(loggerFactory); } public List GetResults(List samples) { _logger.LogInformation("get features begin"); List offlineInputEntities = ExtractFeats(samples); OfflineOutputEntity modelOutput = Forward(offlineInputEntities); List text_results = DecodeMulti(modelOutput.Token_nums); return text_results; } private List ExtractFeats(List waveform_list) { List in_cache = new List(); List offlineInputEntities = new List(); foreach (var waveform in waveform_list) { float[] fbanks = _wavFrontend.GetFbank(waveform); float[] features = _wavFrontend.LfrCmvn(fbanks); OfflineInputEntity offlineInputEntity = new OfflineInputEntity(); offlineInputEntity.Speech = features; offlineInputEntity.SpeechLength = features.Length; offlineInputEntities.Add(offlineInputEntity); } return offlineInputEntities; } private OfflineOutputEntity Forward(List modelInputs) { int BatchSize = modelInputs.Count; float[] padSequence = PadSequence(modelInputs); var inputMeta = _onnxSession.InputMetadata; var container = new List(); foreach (var name in inputMeta.Keys) { if (name == "speech") { int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2] var tensor = new DenseTensor(padSequence, dim, false); container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); } if (name == "speech_lengths") { int[] dim = new int[] { BatchSize }; int[] speech_lengths = new int[BatchSize]; for (int i = 0; i < BatchSize; i++) { speech_lengths[i] = padSequence.Length / 560 / BatchSize; } var tensor = new DenseTensor(speech_lengths, dim, false); container.Add(NamedOnnxValue.CreateFromTensor(name, tensor)); } } IReadOnlyCollection outputNames = new List(); outputNames.Append("logits"); outputNames.Append("token_num"); IDisposableReadOnlyCollection results = null; try { results = _onnxSession.Run(container); } catch (Exception ex) { // } OfflineOutputEntity modelOutput = new OfflineOutputEntity(); if (results != null) { var resultsArray = results.ToArray(); modelOutput.Logits = resultsArray[0].AsEnumerable().ToArray(); modelOutput.Token_nums_length = resultsArray[1].AsEnumerable().ToArray(); Tensor logits_tensor = resultsArray[0].AsTensor(); Tensor token_nums_tensor = resultsArray[1].AsTensor(); List token_nums = new List { }; for (int i = 0; i < logits_tensor.Dimensions[0]; i++) { int[] item = new int[logits_tensor.Dimensions[1]]; for (int j = 0; j < logits_tensor.Dimensions[1]; j++) { int token_num = 0; for (int k = 1; k < logits_tensor.Dimensions[2]; k++) { token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k; } item[j] = (int)token_num; } token_nums.Add(item); } modelOutput.Token_nums = token_nums; } return modelOutput; } private List DecodeMulti(List token_nums) { List text_results = new List(); #pragma warning disable CS8602 // 解引用可能出现空引用。 foreach (int[] token_num in token_nums) { string text_result = ""; foreach (int token in token_num) { if (token == 2) { break; } string tokenChar = _tokens[token]; if (tokenChar != "" && tokenChar != "" && tokenChar != "") { if (IsChinese(tokenChar, true)) { text_result += tokenChar; } else { text_result += "▁" + tokenChar + "▁"; } } } text_results.Add(text_result.Replace("@@▁▁", "").Replace("▁▁", " ").Replace("▁", "")); } #pragma warning restore CS8602 // 解引用可能出现空引用。 return text_results; } /// /// Verify if the string is in Chinese. /// /// The string to be verified. /// Is it an exact match. When the value is true,all are in Chinese; /// When the value is false, only Chinese is included. /// /// private bool IsChinese(string checkedStr, bool allMatch) { string pattern; if (allMatch) pattern = @"^[\u4e00-\u9fa5]+$"; else pattern = @"[\u4e00-\u9fa5]"; if (Regex.IsMatch(checkedStr, pattern)) return true; else return false; } private float[] PadSequence(List modelInputs) { int max_speech_length = modelInputs.Max(x => x.SpeechLength); int speech_length = max_speech_length * modelInputs.Count; float[] speech = new float[speech_length]; float[,] xxx = new float[modelInputs.Count, max_speech_length]; for (int i = 0; i < modelInputs.Count; i++) { if (max_speech_length == modelInputs[i].SpeechLength) { for (int j = 0; j < xxx.GetLength(1); j++) { #pragma warning disable CS8602 // 解引用可能出现空引用。 xxx[i, j] = modelInputs[i].Speech[j]; #pragma warning restore CS8602 // 解引用可能出现空引用。 } continue; } float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength]; float[]? curr_speech = modelInputs[i].Speech; float[] padspeech = new float[max_speech_length]; padspeech = _wavFrontend.ApplyCmvn(padspeech); Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length); for (int j = 0; j < padspeech.Length; j++) { #pragma warning disable CS8602 // 解引用可能出现空引用。 xxx[i, j] = padspeech[j]; #pragma warning restore CS8602 // 解引用可能出现空引用。 } } int s = 0; for (int i = 0; i < xxx.GetLength(0); i++) { for (int j = 0; j < xxx.GetLength(1); j++) { speech[s] = xxx[i, j]; s++; } } return speech; } } }