游雁
2024-02-19 94de39dde2e616a01683c518023d0fab72b4e103
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
// See https://github.com/manyeyes for more information
// Copyright (c)  2023 by manyeyes
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using AliParaformerAsr.Model;
using AliParaformerAsr.Utils;
using Microsoft.ML.OnnxRuntime;
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;
 
namespace AliParaformerAsr
{
    /// <summary>
    /// offline recognizer package
    /// Copyright (c)  2023 by manyeyes
    /// </summary>
    public class OfflineRecognizer
    {
        private InferenceSession _onnxSession;
        private readonly ILogger<OfflineRecognizer> _logger;
        private WavFrontend _wavFrontend;
        private string _frontend;
        private FrontendConfEntity _frontendConfEntity;
        private string[] _tokens;
        private int _batchSize = 1;
 
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
        {
            Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
            //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
            //options.AppendExecutionProvider_DML(0);
            options.AppendExecutionProvider_CPU(0);
            //options.AppendExecutionProvider_CUDA(0);
            options.InterOpNumThreads = threadsNum;
            _onnxSession = new InferenceSession(modelFilePath, options);
 
            _tokens = File.ReadAllLines(tokensFilePath);
 
            OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
            _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
            _frontend = offlineYamlEntity.frontend;
            _frontendConfEntity = offlineYamlEntity.frontend_conf;
            _batchSize = batchSize;
            ILoggerFactory loggerFactory = new LoggerFactory();
            _logger = new Logger<OfflineRecognizer>(loggerFactory);
        }
 
        public List<string> GetResults(List<float[]> samples)
        {
            this._logger.LogInformation("get features begin");
            List<OfflineInputEntity> offlineInputEntities = ExtractFeats(samples);
            OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
            List<string> text_results = this.DecodeMulti(modelOutput.Token_nums);
            return text_results;
        }
 
        private List<OfflineInputEntity> ExtractFeats(List<float[]> waveform_list)
        {
            List<float[]> in_cache = new List<float[]>();
            List<OfflineInputEntity> offlineInputEntities = new List<OfflineInputEntity>();
            foreach (var waveform in waveform_list)
            {
                float[] fbanks = _wavFrontend.GetFbank(waveform);
                float[] features = _wavFrontend.LfrCmvn(fbanks);
                OfflineInputEntity offlineInputEntity = new OfflineInputEntity();
                offlineInputEntity.Speech = features;
                offlineInputEntity.SpeechLength = features.Length;
                offlineInputEntities.Add(offlineInputEntity);
            }
            return offlineInputEntities;
        }
 
        private OfflineOutputEntity Forward(List<OfflineInputEntity> modelInputs)
        {
            int BatchSize = modelInputs.Count;
            float[] padSequence = PadSequence(modelInputs);
            var inputMeta = _onnxSession.InputMetadata;
            var container = new List<NamedOnnxValue>(); 
            foreach (var name in inputMeta.Keys)
            {
                if (name == "speech")
                {
                    int[] dim = new int[] { BatchSize, padSequence.Length / 560 / BatchSize, 560 };//inputMeta["speech"].Dimensions[2]
                    var tensor = new DenseTensor<float>(padSequence, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<float>(name, tensor));
                }
                if (name == "speech_lengths")
                {
                    int[] dim = new int[] { BatchSize };
                    int[] speech_lengths = new int[BatchSize];
                    for (int i = 0; i < BatchSize; i++)
                    {
                        speech_lengths[i] = padSequence.Length / 560 / BatchSize;
                    }
                    var tensor = new DenseTensor<int>(speech_lengths, dim, false);
                    container.Add(NamedOnnxValue.CreateFromTensor<int>(name, tensor));
                }
            }
 
 
            IReadOnlyCollection<string> outputNames = new List<string>();
            outputNames.Append("logits");
            outputNames.Append("token_num");
            IDisposableReadOnlyCollection<DisposableNamedOnnxValue> results = null;
            try
            {
                results = _onnxSession.Run(container);
            }
            catch (Exception ex)
            {
                //
            }
            OfflineOutputEntity modelOutput = new OfflineOutputEntity();
            if (results != null)
            {
                var resultsArray = results.ToArray();
                modelOutput.Logits = resultsArray[0].AsEnumerable<float>().ToArray();
                modelOutput.Token_nums_length = resultsArray[1].AsEnumerable<int>().ToArray();
 
                Tensor<float> logits_tensor = resultsArray[0].AsTensor<float>();
                Tensor<Int64> token_nums_tensor = resultsArray[1].AsTensor<Int64>();
 
                List<int[]> token_nums = new List<int[]> { };
 
                for (int i = 0; i < logits_tensor.Dimensions[0]; i++)
                {
                    int[] item = new int[logits_tensor.Dimensions[1]];
                    for (int j = 0; j < logits_tensor.Dimensions[1]; j++)
                    {                        
                        int token_num = 0;
                        for (int k = 1; k < logits_tensor.Dimensions[2]; k++)
                        {
                            token_num = logits_tensor[i, j, token_num] > logits_tensor[i, j, k] ? token_num : k;
                        }
                        item[j] = (int)token_num;                        
                    }
                    token_nums.Add(item);
                }                
                modelOutput.Token_nums = token_nums;
            }
            return modelOutput;
        }
 
        private List<string> DecodeMulti(List<int[]> token_nums)
        {
            List<string> text_results = new List<string>();
#pragma warning disable CS8602 // 解引用可能出现空引用。
            foreach (int[] token_num in token_nums)
            {
                string text_result = "";
                foreach (int token in token_num)
                {
                    if (token == 2)
                    {
                        break;
                    }
                    if (_tokens[token] != "</s>" && _tokens[token] != "<s>" && _tokens[token] != "<blank>")
                    {                        
                        if (IsChinese(_tokens[token],true))
                        {
                            text_result += _tokens[token];
                        }
                        else
                        {
                            text_result += "▁" + _tokens[token]+ "▁";
                        }
                    }
                }
                text_results.Add(text_result.Replace("@@▁▁", "").Replace("▁▁", " ").Replace("▁", ""));
            }
#pragma warning restore CS8602 // 解引用可能出现空引用。
 
            return text_results;
        }
 
        /// <summary>
        /// Verify if the string is in Chinese.
        /// </summary>
        /// <param name="checkedStr">The string to be verified.</param>
        /// <param name="allMatch">Is it an exact match. When the value is true,all are in Chinese; 
        /// When the value is false, only Chinese is included.
        /// </param>
        /// <returns></returns>
        private bool IsChinese(string checkedStr, bool allMatch)
        {
            string pattern;
            if (allMatch)
                pattern = @"^[\u4e00-\u9fa5]+$";
            else
                pattern = @"[\u4e00-\u9fa5]";
            if (Regex.IsMatch(checkedStr, pattern))
                return true;
            else
                return false;
        }
 
        private float[] PadSequence(List<OfflineInputEntity> modelInputs)
        {
            int max_speech_length = modelInputs.Max(x => x.SpeechLength);
            int speech_length = max_speech_length * modelInputs.Count;
            float[] speech = new float[speech_length];
            float[,] xxx = new float[modelInputs.Count, max_speech_length];
            for (int i = 0; i < modelInputs.Count; i++)
            {
                if (max_speech_length == modelInputs[i].SpeechLength)
                {
                    for (int j = 0; j < xxx.GetLength(1); j++)
                    {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                        xxx[i, j] = modelInputs[i].Speech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                    }
                    continue;
                }
                float[] nullspeech = new float[max_speech_length - modelInputs[i].SpeechLength];
                float[]? curr_speech = modelInputs[i].Speech;
                float[] padspeech = new float[max_speech_length];
                padspeech = _wavFrontend.ApplyCmvn(padspeech);
                Array.Copy(curr_speech, 0, padspeech, 0, curr_speech.Length);
                for (int j = 0; j < padspeech.Length; j++)
                {
#pragma warning disable CS8602 // 解引用可能出现空引用。
                    xxx[i, j] = padspeech[j];
#pragma warning restore CS8602 // 解引用可能出现空引用。
                }
 
            }
            int s = 0;
            for (int i = 0; i < xxx.GetLength(0); i++)
            {
                for (int j = 0; j < xxx.GetLength(1); j++)
                {
                    speech[s] = xxx[i, j];
                    s++;
                }
            }
            return speech;
        }
    }
}