// See https://github.com/manyeyes for more information
// Copyright (c) 2023 by manyeyes
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AliParaformerAsr.Model;
using AliParaformerAsr.Utils;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;
namespace AliParaformerAsr
{
///
/// offline recognizer package
/// Copyright (c) 2023 by manyeyes
///
public class OfflineRecognizer
{
private InferenceSession _onnxSession;
private readonly ILogger _logger;
private WavFrontend _wavFrontend;
private string _frontend;
private FrontendConfEntity _frontendConfEntity;
private string[] _tokens;
private int _batchSize = 1;
public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
{
Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
//options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
//options.AppendExecutionProvider_DML(0);
options.AppendExecutionProvider_CPU(0);
//options.AppendExecutionProvider_CUDA(0);
options.InterOpNumThreads = threadsNum;
_onnxSession = new InferenceSession(modelFilePath, options);
_tokens = File.ReadAllLines(tokensFilePath);
OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml(configFilePath);
_wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
_frontend = offlineYamlEntity.frontend;
_frontendConfEntity = offlineYamlEntity.frontend_conf;
_batchSize = batchSize;
ILoggerFactory loggerFactory = new LoggerFactory();
_logger = new Logger(loggerFactory);
}
public List GetResults(List samples)
{
this._logger.LogInformation("get features begin");
List offlineInputEntities = ExtractFeats(samples);
OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
List text_results = this.DecodeMulti(modelOutput.Token_nums);
return text_results;
}
private List ExtractFeats(List waveform_list)
{
List in_cache = new List();
List offlineInputEntities = new List();
foreach (var waveform in waveform_list)
{
float[] fbanks = _wavFrontend.GetFbank(waveform);
float[] features = _wavFrontend.LfrCmvn(fbanks);
OfflineInputEntity offlineInputEntity = new OfflineInputEntity();
offlineInputEntity.Speech = features;
offlineInputEntity.SpeechLength = features.Length;
offlineInputEntities.Add(offlineInputEntity);
}
return offlineInputEntities;
}
private OfflineOutputEntity Forward(List modelInputs)
{
int BatchSize = modelInputs.Count;
float[] padSequence = PadSequence(modelInputs);
var inputMeta = _onnxSession.InputMetadata;
var container = new List();
foreach (var name in inputMeta.Keys)
{
if (name == "speech")
{
int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2]
var tensor = new DenseTensor(padSequence, dim, false);
container.Add(NamedOnnxValue.CreateFromTensor(name, tensor));
}
if (name == "speech_lengths")
{
int[] dim = new int[] { BatchSize };
int[] speech_lengths = new int[BatchSize];
for (int i = 0; i < BatchSize; i++)
{
speech_lengths[i] = padSequence.Length / 560 / BatchSize;
}
var tensor = new DenseTensor(speech_lengths, dim, false);
container.Add(NamedOnnxValue.CreateFromTensor(name, tensor));
}
}
IReadOnlyCollection outputNames = new List();
outputNames.Append("logits");
outputNames.Append("token_num");
IDisposableReadOnlyCollection results = null;
try
{
results = _onnxSession.Run(container);
}
catch (Exception ex)
{
//
}
OfflineOutputEntity modelOutput = new OfflineOutputEntity();
if (results != null)
{
var resultsArray = results.ToArray();
modelOutput.Logits = resultsArray[0].AsEnumerable().ToArray();
modelOutput.Token_nums_length = resultsArray[1].AsEnumerable().ToArray();
Tensor logits_tensor = resultsArray[0].AsTensor();
Tensor token_nums_tensor = resultsArray[1].AsTensor();
List token_nums = new List { };
for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
{
int[] item = new int[logits_tensor.Dimensions[1]];
for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
{
int token_num = 0;
for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
{
token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
}
item[j] = (int)token_num;
}
token_nums.Add(item);
}
modelOutput.Token_nums = token_nums;
}
return modelOutput;
}
private List DecodeMulti(List token_nums)
{
List text_results = new List();
#pragma warning disable CS8602 // 解引用可能出现空引用。
foreach (int[] token_num in token_nums)
{
string text_result = "";
foreach (int token in token_num)
{
if (token == 2)
{
break;
}
if (_tokens[token] != "" && _tokens[token] != "" && _tokens[token] != "")
{
if (IsChinese(_tokens[token],true))
{
text_result += _tokens[token];
}
else
{
text_result += "▁" + _tokens[token]+ "▁";
}
}
}
text_results.Add(text_result.Replace("@@▁▁", "").Replace("▁▁", " ").Replace("▁", ""));
}
#pragma warning restore CS8602 // 解引用可能出现空引用。
return text_results;
}
///
/// Verify if the string is in Chinese.
///
/// The string to be verified.
/// Is it an exact match. When the value is true,all are in Chinese;
/// When the value is false, only Chinese is included.
///
///
private bool IsChinese(string checkedStr, bool allMatch)
{
string pattern;
if (allMatch)
pattern = @"^[\u4e00-\u9fa5]+$";
else
pattern = @"[\u4e00-\u9fa5]";
if (Regex.IsMatch(checkedStr, pattern))
return true;
else
return false;
}
private float[] PadSequence(List modelInputs)
{
int max_speech_length = modelInputs.Max(x => x.SpeechLength);
int speech_length = max_speech_length * modelInputs.Count;
float[] speech = new float[speech_length];
float[,] xxx = new float[modelInputs.Count, max_speech_length];
for (int i = 0; i < modelInputs.Count; i++)
{
if (max_speech_length == modelInputs[i].SpeechLength)
{
for (int j = 0; j < xxx.GetLength(1); j++)
{
#pragma warning disable CS8602 // 解引用可能出现空引用。
xxx[i, j] = modelInputs[i].Speech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
}
continue;
}
float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
float[]? curr_speech = modelInputs[i].Speech;
float[] padspeech = new float[max_speech_length];
padspeech = _wavFrontend.ApplyCmvn(padspeech);
Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
for (int j = 0; j < padspeech.Length; j++)
{
#pragma warning disable CS8602 // 解引用可能出现空引用。
xxx[i, j] = padspeech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
}
}
int s = 0;
for (int i = 0; i < xxx.GetLength(0); i++)
{
for (int j = 0; j < xxx.GetLength(1); j++)
{
speech[s] = xxx[i, j];
s++;
}
}
return speech;
}
}
}