From c2e4e3c2e9be855277d9f4fa9cd0544892ff829a Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 30 八月 2023 09:57:30 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/runtime/onnxruntime/src/util.cpp |   79 ++++++++++++++++++++++++++++++++-------
 1 files changed, 65 insertions(+), 14 deletions(-)

diff --git a/funasr/runtime/onnxruntime/src/util.cpp b/funasr/runtime/onnxruntime/src/util.cpp
index 5a72c72..70059ca 100644
--- a/funasr/runtime/onnxruntime/src/util.cpp
+++ b/funasr/runtime/onnxruntime/src/util.cpp
@@ -1,7 +1,8 @@
 
 #include "precomp.h"
 
-float *loadparams(const char *filename)
+namespace funasr {
+float *LoadParams(const char *filename)
 {
 
     FILE *fp;
@@ -10,20 +11,20 @@
     uint32_t nFileLen = ftell(fp);
     fseek(fp, 0, SEEK_SET);
 
-    float *params_addr = (float *)aligned_malloc(32, nFileLen);
+    float *params_addr = (float *)AlignedMalloc(32, nFileLen);
     int n = fread(params_addr, 1, nFileLen, fp);
     fclose(fp);
 
     return params_addr;
 }
 
-int val_align(int val, int align)
+int ValAlign(int val, int align)
 {
     float tmp = ceil((float)val / (float)align) * (float)align;
     return (int)tmp;
 }
 
-void disp_params(float *din, int size)
+void DispParams(float *din, int size)
 {
     int i;
     for (i = 0; i < size; i++) {
@@ -39,7 +40,7 @@
     fclose(fp);
 }
 
-void basic_norm(Tensor<float> *&din, float norm)
+void BasicNorm(Tensor<float> *&din, float norm)
 {
 
     int Tmax = din->size[2];
@@ -59,7 +60,7 @@
     }
 }
 
-void findmax(float *din, int len, float &max_val, int &max_idx)
+void FindMax(float *din, int len, float &max_val, int &max_idx)
 {
     int i;
     max_val = -INFINITY;
@@ -72,7 +73,7 @@
     }
 }
 
-string pathAppend(const string &p1, const string &p2)
+string PathAppend(const string &p1, const string &p2)
 {
 
     char sep = '/';
@@ -89,7 +90,7 @@
         return (p1 + p2);
 }
 
-void relu(Tensor<float> *din)
+void Relu(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -98,7 +99,7 @@
     }
 }
 
-void swish(Tensor<float> *din)
+void Swish(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -107,7 +108,7 @@
     }
 }
 
-void sigmoid(Tensor<float> *din)
+void Sigmoid(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -116,7 +117,7 @@
     }
 }
 
-void doubleswish(Tensor<float> *din)
+void DoubleSwish(Tensor<float> *din)
 {
     int i;
     for (i = 0; i < din->buff_size; i++) {
@@ -125,7 +126,7 @@
     }
 }
 
-void softmax(float *din, int mask, int len)
+void Softmax(float *din, int mask, int len)
 {
     float *tmp = (float *)malloc(mask * sizeof(float));
     int i;
@@ -149,7 +150,7 @@
     }
 }
 
-void log_softmax(float *din, int len)
+void LogSoftmax(float *din, int len)
 {
     float *tmp = (float *)malloc(len * sizeof(float));
     int i;
@@ -164,7 +165,7 @@
     free(tmp);
 }
 
-void glu(Tensor<float> *din, Tensor<float> *dout)
+void Glu(Tensor<float> *din, Tensor<float> *dout)
 {
     int mm = din->buff_size / 1024;
     int i, j;
@@ -178,3 +179,53 @@
         }
     }
 }
+
+bool is_target_file(const std::string& filename, const std::string target) {
+    std::size_t pos = filename.find_last_of(".");
+    if (pos == std::string::npos) {
+        return false;
+    }
+    std::string extension = filename.substr(pos + 1);
+    return (extension == target);
+}
+
+void KeepChineseCharacterAndSplit(const std::string &input_str,
+                                  std::vector<std::string> &chinese_characters) {
+  chinese_characters.resize(0);
+  std::vector<U16CHAR_T> u16_buf;
+  u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
+  U16CHAR_T* pu16 = u16_buf.data();
+  U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
+  size_t ilen = input_str.size();
+  size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
+  for (size_t i = 0; i < len; i++) {
+    if (EncodeConverter::IsChineseCharacter(pu16[i])) {
+      U8CHAR_T u8buf[4];
+      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
+      u8buf[n] = '\0';
+      chinese_characters.push_back((const char*)u8buf);
+    }
+  }
+}
+
+std::vector<std::string> split(const std::string &s, char delim) {
+  std::vector<std::string> elems;
+  std::stringstream ss(s);
+  std::string item;
+  while(std::getline(ss, item, delim)) {
+    elems.push_back(item);
+  }
+  return elems;
+}
+
+template<typename T>
+void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name) {
+  std::cout << name << ":" << std::endl;
+  for (auto item : mat) {
+    for (auto item_ : item) {
+      std::cout << item_ << " ";
+    }
+    std::cout << std::endl;
+  }
+}
+} // namespace funasr

--
Gitblit v1.9.1