Dogvane Huang
2024-07-02 dfcc5d47587d3e793cbfec2e9509c0e9a9e1732c
runtime/csharp/AliParaformerAsr/AliParaformerAsr/OfflineRecognizer.cs
@@ -9,9 +9,20 @@
using Microsoft.ML.OnnxRuntime.Tensors;
using Microsoft.Extensions.Logging;
using System.Text.RegularExpressions;
using Newtonsoft.Json.Linq;
// 模型文件地址: https://modelscope.cn/models/iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
namespace AliParaformerAsr
{
    public enum OnnxRumtimeTypes
    {
        CPU = 0,
        DML = 1,
        CUDA = 2,
    }
    /// <summary>
    /// offline recognizer package
    /// Copyright (c)  2023 by manyeyes
@@ -24,35 +35,70 @@
        private string _frontend;
        private FrontendConfEntity _frontendConfEntity;
        private string[] _tokens;
        private int _batchSize = 1;
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath,string tokensFilePath, int batchSize = 1,int threadsNum=1)
        /// <summary>
        ///
        /// </summary>
        /// <param name="modelFilePath"></param>
        /// <param name="configFilePath"></param>
        /// <param name="mvnFilePath"></param>
        /// <param name="tokensFilePath"></param>
        /// <param name="rumtimeType">可以选择gpu,但是目前情况下,不建议使用,因为性能提升有限</param>
        /// <param name="deviceId">设备id,多显卡时用于指定执行的显卡</param>
        /// <param name="batchSize"></param>
        /// <param name="threadsNum"></param>
        /// <exception cref="ArgumentException"></exception>
        public OfflineRecognizer(string modelFilePath, string configFilePath, string mvnFilePath, string tokensFilePath, OnnxRumtimeTypes rumtimeType = OnnxRumtimeTypes.CPU, int deviceId = 0)
        {
            Microsoft.ML.OnnxRuntime.SessionOptions options = new Microsoft.ML.OnnxRuntime.SessionOptions();
            var options = new SessionOptions();
            switch(rumtimeType)
            {
                case OnnxRumtimeTypes.DML:
                    options.AppendExecutionProvider_DML(deviceId);
                    break;
                case OnnxRumtimeTypes.CUDA:
                    options.AppendExecutionProvider_CUDA(deviceId);
                    break;
                default:
                    options.AppendExecutionProvider_CPU(deviceId);
                    break;
            }
            //options.LogSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_INFO;
            //options.AppendExecutionProvider_DML(0);
            options.AppendExecutionProvider_CPU(0);
            //options.AppendExecutionProvider_CUDA(0);
            options.InterOpNumThreads = threadsNum;
            _onnxSession = new InferenceSession(modelFilePath, options);
            _tokens = File.ReadAllLines(tokensFilePath);
            _onnxSession = new InferenceSession(modelFilePath, options);
            string[] tokenLines;
            if (tokensFilePath.EndsWith(".txt"))
            {
                tokenLines = File.ReadAllLines(tokensFilePath);
            }
            else if (tokensFilePath.EndsWith(".json"))
            {
                string jsonContent = File.ReadAllText(tokensFilePath);
                JArray tokenArray = JArray.Parse(jsonContent);
                tokenLines = tokenArray.Select(t => t.ToString()).ToArray();
            }
            else
            {
                throw new ArgumentException("Invalid tokens file format. Only .txt and .json are supported.");
            }
            _tokens = tokenLines;
            OfflineYamlEntity offlineYamlEntity = YamlHelper.ReadYaml<OfflineYamlEntity>(configFilePath);
            _wavFrontend = new WavFrontend(mvnFilePath, offlineYamlEntity.frontend_conf);
            _frontend = offlineYamlEntity.frontend;
            _frontendConfEntity = offlineYamlEntity.frontend_conf;
            _batchSize = batchSize;
            ILoggerFactory loggerFactory = new LoggerFactory();
            _logger = new Logger<OfflineRecognizer>(loggerFactory);
        }
        public List<string> GetResults(List<float[]> samples)
        {
            this._logger.LogInformation("get features begin");
            _logger.LogInformation("get features begin");
            List<OfflineInputEntity> offlineInputEntities = ExtractFeats(samples);
            OfflineOutputEntity modelOutput = this.Forward(offlineInputEntities);
            List<string> text_results = this.DecodeMulti(modelOutput.Token_nums);
            OfflineOutputEntity modelOutput = Forward(offlineInputEntities);
            List<string> text_results = DecodeMulti(modelOutput.Token_nums);
            return text_results;
        }
@@ -156,15 +202,18 @@
                    {
                        break;
                    }
                    if (_tokens[token] != "</s>" && _tokens[token] != "<s>" && _tokens[token] != "<blank>")
                    string tokenChar = _tokens[token];
                    if (tokenChar != "</s>" && tokenChar != "<s>" && tokenChar != "<blank>")
                    {                        
                        if (IsChinese(_tokens[token],true))
                        if (IsChinese(tokenChar, true))
                        {
                            text_result += _tokens[token];
                            text_result += tokenChar;
                        }
                        else
                        {
                            text_result += "▁" + _tokens[token]+ "▁";
                            text_result += "▁" + tokenChar + "▁";
                        }
                    }
                }