From a456ab57a8af26457b3845654863d3601ab199ab Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 01 七月 2024 11:15:08 +0800
Subject: [PATCH] update
---
/dev/null | 1 -
1 files changed, 0 insertions(+), 1 deletions(-)
diff --git a/funasr/models/sense_voice/__init__.py b/funasr/models/sense_voice/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/sense_voice/__init__.py
+++ /dev/null
diff --git a/funasr/models/sense_voice/cuda/wkv5_cuda.cu b/funasr/models/sense_voice/cuda/wkv5_cuda.cu
deleted file mode 100644
index 3e6b859..0000000
--- a/funasr/models/sense_voice/cuda/wkv5_cuda.cu
+++ /dev/null
@@ -1,202 +0,0 @@
-#include <stdio.h>
-#include <assert.h>
-#include "ATen/ATen.h"
-typedef at::BFloat16 bf16;
-
-template <typename F>
-__global__ void kernel_forward(const int B, const int T, const int C, const int H,
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
- F *__restrict__ const _y)
-{
- const int b = blockIdx.x / H;
- const int h = blockIdx.x % H;
- const int i = threadIdx.x;
- _w += h*_N_;
- _u += h*_N_;
-
- __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
- float state[_N_] = {0};
-
- __syncthreads();
- w[i] = _w[i];
- u[i] = float(_u[i]);
- __syncthreads();
-
- for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
- {
- __syncthreads();
- r[i] = float(_r[t]);
- k[i] = float(_k[t]);
- __syncthreads();
-
- const float v = float(_v[t]);
- float y = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j+=4)
- {
- const float4& r_ = (float4&)(r[j]);
- const float4& k_ = (float4&)(k[j]);
- const float4& w_ = (float4&)(w[j]);
- const float4& u_ = (float4&)(u[j]);
- float4& s = (float4&)(state[j]);
- float4 x;
-
- x.x = k_.x * v;
- x.y = k_.y * v;
- x.z = k_.z * v;
- x.w = k_.w * v;
-
- y += r_.x * (u_.x * x.x + s.x);
- y += r_.y * (u_.y * x.y + s.y);
- y += r_.z * (u_.z * x.z + s.z);
- y += r_.w * (u_.w * x.w + s.w);
-
- s.x = s.x * w_.x + x.x;
- s.y = s.y * w_.y + x.y;
- s.z = s.z * w_.z + x.z;
- s.w = s.w * w_.w + x.w;
- }
- _y[t] = F(y);
- }
-}
-
-template <typename F>
-__global__ void kernel_backward(const int B, const int T, const int C, const int H,
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const float *__restrict__ __w, const F *__restrict__ _u, const F *__restrict__ const _gy,
- F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gw, F *__restrict__ const _gu)
-{
- const int b = blockIdx.x / H;
- const int h = blockIdx.x % H;
- const int i = threadIdx.x;
- _w += h*_N_;
- _u += h*_N_;
- __w += h*_N_;
-
- __shared__ float w_[_N_], u_[_N_];
- __shared__ float r[_N_], k[_N_], v[_N_], gy[_N_];
- __syncthreads();
- w_[i] = _w[i];
- u_[i] = float(_u[i]);
- __syncthreads();
-
- const float w = w_[i];
- const float ww = __w[i];
- const float u = u_[i];
-
- float state[_N_] = {0}, saaaa[_N_] = {0}, sbbbb[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
-
- float gw = 0, gu = 0;
- const int t000 = b*T*C + h*_N_ + i;
- const int t111 = (b+1)*T*C + h*_N_ + i;
- const int t222 = t111 - 2*C;
-
- for (int t = t000; t < t111; t += C)
- {
- __syncthreads();
- v[i] = float(_v[t]);
- gy[i] = float(_gy[t]);
- __syncthreads();
-
- const float k = float(_k[t]);
- float gr = 0, gu_ = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = state[j];
- float x = k * v[j];
-
- gr += (u * x + s) * gy[j];
- gu_ += x * gy[j];
- s = s * w + x;
- }
- _gr[t] = F(gr);
- gu += float(_r[t]) * gu_;
- }
- _gu[b*C + h*_N_ + i] = F(gu);
-
- for (int t = t000; t < t222; t += C)
- {
- __syncthreads();
- v[i] = float(_v[t]);
- gy[i] = float(_gy[t + 2*C]);
- __syncthreads();
-
- const float k = float(_k[t]);
- float gw_ = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = saaaa[j];
- float& s2 = sbbbb[j];
- float x = k * v[j];
-
- float tmp = w * (x + s);
- s = tmp;
- s2 = tmp + w * s2;
- gw_ += s2 * gy[j];
- }
- gw += float(_r[t + 2*C]) * gw_;
- }
- _gw[b*C + h*_N_ + i] = F(ww * gw);
-
- for (int t = t111 - C; t >= t000; t -= C)
- {
- __syncthreads();
- v[i] = float(_v[t]);
- gy[i] = float(_gy[t]);
- __syncthreads();
-
- const float rr = float(_r[t]);
- float gk = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = scccc[j];
- float x = rr * gy[j];
-
- gk += (u * x + s) * v[j];
- s = x + s * w;
- }
- _gk[t] = F(gk);
- }
-
- for (int t = t111 - C; t >= t000; t -= C)
- {
- __syncthreads();
- r[i] = float(_r[t]);
- k[i] = float(_k[t]);
- __syncthreads();
-
- const float gyy = float(_gy[t]);
- float gv = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = sdddd[j];
- float x = gyy * r[j];
-
- gv += (u_[j] * x + s) * k[j];
- s = x + s * w_[j];
- }
- _gv[t] = F(gv);
- }
-}
-
-void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
-{
- assert(H*_N_ == C);
- assert(_N_%4 == 0);
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
-}
-
-void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
-{
- assert(H*_N_ == C);
- assert(_N_%4 == 0);
- kernel_backward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, ww, u, gy, gr, gk, gv, gw, gu);
-}
diff --git a/funasr/models/sense_voice/cuda/wkv5_op.cpp b/funasr/models/sense_voice/cuda/wkv5_op.cpp
deleted file mode 100644
index 4c9ece1..0000000
--- a/funasr/models/sense_voice/cuda/wkv5_op.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-#include <torch/extension.h>
-#include "ATen/ATen.h"
-typedef at::BFloat16 bf16;
-
-void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
-void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, float *ww, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
-
-void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
- cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
-}
-void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &ww, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
- cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), ww.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
-}
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("forward", &forward, "wkv5 forward");
- m.def("backward", &backward, "wkv5 backward");
-}
-
-TORCH_LIBRARY(wkv5, m) {
- m.def("forward", forward);
- m.def("backward", backward);
-}
diff --git a/funasr/models/sense_voice/cuda/wkv6_cuda.cu b/funasr/models/sense_voice/cuda/wkv6_cuda.cu
deleted file mode 100644
index d98f57f..0000000
--- a/funasr/models/sense_voice/cuda/wkv6_cuda.cu
+++ /dev/null
@@ -1,243 +0,0 @@
-#include <stdio.h>
-#include <assert.h>
-#include "ATen/ATen.h"
-typedef at::BFloat16 bf16;
-// typedef float bf16;
-
-template <typename F>
-__global__ void kernel_forward(const int B, const int T, const int C, const int H,
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u,
- F *__restrict__ const _y)
-{
- const int b = blockIdx.x / H;
- const int h = blockIdx.x % H;
- const int i = threadIdx.x;
- _u += h*_N_;
-
- __shared__ float r[_N_], k[_N_], u[_N_], w[_N_];
- float state[_N_] = {0};
-
- __syncthreads();
- u[i] = float(_u[i]);
- __syncthreads();
-
- for (int t = b*T*C + h*_N_ + i; t < (b+1)*T*C + h*_N_ + i; t += C)
- {
- __syncthreads();
- w[i] = exp(_w[t]);
- r[i] = float(_r[t]);
- k[i] = float(_k[t]);
- __syncthreads();
-
- const float v = float(_v[t]);
- float y = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j+=4)
- {
- const float4& r_ = (float4&)(r[j]);
- const float4& k_ = (float4&)(k[j]);
- const float4& w_ = (float4&)(w[j]);
- const float4& u_ = (float4&)(u[j]);
- float4& s = (float4&)(state[j]);
- float4 x;
-
- x.x = k_.x * v;
- x.y = k_.y * v;
- x.z = k_.z * v;
- x.w = k_.w * v;
-
- y += r_.x * (u_.x * x.x + s.x);
- y += r_.y * (u_.y * x.y + s.y);
- y += r_.z * (u_.z * x.z + s.z);
- y += r_.w * (u_.w * x.w + s.w);
-
- s.x = s.x * w_.x + x.x;
- s.y = s.y * w_.y + x.y;
- s.z = s.z * w_.z + x.z;
- s.w = s.w * w_.w + x.w;
- }
- _y[t] = F(y);
- }
-}
-
-template <typename F>
-__global__ void kernel_backward_111(const int B, const int T, const int C, const int H,
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
- F *__restrict__ const _gr, F *__restrict__ const _gk, F *__restrict__ const _gv, F *__restrict__ const _gu)
-{
- const int b = blockIdx.x / H;
- const int h = blockIdx.x % H;
- const int i = threadIdx.x;
- _u += h*_N_;
-
- __shared__ float u_[_N_];
- __shared__ float r[_N_], k[_N_], v[_N_], w_[_N_], gy[_N_];
- __syncthreads();
- u_[i] = float(_u[i]);
- __syncthreads();
-
- const float u = u_[i];
-
- float state[_N_] = {0}, scccc[_N_] = {0}, sdddd[_N_] = {0};
-
- const int t_0 = b*T*C + h*_N_ + i;
- const int t_T_1 = t_0 + (T-1)*C;
- const int t_T = t_0 + T*C;
-
- float gu = 0;
- for (int t = t_0; t < t_T; t += C)
- {
- __syncthreads();
- v[i] = float(_v[t]);
- gy[i] = float(_gy[t]);
- __syncthreads();
-
- const float k = float(_k[t]);
- const float w = exp(_w[t]);
- float gr = 0, gu_ = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = state[j];
- float x = k * v[j];
-
- gr += (u * x + s) * gy[j];
- gu_ += x * gy[j];
- s = s * w + x;
- }
- _gr[t] = F(gr);
- gu += float(_r[t]) * gu_;
- }
- _gu[b*C + h*_N_ + i] = F(gu);
-
- for (int t = t_T_1; t >= t_0; t -= C)
- {
- __syncthreads();
- v[i] = float(_v[t]);
- gy[i] = float(_gy[t]);
- __syncthreads();
-
- const float rr = float(_r[t]);
- const float w = exp(_w[t]);
- float gk = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = scccc[j];
- float x = rr * gy[j];
-
- gk += (u * x + s) * v[j];
- s = x + s * w;
- }
- _gk[t] = F(gk);
- }
-
- for (int t = t_T_1; t >= t_0; t -= C)
- {
- __syncthreads();
- r[i] = float(_r[t]);
- k[i] = float(_k[t]);
- w_[i] = exp(_w[t]);
- __syncthreads();
-
- const float gyy = float(_gy[t]);
- float gv = 0;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = sdddd[j];
- float x = gyy * r[j];
-
- gv += (u_[j] * x + s) * k[j];
- s = x + s * w_[j];
- }
- _gv[t] = F(gv);
- }
-}
-
-template <typename F>
-__global__ void kernel_backward_222(const int B, const int T, const int C, const int H,
- const F *__restrict__ const _r, const F *__restrict__ const _k, const F *__restrict__ const _v, const float *__restrict__ _w, const F *__restrict__ _u, const F *__restrict__ const _gy,
- F *__restrict__ const _gw)
-{
- const int b = blockIdx.x / H;
- const int h = blockIdx.x % H;
- const int i = threadIdx.x;
-
- __shared__ float v[_N_], gy[_N_];
- float saaaa[_N_] = {0}, sbbbb[_T_-2] = {0}, scccc[_N_] = {0};
-
- const int t_0 = b*T*C + h*_N_ + i;
- const int t_1 = t_0 + C;
- const int t_2 = t_0 + 2*C;
- const int t_T_1 = t_0 + (T-1)*C;
-
- for (int t = t_T_1; t > t_1; t -= C)
- {
- __syncthreads();
- gy[i] = float(_gy[t]);
- v[i] = float(_v[t-2*C]);
- __syncthreads();
-
- const float r = float(_r[t]);
- const float w = exp(_w[t-C]);
- float sum = 0.0f;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = saaaa[j];
- float x = r * gy[j];
- s = (s + x) * w;
- sum += s * v[j];
- }
- sbbbb[(t-t_2)/C] = sum * float(_k[t-2*C]);
- }
-
- float sss = sbbbb[0];
- _gw[t_0] = 0;
- _gw[t_1] = F(sss * _w[t_1]);
-
- for (int t = t_2; t < t_T_1; t += C)
- {
- __syncthreads();
- gy[i] = float(_gy[t]);
- v[i] = float(_v[t-2*C]);
- __syncthreads();
-
- const float w = exp(_w[t-C]);
- const float k = float(_k[t-2*C]);
- float sum = 0.0f;
-
- #pragma unroll
- for (int j = 0; j < _N_; j++)
- {
- float& s = scccc[j];
- float x = k * v[j];
- s = (s + x) * w;
- sum += s * gy[j];
- }
- sss += sbbbb[(t-t_1)/C] - (sum * float(_r[t]));
- _gw[t] = F(sss * _w[t]);
- }
- _gw[t_T_1] = 0;
-}
-
-void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y)
-{
- assert(H*_N_ == C);
- assert(_N_%4 == 0);
- kernel_forward<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, y);
-}
-
-void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu)
-{
- assert(H*_N_ == C);
- assert(_N_%4 == 0);
- kernel_backward_111<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gr, gk, gv, gu);
- kernel_backward_222<<<dim3(B * H), dim3(_N_)>>>(B, T, C, H, r, k, v, w, u, gy, gw);
-}
diff --git a/funasr/models/sense_voice/cuda/wkv6_op.cpp b/funasr/models/sense_voice/cuda/wkv6_op.cpp
deleted file mode 100644
index 22da520..0000000
--- a/funasr/models/sense_voice/cuda/wkv6_op.cpp
+++ /dev/null
@@ -1,23 +0,0 @@
-#include <torch/extension.h>
-#include "ATen/ATen.h"
- typedef at::BFloat16 bf16;
-//typedef float bf16;
-
-void cuda_forward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *y);
-void cuda_backward(int B, int T, int C, int H, bf16 *r, bf16 *k, bf16 *v, float *w, bf16 *u, bf16 *gy, bf16 *gr, bf16 *gk, bf16 *gv, bf16 *gw, bf16 *gu);
-
-void forward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &y) {
- cuda_forward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), y.data_ptr<bf16>());
-}
-void backward(int64_t B, int64_t T, int64_t C, int64_t H, torch::Tensor &r, torch::Tensor &k, torch::Tensor &v, torch::Tensor &w, torch::Tensor &u, torch::Tensor &gy, torch::Tensor &gr, torch::Tensor &gk, torch::Tensor &gv, torch::Tensor &gw, torch::Tensor &gu) {
- cuda_backward(B, T, C, H, r.data_ptr<bf16>(), k.data_ptr<bf16>(), v.data_ptr<bf16>(), w.data_ptr<float>(), u.data_ptr<bf16>(), gy.data_ptr<bf16>(), gr.data_ptr<bf16>(), gk.data_ptr<bf16>(), gv.data_ptr<bf16>(), gw.data_ptr<bf16>(), gu.data_ptr<bf16>());
-}
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("forward", &forward, "wkv6 forward");
- m.def("backward", &backward, "wkv6 backward");
-}
-
-TORCH_LIBRARY(wkv6, m) {
- m.def("forward", forward);
- m.def("backward", backward);
-}
diff --git a/funasr/models/sense_voice/cuda/wkv_cuda.cu b/funasr/models/sense_voice/cuda/wkv_cuda.cu
deleted file mode 100644
index 6acd0f3..0000000
--- a/funasr/models/sense_voice/cuda/wkv_cuda.cu
+++ /dev/null
@@ -1,125 +0,0 @@
-#include <stdio.h>
-#include <assert.h>
-
-#define MIN_VALUE (-1e38)
-
-template <typename F>
-__global__ void kernel_forward(const int B, const int T, const int C,
- const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v,
- F *__restrict__ const _y) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- const int _b = idx / C;
- const int _c = idx % C;
- const int _offset = _b * T * C + _c;
-
- F u = _u[_c];
- F w = _w[_c];
- const F *__restrict__ const k = _k + _offset;
- const F *__restrict__ const v = _v + _offset;
- F *__restrict__ const y = _y + _offset;
-
- F p = 0, q = 0, o = MIN_VALUE;
- // p and q are running sums divided by exp(o) (to avoid overflows)
- for (int i = 0; i < T; i++) {
- const int ii = i * C;
-
- F no = max(o, u + k[ii]);
- F A = exp(o - no);
- F B = exp(u + k[ii] - no);
- y[ii] = (A * p + B * v[ii]) / (A * q + B);
-
- no = max(w + o, k[ii]);
- A = exp(w + o - no);
- B = exp(k[ii] - no);
- p = A * p + B * v[ii];
- q = A * q + B;
- o = no;
- }
-}
-
-template <typename F>
-__global__ void kernel_backward(const int B, const int T, const int C,
- const F *__restrict__ const _w, const F *__restrict__ const _u, const F *__restrict__ const _k, const F *__restrict__ const _v, const F *__restrict__ const _gy,
- F *__restrict__ const _gw, F *__restrict__ const _gu, F *__restrict__ const _gk, F *__restrict__ const _gv) {
- const int idx = blockIdx.x * blockDim.x + threadIdx.x;
- const int _b = idx / C;
- const int _c = idx % C;
- const int _offset = _b * T * C + _c;
-
- F u = _u[_c];
- F w = _w[_c];
- const F *__restrict__ const k = _k + _offset;
- const F *__restrict__ const v = _v + _offset;
- const F *__restrict__ const gy = _gy + _offset;
-
- F *__restrict__ const gk = _gk + _offset;
- F *__restrict__ const gv = _gv + _offset;
-
- F y[Tmax], z[Tmax], zexp[Tmax];
-
- F gw = 0, gu = 0;
- F p = 0, q = 0;
- F dpdw = 0, dqdw = 0;
- F o = MIN_VALUE;
- for (int i = 0; i < T; i++) {
- const int ii = i * C;
- F no = max(o, k[ii] + u);
- F A = exp(o - no);
- F B = exp(k[ii] + u - no);
-
- F num = A * p + B * v[ii];
- F iden = 1 / (A * q + B);
-
- y[i] = num * iden;
- z[i] = iden;
- zexp[i] = k[ii] + u - no;
-
- gw += gy[ii] * (dpdw - dqdw * y[i]) * iden * A;
- gu += gy[ii] * (v[ii] - y[i]) * B * iden;
-
- no = max(w + o, k[ii]);
- A = exp(w + o - no);
- B = exp(k[ii] - no);
- dpdw = A * (p + dpdw);
- dqdw = A * (q + dqdw);
- p = A * p + B * v[ii];
- q = A * q + B;
- o = no;
- }
-
- F gp = 0, gq = 0;
- o = MIN_VALUE;
- for (int i = T - 1; i >= 0; i--) {
- const int ii = i * C;
- F A = gy[ii] * z[i] * exp(zexp[i]);
- F B = exp(k[ii] + o);
- gk[ii] = A * (v[ii] - y[i]) + B * (gp * v[ii] + gq);
- gv[ii] = A + B * gp;
-
- F no = max(w + o, zexp[i] - k[ii] - u);
- A = exp(w + o - no);
- B = gy[ii] * z[i] * exp(zexp[i] - k[ii] - u - no);
- gp = A * gp + B;
- gq = A * gq - B * y[i];
- o = no;
- }
-
- // Multiply by w because the w -> -exp(w) preprocessing is halfway in the backwards pass, even though it's not in the forward pass
- const int _offsetBC = _b * C + _c;
- _gw[_offsetBC] += gw * _w[_c];
- _gu[_offsetBC] += gu;
-}
-
-void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y) {
- dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
- assert(B * C % threadsPerBlock.x == 0);
- dim3 numBlocks(B * C / threadsPerBlock.x);
- kernel_forward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, y);
-}
-
-void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv) {
- dim3 threadsPerBlock( min(C, 32) ); // requires --maxrregcount 60 for optimal performance
- assert(B * C % threadsPerBlock.x == 0);
- dim3 numBlocks(B * C / threadsPerBlock.x);
- kernel_backward<<<numBlocks, threadsPerBlock>>>(B, T, C, w, u, k, v, gy, gw, gu, gk, gv);
-}
diff --git a/funasr/models/sense_voice/cuda/wkv_op.cpp b/funasr/models/sense_voice/cuda/wkv_op.cpp
deleted file mode 100644
index efe56d8..0000000
--- a/funasr/models/sense_voice/cuda/wkv_op.cpp
+++ /dev/null
@@ -1,21 +0,0 @@
-#include <torch/extension.h>
-
-void cuda_forward(int B, int T, int C, float *w, float *u, float *k, float *v, float *y);
-void cuda_backward(int B, int T, int C, float *w, float *u, float *k, float *v, float *gy, float *gw, float *gu, float *gk, float *gv);
-
-void forward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &y) {
- cuda_forward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), y.data_ptr<float>());
-}
-void backward(int64_t B, int64_t T, int64_t C, torch::Tensor &w, torch::Tensor &u, torch::Tensor &k, torch::Tensor &v, torch::Tensor &gy, torch::Tensor &gw, torch::Tensor &gu, torch::Tensor &gk, torch::Tensor &gv) {
- cuda_backward(B, T, C, w.data_ptr<float>(), u.data_ptr<float>(), k.data_ptr<float>(), v.data_ptr<float>(), gy.data_ptr<float>(), gw.data_ptr<float>(), gu.data_ptr<float>(), gk.data_ptr<float>(), gv.data_ptr<float>());
-}
-
-PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
- m.def("forward", &forward, "wkv forward");
- m.def("backward", &backward, "wkv backward");
-}
-
-TORCH_LIBRARY(wkv, m) {
- m.def("forward", forward);
- m.def("backward", backward);
-}
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
deleted file mode 100644
index ff933d7..0000000
--- a/funasr/models/sense_voice/decoder.py
+++ /dev/null
@@ -1,607 +0,0 @@
-import copy
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.register import tables
-import base64
-import gzip
-from dataclasses import dataclass
-from typing import Dict, Iterable, Optional
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import Tensor, nn
-from funasr.models.transformer.utils.mask import subsequent_mask
-
-
-class LayerNorm(nn.LayerNorm):
- def forward(self, x: Tensor) -> Tensor:
- return super().forward(x.float()).type(x.dtype)
-
-
-class Linear(nn.Linear):
- def forward(self, x: Tensor) -> Tensor:
- return F.linear(
- x,
- self.weight.to(x.dtype),
- None if self.bias is None else self.bias.to(x.dtype),
- )
-
-
-def sense_voice_decode_forward(
- self,
- x: torch.Tensor,
- xa: torch.Tensor,
- kv_cache: Optional[dict] = None,
- **kwargs,
-):
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- # import pdb;pdb.set_trace()
- use_padmask = self.use_padmask
- hlens = kwargs.get("hlens", None)
-
- ys_in_lens = kwargs.get("ys_in_lens", None)
-
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- tgt, memory = x, xa
- tgt[tgt == -1] = 0
- tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
- # tgt = self.dropout(tgt)
-
- x = tgt.to(memory.dtype)
-
- if use_padmask and hlens is not None:
- memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
- else:
- memory_mask = None
-
- for layer, block in enumerate(self.blocks):
- x = block(
- x,
- memory,
- mask=self.mask,
- memory_mask=memory_mask,
- is_pad_mask=False,
- is_pad_memory_mask=True,
- )
-
- x = self.ln(x)
- x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
-
- return x
-
-
-class MultiHeadAttention(nn.Module):
- def __init__(self, n_state: int, n_head: int):
- super().__init__()
- self.n_head = n_head
- self.query = Linear(n_state, n_state)
- self.key = Linear(n_state, n_state, bias=False)
- self.value = Linear(n_state, n_state)
- self.out = Linear(n_state, n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
-
- q = self.query(x)
-
- if kv_cache is None or xa is None or self.key not in kv_cache:
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
- # otherwise, perform key/value projections for self- or cross-attention as usual.
- k = self.key(x if xa is None else xa)
- v = self.value(x if xa is None else xa)
- else:
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
- k = kv_cache[self.key]
- v = kv_cache[self.value]
-
- wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
- return self.out(wv), qk
-
- def qkv_attention(
- self,
- q: Tensor,
- k: Tensor,
- v: Tensor,
- mask: Optional[Tensor] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
- n_batch, n_ctx, n_state = q.shape
- scale = (n_state // self.n_head) ** -0.25
- q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
- k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
-
- qk = q @ k
- if mask is not None:
- if not is_pad_mask:
- qk = qk + mask[:n_ctx, :n_ctx]
- else:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = -float(
- "inf"
- ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
- qk = qk.masked_fill(mask, min_value)
-
- qk = qk.float()
-
- w = F.softmax(qk, dim=-1).to(q.dtype)
- if mask is not None and is_pad_mask:
- w = w.masked_fill(mask, 0.0)
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
-
-
-from omegaconf import OmegaConf
-
-
-class ResidualAttentionBlockRWKV(nn.Module):
- def __init__(
- self, n_state: int, n_head: int, cross_attention: bool = False, layer_id=0, **kwargs
- ):
- super().__init__()
-
- rwkv_cfg = kwargs.get("rwkv_cfg", {})
- args = OmegaConf.create(rwkv_cfg)
- if args.get("version", "v4") == "v4":
- from funasr.models.sense_voice.rwkv_v4 import RWKVLayer
- from funasr.models.sense_voice.rwkv_v4 import RWKV_TimeMix as RWKV_Tmix
- elif args.get("version", "v5") == "v5":
- from funasr.models.sense_voice.rwkv_v5 import RWKVLayer
- from funasr.models.sense_voice.rwkv_v5 import RWKV_Tmix_x052 as RWKV_Tmix
- else:
- from funasr.models.sense_voice.rwkv_v6 import RWKVLayer
- from funasr.models.sense_voice.rwkv_v6 import RWKV_Tmix_x060 as RWKV_Tmix
- # self.att = RWKVLayer(args=args, layer_id=layer_id)
- self.att = RWKV_Tmix(args, layer_id=layer_id)
-
- if args.get("init_rwkv", True):
- print("init_rwkv")
- nn.init.orthogonal_(self.att.receptance.weight, gain=1)
- nn.init.orthogonal_(self.att.key.weight, gain=0.1)
- nn.init.orthogonal_(self.att.value.weight, gain=1)
- nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
- nn.init.zeros_(self.att.output.weight)
-
- self.ln0 = None
- if layer_id == 0 and not args.get("ln0", True):
- self.ln0 = LayerNorm(args.n_embd)
- if args.get("init_rwkv", True):
- print("init_rwkv")
- layer_id = 0
- scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln0.weight, scale)
-
- self.layer_id = layer_id
- self.args = args
-
- self.ln1 = None
- if not args.get("ln1", True):
- self.ln1 = LayerNorm(args.n_embd)
- # init
- if args.get("init_rwkv", True):
- print("init_rwkv")
- scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln1.weight, scale)
-
- if args.get("datatype", "bf16") == "bf16":
- self.att.to(torch.bfloat16)
- # if self.ln1 is not None:
- # self.ln1.to(torch.bfloat16)
-
- self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
-
- n_mlp = n_state * 4
- self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
- self.mlp_ln = LayerNorm(n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
- is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
-
- if self.layer_id == 0 and self.ln0 is not None:
- x = self.ln0(x)
-
- if self.args.get("datatype", "bf16") == "bf16":
- x = x.bfloat16()
- if self.ln1 is None:
- x = x + self.att(x, mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
- else:
- x = x + self.att(self.ln1(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
- if self.args.get("datatype", "bf16") == "bf16":
- x = x.to(torch.float32)
-
- if self.cross_attn:
- x = (
- x
- + self.cross_attn(
- self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
- )[0]
- )
- x = x + self.mlp(self.mlp_ln(x))
-
- return x
-
-
-@tables.register("decoder_classes", "SenseVoiceDecoder")
-class SenseVoiceDecoder(nn.Module):
- def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
- super().__init__()
-
- self.token_embedding = nn.Embedding(n_vocab, n_state)
- self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
-
- self.blocks = nn.ModuleList(
- [
- ResidualAttentionBlockRWKV(
- n_state, n_head, cross_attention=True, layer_id=i, **kwargs
- )
- for i in range(n_layer)
- ]
- )
- self.ln = LayerNorm(n_state)
-
- mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
- self.register_buffer("mask", mask, persistent=False)
-
- self.use_padmask = kwargs.get("use_padmask", True)
-
- def forward(
- self,
- x: torch.Tensor,
- xa: torch.Tensor,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- # import pdb;pdb.set_trace()
- use_padmask = self.use_padmask
- hlens = kwargs.get("hlens", None)
-
- ys_in_lens = kwargs.get("ys_in_lens", None)
-
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- tgt, memory = x, xa
- tgt[tgt == -1] = 0
- tgt = self.token_embedding(tgt) + self.positional_embedding[offset : offset + tgt.size(1)]
- # tgt = self.dropout(tgt)
-
- x = tgt.to(memory.dtype)
-
- if use_padmask and hlens is not None:
- memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
- else:
- memory_mask = None
-
- for layer, block in enumerate(self.blocks):
- x = block(
- x,
- memory,
- mask=self.mask,
- memory_mask=memory_mask,
- is_pad_mask=False,
- is_pad_memory_mask=True,
- )
-
- x = self.ln(x)
- x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
-
- return x
-
- def init_state(self, x):
- state = {}
-
- return state
-
- def final_score(self, state) -> float:
- """Score eos (optional).
-
- Args:
- state: Scorer state for prefix tokens
-
- Returns:
- float: final score
-
- """
- return 0.0
-
- def score(self, ys, state, x):
- """Score."""
- ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
- logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
- logp = torch.log_softmax(logp, dim=-1)
- return logp.squeeze(0)[-1, :], state
-
-
-class MultiHeadedAttentionSANMDecoder(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_feat, dropout_rate, kernel_size, sanm_shfit=0):
- """Construct an MultiHeadedAttention object."""
- super().__init__()
-
- self.dropout = nn.Dropout(p=dropout_rate)
-
- self.fsmn_block = nn.Conv1d(
- n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
- )
- # padding
- # padding
- left_padding = (kernel_size - 1) // 2
- if sanm_shfit > 0:
- left_padding = left_padding + sanm_shfit
- right_padding = kernel_size - 1 - left_padding
- self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
- self.kernel_size = kernel_size
-
- def forward(self, inputs, mask, cache=None, mask_shfit_chunk=None, **kwargs):
- """
- :param x: (#batch, time1, size).
- :param mask: Mask tensor (#batch, 1, time)
- :return:
- """
- # print("in fsmn, inputs", inputs.size())
- b, t, d = inputs.size()
- # logging.info(
- # "mask: {}".format(mask.size()))
- if mask is not None:
- mask = torch.reshape(mask, (b, -1, 1))
- # logging.info("in fsmn, mask: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- if mask_shfit_chunk is not None:
- # logging.info("in fsmn, mask_fsmn: {}, {}".format(mask_shfit_chunk.size(), mask_shfit_chunk[0:100:50, :, :]))
- mask = mask * mask_shfit_chunk
- # logging.info("in fsmn, mask_after_fsmn: {}, {}".format(mask.size(), mask[0:100:50, :, :]))
- # print("in fsmn, mask", mask.size())
- # print("in fsmn, inputs", inputs.size())
- inputs = inputs * mask
-
- x = inputs.transpose(1, 2)
- b, d, t = x.size()
- if cache is None:
- # print("in fsmn, cache is None, x", x.size())
-
- x = self.pad_fn(x)
- if not self.training:
- cache = x
- else:
- # print("in fsmn, cache is not None, x", x.size())
- # x = torch.cat((x, cache), dim=2)[:, :, :-1]
- # if t < self.kernel_size:
- # x = self.pad_fn(x)
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- x = x[:, :, -(self.kernel_size + t - 1) :]
- # print("in fsmn, cache is not None, x_cat", x.size())
- cache = x
- x = self.fsmn_block(x)
- x = x.transpose(1, 2)
- # print("in fsmn, fsmn_out", x.size())
- if x.size(1) != inputs.size(1):
- inputs = inputs[:, -1, :]
-
- x = x + inputs
- x = self.dropout(x)
- if mask is not None:
- x = x * mask
- return x, cache
-
-
-class ResidualAttentionBlockFSMN(nn.Module):
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, **kwargs):
- super().__init__()
-
- self.attn = MultiHeadedAttentionSANMDecoder(
- n_state,
- kwargs.get("self_attention_dropout_rate"),
- kwargs.get("kernel_size", 20),
- kwargs.get("sanm_shfit", 10),
- )
- self.attn_ln = LayerNorm(n_state)
-
- self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
-
- n_mlp = n_state * 4
- self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
- self.mlp_ln = LayerNorm(n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- cache = kwargs.get("cache", {})
- layer = kwargs.get("layer", 0)
- is_pad_mask = kwargs.get("is_pad_mask", False)
- is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
-
- fsmn_cache = cache[layer]["fsmn_cache"] if cache is not None and len(cache) > 0 else None
- # if fsmn_cache is not None:
- # x = x[:, -1:]
- att_res, fsmn_cache = self.attn(self.attn_ln(x), mask=None, cache=fsmn_cache)
- # if len(cache)>1:
- # cache[layer]["fsmn_cache"] = fsmn_cache
- # x = x[:, -1:]
- x = x + att_res
- if self.cross_attn:
- x = (
- x
- + self.cross_attn(
- self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
- )[0]
- )
- x = x + self.mlp(self.mlp_ln(x))
- return x
-
-
-@tables.register("decoder_classes", "SenseVoiceDecoderFSMN")
-class SenseVoiceDecoderFSMN(nn.Module):
- def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int, **kwargs):
- super().__init__()
-
- self.token_embedding = nn.Embedding(n_vocab, n_state)
- self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
-
- self.blocks = nn.ModuleList(
- [
- ResidualAttentionBlockFSMN(
- n_state, n_head, cross_attention=True, layer_id=i, **kwargs
- )
- for i in range(n_layer)
- ]
- )
- self.ln = LayerNorm(n_state)
-
- mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
- self.register_buffer("mask", mask, persistent=False)
-
- self.use_padmask = kwargs.get("use_padmask", True)
-
- def forward(
- self,
- x: torch.Tensor,
- xa: torch.Tensor,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- """Forward decoder.
-
- Args:
- hs_pad: encoded memory, float32 (batch, maxlen_in, feat)
- hlens: (batch)
- ys_in_pad:
- input token ids, int64 (batch, maxlen_out)
- if input_layer == "embed"
- input tensor (batch, maxlen_out, #mels) in the other cases
- ys_in_lens: (batch)
- Returns:
- (tuple): tuple containing:
-
- x: decoded token score before softmax (batch, maxlen_out, token)
- if use_output_layer is True,
- olens: (batch, )
- """
- # import pdb;pdb.set_trace()
- use_padmask = self.use_padmask
- hlens = kwargs.get("hlens", None)
-
- ys_in_lens = kwargs.get("ys_in_lens", None)
-
- tgt, memory = x, xa
- tgt[tgt == -1] = 0
- tgt = self.token_embedding(tgt) + self.positional_embedding[: tgt.size(1)]
- # tgt = self.dropout(tgt)
-
- x = tgt.to(memory.dtype)
-
- if use_padmask and hlens is not None:
- memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
- else:
- memory_mask = None
-
- for layer, block in enumerate(self.blocks):
- x = block(
- x,
- memory,
- mask=self.mask,
- memory_mask=memory_mask,
- is_pad_mask=False,
- is_pad_memory_mask=True,
- cache=kwargs.get("cache", None),
- layer=layer,
- )
-
- x = self.ln(x)
- x = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
-
- return x
-
- def init_state(self, x):
- state = {}
- for layer, block in enumerate(self.blocks):
- state[layer] = {
- "fsmn_cache": None,
- "memory_key": None,
- "memory_value": None,
- }
-
- return state
-
- def final_score(self, state) -> float:
- """Score eos (optional).
-
- Args:
- state: Scorer state for prefix tokens
-
- Returns:
- float: final score
-
- """
- return 0.0
-
- def score(self, ys, state, x):
- """Score."""
- ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
- logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=None)
- logp = torch.log_softmax(logp, dim=-1)
- return logp.squeeze(0)[-1, :], state
diff --git a/funasr/models/sense_voice/encoder.py b/funasr/models/sense_voice/encoder.py
deleted file mode 100644
index d464f1c..0000000
--- a/funasr/models/sense_voice/encoder.py
+++ /dev/null
@@ -1,57 +0,0 @@
-import copy
-from typing import Optional, Tuple, Union
-
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-
-
-def sense_voice_encode_forward(
- self,
- x: torch.Tensor,
- ilens: torch.Tensor = None,
- **kwargs,
-):
- use_padmask = self.use_padmask
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
-
- n_frames = x.size(1)
- max_pos = self.positional_embedding.size(0)
- max_pos = n_frames if n_frames < max_pos else max_pos
- x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
-
- if ilens is not None:
- if self.downsample_rate == 4:
- olens = (
- 1
- + (ilens - self.conv1.kernel_size[0] + 2 * self.conv1.padding[0])
- // self.conv1.stride[0]
- )
- else:
- olens = ilens
- olens = (
- 1
- + (olens - self.conv2.kernel_size[0] + 2 * self.conv2.padding[0])
- // self.conv2.stride[0]
- )
- olens = torch.clamp(olens, max=max_pos)
- else:
- olens = None
-
- if use_padmask and olens is not None:
- padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- else:
- padding_mask = None
-
- for layer, block in enumerate(self.blocks):
- x = block(x, mask=padding_mask, is_pad_mask=True)
-
- x = self.ln_post(x)
-
- if ilens is None:
- return x
- else:
- return x, olens
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
deleted file mode 100644
index 9db6539..0000000
--- a/funasr/models/sense_voice/model.py
+++ /dev/null
@@ -1,1394 +0,0 @@
-import logging
-from dataclasses import dataclass
-from typing import Dict
-from typing import Iterable, Optional
-import types
-import time
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import Tensor
-from torch import nn
-from torch.cuda.amp import autocast
-from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
-from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
-from funasr.train_utils.device_funcs import force_gatherable
-from . import whisper_lib as whisper
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.models.ctc.ctc import CTC
-
-from funasr.register import tables
-
-
-@tables.register("model_classes", "SenseVoice")
-class SenseVoice(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- dims = kwargs.get("dims", {})
- dims = whisper.model.ModelDimensions(**dims)
- model = whisper.model.Whisper(dims=dims)
-
- # encoder
- model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
- model.encoder.use_padmask = kwargs.get("use_padmask", True)
- from .encoder import sense_voice_encode_forward
-
- model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
-
- # decoder
- model.decoder.use_padmask = kwargs.get("use_padmask", True)
- from .decoder import sense_voice_decode_forward
-
- model.decoder.forward = types.MethodType(sense_voice_decode_forward, model.decoder)
-
- self.model = model
-
- self.encoder_output_size = self.model.dims.n_audio_state
-
- self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
- self.ignore_id = kwargs.get("ignore_id", -1)
- self.vocab_size = kwargs.get("vocab_size", -1)
- self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
- self.criterion_att = LabelSmoothingLoss(
- size=self.vocab_size,
- padding_idx=self.ignore_id,
- smoothing=kwargs.get("lsm_weight", 0.0),
- normalize_length=self.length_normalized_loss,
- )
-
- specaug = kwargs.get("specaug", None)
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**kwargs.get("specaug_conf", {}))
- self.specaug = specaug
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
-
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- if self.activation_checkpoint:
- from torch.utils.checkpoint import checkpoint
-
- encoder_out, encoder_out_lens = checkpoint(
- self.encode, speech, speech_lengths, use_reentrant=False
- )
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
- )
- loss = loss_att
- stats = {}
- stats["acc"] = acc_att
- stats["loss"] = torch.clone(loss.detach())
- stats["batch_size"] = batch_size
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- **kwargs,
- ):
- """Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
-
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Forward encoder
- encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
-
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
- stats = {}
-
- # 1. Forward decoder
- decoder_out = self.model.decoder(
- x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
- )
-
- # 2. Compute attention loss
- mask = torch.ones_like(ys_pad) * (-1)
- ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
- ys_pad_mask[ys_pad_mask == 0] = -1
- loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
-
- with torch.no_grad():
- preds = torch.argmax(decoder_out, -1)
- acc_att = compute_accuracy(
- preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
- )
-
- return loss_att, acc_att, None, None
-
- def inference(
- self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(
- n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
- )
- self.frontend = frontend
- else:
- frontend = frontend if frontend is not None else self.frontend
-
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs if hasattr(frontend, "fs") else 16000,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
-
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
-
- DecodingOptions = kwargs.get("DecodingOptions", {})
- task = DecodingOptions.get("task", "ASR")
- if isinstance(task, str):
- task = [task]
- task = "".join([f"<|{x}|>" for x in task])
- initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
- DecodingOptions["initial_prompt"] = initial_prompt
-
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
- DecodingOptions["language"] = language
-
- DecodingOptions["vocab_path"] = kwargs["tokenizer_conf"].get("vocab_path", None)
-
- if "without_timestamps" not in DecodingOptions:
- DecodingOptions["without_timestamps"] = True
-
- options = whisper.DecodingOptions(**DecodingOptions)
-
- result = whisper.decode(self.model, speech, options)
- text = f"{result.text}"
- results = []
- result_i = {"key": key[0], "text": text}
-
- results.append(result_i)
-
- return results, meta_data
-
-
-@tables.register("model_classes", "SenseVoiceRWKV")
-class SenseVoiceRWKV(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- dims = kwargs.get("dims", {})
- dims = whisper.model.ModelDimensions(**dims)
- model = whisper.model.Whisper(dims=dims)
-
- # encoder
- model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
- model.encoder.use_padmask = kwargs.get("use_padmask", True)
- from .encoder import sense_voice_encode_forward
-
- model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
-
- # decoder
- del model.decoder
- decoder = kwargs.get("decoder", "SenseVoiceDecoder")
- decoder_class = tables.decoder_classes.get(decoder)
- decoder = decoder_class(
- n_vocab=dims.n_vocab,
- n_ctx=dims.n_text_ctx,
- n_state=dims.n_text_state,
- n_head=dims.n_text_head,
- n_layer=dims.n_text_layer,
- **kwargs.get("decoder_conf"),
- )
- model.decoder = decoder
-
- self.model = model
-
- self.encoder_output_size = self.model.dims.n_audio_state
-
- self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
- self.ignore_id = kwargs.get("ignore_id", -1)
- self.vocab_size = kwargs.get("vocab_size", -1)
- self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
- self.criterion_att = LabelSmoothingLoss(
- size=self.vocab_size,
- padding_idx=self.ignore_id,
- smoothing=kwargs.get("lsm_weight", 0.0),
- normalize_length=self.length_normalized_loss,
- )
-
- specaug = kwargs.get("specaug", None)
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**kwargs.get("specaug_conf", {}))
- self.specaug = specaug
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
-
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size, frames, _ = speech.shape
- _, text_tokens = text.shape
-
- if self.activation_checkpoint:
- from torch.utils.checkpoint import checkpoint
-
- encoder_out, encoder_out_lens = checkpoint(
- self.encode, speech, speech_lengths, use_reentrant=False
- )
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
- )
- loss = loss_att
- stats = {}
- stats["acc"] = acc_att
- stats["loss"] = torch.clone(loss.detach())
- stats["batch_size"] = batch_size
- stats["batch_size_x_frames"] = frames * batch_size
- stats["batch_size_real_frames"] = speech_lengths.sum().item()
- stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
- stats["batch_size_x_tokens"] = text_tokens * batch_size
- stats["batch_size_real_tokens"] = text_lengths.sum().item()
- stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
- stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- **kwargs,
- ):
- """Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Forward encoder
- encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
-
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
- stats = {}
-
- # 1. Forward decoder
- # ys_pad: [sos, task, lid, text, eos]
- decoder_out = self.model.decoder(
- x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
- )
-
- # 2. Compute attention loss
- mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
- ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
- torch.int64
- ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
- ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos]
- # decoder_out: [sos, task, lid, text]
- # ys_pad_mask: [-1, lid, text, eos]
- loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
-
- with torch.no_grad():
- preds = torch.argmax(decoder_out, -1)
- acc_att = compute_accuracy(
- preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
- )
-
- return loss_att, acc_att, None, None
-
- def init_beam_search(
- self,
- **kwargs,
- ):
- from .search import BeamSearch
-
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- scorers.update(
- decoder=self.model.decoder,
- length_bonus=LengthBonus(self.vocab_size),
- )
-
- weights = dict(
- decoder=1.0,
- ctc=0.0,
- lm=0.0,
- ngram=0.0,
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 5),
- weights=weights,
- scorers=scorers,
- sos=None,
- eos=None,
- vocab_size=self.vocab_size,
- token_list=None,
- pre_beam_score_key="full",
- )
-
- self.beam_search = beam_search
-
- def inference(
- self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if not hasattr(self, "beam_search") or self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(
- n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
- )
- self.frontend = frontend
- else:
- frontend = frontend if frontend is not None else self.frontend
-
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs if hasattr(frontend, "fs") else 16000,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
-
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
-
- DecodingOptions = kwargs.get("DecodingOptions", {})
- task = DecodingOptions.get("task", "ASR")
- if isinstance(task, str):
- task = [task]
- task = "".join([f"<|{x}|>" for x in task])
- initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
-
- sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
- sos_int = tokenizer.encode(sos, allowed_special="all")
- eos = kwargs.get("model_conf").get("eos")
- eos_int = tokenizer.encode(eos, allowed_special="all")
- self.beam_search.sos = sos_int
- self.beam_search.eos = eos_int[0]
-
- # Paramterts for rich decoding
- self.beam_search.emo_unk = tokenizer.encode(
- DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
- )[0]
- self.beam_search.emo_unk_score = 1
- self.beam_search.emo_tokens = tokenizer.encode(
- DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
- allowed_special="all",
- )
- self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
-
- self.beam_search.event_bg_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_ed_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
-
- encoder_out, encoder_out_lens = self.encode(
- speech[None, :, :].permute(0, 2, 1), speech_lengths
- )
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0],
- maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0),
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # # remove blank symbol id, which is assumed to be 0
- # token_int = list(
- # filter(
- # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
- # )
- # )
-
- # Change integer-ids to tokens
- # token = tokenizer.ids2tokens(token_int)
- text = tokenizer.decode(token_int)
-
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
-
- if ibest_writer is not None:
- # ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
-
- return results, meta_data
-
-
-@tables.register("model_classes", "SenseVoiceFSMN")
-class SenseVoiceFSMN(nn.Module):
- def __init__(self, *args, **kwargs):
- super().__init__()
-
- dims = kwargs.get("dims", {})
- dims = whisper.model.ModelDimensions(**dims)
- model = whisper.model.Whisper(dims=dims)
-
- # encoder
- model.encoder.downsample_rate = kwargs.get("downsample_rate", 4)
- model.encoder.use_padmask = kwargs.get("use_padmask", True)
- from .encoder import sense_voice_encode_forward
-
- model.encoder.forward = types.MethodType(sense_voice_encode_forward, model.encoder)
-
- # decoder
- del model.decoder
- decoder = kwargs.get("decoder", "SenseVoiceDecoder")
- decoder_class = tables.decoder_classes.get(decoder)
- decoder = decoder_class(
- n_vocab=dims.n_vocab,
- n_ctx=dims.n_text_ctx,
- n_state=dims.n_text_state,
- n_head=dims.n_text_head,
- n_layer=dims.n_text_layer,
- **kwargs.get("decoder_conf"),
- )
- model.decoder = decoder
-
- self.model = model
-
- self.encoder_output_size = self.model.dims.n_audio_state
-
- self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
- self.ignore_id = kwargs.get("ignore_id", -1)
- self.vocab_size = dims.n_vocab
- self.length_normalized_loss = kwargs.get("length_normalized_loss", True)
- self.criterion_att = LabelSmoothingLoss(
- size=self.vocab_size,
- padding_idx=self.ignore_id,
- smoothing=kwargs.get("lsm_weight", 0.0),
- normalize_length=self.length_normalized_loss,
- )
-
- specaug = kwargs.get("specaug", None)
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**kwargs.get("specaug_conf", {}))
- self.specaug = specaug
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
-
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size, frames, _ = speech.shape
- _, text_tokens = text.shape
-
- if self.activation_checkpoint:
- from torch.utils.checkpoint import checkpoint
-
- encoder_out, encoder_out_lens = checkpoint(
- self.encode, speech, speech_lengths, use_reentrant=False
- )
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- with autocast(False):
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
- )
-
- loss = loss_att
- stats = {}
- stats["acc"] = acc_att
- stats["loss"] = torch.clone(loss.detach())
- stats["batch_size"] = batch_size
- stats["batch_size_x_frames"] = frames * batch_size
- stats["batch_size_real_frames"] = speech_lengths.sum().item()
- stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
- stats["batch_size_x_tokens"] = text_tokens * batch_size
- stats["batch_size_real_tokens"] = text_lengths.sum().item()
- stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
- stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- **kwargs,
- ):
- """Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Forward encoder
- encoder_out, encoder_out_lens = self.model.encoder(speech.permute(0, 2, 1), speech_lengths)
-
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
- stats = {}
-
- # 1. Forward decoder
- decoder_out = self.model.decoder(
- x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
- )
- # decoder_out, _ = self.model.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- # 2. Compute attention loss
- mask = torch.ones_like(ys_pad) * (-1)
- ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
- ys_pad_mask[ys_pad_mask == 0] = -1
- loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
-
- with torch.no_grad():
- preds = torch.argmax(decoder_out, -1)
- acc_att = compute_accuracy(
- preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
- )
-
- return loss_att, acc_att, None, None
-
- def init_beam_search(
- self,
- **kwargs,
- ):
- from .search import BeamSearch
-
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- scorers.update(
- decoder=self.model.decoder,
- length_bonus=LengthBonus(self.vocab_size),
- )
-
- weights = dict(
- decoder=1.0,
- ctc=0.0,
- lm=0.0,
- ngram=0.0,
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 5),
- weights=weights,
- scorers=scorers,
- sos=None,
- eos=None,
- vocab_size=self.vocab_size,
- token_list=None,
- pre_beam_score_key="full",
- )
-
- self.beam_search = beam_search
-
- def inference(
- self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if not hasattr(self, "beam_search") or self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(
- n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
- )
- self.frontend = frontend
- else:
- frontend = frontend if frontend is not None else self.frontend
-
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs if hasattr(frontend, "fs") else 16000,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
-
- if (
- isinstance(kwargs.get("data_type", None), (list, tuple))
- and len(kwargs.get("data_type", [])) > 1
- ):
- audio_sample_list, text_token_int_list = audio_sample_list
- text_token_int = text_token_int_list[0]
- else:
- text_token_int = None
-
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
-
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
-
- DecodingOptions = kwargs.get("DecodingOptions", {})
- task = DecodingOptions.get("task", "ASR")
- if isinstance(task, str):
- task = [task]
- task = "".join([f"<|{x}|>" for x in task])
- initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
-
- sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
- sos_int = tokenizer.encode(sos, allowed_special="all")
- eos = kwargs.get("model_conf").get("eos")
- eos_int = tokenizer.encode(eos, allowed_special="all")
- self.beam_search.sos = sos_int
- self.beam_search.eos = eos_int[0]
-
- # Paramterts for rich decoding
- self.beam_search.emo_unk = tokenizer.encode(
- DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
- )[0]
- self.beam_search.emo_unk_score = 1
- self.beam_search.emo_tokens = tokenizer.encode(
- DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
- allowed_special="all",
- )
- self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
-
- self.beam_search.event_bg_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_ed_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
-
- encoder_out, encoder_out_lens = self.encode(
- speech[None, :, :].permute(0, 2, 1), speech_lengths
- )
-
- if text_token_int is not None:
- i = 0
- results = []
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"1best_recog"]
-
- # 1. Forward decoder
- ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
- None, :
- ]
- ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
- kwargs["device"]
- )[None, :]
- decoder_out = self.model.decoder(
- x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
- )
-
- token_int = decoder_out.argmax(-1)[0, :].tolist()
- text = tokenizer.decode(token_int)
-
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
-
- if ibest_writer is not None:
- # ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- return results, meta_data
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0],
- maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0),
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # # remove blank symbol id, which is assumed to be 0
- # token_int = list(
- # filter(
- # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
- # )
- # )
-
- # Change integer-ids to tokens
- # token = tokenizer.ids2tokens(token_int)
- text = tokenizer.decode(token_int)
-
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
-
- if ibest_writer is not None:
- # ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
-
- return results, meta_data
-
-
-@tables.register("model_classes", "SenseVoiceSANM")
-class SenseVoiceSANM(nn.Module):
-
- def __init__(
- self,
- specaug: str = None,
- specaug_conf: dict = None,
- normalize: str = None,
- normalize_conf: dict = None,
- encoder: str = None,
- encoder_conf: dict = None,
- decoder: str = None,
- decoder_conf: dict = None,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- **kwargs,
- ):
-
- super().__init__()
-
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**specaug_conf)
-
- encoder_class = tables.encoder_classes.get(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
-
- decoder_class = tables.decoder_classes.get(decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
-
- self.blank_id = blank_id
- self.sos = sos if sos is not None else vocab_size - 1
- self.eos = eos if eos is not None else vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
-
- self.specaug = specaug
-
- self.encoder = encoder
-
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
-
- self.error_calculator = None
-
- self.length_normalized_loss = length_normalized_loss
- self.beam_search = None
- self.activation_checkpoint = kwargs.get("activation_checkpoint", False)
- self.encoder_output_size = encoder_output_size
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
-
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size, frames, _ = speech.shape
- _, text_tokens = text.shape
-
- if self.activation_checkpoint:
- from torch.utils.checkpoint import checkpoint
-
- encoder_out, encoder_out_lens = checkpoint(
- self.encode, speech, speech_lengths, use_reentrant=False
- )
- else:
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths, target_mask=target_mask
- )
-
- loss = loss_att
- stats = {}
- stats["acc"] = acc_att
- stats["loss"] = torch.clone(loss.detach())
- stats["batch_size"] = batch_size
- stats["batch_size_x_frames"] = frames * batch_size
- stats["batch_size_real_frames"] = speech_lengths.sum().item()
- stats["padding_frames"] = stats["batch_size_x_frames"] - stats["batch_size_real_frames"]
- stats["batch_size_x_tokens"] = text_tokens * batch_size
- stats["batch_size_real_tokens"] = text_lengths.sum().item()
- stats["padding_tokens"] = stats["batch_size_x_tokens"] - stats["batch_size_real_tokens"]
- stats["batch_size_x_frames_plus_tokens"] = (text_tokens + frames) * batch_size
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
- def encode(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- **kwargs,
- ):
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
-
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
-
- # Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
-
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
- if isinstance(encoder_out, (tuple, list)):
- encoder_out = encoder_out[0]
-
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- **kwargs,
- ):
- target_mask = kwargs.get("target_mask", None)
- stats = {}
-
- # 1. Forward decoder
- ys_pad[ys_pad == -1] = 0
- decoder_out = self.decoder(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
- if isinstance(decoder_out, (list, tuple)):
- decoder_out = decoder_out[0]
-
- # 2. Compute attention loss
- mask = torch.ones_like(ys_pad) * (-1)
- ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
- ys_pad_mask[ys_pad_mask == 0] = -1
- loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
-
- with torch.no_grad():
- preds = torch.argmax(decoder_out, -1)
- acc_att = compute_accuracy(
- preds[:, :-1], ys_pad_mask[:, 1:], ignore_label=self.ignore_id
- )
-
- return loss_att, acc_att, None, None
-
- def init_beam_search(
- self,
- **kwargs,
- ):
- from .search import BeamSearch
-
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- scorers.update(
- decoder=self.decoder,
- length_bonus=LengthBonus(self.vocab_size),
- )
-
- weights = dict(
- decoder=1.0,
- ctc=0.0,
- lm=0.0,
- ngram=0.0,
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 5),
- weights=weights,
- scorers=scorers,
- sos=None,
- eos=None,
- vocab_size=self.vocab_size,
- token_list=None,
- pre_beam_score_key="full",
- )
-
- self.beam_search = beam_search
-
- def inference(
- self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if not hasattr(self, "beam_search") or self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- if frontend is None and not hasattr(self, "frontend"):
- frontend_class = tables.frontend_classes.get("WhisperFrontend")
- frontend = frontend_class(
- n_mels=self.model.dims.n_mels, do_pad_trim=kwargs.get("do_pad_trim", True)
- )
- self.frontend = frontend
- else:
- frontend = frontend if frontend is not None else self.frontend
-
- meta_data = {}
- if (
- isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
- ): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(
- data_in,
- fs=frontend.fs if hasattr(frontend, "fs") else 16000,
- audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer,
- )
-
- if (
- isinstance(kwargs.get("data_type", None), (list, tuple))
- and len(kwargs.get("data_type", [])) > 1
- ):
- audio_sample_list, text_token_int_list = audio_sample_list
- text_token_int = text_token_int_list[0]
- else:
- text_token_int = None
-
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(
- audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
- )
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = frontend.frame_shift if hasattr(frontend, "frame_shift") else 10
- lfr_n = frontend.lfr_n if hasattr(frontend, "lfr_n") else 1
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift * lfr_n / 1000
-
- speech = speech.to(device=kwargs["device"])[0, :, :]
- speech_lengths = speech_lengths.to(device=kwargs["device"])
-
- DecodingOptions = kwargs.get("DecodingOptions", {})
- task = DecodingOptions.get("task", "ASR")
- if isinstance(task, str):
- task = [task]
- task = "".join([f"<|{x}|>" for x in task])
-
- sos = kwargs.get("model_conf").get("sos")
- if isinstance(sos, str):
- initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
-
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
-
- sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
- sos_int = tokenizer.encode(sos, allowed_special="all")
- else:
- language = DecodingOptions.get("language", None)
- language = None if language == "auto" else language
- initial_prompt = kwargs.get("initial_prompt", f"{task}")
- initial_prompt_lid = (
- f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
- )
- initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
- sos_int = [sos] + initial_prompt_lid_int
- eos = kwargs.get("model_conf").get("eos")
- if isinstance(eos, str):
- eos_int = tokenizer.encode(eos, allowed_special="all")
- else:
- eos_int = [eos]
-
- self.beam_search.sos = sos_int
- self.beam_search.eos = eos_int[0]
-
- # Paramterts for rich decoding
- self.beam_search.emo_unk = tokenizer.encode(
- DecodingOptions.get("emo_unk_token", "<|SPECIAL_TOKEN_1|>"), allowed_special="all"
- )[0]
- self.beam_search.emo_unk_score = 1
- self.beam_search.emo_tokens = tokenizer.encode(
- DecodingOptions.get("emo_target_tokens", "<|HAPPY|><|SAD|><|ANGRY|>"),
- allowed_special="all",
- )
- self.beam_search.emo_scores = DecodingOptions.get("emo_target_threshold", [0.1, 0.1, 0.1])
-
- self.beam_search.event_bg_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_bg", "<|Speech|><|BGM|><|Applause|><|Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_ed_token = tokenizer.encode(
- DecodingOptions.get("gain_tokens_ed", "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"),
- allowed_special="all",
- )
- self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
-
- encoder_out, encoder_out_lens = self.encode(speech[None, :, :], speech_lengths)
-
- if text_token_int is not None:
- i = 0
- results = []
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"1best_recog"]
-
- # 1. Forward decoder
- ys_pad = torch.tensor(sos_int + text_token_int, dtype=torch.int64).to(kwargs["device"])[
- None, :
- ]
- ys_pad_lens = torch.tensor([len(sos_int + text_token_int)], dtype=torch.int64).to(
- kwargs["device"]
- )[None, :]
- decoder_out = self.model.decoder(
- x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
- )
-
- token_int = decoder_out.argmax(-1)[0, :].tolist()
- text = tokenizer.decode(token_int)
-
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
-
- if ibest_writer is not None:
- # ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- return results, meta_data
-
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0],
- maxlenratio=kwargs.get("maxlenratio", 0.0),
- minlenratio=kwargs.get("minlenratio", 0.0),
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # # remove blank symbol id, which is assumed to be 0
- # token_int = list(
- # filter(
- # lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int
- # )
- # )
-
- # Change integer-ids to tokens
- # token = tokenizer.ids2tokens(token_int)
- text = tokenizer.decode(token_int)
-
- result_i = {"key": key[i], "text": text}
- results.append(result_i)
-
- if ibest_writer is not None:
- # ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
-
- return results, meta_data
-
-
-from funasr.models.paraformer.search import Hypothesis
-from funasr.utils import postprocess_utils
diff --git a/funasr/models/sense_voice/rwkv_v4.py b/funasr/models/sense_voice/rwkv_v4.py
deleted file mode 100644
index c154ac0..0000000
--- a/funasr/models/sense_voice/rwkv_v4.py
+++ /dev/null
@@ -1,412 +0,0 @@
-########################################################################################################
-# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
-########################################################################################################
-
-import os, math, gc, importlib
-import torch
-
-# torch._C._jit_set_profiling_executor(True)
-# torch._C._jit_set_profiling_mode(True)
-import torch.nn as nn
-from torch.nn import functional as F
-
-
-def __nop(ob):
- return ob
-
-
-MyModule = nn.Module
-MyFunction = __nop
-if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
- MyModule = torch.jit.ScriptModule
- MyFunction = torch.jit.script_method
-
-########################################################################################################
-# CUDA Kernel
-########################################################################################################
-
-wkv_cuda = None
-
-
-def load_rwkv_kernel(
- HEAD_SIZE: int = 64,
- RWKV_CTXLEN: int = 512,
- T_MAX: int = 512,
-):
- from torch.utils.cpp_extension import load
-
- global wkv_cuda
-
- if wkv_cuda is not None:
- return
-
- absolute_file_path = os.path.abspath(__file__)
- cur_dir = os.path.dirname(absolute_file_path)
- wkv_cuda = load(
- name="wkv",
- sources=[f"{cur_dir}/cuda/wkv_op.cpp", f"{cur_dir}/cuda/wkv_cuda.cu"],
- verbose=True,
- extra_cuda_cflags=[
- "-res-usage",
- "--maxrregcount 60",
- "--use_fast_math",
- "-O3",
- "-Xptxas -O3",
- f"-DTmax={T_MAX}",
- ],
- )
-
-
-class WKV(torch.autograd.Function):
- @staticmethod
- def forward(ctx, B, T, C, w, u, k, v):
- ctx.B = B
- ctx.T = T
- ctx.C = C
- # assert T <= T_MAX
- assert B * C % min(C, 1024) == 0
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- w = -torch.exp(w.contiguous())
- u = u.contiguous()
- k = k.contiguous()
- v = v.contiguous()
- else:
- w = -torch.exp(w.float().contiguous())
- u = u.float().contiguous()
- k = k.float().contiguous()
- v = v.float().contiguous()
- ctx.save_for_backward(w, u, k, v)
- y = torch.empty((B, T, C), device="cuda", memory_format=torch.contiguous_format)
- wkv_cuda.forward(B, T, C, w, u, k, v, y)
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- return y
- elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
- return y.half()
- elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
- return y.bfloat16()
-
- @staticmethod
- def backward(ctx, gy):
- B = ctx.B
- T = ctx.T
- C = ctx.C
- assert T <= T_MAX
- assert B * C % min(C, 1024) == 0
- w, u, k, v = ctx.saved_tensors
- gw = torch.zeros((B, C), device="cuda").contiguous()
- gu = torch.zeros((B, C), device="cuda").contiguous()
- gk = torch.zeros((B, T, C), device="cuda").contiguous()
- gv = torch.zeros((B, T, C), device="cuda").contiguous()
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- wkv_cuda.backward(B, T, C, w, u, k, v, gy.contiguous(), gw, gu, gk, gv)
- else:
- wkv_cuda.backward(B, T, C, w, u, k, v, gy.float().contiguous(), gw, gu, gk, gv)
- gw = torch.sum(gw, dim=0)
- gu = torch.sum(gu, dim=0)
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- return (None, None, None, gw, gu, gk, gv)
- elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
- return (None, None, None, gw.half(), gu.half(), gk.half(), gv.half())
- elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
- return (None, None, None, gw.bfloat16(), gu.bfloat16(), gk.bfloat16(), gv.bfloat16())
-
-
-def RUN_CUDA(B, T, C, w, u, k, v):
- return WKV.apply(B, T, C, w.cuda(), u.cuda(), k.cuda(), v.cuda())
-
-
-class RWKV_TimeMix(torch.jit.ScriptModule):
- def __init__(self, config, layer_id):
- super().__init__()
- load_rwkv_kernel()
- self.layer_id = layer_id
- self.ctx_len = config.ctx_len
- self.n_embd = config.n_embd
-
- attn_sz = config.n_embd
-
- with torch.no_grad(): # fancy init
- ratio_0_to_1 = layer_id / (config.n_layer - 1) # 0 to 1
- ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
-
- # fancy time_decay
- decay_speed = torch.ones(attn_sz)
- for h in range(attn_sz):
- decay_speed[h] = -5 + 8 * (h / (attn_sz - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
- self.time_decay = nn.Parameter(decay_speed)
- # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
-
- # fancy time_first
- zigzag = torch.tensor([(i + 1) % 3 - 1 for i in range(attn_sz)]) * 0.5
- self.time_first = nn.Parameter(torch.ones(attn_sz) * math.log(0.3) + zigzag)
-
- # fancy time_mix
- x = torch.ones(1, 1, config.n_embd)
- for i in range(config.n_embd):
- x[0, 0, i] = i / config.n_embd
- self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
- self.time_mix_v = nn.Parameter(torch.pow(x, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
- self.time_mix_r = nn.Parameter(torch.pow(x, 0.5 * ratio_1_to_almost0))
-
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-
- self.key = nn.Linear(config.n_embd, attn_sz, bias=False)
- self.value = nn.Linear(config.n_embd, attn_sz, bias=False)
- self.receptance = nn.Linear(config.n_embd, attn_sz, bias=False)
-
- self.output = nn.Linear(attn_sz, config.n_embd, bias=False)
-
- self.key.scale_init = 0
- self.receptance.scale_init = 0
- self.output.scale_init = 0
-
- @torch.jit.script_method
- def jit_func(self, x):
-
- # Mix x with the previous timestep to produce xk, xv, xr
- xx = self.time_shift(x)
- xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
- xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
- xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
-
- # Use xk, xv, xr to produce k, v, r
- k = self.key(xk)
- v = self.value(xv)
- r = self.receptance(xr)
- sr = torch.sigmoid(r)
-
- return sr, k, v
-
- def forward(self, x):
- B, T, C = x.size() # x = (Batch,Time,Channel)
-
- sr, k, v = self.jit_func(x)
-
- rwkv = sr * RUN_CUDA(B, T, C, self.time_decay, self.time_first, k, v)
- rwkv = self.output(rwkv)
- return rwkv
-
-
-class RWKV_ChannelMix(torch.jit.ScriptModule):
- def __init__(self, config, layer_id):
- super().__init__()
- self.layer_id = layer_id
-
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-
- with torch.no_grad(): # fancy init of time_mix
- ratio_1_to_almost0 = 1.0 - (layer_id / config.n_layer) # 1 to ~0
-
- x = torch.ones(1, 1, config.n_embd)
- for i in range(config.n_embd):
- x[0, 0, i] = i / config.n_embd
-
- self.time_mix_k = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
- self.time_mix_r = nn.Parameter(torch.pow(x, ratio_1_to_almost0))
-
- hidden_sz = 4 * config.n_embd
- self.key = nn.Linear(config.n_embd, hidden_sz, bias=False)
- self.receptance = nn.Linear(config.n_embd, config.n_embd, bias=False)
- self.value = nn.Linear(hidden_sz, config.n_embd, bias=False)
-
- self.value.scale_init = 0
- self.receptance.scale_init = 0
-
- @torch.jit.script_method
- def forward(self, x):
- xx = self.time_shift(x)
- xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
- xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
-
- k = self.key(xk)
- k = torch.square(torch.relu(k))
- kv = self.value(k)
-
- rkv = torch.sigmoid(self.receptance(xr)) * kv
- return rkv
-
-
-# class Block(nn.Module):
-# def __init__(self, args, layer_id):
-# super().__init__()
-# self.args = args
-# self.layer_id = layer_id
-#
-# self.ln1 = nn.LayerNorm(args.n_embd)
-# self.ln2 = nn.LayerNorm(args.n_embd)
-#
-# if self.layer_id == 0:
-# self.ln0 = nn.LayerNorm(args.n_embd)
-#
-# self.att = RWKV_Tmix_x060(args, layer_id)
-#
-# self.ffn = RWKV_CMix_x060(args, layer_id)
-#
-# if args.dropout > 0:
-# self.drop0 = nn.Dropout(p=args.dropout)
-# self.drop1 = nn.Dropout(p=args.dropout)
-#
-# def forward(self, x, x_emb=None):
-# args = self.args
-# B, T, C = x.size()
-# if self.layer_id == 0:
-# x = self.ln0(x)
-#
-# if self.args.dropout == 0:
-# if self.layer_id == 0 and args.pre_ffn > 0:
-# x = x + self.ffnPre(self.ln1(x))
-# else:
-# x = x + self.att(self.ln1(x))
-# x = x + self.ffn(self.ln2(x))
-# else:
-# if self.layer_id == 0 and args.pre_ffn > 0:
-# x = self.drop0(x + self.ffnPre(self.ln1(x)))
-# else:
-# x = self.drop0(x + self.att(self.ln1(x)))
-# x = self.drop1(x + self.ffn(self.ln2(x)))
-#
-# return x
-
-
-class RWKVLayer(nn.Module):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
- if args.dim_ffn is None:
- args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
- self.ln0 = None
- if self.layer_id == 0 and args.get("ln0", True):
- self.ln0 = nn.LayerNorm(args.n_embd)
-
- self.ln1 = None
- if args.get("ln1", True):
- self.ln1 = nn.LayerNorm(args.n_embd)
- self.ln2 = nn.LayerNorm(args.n_embd)
-
- self.att = RWKV_TimeMix(args, layer_id)
-
- self.ffn = RWKV_ChannelMix(args, layer_id)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
- self.drop1 = nn.Dropout(p=args.dropout)
-
- # init
- if args.get("init_rwkv", True):
- print("init_rwkv")
- nn.init.orthogonal_(self.att.receptance.weight, gain=1)
- nn.init.orthogonal_(self.att.key.weight, gain=0.1)
- nn.init.orthogonal_(self.att.value.weight, gain=1)
- nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
- nn.init.zeros_(self.att.output.weight)
-
- nn.init.orthogonal_(self.ffn.key.weight, gain=1)
- nn.init.zeros_(self.ffn.value.weight)
- nn.init.zeros_(self.ffn.receptance.weight)
- scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln2.weight, scale)
- if self.ln0 is not None:
- nn.init.constant_(self.ln0.weight, scale)
- if self.ln1 is not None:
- nn.init.constant_(self.ln1.weight, scale)
-
- def forward(self, x, x_emb=None, mask=None, **kwargs):
-
- args = self.args
- if args.get("datatype", "bf16") == "bf16":
- x = x.bfloat16()
- B, T, C = x.size()
- if self.layer_id == 0 and self.ln0 is not None:
- x = self.ln0(x)
-
- if self.args.dropout == 0:
- if self.ln1 is None:
- x = x + self.att(x)
- else:
- x = x + self.att(self.ln1(x))
- x = x + self.ffn(self.ln2(x))
- else:
- if self.ln1 is None:
- x = self.drop0(x + self.att(x))
- else:
- x = self.drop0(x + self.att(self.ln1(x)))
- x = self.drop1(x + self.ffn(self.ln2(x)))
-
- if args.get("datatype", "bf16") == "bf16":
- x = x.to(torch.float32)
- return x
-
-
-class RWKV(nn.Module):
- def __init__(self, args):
- super().__init__()
- self.args = args
- if not hasattr(args, "dim_att"):
- args.dim_att = args.n_embd
- if not hasattr(args, "dim_ffn"):
- if "-f4" in os.environ["RWKV_MY_TESTING"]:
- args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
- else:
- args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
- if not hasattr(args, "tiny_att_layer"):
- args.tiny_att_layer = -1
- if not hasattr(args, "tiny_att_dim"):
- args.tiny_att_dim = -1
- assert args.n_embd % 32 == 0
- assert args.dim_att % 32 == 0
- assert args.dim_ffn % 32 == 0
-
- self.emb = nn.Embedding(args.vocab_size, args.n_embd)
-
- self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
-
- self.ln_out = nn.LayerNorm(args.n_embd)
- self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
-
- def forward(self, idx):
- args = self.args
- B, T = idx.size()
- assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
-
- x = self.emb(idx)
- x_emb = x
-
- if args.dropout > 0:
- x = self.drop0(x)
- if args.tiny_att_dim > 0:
- for block in self.blocks:
- if args.grad_cp == 1:
- x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
- else:
- x = block(x, x_emb)
- else:
- for block in self.blocks:
- if args.grad_cp == 1:
- x = deepspeed.checkpointing.checkpoint(block, x)
- else:
- x = block(x)
-
- x = self.ln_out(x)
-
- if args.head_qk > 0:
- q = self.head_q(x)[:, :T, :]
- k = self.head_k(x)[:, :T, :]
- c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
- c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
-
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- c = c @ F.one_hot(idx, num_classes=args.vocab_size)
- elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
- c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
- elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
- c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
-
- x = self.head(x) + c
- else:
- x = self.head(x)
-
- return x
diff --git a/funasr/models/sense_voice/rwkv_v5.py b/funasr/models/sense_voice/rwkv_v5.py
deleted file mode 100644
index f19ca79..0000000
--- a/funasr/models/sense_voice/rwkv_v5.py
+++ /dev/null
@@ -1,597 +0,0 @@
-########################################################################################################
-# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
-########################################################################################################
-
-import os, math, gc, importlib
-import torch
-
-# torch._C._jit_set_profiling_executor(True)
-# torch._C._jit_set_profiling_mode(True)
-import torch.nn as nn
-from torch.nn import functional as F
-
-
-def __nop(ob):
- return ob
-
-
-MyModule = nn.Module
-MyFunction = __nop
-if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
- MyModule = torch.jit.ScriptModule
- MyFunction = torch.jit.script_method
-
-########################################################################################################
-# CUDA Kernel
-########################################################################################################
-
-wkv5_cuda = None
-
-
-def load_rwkv_kernel(
- HEAD_SIZE: int = 64,
- RWKV_CTXLEN: int = 512,
-):
- from torch.utils.cpp_extension import load
-
- global wkv5_cuda
-
- if wkv5_cuda is not None:
- return
-
- absolute_file_path = os.path.abspath(__file__)
- cur_dir = os.path.dirname(absolute_file_path)
-
- wkv5_cuda = load(
- name="wkv5",
- sources=[f"{cur_dir}/cuda/wkv5_op.cpp", f"{cur_dir}/cuda/wkv5_cuda.cu"],
- verbose=True,
- extra_cuda_cflags=[
- "-res-usage",
- "--use_fast_math",
- "-O3",
- "-Xptxas -O3",
- "--extra-device-vectorization",
- f"-D_N_={HEAD_SIZE}",
- ],
- )
-
-
-# dtype = torch.float
-dtype = torch.bfloat16
-
-
-class WKV_5(torch.autograd.Function):
- @staticmethod
- def forward(ctx, B, T, C, H, r, k, v, w, u):
- with torch.no_grad():
- # assert r.dtype == torch.bfloat16
- # assert k.dtype == torch.bfloat16
- # assert v.dtype == torch.bfloat16
- # assert w.dtype == torch.bfloat16
- # assert u.dtype == torch.bfloat16
- # assert HEAD_SIZE == C // H
- ctx.B = B
- ctx.T = T
- ctx.C = C
- ctx.H = H
- assert r.is_contiguous()
- assert k.is_contiguous()
- assert v.is_contiguous()
- assert w.is_contiguous()
- assert u.is_contiguous()
- ew = (-torch.exp(w.float())).contiguous()
- eew = (torch.exp(ew)).contiguous()
- ctx.save_for_backward(r, k, v, eew, ew, u)
- y = torch.empty(
- (B, T, C),
- device=r.device,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- wkv5_cuda.forward(B, T, C, H, r, k, v, eew, u, y)
- return y
-
- @staticmethod
- def backward(ctx, gy):
- with torch.no_grad():
- assert gy.dtype == torch.bfloat16
- B = ctx.B
- T = ctx.T
- C = ctx.C
- H = ctx.H
- assert gy.is_contiguous()
- r, k, v, eew, ew, u = ctx.saved_tensors
- gr = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- gk = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- gv = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- gw = torch.empty(
- (B, C),
- device=gy.device,
- requires_grad=False,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- gu = torch.empty(
- (B, C),
- device=gy.device,
- requires_grad=False,
- dtype=torch.bfloat16,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-1, 1)
- wkv5_cuda.backward(B, T, C, H, r, k, v, eew, ew, u, gy, gr, gk, gv, gw, gu)
- gw = torch.sum(gw, 0).view(H, C // H)
- gu = torch.sum(gu, 0).view(H, C // H)
- return (None, None, None, None, gr, gk, gv, gw, gu)
-
-
-def RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w, u):
- return WKV_5.apply(B, T, C, H, r, k, v, w, u)
-
-
-class RWKV_Tmix_x052(MyModule):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- load_rwkv_kernel(args.head_size_a, args.ctx_len)
- self.layer_id = layer_id
-
- self.head_size = args.head_size_a
- # assert HEAD_SIZE == self.head_size # change HEAD_SIZE to match args.head_size_a
- self.n_head = args.dim_att // self.head_size
- assert args.dim_att % self.n_head == 0
- self.head_size_divisor = args.head_size_divisor
-
- with torch.no_grad():
- ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
- ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
- ddd = torch.ones(1, 1, args.n_embd)
- for i in range(args.n_embd):
- ddd[0, 0, i] = i / args.n_embd
-
- # fancy time_mix
- self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
- self.time_mix_v = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
- self.time_mix_r = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
- self.time_mix_g = nn.Parameter(torch.pow(ddd, 0.5 * ratio_1_to_almost0))
-
- # fancy time_decay
- decay_speed = torch.ones(args.dim_att)
- for n in range(args.dim_att):
- decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
- self.time_decay = nn.Parameter(decay_speed.reshape(self.n_head, self.head_size))
- # print(layer_id, self.time_decay.flatten()[:3].cpu().numpy(), '...', self.time_decay.flatten()[-3:].cpu().numpy())
-
- tmp = torch.zeros(args.dim_att)
- for n in range(args.dim_att):
- zigzag = ((n + 1) % 3 - 1) * 0.1
- tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
-
- self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
-
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
- self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
-
- self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
- self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.ln_x = nn.GroupNorm(self.n_head, args.dim_att)
-
- @MyFunction
- def jit_func(self, x):
- B, T, C = x.size()
-
- xx = self.time_shift(x) # Mix x with the previous timestep to produce xk, xv, xr
- xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
- xv = x * self.time_mix_v + xx * (1 - self.time_mix_v)
- xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
- xg = x * self.time_mix_g + xx * (1 - self.time_mix_g)
-
- r = self.receptance(xr)
- k = self.key(xk)
- v = self.value(xv)
- g = F.silu(self.gate(xg))
-
- return r, k, v, g
-
- @MyFunction
- def jit_func_2(self, x, g):
- B, T, C = x.size()
- x = x.view(B * T, C)
-
- x = self.ln_x(x / self.head_size_divisor).view(B, T, C)
- x = self.output(x * g)
- return x
-
- def forward(self, x, **kwargs):
- B, T, C = x.size()
- H = self.n_head
-
- r, k, v, g = self.jit_func(x)
-
- x = RUN_CUDA_RWKV5(B, T, C, H, r, k, v, w=self.time_decay, u=self.time_faaaa)
-
- return self.jit_func_2(x, g)
-
-
-#
-# class RWKV_Tmix_x060(MyModule):
-# def __init__(self, args, layer_id):
-# super().__init__()
-# self.args = args
-#
-# load_rwkv_kernel(args.head_size_a, args.ctx_len)
-#
-# self.layer_id = layer_id
-#
-# self.head_size = args.head_size_a
-# self.n_head = args.dim_att // self.head_size
-# assert args.dim_att % self.n_head == 0
-#
-# with torch.no_grad():
-# ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
-# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
-# ddd = torch.ones(1, 1, args.n_embd)
-# for i in range(args.n_embd):
-# ddd[0, 0, i] = i / args.n_embd
-#
-# # fancy time_mix
-# self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-# self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-# self.time_maa_v = nn.Parameter(
-# 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
-# )
-# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
-# self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
-#
-# D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
-# self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
-# self.time_maa_w2 = nn.Parameter(
-# torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
-# )
-#
-# # fancy time_decay
-# decay_speed = torch.ones(args.dim_att)
-# for n in range(args.dim_att):
-# decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
-# self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
-#
-# D_DECAY_LORA = 64
-# self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
-# self.time_decay_w2 = nn.Parameter(
-# torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
-# )
-#
-# tmp = torch.zeros(args.dim_att)
-# for n in range(args.dim_att):
-# zigzag = ((n + 1) % 3 - 1) * 0.1
-# tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
-#
-# self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
-#
-# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-# self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
-# self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
-#
-# self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
-# self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
-# self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
-# self.ln_x = nn.GroupNorm(
-# self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
-# )
-#
-# @MyFunction
-# def jit_func(self, x):
-# B, T, C = x.size()
-#
-# xx = self.time_shift(x) - x
-#
-# xxx = x + xx * self.time_maa_x
-# xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
-# xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
-# mw, mk, mv, mr, mg = xxx.unbind(dim=0)
-#
-# xw = x + xx * (self.time_maa_w + mw)
-# xk = x + xx * (self.time_maa_k + mk)
-# xv = x + xx * (self.time_maa_v + mv)
-# xr = x + xx * (self.time_maa_r + mr)
-# xg = x + xx * (self.time_maa_g + mg)
-#
-# r = self.receptance(xr)
-# k = self.key(xk)
-# v = self.value(xv)
-# g = F.silu(self.gate(xg))
-#
-# ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
-# w = self.time_decay + ww
-#
-# return r, k, v, g, w
-#
-# @MyFunction
-# def jit_func_2(self, x, g):
-# B, T, C = x.size()
-# x = x.view(B * T, C)
-#
-# x = self.ln_x(x).view(B, T, C)
-# x = self.output(x * g)
-# return x
-#
-# def forward(self, x):
-# B, T, C = x.size()
-# H = self.n_head
-#
-# r, k, v, g, w = self.jit_func(x)
-# x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
-#
-# return self.jit_func_2(x, g)
-
-
-class RWKV_CMix_x052(MyModule):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-
- with torch.no_grad(): # fancy init of time_mix
- ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
- ddd = torch.ones(1, 1, args.n_embd)
- for i in range(args.n_embd):
- ddd[0, 0, i] = i / args.n_embd
- self.time_mix_k = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
- self.time_mix_r = nn.Parameter(torch.pow(ddd, ratio_1_to_almost0))
-
- self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
- self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
- self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
-
- @MyFunction
- def forward(self, x):
- xx = self.time_shift(x)
- xk = x * self.time_mix_k + xx * (1 - self.time_mix_k)
- xr = x * self.time_mix_r + xx * (1 - self.time_mix_r)
- k = self.key(xk)
- k = torch.relu(k) ** 2
- kv = self.value(k)
- return torch.sigmoid(self.receptance(xr)) * kv
-
-
-# class RWKV_CMix_x060(MyModule):
-# def __init__(self, args, layer_id):
-# super().__init__()
-# self.args = args
-# self.layer_id = layer_id
-# self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-#
-# with torch.no_grad(): # fancy init of time_mix
-# ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
-# ddd = torch.ones(1, 1, args.n_embd)
-# for i in range(args.n_embd):
-# ddd[0, 0, i] = i / args.n_embd
-# self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-# self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-#
-# self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
-# self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
-# self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
-#
-# @MyFunction
-# def forward(self, x):
-# xx = self.time_shift(x) - x
-# xk = x + xx * self.time_maa_k
-# xr = x + xx * self.time_maa_r
-#
-# k = self.key(xk)
-# k = torch.relu(k) ** 2
-# kv = self.value(k)
-# return torch.sigmoid(self.receptance(xr)) * kv
-
-
-# class Block(nn.Module):
-# def __init__(self, args, layer_id):
-# super().__init__()
-# self.args = args
-# self.layer_id = layer_id
-#
-# self.ln1 = nn.LayerNorm(args.n_embd)
-# self.ln2 = nn.LayerNorm(args.n_embd)
-#
-# if self.layer_id == 0:
-# self.ln0 = nn.LayerNorm(args.n_embd)
-#
-# self.att = RWKV_Tmix_x060(args, layer_id)
-#
-# self.ffn = RWKV_CMix_x060(args, layer_id)
-#
-# if args.dropout > 0:
-# self.drop0 = nn.Dropout(p=args.dropout)
-# self.drop1 = nn.Dropout(p=args.dropout)
-#
-# def forward(self, x, x_emb=None):
-# args = self.args
-# B, T, C = x.size()
-# if self.layer_id == 0:
-# x = self.ln0(x)
-#
-# if self.args.dropout == 0:
-# if self.layer_id == 0 and args.pre_ffn > 0:
-# x = x + self.ffnPre(self.ln1(x))
-# else:
-# x = x + self.att(self.ln1(x))
-# x = x + self.ffn(self.ln2(x))
-# else:
-# if self.layer_id == 0 and args.pre_ffn > 0:
-# x = self.drop0(x + self.ffnPre(self.ln1(x)))
-# else:
-# x = self.drop0(x + self.att(self.ln1(x)))
-# x = self.drop1(x + self.ffn(self.ln2(x)))
-#
-# return x
-
-
-class RWKVLayer(nn.Module):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
- if args.dim_ffn is None:
- args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
- self.ln0 = None
- if self.layer_id == 0 and args.get("ln0", True):
- self.ln0 = nn.LayerNorm(args.n_embd)
-
- self.ln1 = None
- if args.get("ln1", True):
- self.ln1 = nn.LayerNorm(args.n_embd)
-
- self.att = RWKV_Tmix_x052(args, layer_id)
- self.ln2 = None
- self.ffn = None
- if args.get("use_rwkv_ffn", True):
- self.ln2 = nn.LayerNorm(args.n_embd)
- self.ffn = RWKV_CMix_x052(args, layer_id)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
- self.drop1 = nn.Dropout(p=args.dropout)
-
- # init
- if args.get("init_rwkv", True):
- print("init_rwkv")
- nn.init.orthogonal_(self.att.receptance.weight, gain=1)
- nn.init.orthogonal_(self.att.key.weight, gain=0.1)
- nn.init.orthogonal_(self.att.value.weight, gain=1)
- nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
- nn.init.zeros_(self.att.output.weight)
-
- nn.init.orthogonal_(self.ffn.key.weight, gain=1)
- nn.init.zeros_(self.ffn.value.weight)
- nn.init.zeros_(self.ffn.receptance.weight)
- scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
- nn.init.constant_(self.ln2.weight, scale)
- if self.ln0 is not None:
- nn.init.constant_(self.ln0.weight, scale)
- if self.ln1 is not None:
- nn.init.constant_(self.ln1.weight, scale)
-
- def forward(self, x, x_emb=None, mask=None, **kwargs):
-
- args = self.args
- if args.get("datatype", "bf16") == "bf16":
- x = x.bfloat16()
- B, T, C = x.size()
- if self.layer_id == 0 and self.ln0 is not None:
- x = self.ln0(x)
-
- if self.args.dropout == 0:
- if self.ln1 is None:
- x = x + self.att(x)
- else:
- x = x + self.att(self.ln1(x))
- if self.ffn is not None:
- x = x + self.ffn(self.ln2(x))
- else:
- if self.ln1 is None:
- x = self.drop0(x + self.att(x))
- else:
- x = self.drop0(x + self.att(self.ln1(x)))
- if self.ffn is not None:
- x = self.drop1(x + self.ffn(self.ln2(x)))
-
- if args.get("datatype", "bf16") == "bf16":
- x = x.to(torch.float32)
- return x
-
-
-# class RWKV(nn.Module):
-# def __init__(self, args):
-# super().__init__()
-# self.args = args
-# if not hasattr(args, "dim_att"):
-# args.dim_att = args.n_embd
-# if not hasattr(args, "dim_ffn"):
-# if "-f4" in os.environ["RWKV_MY_TESTING"]:
-# args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
-# else:
-# args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
-# if not hasattr(args, "tiny_att_layer"):
-# args.tiny_att_layer = -1
-# if not hasattr(args, "tiny_att_dim"):
-# args.tiny_att_dim = -1
-# assert args.n_embd % 32 == 0
-# assert args.dim_att % 32 == 0
-# assert args.dim_ffn % 32 == 0
-#
-# self.emb = nn.Embedding(args.vocab_size, args.n_embd)
-#
-# self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
-#
-# self.ln_out = nn.LayerNorm(args.n_embd)
-# self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
-#
-# if args.dropout > 0:
-# self.drop0 = nn.Dropout(p=args.dropout)
-#
-# def forward(self, idx):
-# args = self.args
-# B, T = idx.size()
-# assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
-#
-# x = self.emb(idx)
-# x_emb = x
-#
-# if args.dropout > 0:
-# x = self.drop0(x)
-# if args.tiny_att_dim > 0:
-# for block in self.blocks:
-# if args.grad_cp == 1:
-# x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
-# else:
-# x = block(x, x_emb)
-# else:
-# for block in self.blocks:
-# if args.grad_cp == 1:
-# x = deepspeed.checkpointing.checkpoint(block, x)
-# else:
-# x = block(x)
-#
-# x = self.ln_out(x)
-#
-# if args.head_qk > 0:
-# q = self.head_q(x)[:, :T, :]
-# k = self.head_k(x)[:, :T, :]
-# c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
-# c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
-#
-# if "32" in os.environ["RWKV_FLOAT_MODE"]:
-# c = c @ F.one_hot(idx, num_classes=args.vocab_size)
-# elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
-# c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
-# elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
-# c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
-#
-# x = self.head(x) + c
-# else:
-# x = self.head(x)
-#
-# return x
diff --git a/funasr/models/sense_voice/rwkv_v6.py b/funasr/models/sense_voice/rwkv_v6.py
deleted file mode 100644
index b91d47a..0000000
--- a/funasr/models/sense_voice/rwkv_v6.py
+++ /dev/null
@@ -1,478 +0,0 @@
-########################################################################################################
-# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM
-########################################################################################################
-
-import os, math, gc, importlib
-import torch
-
-# torch._C._jit_set_profiling_executor(True)
-# torch._C._jit_set_profiling_mode(True)
-import torch.nn as nn
-from torch.nn import functional as F
-
-
-def __nop(ob):
- return ob
-
-
-MyModule = nn.Module
-MyFunction = __nop
-if "RWKV_JIT_ON" in os.environ and os.environ["RWKV_JIT_ON"] == "1":
- MyModule = torch.jit.ScriptModule
- MyFunction = torch.jit.script_method
-
-########################################################################################################
-# CUDA Kernel
-########################################################################################################
-
-wkv6_cuda = None
-
-
-def load_rwkv_kernel(
- HEAD_SIZE: int = 64,
- RWKV_CTXLEN: int = 512,
-):
- from torch.utils.cpp_extension import load
-
- global wkv6_cuda
-
- if wkv6_cuda is not None:
- return
-
- absolute_file_path = os.path.abspath(__file__)
- cur_dir = os.path.dirname(absolute_file_path)
- wkv6_cuda = load(
- name="wkv6",
- sources=[f"{cur_dir}/cuda/wkv6_op.cpp", f"{cur_dir}/cuda/wkv6_cuda.cu"],
- verbose=True,
- extra_cuda_cflags=[
- "-res-usage",
- "--use_fast_math",
- "-O3",
- "-Xptxas -O3",
- "--extra-device-vectorization",
- f"-D_N_={HEAD_SIZE}",
- f"-D_T_={RWKV_CTXLEN}",
- ],
- )
-
-
-# dtype = torch.float
-dtype = torch.bfloat16
-
-
-class WKV_6(torch.autograd.Function):
- @staticmethod
- def forward(ctx, B, T, C, H, r, k, v, w, u):
- with torch.no_grad():
- # assert r.dtype == torch.bfloat16
- # assert k.dtype == torch.bfloat16
- # assert v.dtype == torch.bfloat16
- # assert w.dtype == torch.bfloat16
- # assert u.dtype == torch.bfloat16
- # assert HEAD_SIZE == C // H
- ctx.B = B
- ctx.T = T
- ctx.C = C
- ctx.H = H
- assert r.is_contiguous()
- assert k.is_contiguous()
- assert v.is_contiguous()
- assert w.is_contiguous()
- assert u.is_contiguous()
- ew = (-torch.exp(w.float())).contiguous()
- ctx.save_for_backward(r, k, v, ew, u)
- y = torch.empty(
- (B, T, C), device=r.device, dtype=dtype, memory_format=torch.contiguous_format
- ) # .uniform_(-100, 100)
- wkv6_cuda.forward(B, T, C, H, r, k, v, ew, u, y)
- return y
-
- @staticmethod
- def backward(ctx, gy):
- with torch.no_grad():
- # assert gy.dtype == torch.bfloat16
- B = ctx.B
- T = ctx.T
- C = ctx.C
- H = ctx.H
- assert gy.is_contiguous()
- r, k, v, ew, u = ctx.saved_tensors
- gr = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=dtype,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-100, 100)
- gk = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=dtype,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-100, 100)
- gv = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=dtype,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-100, 100)
- gw = torch.empty(
- (B, T, C),
- device=gy.device,
- requires_grad=False,
- dtype=dtype,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-100, 100)
- gu = torch.empty(
- (B, C),
- device=gy.device,
- requires_grad=False,
- dtype=dtype,
- memory_format=torch.contiguous_format,
- ) # .uniform_(-100, 100)
- wkv6_cuda.backward(B, T, C, H, r, k, v, ew, u, gy, gr, gk, gv, gw, gu)
- gu = torch.sum(gu, 0).view(H, C // H)
- return (None, None, None, None, gr, gk, gv, gw, gu)
-
-
-def RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u):
- return WKV_6.apply(B, T, C, H, r, k, v, w, u)
-
-
-class RWKV_Tmix_x060(MyModule):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
-
- load_rwkv_kernel(args.head_size_a, args.ctx_len)
-
- self.layer_id = layer_id
-
- self.head_size = args.head_size_a
- self.n_head = args.dim_att // self.head_size
- assert args.dim_att % self.n_head == 0
-
- with torch.no_grad():
- ratio_0_to_1 = layer_id / (args.n_layer - 1) # 0 to 1
- ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
- ddd = torch.ones(1, 1, args.n_embd)
- for i in range(args.n_embd):
- ddd[0, 0, i] = i / args.n_embd
-
- # fancy time_mix
- self.time_maa_x = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
- self.time_maa_w = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
- self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
- self.time_maa_v = nn.Parameter(
- 1.0 - (torch.pow(ddd, ratio_1_to_almost0) + 0.3 * ratio_0_to_1)
- )
- self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
- self.time_maa_g = nn.Parameter(1.0 - torch.pow(ddd, 0.5 * ratio_1_to_almost0))
-
- D_MIX_LORA = 32 # generate TIME_MIX for w,k,v,r,g
- self.time_maa_w1 = nn.Parameter(torch.zeros(args.n_embd, D_MIX_LORA * 5))
- self.time_maa_w2 = nn.Parameter(
- torch.zeros(5, D_MIX_LORA, args.n_embd).uniform_(-0.01, 0.01)
- )
-
- # fancy time_decay
- decay_speed = torch.ones(args.dim_att)
- for n in range(args.dim_att):
- decay_speed[n] = -6 + 5 * (n / (args.dim_att - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
- self.time_decay = nn.Parameter(decay_speed.reshape(1, 1, args.dim_att))
-
- D_DECAY_LORA = 64
- self.time_decay_w1 = nn.Parameter(torch.zeros(args.n_embd, D_DECAY_LORA))
- self.time_decay_w2 = nn.Parameter(
- torch.zeros(D_DECAY_LORA, args.dim_att).uniform_(-0.01, 0.01)
- )
-
- tmp = torch.zeros(args.dim_att)
- for n in range(args.dim_att):
- zigzag = ((n + 1) % 3 - 1) * 0.1
- tmp[n] = ratio_0_to_1 * (1 - (n / (args.dim_att - 1))) + zigzag
-
- self.time_faaaa = nn.Parameter(tmp.reshape(self.n_head, self.head_size))
-
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
- self.receptance = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.key = nn.Linear(args.n_embd, args.dim_att, bias=False)
-
- self.value = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.output = nn.Linear(args.dim_att, args.n_embd, bias=False)
- self.gate = nn.Linear(args.n_embd, args.dim_att, bias=False)
- self.ln_x = nn.GroupNorm(
- self.n_head, args.dim_att, eps=(1e-5) * (args.head_size_divisor**2)
- )
-
- @MyFunction
- def jit_func(self, x):
- B, T, C = x.size()
-
- xx = self.time_shift(x) - x
-
- xxx = x + xx * self.time_maa_x
- xxx = torch.tanh(xxx @ self.time_maa_w1).view(B * T, 5, -1).transpose(0, 1)
- xxx = torch.bmm(xxx, self.time_maa_w2).view(5, B, T, -1)
- mw, mk, mv, mr, mg = xxx.unbind(dim=0)
-
- xw = x + xx * (self.time_maa_w + mw)
- xk = x + xx * (self.time_maa_k + mk)
- xv = x + xx * (self.time_maa_v + mv)
- xr = x + xx * (self.time_maa_r + mr)
- xg = x + xx * (self.time_maa_g + mg)
-
- r = self.receptance(xr)
- k = self.key(xk)
- v = self.value(xv)
- g = F.silu(self.gate(xg))
-
- ww = torch.tanh(xw @ self.time_decay_w1) @ self.time_decay_w2
- w = self.time_decay + ww
-
- return r, k, v, g, w
-
- @MyFunction
- def jit_func_2(self, x, g):
- B, T, C = x.size()
- x = x.view(B * T, C)
-
- x = self.ln_x(x).view(B, T, C)
- x = self.output(x * g)
- return x
-
- def forward(self, x, **kwargs):
- B, T, C = x.size()
- H = self.n_head
-
- r, k, v, g, w = self.jit_func(x)
- x = RUN_CUDA_RWKV6(B, T, C, H, r, k, v, w, u=self.time_faaaa)
-
- return self.jit_func_2(x, g)
-
-
-class RWKV_CMix_x060(MyModule):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
- self.time_shift = nn.ZeroPad2d((0, 0, 1, -1))
-
- with torch.no_grad(): # fancy init of time_mix
- ratio_1_to_almost0 = 1.0 - (layer_id / args.n_layer) # 1 to ~0
- ddd = torch.ones(1, 1, args.n_embd)
- for i in range(args.n_embd):
- ddd[0, 0, i] = i / args.n_embd
- self.time_maa_k = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
- self.time_maa_r = nn.Parameter(1.0 - torch.pow(ddd, ratio_1_to_almost0))
-
- self.key = nn.Linear(args.n_embd, args.dim_ffn, bias=False)
- self.receptance = nn.Linear(args.n_embd, args.n_embd, bias=False)
- self.value = nn.Linear(args.dim_ffn, args.n_embd, bias=False)
-
- @MyFunction
- def forward(self, x):
- xx = self.time_shift(x) - x
- xk = x + xx * self.time_maa_k
- xr = x + xx * self.time_maa_r
-
- k = self.key(xk)
- k = torch.relu(k) ** 2
- kv = self.value(k)
- return torch.sigmoid(self.receptance(xr)) * kv
-
-
-class Block(nn.Module):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
-
- self.ln1 = nn.LayerNorm(args.n_embd)
- self.ln2 = nn.LayerNorm(args.n_embd)
-
- if self.layer_id == 0:
- self.ln0 = nn.LayerNorm(args.n_embd)
-
- self.att = RWKV_Tmix_x060(args, layer_id)
-
- self.ffn = RWKV_CMix_x060(args, layer_id)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
- self.drop1 = nn.Dropout(p=args.dropout)
-
- def forward(self, x, x_emb=None):
- args = self.args
- B, T, C = x.size()
- if self.layer_id == 0:
- x = self.ln0(x)
-
- if self.args.dropout == 0:
- if self.layer_id == 0 and args.pre_ffn > 0:
- x = x + self.ffnPre(self.ln1(x))
- else:
- x = x + self.att(self.ln1(x))
- x = x + self.ffn(self.ln2(x))
- else:
- if self.layer_id == 0 and args.pre_ffn > 0:
- x = self.drop0(x + self.ffnPre(self.ln1(x)))
- else:
- x = self.drop0(x + self.att(self.ln1(x)))
- x = self.drop1(x + self.ffn(self.ln2(x)))
-
- return x
-
-
-class RWKVLayer(nn.Module):
- def __init__(self, args, layer_id):
- super().__init__()
- self.args = args
- self.layer_id = layer_id
- if args.dim_ffn is None:
- args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32)
- self.ln0 = None
- if self.layer_id == 0 and args.get("ln0", True):
- self.ln0 = nn.LayerNorm(args.n_embd)
-
- self.ln1 = None
- if args.get("ln1", True):
- self.ln1 = nn.LayerNorm(args.n_embd)
-
- self.att = RWKV_Tmix_x060(args, layer_id)
-
- self.ln2 = None
- self.ffn = None
- if args.get("use_rwkv_ffn", True):
- self.ln2 = nn.LayerNorm(args.n_embd)
- self.ffn = RWKV_CMix_x060(args, layer_id)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
- self.drop1 = nn.Dropout(p=args.dropout)
-
- # init
- if args.get("init_rwkv", True):
- print("init_rwkv")
- nn.init.orthogonal_(self.att.receptance.weight, gain=1)
- nn.init.orthogonal_(self.att.key.weight, gain=0.1)
- nn.init.orthogonal_(self.att.value.weight, gain=1)
- nn.init.orthogonal_(self.att.gate.weight, gain=0.1)
- nn.init.zeros_(self.att.output.weight)
-
- nn.init.orthogonal_(self.ffn.key.weight, gain=1)
- nn.init.zeros_(self.ffn.value.weight)
- nn.init.zeros_(self.ffn.receptance.weight)
- scale = ((1 + layer_id) / args.get("n_layer")) ** 0.7
-
- if self.ln0 is not None:
- nn.init.constant_(self.ln0.weight, scale)
- if self.ln1 is not None:
- nn.init.constant_(self.ln1.weight, scale)
- if self.ln2 is not None:
- nn.init.constant_(self.ln2.weight, scale)
-
- def forward(self, x, x_emb=None, mask=None, **kwargs):
-
- args = self.args
- if args.get("datatype", "bf16") == "bf16":
- x = x.bfloat16()
- B, T, C = x.size()
- if self.layer_id == 0 and self.ln0 is not None:
- x = self.ln0(x)
-
- if self.args.dropout == 0:
- if self.ln1 is None:
- x = x + self.att(x)
- else:
- x = x + self.att(self.ln1(x))
- if self.ffn is not None:
- x = x + self.ffn(self.ln2(x))
- else:
- if self.ln1 is None:
- x = self.drop0(x + self.att(x))
- else:
- x = self.drop0(x + self.att(self.ln1(x)))
- if self.ffn is not None:
- x = self.drop1(x + self.ffn(self.ln2(x)))
-
- if args.get("datatype", "bf16") == "bf16":
- x = x.to(torch.float32)
- return x
-
-
-class RWKV(nn.Module):
- def __init__(self, args):
- super().__init__()
- self.args = args
- if not hasattr(args, "dim_att"):
- args.dim_att = args.n_embd
- if not hasattr(args, "dim_ffn"):
- if "-f4" in os.environ["RWKV_MY_TESTING"]:
- args.dim_ffn = int((args.n_embd * 4) // 32 * 32)
- else:
- args.dim_ffn = int((args.n_embd * 3.5) // 32 * 32) # default = 3.5x emb size
- if not hasattr(args, "tiny_att_layer"):
- args.tiny_att_layer = -1
- if not hasattr(args, "tiny_att_dim"):
- args.tiny_att_dim = -1
- assert args.n_embd % 32 == 0
- assert args.dim_att % 32 == 0
- assert args.dim_ffn % 32 == 0
-
- self.emb = nn.Embedding(args.vocab_size, args.n_embd)
-
- self.blocks = nn.ModuleList([Block(args, i) for i in range(args.n_layer)])
-
- self.ln_out = nn.LayerNorm(args.n_embd)
- self.head = nn.Linear(args.n_embd, args.vocab_size, bias=False)
-
- if args.dropout > 0:
- self.drop0 = nn.Dropout(p=args.dropout)
-
- def forward(self, idx):
- args = self.args
- B, T = idx.size()
- assert T <= args.ctx_len, "Cannot forward, model ctx_len is exhausted."
-
- x = self.emb(idx)
- x_emb = x
-
- if args.dropout > 0:
- x = self.drop0(x)
- if args.tiny_att_dim > 0:
- for block in self.blocks:
- if args.grad_cp == 1:
- x = deepspeed.checkpointing.checkpoint(block, x, x_emb)
- else:
- x = block(x, x_emb)
- else:
- for block in self.blocks:
- if args.grad_cp == 1:
- x = deepspeed.checkpointing.checkpoint(block, x)
- else:
- x = block(x)
-
- x = self.ln_out(x)
-
- if args.head_qk > 0:
- q = self.head_q(x)[:, :T, :]
- k = self.head_k(x)[:, :T, :]
- c = (q @ k.transpose(-2, -1)) * (1.0 / args.head_qk)
- c = c.masked_fill(self.copy_mask[:T, :T] == 0, 0)
-
- if "32" in os.environ["RWKV_FLOAT_MODE"]:
- c = c @ F.one_hot(idx, num_classes=args.vocab_size)
- elif os.environ["RWKV_FLOAT_MODE"] == "fp16":
- c = c @ F.one_hot(idx, num_classes=args.vocab_size).half()
- elif os.environ["RWKV_FLOAT_MODE"] == "bf16":
- c = c @ F.one_hot(idx, num_classes=args.vocab_size).bfloat16()
-
- x = self.head(x) + c
- else:
- x = self.head(x)
-
- return x
diff --git a/funasr/models/sense_voice/search.py b/funasr/models/sense_voice/search.py
deleted file mode 100644
index 3a1a049..0000000
--- a/funasr/models/sense_voice/search.py
+++ /dev/null
@@ -1,513 +0,0 @@
-from itertools import chain
-from dataclasses import field
-import logging
-from typing import Any
-from typing import Dict
-from typing import List
-from typing import NamedTuple
-from typing import Tuple
-from typing import Union
-
-import torch
-import numpy as np
-
-from funasr.metrics.common import end_detect
-from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
-from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
-
-
-class Hypothesis(NamedTuple):
- """Hypothesis data type."""
-
- yseq: torch.Tensor
- score: Union[float, torch.Tensor] = 0
- scores: Dict[str, Union[float, torch.Tensor]] = dict()
- states: Dict[str, Any] = dict()
-
- def asdict(self) -> dict:
- """Convert data to JSON-friendly dict."""
- return self._replace(
- yseq=self.yseq.tolist(),
- score=float(self.score),
- scores={k: float(v) for k, v in self.scores.items()},
- )._asdict()
-
-
-class BeamSearch(torch.nn.Module):
- """Beam search implementation."""
-
- def __init__(
- self,
- scorers: Dict[str, ScorerInterface],
- weights: Dict[str, float],
- beam_size: int,
- vocab_size: int,
- sos=None,
- eos=None,
- # NOTE add rich decoding parameters
- # [SPECIAL_TOKEN_1, HAPPY, SAD, ANGRY, NEUTRAL]
- emo_unk: int = 58964,
- emo_unk_score: float = 1.0,
- emo_tokens: List[int] = field(default_factory=lambda: [58954, 58955, 58956, 58957]),
- emo_scores: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1, 0.1]),
- # [Speech, BGM, Laughter, Applause]
- event_bg_token: List[int] = field(default_factory=lambda: [58946, 58948, 58950, 58952]),
- event_ed_token: List[int] = field(default_factory=lambda: [58947, 58949, 58951, 58953]),
- event_score_ga: List[float] = field(default_factory=lambda: [1, 1, 5, 25]),
- token_list: List[str] = None,
- pre_beam_ratio: float = 1.5,
- pre_beam_score_key: str = None,
- ):
- """Initialize beam search.
-
- Args:
- scorers (dict[str, ScorerInterface]): Dict of decoder modules
- e.g., Decoder, CTCPrefixScorer, LM
- The scorer will be ignored if it is `None`
- weights (dict[str, float]): Dict of weights for each scorers
- The scorer will be ignored if its weight is 0
- beam_size (int): The number of hypotheses kept during search
- vocab_size (int): The number of vocabulary
- sos (int): Start of sequence id
- eos (int): End of sequence id
- token_list (list[str]): List of tokens for debug log
- pre_beam_score_key (str): key of scores to perform pre-beam search
- pre_beam_ratio (float): beam size in the pre-beam search
- will be `int(pre_beam_ratio * beam_size)`
-
- """
- super().__init__()
- # set scorers
- self.weights = weights
- self.scorers = dict()
- self.full_scorers = dict()
- self.part_scorers = dict()
- # this module dict is required for recursive cast
- # `self.to(device, dtype)` in `recog.py`
- self.nn_dict = torch.nn.ModuleDict()
- for k, v in scorers.items():
- w = weights.get(k, 0)
- if w == 0 or v is None:
- continue
- # assert isinstance(
- # v, ScorerInterface
- # ), f"{k} ({type(v)}) does not implement ScorerInterface"
- self.scorers[k] = v
- if isinstance(v, PartialScorerInterface):
- self.part_scorers[k] = v
- else:
- self.full_scorers[k] = v
- if isinstance(v, torch.nn.Module):
- self.nn_dict[k] = v
-
- # set configurations
- self.sos = sos
- self.eos = eos
- if isinstance(self.eos, (list, tuple)):
- self.eos = eos[0]
- self.token_list = token_list
- self.pre_beam_size = int(pre_beam_ratio * beam_size)
- self.beam_size = beam_size
- self.n_vocab = vocab_size
- if (
- pre_beam_score_key is not None
- and pre_beam_score_key != "full"
- and pre_beam_score_key not in self.full_scorers
- ):
- raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
- self.pre_beam_score_key = pre_beam_score_key
- self.do_pre_beam = (
- self.pre_beam_score_key is not None
- and self.pre_beam_size < self.n_vocab
- and len(self.part_scorers) > 0
- )
-
- self.emo_unk = emo_unk
- self.emo_unk_score = emo_unk_score
- self.emo_tokens = emo_tokens
- self.emo_scores = emo_scores
- self.event_bg_token = event_bg_token
- self.event_ed_token = event_ed_token
- self.event_score_ga = event_score_ga
-
- def init_hyp(self, x: torch.Tensor) -> List[Hypothesis]:
- """Get an initial hypothesis data.
-
- Args:
- x (torch.Tensor): The encoder output feature
-
- Returns:
- Hypothesis: The initial hypothesis.
-
- """
- init_states = dict()
- init_scores = dict()
- for k, d in self.scorers.items():
- init_states[k] = d.init_state(x)
- init_scores[k] = 0.0
- if not isinstance(self.sos, (list, tuple)):
- self.sos = [self.sos]
- return [
- Hypothesis(
- score=0.0,
- scores=init_scores,
- states=init_states,
- yseq=torch.tensor(self.sos, device=x.device),
- )
- ]
-
- @staticmethod
- def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
- """Append new token to prefix tokens.
-
- Args:
- xs (torch.Tensor): The prefix token
- x (int): The new token to append
-
- Returns:
- torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
-
- """
- x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
- return torch.cat((xs, x))
-
- def score_full(
- self, hyp: Hypothesis, x: torch.Tensor
- ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
- """Score new hypothesis by `self.full_scorers`.
-
- Args:
- hyp (Hypothesis): Hypothesis with prefix tokens to score
- x (torch.Tensor): Corresponding input feature
-
- Returns:
- Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
- score dict of `hyp` that has string keys of `self.full_scorers`
- and tensor score values of shape: `(self.n_vocab,)`,
- and state dict that has string keys
- and state values of `self.full_scorers`
-
- """
- scores = dict()
- states = dict()
-
- def get_score(yseq, sp1, sp2):
- score = [0, 0]
- last_token = yseq[-1]
- last_token2 = yseq[-2] if len(yseq) > 1 else yseq[-1]
- sum_sp1 = sum([1 if x == sp1 else 0 for x in yseq])
- sum_sp2 = sum([1 if x == sp2 else 0 for x in yseq])
- if sum_sp1 > sum_sp2 or last_token in [sp1, sp2]:
- score[0] = -np.inf
- if sum_sp2 >= sum_sp1:
- score[1] = -np.inf
- return score
-
- def struct_score(yseq, score):
- import math
-
- last_token = yseq[-1]
- if last_token in self.emo_tokens + [self.emo_unk]:
- # prevent output event after emotation token
- score[self.event_bg_token] = -np.inf
-
- for eve_bg, eve_ed, eve_ga in zip(
- self.event_bg_token, self.event_ed_token, self.event_score_ga
- ):
- score_offset = get_score(yseq, eve_bg, eve_ed)
- score[eve_bg] += score_offset[0]
- score[eve_ed] += score_offset[1]
- score[eve_bg] += math.log(eve_ga)
-
- score[self.emo_unk] += math.log(self.emo_unk_score)
- for emo, emo_th in zip(self.emo_tokens, self.emo_scores):
- if score.argmax() == emo and score[emo] < math.log(emo_th):
- score[self.emo_unk] = max(score[emo], score[self.emo_unk])
- score[emo] = -np.inf
- return score
-
- for k, d in self.full_scorers.items():
- scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x)
- scores[k] = struct_score(hyp.yseq, scores[k])
-
- return scores, states
-
- def score_partial(
- self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
- ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
- """Score new hypothesis by `self.part_scorers`.
-
- Args:
- hyp (Hypothesis): Hypothesis with prefix tokens to score
- ids (torch.Tensor): 1D tensor of new partial tokens to score
- x (torch.Tensor): Corresponding input feature
-
- Returns:
- Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
- score dict of `hyp` that has string keys of `self.part_scorers`
- and tensor score values of shape: `(len(ids),)`,
- and state dict that has string keys
- and state values of `self.part_scorers`
-
- """
- scores = dict()
- states = dict()
- for k, d in self.part_scorers.items():
- scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
- return scores, states
-
- def beam(
- self, weighted_scores: torch.Tensor, ids: torch.Tensor
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Compute topk full token ids and partial token ids.
-
- Args:
- weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
- Its shape is `(self.n_vocab,)`.
- ids (torch.Tensor): The partial token ids to compute topk
-
- Returns:
- Tuple[torch.Tensor, torch.Tensor]:
- The topk full token ids and partial token ids.
- Their shapes are `(self.beam_size,)`
-
- """
- # no pre beam performed
- if weighted_scores.size(0) == ids.size(0):
- top_ids = weighted_scores.topk(self.beam_size)[1]
- return top_ids, top_ids
-
- # mask pruned in pre-beam not to select in topk
- tmp = weighted_scores[ids]
- weighted_scores[:] = -float("inf")
- weighted_scores[ids] = tmp
- top_ids = weighted_scores.topk(self.beam_size)[1]
- local_ids = weighted_scores[ids].topk(self.beam_size)[1]
- return top_ids, local_ids
-
- @staticmethod
- def merge_scores(
- prev_scores: Dict[str, float],
- next_full_scores: Dict[str, torch.Tensor],
- full_idx: int,
- next_part_scores: Dict[str, torch.Tensor],
- part_idx: int,
- ) -> Dict[str, torch.Tensor]:
- """Merge scores for new hypothesis.
-
- Args:
- prev_scores (Dict[str, float]):
- The previous hypothesis scores by `self.scorers`
- next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
- full_idx (int): The next token id for `next_full_scores`
- next_part_scores (Dict[str, torch.Tensor]):
- scores of partial tokens by `self.part_scorers`
- part_idx (int): The new token id for `next_part_scores`
-
- Returns:
- Dict[str, torch.Tensor]: The new score dict.
- Its keys are names of `self.full_scorers` and `self.part_scorers`.
- Its values are scalar tensors by the scorers.
-
- """
- new_scores = dict()
- for k, v in next_full_scores.items():
- new_scores[k] = prev_scores[k] + v[full_idx]
- for k, v in next_part_scores.items():
- new_scores[k] = prev_scores[k] + v[part_idx]
- return new_scores
-
- def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
- """Merge states for new hypothesis.
-
- Args:
- states: states of `self.full_scorers`
- part_states: states of `self.part_scorers`
- part_idx (int): The new token id for `part_scores`
-
- Returns:
- Dict[str, torch.Tensor]: The new score dict.
- Its keys are names of `self.full_scorers` and `self.part_scorers`.
- Its values are states of the scorers.
-
- """
- new_states = dict()
- for k, v in states.items():
- new_states[k] = v
- for k, d in self.part_scorers.items():
- new_states[k] = d.select_state(part_states[k], part_idx)
- return new_states
-
- def search(self, running_hyps: List[Hypothesis], x: torch.Tensor) -> List[Hypothesis]:
- """Search new tokens for running hypotheses and encoded speech x.
-
- Args:
- running_hyps (List[Hypothesis]): Running hypotheses on beam
- x (torch.Tensor): Encoded speech feature (T, D)
-
- Returns:
- List[Hypotheses]: Best sorted hypotheses
-
- """
- best_hyps = []
- part_ids = torch.arange(self.n_vocab, device=x.device) # no pre-beam
- for hyp in running_hyps:
- # scoring
- weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
- scores, states = self.score_full(hyp, x)
- for k in self.full_scorers:
- weighted_scores += self.weights[k] * scores[k]
- # partial scoring
- if self.do_pre_beam:
- pre_beam_scores = (
- weighted_scores
- if self.pre_beam_score_key == "full"
- else scores[self.pre_beam_score_key]
- )
- part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
- part_scores, part_states = self.score_partial(hyp, part_ids, x)
- for k in self.part_scorers:
- weighted_scores[part_ids] += self.weights[k] * part_scores[k]
- # add previous hyp score
- weighted_scores += hyp.score
-
- # update hyps
- for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
- # will be (2 x beam at most)
- best_hyps.append(
- Hypothesis(
- score=weighted_scores[j],
- yseq=self.append_token(hyp.yseq, j),
- scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
- states=self.merge_states(states, part_states, part_j),
- )
- )
-
- # sort and prune 2 x beam -> beam
- best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
- : min(len(best_hyps), self.beam_size)
- ]
- return best_hyps
-
- def forward(
- self, x: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
- ) -> List[Hypothesis]:
- """Perform beam search.
-
- Args:
- x (torch.Tensor): Encoded speech feature (T, D)
- maxlenratio (float): Input length ratio to obtain max output length.
- If maxlenratio=0.0 (default), it uses a end-detect function
- to automatically find maximum hypothesis lengths
- If maxlenratio<0.0, its absolute value is interpreted
- as a constant max output length.
- minlenratio (float): Input length ratio to obtain min output length.
-
- Returns:
- list[Hypothesis]: N-best decoding results
-
- """
- # set length bounds
- if maxlenratio == 0:
- maxlen = x.shape[0]
- elif maxlenratio < 0:
- maxlen = -1 * int(maxlenratio)
- else:
- maxlen = max(1, int(maxlenratio * x.size(0)))
- minlen = int(minlenratio * x.size(0))
- logging.info("decoder input length: " + str(x.shape[0]))
- logging.info("max output length: " + str(maxlen))
- logging.info("min output length: " + str(minlen))
-
- # main loop of prefix search
- running_hyps = self.init_hyp(x)
- ended_hyps = []
- for i in range(maxlen):
- logging.debug("position " + str(i))
- best = self.search(running_hyps, x)
- # post process of one iteration
- running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
- # end detection
- # if len(ended_hyps) > 0:
- # print(f"ended_hyps: {ended_hyps}")
- if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
- logging.info(f"end detected at {i}")
- break
- if len(running_hyps) == 0:
- logging.info("no hypothesis. Finish decoding.")
- break
- else:
- logging.debug(f"remained hypotheses: {len(running_hyps)}")
-
- nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
- # check the number of hypotheses reaching to eos
- if len(nbest_hyps) == 0:
- logging.warning(
- "there is no N-best results, perform recognition " "again with smaller minlenratio."
- )
- return (
- []
- if minlenratio < 0.1
- else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
- )
-
- # report the best result
- best = nbest_hyps[0]
- for k, v in best.scores.items():
- logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
- logging.info(f"total log probability: {best.score:.2f}")
- logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
- logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
- if self.token_list is not None:
- logging.info(
- "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
- )
- return nbest_hyps
-
- def post_process(
- self,
- i: int,
- maxlen: int,
- maxlenratio: float,
- running_hyps: List[Hypothesis],
- ended_hyps: List[Hypothesis],
- ) -> List[Hypothesis]:
- """Perform post-processing of beam search iterations.
-
- Args:
- i (int): The length of hypothesis tokens.
- maxlen (int): The maximum length of tokens in beam search.
- maxlenratio (int): The maximum length ratio in beam search.
- running_hyps (List[Hypothesis]): The running hypotheses in beam search.
- ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
-
- Returns:
- List[Hypothesis]: The new running hypotheses.
-
- """
- logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
- if self.token_list is not None:
- logging.debug(
- "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
- )
- # add eos in the final loop to avoid that there are no ended hyps
- if i == maxlen - 1:
- logging.info("adding <eos> in the last position in the loop")
- running_hyps = [
- h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
- ]
-
- # add ended hypotheses to a final list, and removed them from current hypotheses
- # (this will be a problem, number of hyps < beam)
- remained_hyps = []
- for hyp in running_hyps:
- if hyp.yseq[-1] == self.eos:
- # e.g., Word LM needs to add final <eos> score
- for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
- s = d.final_score(hyp.states[k])
- hyp.scores[k] += s
- hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
- ended_hyps.append(hyp)
- else:
- remained_hyps.append(hyp)
- return remained_hyps
diff --git a/funasr/models/sense_voice/template.yaml b/funasr/models/sense_voice/template.yaml
deleted file mode 100644
index 1a25ea4..0000000
--- a/funasr/models/sense_voice/template.yaml
+++ /dev/null
@@ -1,108 +0,0 @@
-# network architecture
-model: SenseVoiceRWKV
-model_conf:
- lsm_weight: 0.1
- length_normalized_loss: true
- activation_checkpoint: true
- sos: "<|startoftranscript|>"
- eos: "<|endoftext|>"
- downsample_rate: 4
- use_padmask: true
-
- dims:
- n_mels: 128
- n_vocab: 60515
- n_audio_ctx: 1500
- n_audio_state: 1280
- n_audio_head: 20
- n_audio_layer: 32
- n_text_ctx: 448
- n_text_state: 1280
- n_text_head: 20
- n_text_layer: 32
-
-
-# decoder
-decoder: SenseVoiceDecoder
-decoder_conf:
- rwkv_cfg:
- n_embd: 1280
- dropout: 0
- head_size_a: 64
- ctx_len: 1280
- dim_att: 1280 #${model_conf.rwkv_cfg.n_embd}
- dim_ffn: null
- head_size_divisor: 8
- n_layer: 32
- pre_ffn: 0
- ln0: false
- ln1: false
- init_rwkv: false
- datatype: bf16
-
-
-# frontend related
-frontend: WhisperFrontend
-frontend_conf:
- fs: 16000
- n_mels: ${model_conf.dims.n_mels}
- do_pad_trim: false
-
-tokenizer: SenseVoiceTokenizer
-tokenizer_conf:
- vocab_path: null
- is_multilingual: true
- num_languages: 8749
-
-dataset: SenseVoiceDataset
-dataset_conf:
- index_ds: IndexDSJsonlRankSplit
- batch_sampler: EspnetStyleBatchSampler
- rank_split: true
- batch_type: token # example or length
- batch_size: 3500 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
- max_token_length: 2200
- min_token_length: 60
- max_source_length: 2000
- min_source_length: 60
- max_target_length: 150
- min_target_length: 0
- shuffle: True
- num_workers: 4
- sos: ${model_conf.sos}
- eos: ${model_conf.eos}
-
-train_conf:
- accum_grad: 2
- grad_clip: 5
- max_epoch: 20
- keep_nbest_models: 20
- avg_nbest_model: ${train_conf.keep_nbest_models}
- log_interval: 50
- reset_gpu_cache: true
-
-optim: adamw
-optim_conf:
- lr: 0.00002
-
-scheduler: warmuplr
-scheduler_conf:
- warmup_steps: 10000
-
-specaug: SpecAug
-specaug_conf:
- apply_time_warp: true
- time_warp_window: 5
- time_warp_mode: bicubic
- apply_freq_mask: true
- freq_mask_width_range:
- - 0
- - 40
- num_freq_mask: 2
- apply_time_mask: true
- time_mask_width_ratio_range:
- - 0.0
- - 0.12
- num_time_mask: 2
-
-scope_map: ['encoder.encoders', 'model.encoder', 'decoder.decoders', 'model.decoder']
\ No newline at end of file
diff --git a/funasr/models/sense_voice/whisper_lib/__init__.py b/funasr/models/sense_voice/whisper_lib/__init__.py
deleted file mode 100644
index 855b8cf..0000000
--- a/funasr/models/sense_voice/whisper_lib/__init__.py
+++ /dev/null
@@ -1,152 +0,0 @@
-import hashlib
-import io
-import os
-import urllib
-import warnings
-from typing import List, Optional, Union
-
-import torch
-from tqdm import tqdm
-
-from .audio import load_audio, log_mel_spectrogram, pad_or_trim
-from .decoding import DecodingOptions, DecodingResult, decode, detect_language
-from .model import ModelDimensions, Whisper
-from .transcribe import transcribe
-from .version import __version__
-
-_MODELS = {
- "tiny.en": "https://openaipublic.azureedge.net/main/whisper/models/d3dd57d32accea0b295c96e26691aa14d8822fac7d9d27d5dc00b4ca2826dd03/tiny.en.pt",
- "tiny": "https://openaipublic.azureedge.net/main/whisper/models/65147644a518d12f04e32d6f3b26facc3f8dd46e5390956a9424a650c0ce22b9/tiny.pt",
- "base.en": "https://openaipublic.azureedge.net/main/whisper/models/25a8566e1d0c1e2231d1c762132cd20e0f96a85d16145c3a00adf5d1ac670ead/base.en.pt",
- "base": "https://openaipublic.azureedge.net/main/whisper/models/ed3a0b6b1c0edf879ad9b11b1af5a0e6ab5db9205f891f668f8b0e6c6326e34e/base.pt",
- "small.en": "https://openaipublic.azureedge.net/main/whisper/models/f953ad0fd29cacd07d5a9eda5624af0f6bcf2258be67c92b79389873d91e0872/small.en.pt",
- "small": "https://openaipublic.azureedge.net/main/whisper/models/9ecf779972d90ba49c06d968637d720dd632c55bbf19d441fb42bf17a411e794/small.pt",
- "medium.en": "https://openaipublic.azureedge.net/main/whisper/models/d7440d1dc186f76616474e0ff0b3b6b879abc9d1a4926b7adfa41db2d497ab4f/medium.en.pt",
- "medium": "https://openaipublic.azureedge.net/main/whisper/models/345ae4da62f9b3d59415adc60127b97c714f32e89e936602e85993674d08dcb1/medium.pt",
- "large-v1": "https://openaipublic.azureedge.net/main/whisper/models/e4b87e7e0bf463eb8e6956e646f1e277e901512310def2c24bf0e11bd3c28e9a/large-v1.pt",
- "large-v2": "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt",
- "large-v3": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
- "large": "https://openaipublic.azureedge.net/main/whisper/models/e5b1a55b89c1367dacf97e3e19bfd829a01529dbfdeefa8caeb59b3f1b81dadb/large-v3.pt",
-}
-
-# base85-encoded (n_layers, n_heads) boolean arrays indicating the cross-attention heads that are
-# highly correlated to the word-level timing, i.e. the alignment between audio and text tokens.
-_ALIGNMENT_HEADS = {
- "tiny.en": b"ABzY8J1N>@0{>%R00Bk>$p{7v037`oCl~+#00",
- "tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO",
- "base.en": b"ABzY8;40c<0{>%RzzG;p*o+Vo09|#PsxSZm00",
- "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m",
- "small.en": b"ABzY8>?_)10{>%RpeA61k&I|OI3I$65C{;;pbCHh0B{qLQ;+}v00",
- "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000",
- "medium.en": b"ABzY8usPae0{>%R7<zz_OvQ{)4kMa0BMw6u5rT}kRKX;$NfYBv00*Hl@qhsU00",
- "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9",
- "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj",
- "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj",
- "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
- "large": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00",
-}
-
-
-def _download(url: str, root: str, in_memory: bool) -> Union[bytes, str]:
- os.makedirs(root, exist_ok=True)
-
- expected_sha256 = url.split("/")[-2]
- download_target = os.path.join(root, os.path.basename(url))
-
- if os.path.exists(download_target) and not os.path.isfile(download_target):
- raise RuntimeError(f"{download_target} exists and is not a regular file")
-
- if os.path.isfile(download_target):
- with open(download_target, "rb") as f:
- model_bytes = f.read()
- if hashlib.sha256(model_bytes).hexdigest() == expected_sha256:
- return model_bytes if in_memory else download_target
- else:
- warnings.warn(
- f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file"
- )
-
- with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
- with tqdm(
- total=int(source.info().get("Content-Length")),
- ncols=80,
- unit="iB",
- unit_scale=True,
- unit_divisor=1024,
- ) as loop:
- while True:
- buffer = source.read(8192)
- if not buffer:
- break
-
- output.write(buffer)
- loop.update(len(buffer))
-
- model_bytes = open(download_target, "rb").read()
- if hashlib.sha256(model_bytes).hexdigest() != expected_sha256:
- raise RuntimeError(
- "Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
- )
-
- return model_bytes if in_memory else download_target
-
-
-def available_models() -> List[str]:
- """Returns the names of available models"""
- return list(_MODELS.keys())
-
-
-def load_model(
- name: str,
- device: Optional[Union[str, torch.device]] = None,
- download_root: str = None,
- in_memory: bool = False,
-) -> Whisper:
- """
- Load a Whisper ASR model
-
- Parameters
- ----------
- name : str
- one of the official model names listed by `whisper.available_models()`, or
- path to a model checkpoint containing the model dimensions and the model state_dict.
- device : Union[str, torch.device]
- the PyTorch device to put the model into
- download_root: str
- path to download the model files; by default, it uses "~/.cache/whisper"
- in_memory: bool
- whether to preload the model weights into host memory
-
- Returns
- -------
- model : Whisper
- The Whisper ASR model instance
- """
-
- if device is None:
- device = "cuda" if torch.cuda.is_available() else "cpu"
- if download_root is None:
- default = os.path.join(os.path.expanduser("~"), ".cache")
- download_root = os.path.join(os.getenv("XDG_CACHE_HOME", default), "whisper")
-
- if name in _MODELS:
- checkpoint_file = _download(_MODELS[name], download_root, in_memory)
- alignment_heads = _ALIGNMENT_HEADS[name]
- elif os.path.isfile(name):
- checkpoint_file = open(name, "rb").read() if in_memory else name
- alignment_heads = None
- else:
- raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
-
- with io.BytesIO(checkpoint_file) if in_memory else open(checkpoint_file, "rb") as fp:
- checkpoint = torch.load(fp, map_location=device)
- del checkpoint_file
-
- dims = ModelDimensions(**checkpoint["dims"])
- model = Whisper(dims)
- model.load_state_dict(checkpoint["model_state_dict"])
-
- if alignment_heads is not None:
- model.set_alignment_heads(alignment_heads)
-
- return model.to(device)
diff --git a/funasr/models/sense_voice/whisper_lib/__main__.py b/funasr/models/sense_voice/whisper_lib/__main__.py
deleted file mode 100644
index 8874c7a..0000000
--- a/funasr/models/sense_voice/whisper_lib/__main__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# from .transcribe import cli
-#
-# cli()
diff --git a/funasr/models/sense_voice/whisper_lib/audio.py b/funasr/models/sense_voice/whisper_lib/audio.py
deleted file mode 100644
index 2f688e9..0000000
--- a/funasr/models/sense_voice/whisper_lib/audio.py
+++ /dev/null
@@ -1,155 +0,0 @@
-import os
-from functools import lru_cache
-from subprocess import CalledProcessError, run
-from typing import Optional, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from .utils import exact_div
-
-# hard-coded audio hyperparameters
-SAMPLE_RATE = 16000
-N_FFT = 400
-HOP_LENGTH = 160
-CHUNK_LENGTH = 30
-N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE # 480000 samples in a 30-second chunk
-N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH) # 3000 frames in a mel spectrogram input
-
-N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 # the initial convolutions has stride 2
-FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
-TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
-
-
-def load_audio(file: str, sr: int = SAMPLE_RATE):
- """
- Open an audio file and read as mono waveform, resampling as necessary
-
- Parameters
- ----------
- file: str
- The audio file to open
-
- sr: int
- The sample rate to resample the audio if necessary
-
- Returns
- -------
- A NumPy array containing the audio waveform, in float32 dtype.
- """
-
- # This launches a subprocess to decode audio while down-mixing
- # and resampling as necessary. Requires the ffmpeg CLI in PATH.
- # fmt: off
- cmd = [
- "ffmpeg",
- "-nostdin",
- "-threads", "0",
- "-i", file,
- "-f", "s16le",
- "-ac", "1",
- "-acodec", "pcm_s16le",
- "-ar", str(sr),
- "-"
- ]
- # fmt: on
- try:
- out = run(cmd, capture_output=True, check=True).stdout
- except CalledProcessError as e:
- raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
-
- return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
-
-
-def pad_or_trim(array, length: int = N_SAMPLES, *, axis: int = -1):
- """
- Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
- """
- if torch.is_tensor(array):
- if array.shape[axis] > length:
- array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
-
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
- else:
- if array.shape[axis] > length:
- array = array.take(indices=range(length), axis=axis)
-
- if array.shape[axis] < length:
- pad_widths = [(0, 0)] * array.ndim
- pad_widths[axis] = (0, length - array.shape[axis])
- array = np.pad(array, pad_widths)
-
- return array
-
-
-@lru_cache(maxsize=None)
-def mel_filters(device, n_mels: int, filters_path: str = None) -> torch.Tensor:
- """
- load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
- Allows decoupling librosa dependency; saved using:
-
- np.savez_compressed(
- "mel_filters.npz",
- mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
- mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
- )
- """
- assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
- if filters_path is None:
- filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
- with np.load(filters_path, allow_pickle=False) as f:
- return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
-
-
-def log_mel_spectrogram(
- audio: Union[str, np.ndarray, torch.Tensor],
- n_mels: int = 80,
- padding: int = 0,
- device: Optional[Union[str, torch.device]] = None,
-):
- """
- Compute the log-Mel spectrogram of
-
- Parameters
- ----------
- audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
- The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
-
- n_mels: int
- The number of Mel-frequency filters, only 80 is supported
-
- padding: int
- Number of zero samples to pad to the right
-
- device: Optional[Union[str, torch.device]]
- If given, the audio tensor is moved to this device before STFT
-
- Returns
- -------
- torch.Tensor, shape = (80, n_frames)
- A Tensor that contains the Mel spectrogram
- """
- if not torch.is_tensor(audio):
- if isinstance(audio, str):
- audio = load_audio(audio)
- audio = torch.from_numpy(audio)
-
- if device is not None:
- audio = audio.to(device)
- if padding > 0:
- audio = F.pad(audio, (0, padding))
- window = torch.hann_window(N_FFT).to(audio.device)
- stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
- magnitudes = stft[..., :-1].abs() ** 2
-
- filters = mel_filters(audio.device, n_mels)
- mel_spec = filters @ magnitudes
-
- log_spec = torch.clamp(mel_spec, min=1e-10).log10()
- log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
- log_spec = (log_spec + 4.0) / 4.0
- return log_spec
diff --git a/funasr/models/sense_voice/whisper_lib/decoding.py b/funasr/models/sense_voice/whisper_lib/decoding.py
deleted file mode 100644
index a468efa..0000000
--- a/funasr/models/sense_voice/whisper_lib/decoding.py
+++ /dev/null
@@ -1,890 +0,0 @@
-from dataclasses import dataclass, field, replace
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import Tensor
-from torch.distributions import Categorical
-
-from .audio import CHUNK_LENGTH
-from .tokenizer import Tokenizer, get_tokenizer
-from .utils import compression_ratio
-from funasr.models.transformer.utils.nets_utils import to_device
-
-
-if TYPE_CHECKING:
- from .model import Whisper
-
-
-@torch.no_grad()
-def detect_language(
- model: "Whisper",
- mel: Tensor,
- tokenizer: Tokenizer = None,
- initial_prompt=None,
- x=None,
-) -> Tuple[Tensor, List[dict]]:
- """
- Detect the spoken language in the audio, and return them as list of strings, along with the ids
- of the most probable language tokens and the probability distribution over all language tokens.
- This is performed outside the main decode loop in order to not interfere with kv-caching.
-
- Returns
- -------
- language_tokens : Tensor, shape = (n_audio,)
- ids of the most probable language tokens, which appears after the startoftranscript token.
- language_probs : List[Dict[str, float]], length = n_audio
- list of dictionaries containing the probability distribution over all languages.
- """
- if tokenizer is None:
- tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
- if tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence:
- raise ValueError("This model doesn't have language tokens so it can't perform lang id")
-
- single = mel.ndim == 2
- if single:
- mel = mel.unsqueeze(0)
-
- # skip encoder forward pass if already-encoded audio features were given
- # FIX(funasr): sense vocie
- if mel.shape[-1] != model.dims.n_audio_state:
- mel = model.encoder(mel)
-
- # forward pass using a single token, startoftranscript
- n_audio = mel.shape[0]
- # FIX(funasr): sense vocie
- # x = torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device) # [n_audio, 1]
- if x is None:
- x = torch.tensor([tokenizer.encode(initial_prompt, allowed_special="all")] * n_audio).to(
- mel.device
- ) # [n_audio, 1]
-
- else:
- x = x.to(mel.device)
- # FIX(funasr): sense vocie
- logits = model.logits(x[:, :-1], mel)[:, -1]
- # logits = model.logits(x[:, :], mel)[:, -1]
-
- # collect detected languages; suppress all non-language tokens
- mask = torch.ones(logits.shape[-1], dtype=torch.bool)
- mask[list(tokenizer.all_language_tokens)] = False
- mask[tokenizer.no_speech] = False
-
- logits[:, mask] = -np.inf
- language_tokens = logits.argmax(dim=-1)
- language_token_probs = logits.softmax(dim=-1).cpu()
-
- language_probs = [
- {
- c: language_token_probs[i, j].item()
- for j, c in zip(
- list(tokenizer.all_language_tokens) + [tokenizer.no_speech],
- list(tokenizer.all_language_codes) + ["nospeech"],
- )
- }
- for i in range(n_audio)
- ]
-
- if single:
- language_tokens = language_tokens[0]
- language_probs = language_probs[0]
-
- return language_tokens, language_probs
-
-
-@dataclass(frozen=True)
-class DecodingOptions:
- # whether to perform X->X "transcribe" or X->English "translate"
- task: str = "transcribe"
-
- # language that the audio is in; uses detected language if None
- language: Optional[str] = None
-
- # sampling-related options
- temperature: float = 0.0
- sample_len: Optional[int] = None # maximum number of tokens to sample
- best_of: Optional[int] = None # number of independent sample trajectories, if t > 0
- beam_size: Optional[int] = None # number of beams in beam search, if t == 0
- patience: Optional[float] = None # patience in beam search (arxiv:2204.05424)
-
- # "alpha" in Google NMT, or None for length norm, when ranking generations
- # to select which to return among the beams or best-of-N samples
- length_penalty: Optional[float] = None
-
- # text or tokens to feed as the prompt or the prefix; for more info:
- # https://github.com/openai/whisper/discussions/117#discussioncomment-3727051
- prompt: Optional[Union[str, List[int]]] = None # for the previous context
- prefix: Optional[Union[str, List[int]]] = None # to prefix the current context
-
- # list of tokens ids (or comma-separated token ids) to suppress
- # "-1" will suppress a set of symbols as defined in `tokenizer.non_speech_tokens()`
- suppress_tokens: Optional[Union[str, Iterable[int]]] = "-1"
- suppress_blank: bool = True # this will suppress blank outputs
-
- gain_event: bool = False # this will suppress blank outputs
- gain_tokens_bg: Optional[Union[str, List[int]]] = "<|Speech|><|BGM|><|Applause|><|Laughter|>"
- gain_tokens_ed: Optional[Union[str, List[int]]] = (
- "<|/Speech|><|/BGM|><|/Applause|><|/Laughter|>"
- )
- gain_tokens_score: List[float] = field(default_factory=lambda: [1, 1, 25.0, 5.0]) # [25, 5]
-
- use_emo_threshold: bool = False # this will suppress blank outputs
- emo_unk_token: Optional[Union[str, List[int]]] = "<|SPECIAL_TOKEN_1|>"
- emo_target_tokens: Optional[Union[str, List[int]]] = "<|HAPPY|><|SAD|><|ANGRY|>"
- emo_target_threshold: List[float] = field(default_factory=lambda: [0.1, 0.1, 0.1]) # [25, 5]
-
- # timestamp sampling options
- without_timestamps: bool = False # use <|notimestamps|> to sample text tokens only
- max_initial_timestamp: Optional[float] = 1.0
-
- # implementation details
- fp16: bool = True # use fp16 for most of the calculation
-
- # FIX(funasr): sense vocie
- initial_prompt: str = None
- vocab_path: str = None
-
-
-@dataclass(frozen=True)
-class DecodingResult:
- audio_features: Tensor
- language: str
- language_probs: Optional[Dict[str, float]] = None
- tokens: List[int] = field(default_factory=list)
- text: str = ""
- avg_logprob: float = np.nan
- no_speech_prob: float = np.nan
- temperature: float = np.nan
- compression_ratio: float = np.nan
-
-
-class Inference:
- def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
- """Perform a forward pass on the decoder and return per-token logits"""
- raise NotImplementedError
-
- def rearrange_kv_cache(self, source_indices) -> None:
- """Update the key-value cache according to the updated beams"""
- raise NotImplementedError
-
- def cleanup_caching(self) -> None:
- """Clean up any resources or hooks after decoding is finished"""
- pass
-
-
-class PyTorchInference(Inference):
- def __init__(self, model: "Whisper", initial_token_length: int):
- self.model: "Whisper" = model
- self.initial_token_length = initial_token_length
- self.kv_cache = {}
- self.hooks = []
-
- key_modules = [block.attn.key for block in self.model.decoder.blocks]
- value_modules = [block.attn.value for block in self.model.decoder.blocks]
- self.kv_modules = key_modules + value_modules
-
- def logits(self, tokens: Tensor, audio_features: Tensor) -> Tensor:
- if not self.kv_cache:
- self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
-
- if tokens.shape[-1] > self.initial_token_length:
- # only need to use the last token except in the first forward pass
- tokens = tokens[:, -1:]
-
- return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
-
- def cleanup_caching(self):
- for hook in self.hooks:
- hook.remove()
-
- self.kv_cache = {}
- self.hooks = []
-
- def rearrange_kv_cache(self, source_indices):
- if source_indices != list(range(len(source_indices))):
- for module in self.kv_modules:
- # update the key/value cache to contain the selected sequences
- self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
-
-
-class SequenceRanker:
- def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]) -> List[int]:
- """
- Given a list of groups of samples and their cumulative log probabilities,
- return the indices of the samples in each group to select as the final result
- """
- raise NotImplementedError
-
-
-class MaximumLikelihoodRanker(SequenceRanker):
- """
- Select the sample with the highest log probabilities, penalized using either
- a simple length normalization or Google NMT paper's length penalty
- """
-
- def __init__(self, length_penalty: Optional[float]):
- self.length_penalty = length_penalty
-
- def rank(self, tokens: List[List[Tensor]], sum_logprobs: List[List[float]]):
- def scores(logprobs, lengths):
- result = []
- for logprob, length in zip(logprobs, lengths):
- if self.length_penalty is None:
- penalty = length
- else:
- # from the Google NMT paper
- penalty = ((5 + length) / 6) ** self.length_penalty
- result.append(logprob / penalty)
- return result
-
- # get the sequence with the highest score
- lengths = [[len(t) for t in s] for s in tokens]
- return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, lengths)]
-
-
-class TokenDecoder:
- def reset(self):
- """Initialize any stateful variables for decoding a new sequence"""
-
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
- """Specify how to select the next token, based on the current trace and logits
-
- Parameters
- ----------
- tokens : Tensor, shape = (n_batch, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence tokens
-
- logits : Tensor, shape = (n_batch, vocab_size)
- per-token logits of the probability distribution at the current step
-
- sum_logprobs : Tensor, shape = (n_batch)
- cumulative log probabilities for each sequence
-
- Returns
- -------
- tokens : Tensor, shape = (n_batch, current_sequence_length + 1)
- the tokens, appended with the selected next token
-
- completed : bool
- True if all sequences has reached the end of text
-
- """
- raise NotImplementedError
-
- def finalize(
- self, tokens: Tensor, sum_logprobs: Tensor
- ) -> Tuple[Sequence[Sequence[Tensor]], List[List[float]]]:
- """Finalize search and return the final candidate sequences
-
- Parameters
- ----------
- tokens : Tensor, shape = (n_audio, n_group, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence
-
- sum_logprobs : Tensor, shape = (n_audio, n_group)
- cumulative log probabilities for each sequence
-
- Returns
- -------
- tokens : Sequence[Sequence[Tensor]], length = n_audio
- sequence of Tensors containing candidate token sequences, for each audio input
-
- sum_logprobs : List[List[float]], length = n_audio
- sequence of cumulative log probabilities corresponding to the above
-
- """
- raise NotImplementedError
-
-
-class GreedyDecoder(TokenDecoder):
- def __init__(self, temperature: float, eot: int):
- self.temperature = temperature
- self.eot = eot
-
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
- if self.temperature == 0:
- next_tokens = logits.argmax(dim=-1)
- else:
- next_tokens = Categorical(logits=logits / self.temperature).sample()
-
- logprobs = F.log_softmax(logits.float(), dim=-1)
- current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
- sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
-
- next_tokens[tokens[:, -1] == self.eot] = self.eot
- tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
-
- completed = (tokens[:, -1] == self.eot).all()
- return tokens, completed
-
- def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
- # make sure each sequence has at least one EOT token at the end
- tokens = F.pad(tokens, (0, 1), value=self.eot)
- return tokens, sum_logprobs.tolist()
-
-
-class BeamSearchDecoder(TokenDecoder):
- def __init__(
- self,
- beam_size: int,
- eot: int,
- inference: Inference,
- patience: Optional[float] = None,
- ):
- self.beam_size = beam_size
- self.eot = eot
- self.inference = inference
- self.patience = patience or 1.0
- self.max_candidates: int = round(beam_size * self.patience)
- self.finished_sequences = None
-
- assert self.max_candidates > 0, f"Invalid beam size ({beam_size}) or patience ({patience})"
-
- def reset(self):
- self.finished_sequences = None
-
- def update(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor) -> Tuple[Tensor, bool]:
- if tokens.shape[0] % self.beam_size != 0:
- raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
-
- n_audio = tokens.shape[0] // self.beam_size
- if self.finished_sequences is None: # for the first update
- self.finished_sequences = [{} for _ in range(n_audio)]
-
- logprobs = F.log_softmax(logits.float(), dim=-1)
- next_tokens, source_indices, finished_sequences = [], [], []
- for i in range(n_audio):
- scores, sources, finished = {}, {}, {}
-
- # STEP 1: calculate the cumulative log probabilities for possible candidates
- for j in range(self.beam_size):
- idx = i * self.beam_size + j
- prefix = tokens[idx].tolist()
- for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
- new_logprob = (sum_logprobs[idx] + logprob).item()
- sequence = tuple(prefix + [token.item()])
- scores[sequence] = new_logprob
- sources[sequence] = idx
-
- # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
- saved = 0
- for sequence in sorted(scores, key=scores.get, reverse=True):
- if sequence[-1] == self.eot:
- finished[sequence] = scores[sequence]
- else:
- sum_logprobs[len(next_tokens)] = scores[sequence]
- next_tokens.append(sequence)
- source_indices.append(sources[sequence])
-
- saved += 1
- if saved == self.beam_size:
- break
-
- finished_sequences.append(finished)
-
- tokens = torch.tensor(next_tokens, device=tokens.device)
- self.inference.rearrange_kv_cache(source_indices)
-
- # add newly finished sequences to self.finished_sequences
- assert len(self.finished_sequences) == len(finished_sequences)
- for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
- for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
- if len(previously_finished) >= self.max_candidates:
- break # the candidate list is full
- previously_finished[seq] = newly_finished[seq]
-
- # mark as completed if all audio has enough number of samples
- completed = all(
- len(sequences) >= self.max_candidates for sequences in self.finished_sequences
- )
- return tokens, completed
-
- def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
- # collect all finished sequences, including patience, and add unfinished ones if not enough
- sum_logprobs = sum_logprobs.cpu()
- for i, sequences in enumerate(self.finished_sequences):
- if len(sequences) < self.beam_size: # when not enough sequences are finished
- for j in list(np.argsort(sum_logprobs[i]))[::-1]:
- sequence = preceding_tokens[i, j].tolist() + [self.eot]
- sequences[tuple(sequence)] = sum_logprobs[i][j].item()
- if len(sequences) >= self.beam_size:
- break
-
- tokens: List[List[Tensor]] = [
- [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
- ]
- sum_logprobs: List[List[float]] = [
- list(sequences.values()) for sequences in self.finished_sequences
- ]
- return tokens, sum_logprobs
-
-
-class LogitFilter:
- def apply(self, logits: Tensor, tokens: Tensor) -> None:
- """Apply any filtering or masking to logits in-place
-
- Parameters
- ----------
- logits : Tensor, shape = (n_batch, vocab_size)
- per-token logits of the probability distribution at the current step
-
- tokens : Tensor, shape = (n_batch, current_sequence_length)
- all tokens in the context so far, including the prefix and sot_sequence tokens
-
- """
- raise NotImplementedError
-
-
-class SuppressBlank(LogitFilter):
- def __init__(self, tokenizer: Tokenizer, sample_begin: int):
- self.tokenizer = tokenizer
- self.sample_begin = sample_begin
-
- def apply(self, logits: Tensor, tokens: Tensor):
- if tokens.shape[1] == self.sample_begin:
- logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
-
-
-class SuppressTokens(LogitFilter):
- def __init__(self, suppress_tokens: Sequence[int]):
- self.suppress_tokens = list(suppress_tokens)
-
- def apply(self, logits: Tensor, tokens: Tensor):
- logits[:, self.suppress_tokens] = -np.inf
-
-
-class GainEventToken(LogitFilter):
- def __init__(
- self, bg_tokens: Sequence[int], ed_tokens: Sequence[int], gain_values: Sequence[float]
- ):
- self.bg_tokens = list(bg_tokens)
- self.ed_tokens = list(ed_tokens)
- self.gain_value = [np.log(max(ga, 1e-9)) for ga in gain_values]
- assert len(self.ed_tokens) == len(self.gain_value)
- assert len(self.bg_tokens) == len(self.gain_value)
-
- def apply(self, logits: Tensor, tokens: Tensor):
- for i in range(len(tokens)):
- for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
- sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
- sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
- logits[i, bg] += ga
- if sum_bg > sum_ed or tokens[i, -1] in [bg, ed]:
- logits[i, bg] = -np.inf
- if sum_bg <= sum_ed:
- logits[i, ed] = -np.inf
-
-
-class ThresholdEmoToken(LogitFilter):
- def __init__(
- self, unk_tokens: Sequence[int], emo_tokens: Sequence[int], th_values: Sequence[float]
- ):
- self.unk_token = list(unk_tokens)[0]
- self.emo_tokens = list(emo_tokens)
- self.th_values = list(th_values)
- assert len(self.emo_tokens) == len(self.th_values)
-
- def apply(self, logits: Tensor, tokens: Tensor):
- for i in range(len(tokens)):
- for emo, th in zip(self.emo_tokens, self.th_values):
- if logits[i].argmax() == emo and logits[i].softmax(dim=-1)[emo] < th:
- logits[i, self.unk_token] = max(logits[i, emo], logits[i, self.unk_token])
- logits[i, emo] = -np.inf
-
- # for bg, ed, ga in zip(self.bg_tokens, self.ed_tokens, self.gain_value):
- # sum_bg = sum([1 if x == bg else 0 for x in tokens[i]])
- # sum_ed = sum([1 if x == ed else 0 for x in tokens[i]])
- # logits[i, bg] += ga
- # if sum_bg > sum_ed or tokens[i,-1] in [bg, ed]:
- # logits[i, bg] = -np.inf
- # if sum_bg <= sum_ed:
- # logits[i, ed] = -np.inf
-
-
-class ApplyTimestampRules(LogitFilter):
- def __init__(
- self,
- tokenizer: Tokenizer,
- sample_begin: int,
- max_initial_timestamp_index: Optional[int],
- ):
- self.tokenizer = tokenizer
- self.sample_begin = sample_begin
- self.max_initial_timestamp_index = max_initial_timestamp_index
-
- def apply(self, logits: Tensor, tokens: Tensor):
- # suppress <|notimestamps|> which is handled by without_timestamps
- if self.tokenizer.no_timestamps is not None:
- logits[:, self.tokenizer.no_timestamps] = -np.inf
-
- # timestamps have to appear in pairs, except directly before EOT; mask logits accordingly
- for k in range(tokens.shape[0]):
- sampled_tokens = tokens[k, self.sample_begin :]
- seq = [t for t in sampled_tokens.tolist()]
- last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin
- penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin
-
- if last_was_timestamp:
- if penultimate_was_timestamp: # has to be non-timestamp
- logits[k, self.tokenizer.timestamp_begin :] = -np.inf
- else: # cannot be normal text tokens
- logits[k, : self.tokenizer.eot] = -np.inf
-
- timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
- if timestamps.numel() > 0:
- # timestamps shouldn't decrease; forbid timestamp tokens smaller than the last
- # also force each segment to have a nonzero length, to prevent infinite looping
- if last_was_timestamp and not penultimate_was_timestamp:
- timestamp_last = timestamps[-1]
- else:
- timestamp_last = timestamps[-1] + 1
- logits[k, self.tokenizer.timestamp_begin : timestamp_last] = -np.inf
-
- if tokens.shape[1] == self.sample_begin:
- # suppress generating non-timestamp tokens at the beginning
- logits[:, : self.tokenizer.timestamp_begin] = -np.inf
-
- # apply the `max_initial_timestamp` option
- if self.max_initial_timestamp_index is not None:
- last_allowed = self.tokenizer.timestamp_begin + self.max_initial_timestamp_index
- logits[:, last_allowed + 1 :] = -np.inf
-
- # if sum of probability over timestamps is above any other token, sample timestamp
- logprobs = F.log_softmax(logits.float(), dim=-1)
- for k in range(tokens.shape[0]):
- timestamp_logprob = logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1)
- max_text_token_logprob = logprobs[k, : self.tokenizer.timestamp_begin].max()
- if timestamp_logprob > max_text_token_logprob:
- logits[k, : self.tokenizer.timestamp_begin] = -np.inf
-
-
-class DecodingTask:
- inference: Inference
- sequence_ranker: SequenceRanker
- decoder: TokenDecoder
- logit_filters: List[LogitFilter]
-
- def __init__(self, model: "Whisper", options: DecodingOptions):
- self.model = model
-
- language = options.language or "en"
- tokenizer = get_tokenizer(
- model.is_multilingual,
- num_languages=model.num_languages,
- language=language,
- task=options.task,
- vocab_path=options.vocab_path,
- )
- self.tokenizer: Tokenizer = tokenizer
- self.options: DecodingOptions = self._verify_options(options)
-
- self.n_group: int = options.beam_size or options.best_of or 1
- self.n_ctx: int = model.dims.n_text_ctx
- self.sample_len: int = options.sample_len or model.dims.n_text_ctx // 2
-
- self.sot_sequence: Tuple[int] = tokenizer.sot_sequence
- if self.options.without_timestamps:
- self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
-
- self.initial_tokens: Tuple[int] = self._get_initial_tokens()
- self.sample_begin: int = len(self.initial_tokens)
- self.sot_index: int = self.initial_tokens.index(tokenizer.sot)
-
- # inference: implements the forward pass through the decoder, including kv caching
- self.inference = PyTorchInference(model, len(self.initial_tokens))
-
- # sequence ranker: implements how to rank a group of sampled sequences
- self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
-
- # decoder: implements how to select the next tokens, given the autoregressive distribution
- if options.beam_size is not None:
- self.decoder = BeamSearchDecoder(
- options.beam_size, tokenizer.eot, self.inference, options.patience
- )
- else:
- self.decoder = GreedyDecoder(options.temperature, tokenizer.eot)
-
- # logit filters: applies various rules to suppress or penalize certain tokens
- self.logit_filters = []
- if self.options.suppress_blank:
- self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
- if self.options.suppress_tokens:
- self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
- if self.options.gain_event:
- self.logit_filters.append(
- GainEventToken(
- self.tokenizer.encode(self.options.gain_tokens_bg, allowed_special="all"),
- self.tokenizer.encode(self.options.gain_tokens_ed, allowed_special="all"),
- self.options.gain_tokens_score,
- )
- )
- if self.options.use_emo_threshold:
- self.logit_filters.append(
- ThresholdEmoToken(
- self.tokenizer.encode(self.options.emo_unk_token, allowed_special="all"),
- self.tokenizer.encode(self.options.emo_target_tokens, allowed_special="all"),
- self.options.emo_target_threshold,
- )
- )
- if not options.without_timestamps:
- precision = CHUNK_LENGTH / model.dims.n_audio_ctx # usually 0.02 seconds
- max_initial_timestamp_index = None
- if options.max_initial_timestamp:
- max_initial_timestamp_index = round(self.options.max_initial_timestamp / precision)
- self.logit_filters.append(
- ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index)
- )
-
- def _verify_options(self, options: DecodingOptions) -> DecodingOptions:
- if options.beam_size is not None and options.best_of is not None:
- raise ValueError("beam_size and best_of can't be given together")
- if options.temperature == 0:
- if options.best_of is not None:
- raise ValueError("best_of with greedy sampling (T=0) is not compatible")
- if options.patience is not None and options.beam_size is None:
- raise ValueError("patience requires beam_size to be given")
- if options.length_penalty is not None and not (0 <= options.length_penalty <= 1):
- raise ValueError("length_penalty (alpha) should be a value between 0 and 1")
-
- return options
-
- def _get_initial_tokens(self) -> Tuple[int]:
- tokens = list(self.sot_sequence)
-
- if prefix := self.options.prefix:
- prefix_tokens = (
- self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix
- )
- if self.sample_len is not None:
- max_prefix_len = self.n_ctx // 2 - self.sample_len
- prefix_tokens = prefix_tokens[-max_prefix_len:]
- tokens = tokens + prefix_tokens
-
- if prompt := self.options.prompt:
- prompt_tokens = (
- self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt
- )
- tokens = [self.tokenizer.sot_prev] + prompt_tokens[-(self.n_ctx // 2 - 1) :] + tokens
- # FIX(funasr): sense vocie
- if initial_prompt := self.options.initial_prompt:
- if self.options.language is not None:
- initial_prompt = f"{initial_prompt}<|{self.options.language}|>"
- tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
- else:
- tokens = self.tokenizer.encode(initial_prompt, allowed_special="all")
- tokens += [0]
-
- return tuple(tokens)
-
- def _get_suppress_tokens(self) -> Tuple[int]:
- suppress_tokens = self.options.suppress_tokens
-
- if isinstance(suppress_tokens, str):
- suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
-
- if -1 in suppress_tokens:
- suppress_tokens = [t for t in suppress_tokens if t >= 0]
- suppress_tokens.extend(self.tokenizer.non_speech_tokens)
- elif suppress_tokens is None or len(suppress_tokens) == 0:
- suppress_tokens = [] # interpret empty string as an empty list
- else:
- assert isinstance(suppress_tokens, list), "suppress_tokens must be a list"
-
- suppress_tokens.extend(
- [
- self.tokenizer.transcribe,
- self.tokenizer.translate,
- self.tokenizer.sot,
- self.tokenizer.sot_prev,
- self.tokenizer.sot_lm,
- ]
- )
- if self.tokenizer.no_speech is not None:
- # no-speech probability is collected separately
- suppress_tokens.append(self.tokenizer.no_speech)
-
- return tuple(sorted(set(suppress_tokens)))
-
- def _get_audio_features(self, mel: Tensor):
- if self.options.fp16:
- mel = mel.half()
-
- if mel.shape[-2:] == (
- self.model.dims.n_audio_ctx,
- self.model.dims.n_audio_state,
- ):
- # encoded audio features are given; skip audio encoding
- audio_features = mel
- else:
- audio_features = self.model.encoder(mel)
-
- if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32):
- return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
-
- return audio_features
-
- def _detect_language(self, audio_features: Tensor, tokens: Tensor):
- languages = [self.options.language] * audio_features.shape[0]
- lang_probs = None
-
- if self.options.language is None or self.options.task == "lang_id":
- lang_tokens, lang_probs = self.model.detect_language(
- audio_features, self.tokenizer, x=tokens
- )
- languages = [max(probs, key=probs.get) for probs in lang_probs]
- # FIX(funasr): sense vocie
- # if self.options.language is None:
- # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
- if self.options.language is None:
- # tokens[:, self.sot_index + 1] = lang_tokens # write language tokens
- languages = "".join([f"<|{language}|>" for language in languages])
-
- n_audio = audio_features.shape[0]
- lang_tokens = torch.tensor(
- [self.tokenizer.encode(languages, allowed_special="all")] * n_audio
- ).to(
- audio_features.device
- ) # [n_audio, 1]
-
- tokens[:, -1:] = lang_tokens[:, :]
- languages = [languages]
-
- return languages, lang_probs
-
- def _main_loop(self, audio_features: Tensor, tokens: Tensor):
- n_batch = tokens.shape[0]
- sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
- no_speech_probs = [np.nan] * n_batch
-
- try:
- for i in range(self.sample_len):
- logits = self.inference.logits(tokens, audio_features)
-
- if i == 0 and self.tokenizer.no_speech is not None: # save no_speech_probs
- probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
- no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
-
- # now we need to consider the logits at the last token only
- logits = logits[:, -1]
-
- # apply the logit filters, e.g. for suppressing or applying penalty to
- for logit_filter in self.logit_filters:
- logit_filter.apply(logits, tokens)
-
- # expand the tokens tensor with the selected next tokens
- tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
-
- if completed or tokens.shape[-1] > self.n_ctx:
- break
- finally:
- self.inference.cleanup_caching()
-
- return tokens, sum_logprobs, no_speech_probs
-
- @torch.no_grad()
- def run(self, mel: Tensor) -> List[DecodingResult]:
- self.decoder.reset()
- tokenizer: Tokenizer = self.tokenizer
- n_audio: int = mel.shape[0]
-
- audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
- tokens: Tensor = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
-
- # detect language if requested, overwriting the language token
- languages, language_probs = self._detect_language(audio_features, tokens)
- if self.options.task == "lang_id":
- return [
- DecodingResult(audio_features=features, language=language, language_probs=probs)
- for features, language, probs in zip(audio_features, languages, language_probs)
- ]
-
- # repeat text tensors by the group size, for beam search or best-of-n sampling
- tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
-
- # call the main sampling loop
- tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
-
- # reshape the tensors to have (n_audio, n_group) as the first two dimensions
- audio_features = audio_features[:: self.n_group]
- no_speech_probs = no_speech_probs[:: self.n_group]
- assert audio_features.shape[0] == len(no_speech_probs) == n_audio
-
- tokens = tokens.reshape(n_audio, self.n_group, -1)
- sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
-
- # get the final candidates for each group, and slice between the first sampled token and EOT
- tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
- tokens: List[List[Tensor]] = [
- [t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
- ]
-
- # select the top-ranked sample in each group
- selected = self.sequence_ranker.rank(tokens, sum_logprobs)
- tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
- texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
-
- sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
- avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
-
- fields = (
- texts,
- languages,
- tokens,
- audio_features,
- avg_logprobs,
- no_speech_probs,
- )
- if len(set(map(len, fields))) != 1:
- raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
-
- return [
- DecodingResult(
- audio_features=features,
- language=language,
- tokens=tokens,
- text=text,
- avg_logprob=avg_logprob,
- no_speech_prob=no_speech_prob,
- temperature=self.options.temperature,
- compression_ratio=compression_ratio(text),
- )
- for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
- ]
-
-
-@torch.no_grad()
-def decode(
- model: "Whisper",
- mel: Tensor,
- options: DecodingOptions = DecodingOptions(),
- **kwargs,
-) -> Union[DecodingResult, List[DecodingResult]]:
- """
- Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
-
- Parameters
- ----------
- model: Whisper
- the Whisper model instance
-
- mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
- A tensor containing the Mel spectrogram(s)
-
- options: DecodingOptions
- A dataclass that contains all necessary options for decoding 30-second segments
-
- Returns
- -------
- result: Union[DecodingResult, List[DecodingResult]]
- The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
- """
- if single := mel.ndim == 2:
- mel = mel.unsqueeze(0)
-
- if kwargs:
- options = replace(options, **kwargs)
-
- result = DecodingTask(model, options).run(mel)
-
- return result[0] if single else result
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
deleted file mode 100644
index 3d0d6a8..0000000
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ /dev/null
@@ -1,333 +0,0 @@
-import base64
-import gzip
-from dataclasses import dataclass
-from typing import Dict, Iterable, Optional
-
-import numpy as np
-import torch
-import torch.nn.functional as F
-from torch import Tensor, nn
-
-from .decoding import decode as decode_function
-from .decoding import detect_language as detect_language_function
-from .transcribe import transcribe as transcribe_function
-
-
-@dataclass
-class ModelDimensions:
- n_mels: int
- n_audio_ctx: int
- n_audio_state: int
- n_audio_head: int
- n_audio_layer: int
- n_vocab: int
- n_text_ctx: int
- n_text_state: int
- n_text_head: int
- n_text_layer: int
-
-
-# class LayerNorm(nn.LayerNorm):
-# def forward(self, x: Tensor) -> Tensor:
-# return super().forward(x.float()).type(x.dtype)
-
-
-class LayerNorm(nn.LayerNorm):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
-
- def forward(self, input):
- output = F.layer_norm(
- input.float(),
- self.normalized_shape,
- self.weight.float() if self.weight is not None else None,
- self.bias.float() if self.bias is not None else None,
- self.eps,
- )
- return output.type_as(input)
-
-
-class Linear(nn.Linear):
- def forward(self, x: Tensor) -> Tensor:
- return F.linear(
- x,
- self.weight.to(x.dtype),
- None if self.bias is None else self.bias.to(x.dtype),
- )
-
-
-class Conv1d(nn.Conv1d):
- def _conv_forward(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor:
- return super()._conv_forward(
- x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
- )
-
-
-def sinusoids(length, channels, max_timescale=10000):
- """Returns sinusoids for positional embedding"""
- assert channels % 2 == 0
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
- return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
-
-
-class MultiHeadAttention(nn.Module):
- def __init__(self, n_state: int, n_head: int):
- super().__init__()
- self.n_head = n_head
- self.query = Linear(n_state, n_state)
- self.key = Linear(n_state, n_state, bias=False)
- self.value = Linear(n_state, n_state)
- self.out = Linear(n_state, n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
-
- q = self.query(x)
-
- if kv_cache is None or xa is None or self.key not in kv_cache:
- # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
- # otherwise, perform key/value projections for self- or cross-attention as usual.
- k = self.key(x if xa is None else xa)
- v = self.value(x if xa is None else xa)
- else:
- # for cross-attention, calculate keys and values once and reuse in subsequent calls.
- k = kv_cache[self.key]
- v = kv_cache[self.value]
-
- wv, qk = self.qkv_attention(q, k, v, mask, is_pad_mask=is_pad_mask)
- return self.out(wv), qk
-
- def qkv_attention(
- self,
- q: Tensor,
- k: Tensor,
- v: Tensor,
- mask: Optional[Tensor] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
- n_batch, n_ctx, n_state = q.shape
- scale = (n_state // self.n_head) ** -0.25
- q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
- k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
- v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
-
- qk = q @ k
- if mask is not None:
- if not is_pad_mask:
- qk = qk + mask[:n_ctx, :n_ctx]
- else:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = -float(
- "inf"
- ) # min_value = float(np.finfo(torch.tensor(0, dtype=qk.dtype).numpy().dtype).min)
- qk = qk.masked_fill(mask, min_value)
-
- qk = qk.float()
-
- w = F.softmax(qk, dim=-1).to(q.dtype)
- if mask is not None and is_pad_mask:
- w = w.masked_fill(mask, 0.0)
- return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
-
-
-class ResidualAttentionBlock(nn.Module):
- def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
- super().__init__()
-
- self.attn = MultiHeadAttention(n_state, n_head)
- self.attn_ln = LayerNorm(n_state)
-
- self.cross_attn = MultiHeadAttention(n_state, n_head) if cross_attention else None
- self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
-
- n_mlp = n_state * 4
- self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
- self.mlp_ln = LayerNorm(n_state)
-
- def forward(
- self,
- x: Tensor,
- xa: Optional[Tensor] = None,
- mask: Optional[Tensor] = None,
- kv_cache: Optional[dict] = None,
- **kwargs,
- ):
- is_pad_mask = kwargs.get("is_pad_mask", False)
- is_pad_memory_mask = kwargs.get("is_pad_memory_mask", False)
- x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache, is_pad_mask=is_pad_mask)[0]
- if self.cross_attn:
- x = (
- x
- + self.cross_attn(
- self.cross_attn_ln(x), xa, kv_cache=kv_cache, is_pad_mask=is_pad_memory_mask
- )[0]
- )
- x = x + self.mlp(self.mlp_ln(x))
- return x
-
-
-class AudioEncoder(nn.Module):
- def __init__(self, n_mels: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
- super().__init__()
- self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, stride=2, padding=1)
- self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
- self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
-
- self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
- )
- self.ln_post = LayerNorm(n_state)
-
- def forward(self, x: Tensor):
- """
- x : torch.Tensor, shape = (batch_size, n_mels, n_ctx)
- the mel spectrogram of the audio
- """
- x = F.gelu(self.conv1(x))
- x = F.gelu(self.conv2(x))
- x = x.permute(0, 2, 1)
-
- # assert x.shape[1:] == self.positional_embedding.shape, "incorrect audio shape"
- # x = (x + self.positional_embedding).to(x.dtype)
- x = (x + self.positional_embedding[: x.size(1), :]).to(x.dtype)
-
- for block in self.blocks:
- x = block(x)
-
- x = self.ln_post(x)
- return x
-
-
-class TextDecoder(nn.Module):
- def __init__(self, n_vocab: int, n_ctx: int, n_state: int, n_head: int, n_layer: int):
- super().__init__()
-
- self.token_embedding = nn.Embedding(n_vocab, n_state)
- self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
-
- self.blocks: Iterable[ResidualAttentionBlock] = nn.ModuleList(
- [ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)]
- )
- self.ln = LayerNorm(n_state)
-
- mask = torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1)
- self.register_buffer("mask", mask, persistent=False)
-
- def forward(self, x: Tensor, xa: Tensor, kv_cache: Optional[dict] = None):
- """
- x : torch.LongTensor, shape = (batch_size, <= n_ctx)
- the text tokens
- xa : torch.Tensor, shape = (batch_size, n_audio_ctx, n_audio_state)
- the encoded audio features to be attended on
- """
- offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
- x = self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]
- x = x.to(xa.dtype)
-
- for block in self.blocks:
- x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
-
- x = self.ln(x)
- logits = (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
-
- return logits
-
-
-class Whisper(nn.Module):
- def __init__(self, dims: ModelDimensions):
- super().__init__()
- self.dims = dims
- self.encoder = AudioEncoder(
- self.dims.n_mels,
- self.dims.n_audio_ctx,
- self.dims.n_audio_state,
- self.dims.n_audio_head,
- self.dims.n_audio_layer,
- )
- self.decoder = TextDecoder(
- self.dims.n_vocab,
- self.dims.n_text_ctx,
- self.dims.n_text_state,
- self.dims.n_text_head,
- self.dims.n_text_layer,
- )
- # use the last half among the decoder layers for time alignment by default;
- # to use a specific set of heads, see `set_alignment_heads()` below.
- all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
- all_heads[self.dims.n_text_layer // 2 :] = True
- # self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
- # alignment_heads_dense = model.get_buffer("alignment_heads").to_dense()
- # model.register_buffer("alignment_heads", alignment_heads_dense, persistent=False)
-
- def set_alignment_heads(self, dump: bytes):
- array = np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()
- mask = torch.from_numpy(array).reshape(self.dims.n_text_layer, self.dims.n_text_head)
- self.register_buffer("alignment_heads", mask.to_sparse(), persistent=False)
-
- def embed_audio(self, mel: torch.Tensor):
- return self.encoder(mel)
-
- def logits(self, tokens: torch.Tensor, audio_features: torch.Tensor):
- return self.decoder(tokens, audio_features)
-
- def forward(self, mel: torch.Tensor, tokens: torch.Tensor) -> Dict[str, torch.Tensor]:
- return self.decoder(tokens, self.encoder(mel))
-
- @property
- def device(self):
- return next(self.parameters()).device
-
- @property
- def is_multilingual(self):
- return self.dims.n_vocab >= 51865
-
- @property
- def num_languages(self):
- return self.dims.n_vocab - 51765 - int(self.is_multilingual)
-
- def install_kv_cache_hooks(self, cache: Optional[dict] = None):
- """
- The `MultiHeadAttention` module optionally accepts `kv_cache` which stores the key and value
- tensors calculated for the previous positions. This method returns a dictionary that stores
- all caches, and the necessary hooks for the key and value projection modules that save the
- intermediate tensors to be reused during later calculations.
-
- Returns
- -------
- cache : Dict[nn.Module, torch.Tensor]
- A dictionary object mapping the key/value projection modules to its cache
- hooks : List[RemovableHandle]
- List of PyTorch RemovableHandle objects to stop the hooks to be called
- """
- cache = {**cache} if cache is not None else {}
- hooks = []
-
- def save_to_cache(module, _, output):
- if module not in cache or output.shape[1] > self.dims.n_text_ctx:
- # save as-is, for the first token or cross attention
- cache[module] = output
- else:
- cache[module] = torch.cat([cache[module], output], dim=1).detach()
- return cache[module]
-
- def install_hooks(layer: nn.Module):
- if isinstance(layer, MultiHeadAttention):
- hooks.append(layer.key.register_forward_hook(save_to_cache))
- hooks.append(layer.value.register_forward_hook(save_to_cache))
-
- self.decoder.apply(install_hooks)
- return cache, hooks
-
- detect_language = detect_language_function
- transcribe = transcribe_function
- decode = decode_function
diff --git a/funasr/models/sense_voice/whisper_lib/normalizers/__init__.py b/funasr/models/sense_voice/whisper_lib/normalizers/__init__.py
deleted file mode 100644
index 896d5e3..0000000
--- a/funasr/models/sense_voice/whisper_lib/normalizers/__init__.py
+++ /dev/null
@@ -1,2 +0,0 @@
-from .basic import BasicTextNormalizer as BasicTextNormalizer
-from .english import EnglishTextNormalizer as EnglishTextNormalizer
diff --git a/funasr/models/sense_voice/whisper_lib/normalizers/basic.py b/funasr/models/sense_voice/whisper_lib/normalizers/basic.py
deleted file mode 100644
index 76addc5..0000000
--- a/funasr/models/sense_voice/whisper_lib/normalizers/basic.py
+++ /dev/null
@@ -1,75 +0,0 @@
-import re
-import unicodedata
-
-import regex
-
-# non-ASCII letters that are not separated by "NFKD" normalization
-ADDITIONAL_DIACRITICS = {
- "艙": "oe",
- "艗": "OE",
- "酶": "o",
- "脴": "O",
- "忙": "ae",
- "脝": "AE",
- "脽": "ss",
- "岷�": "SS",
- "膽": "d",
- "膼": "D",
- "冒": "d",
- "脨": "D",
- "镁": "th",
- "脼": "th",
- "艂": "l",
- "艁": "L",
-}
-
-
-def remove_symbols_and_diacritics(s: str, keep=""):
- """
- Replace any other markers, symbols, and punctuations with a space,
- and drop any diacritics (category 'Mn' and some manual mappings)
- """
- return "".join(
- (
- c
- if c in keep
- else (
- ADDITIONAL_DIACRITICS[c]
- if c in ADDITIONAL_DIACRITICS
- else (
- ""
- if unicodedata.category(c) == "Mn"
- else " " if unicodedata.category(c)[0] in "MSP" else c
- )
- )
- )
- for c in unicodedata.normalize("NFKD", s)
- )
-
-
-def remove_symbols(s: str):
- """
- Replace any other markers, symbols, punctuations with a space, keeping diacritics
- """
- return "".join(
- " " if unicodedata.category(c)[0] in "MSP" else c for c in unicodedata.normalize("NFKC", s)
- )
-
-
-class BasicTextNormalizer:
- def __init__(self, remove_diacritics: bool = False, split_letters: bool = False):
- self.clean = remove_symbols_and_diacritics if remove_diacritics else remove_symbols
- self.split_letters = split_letters
-
- def __call__(self, s: str):
- s = s.lower()
- s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
- s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
- s = self.clean(s).lower()
-
- if self.split_letters:
- s = " ".join(regex.findall(r"\X", s, regex.U))
-
- s = re.sub(r"\s+", " ", s) # replace any successive whitespace characters with a space
-
- return s
diff --git a/funasr/models/sense_voice/whisper_lib/normalizers/english.json b/funasr/models/sense_voice/whisper_lib/normalizers/english.json
deleted file mode 100644
index 74a1c35..0000000
--- a/funasr/models/sense_voice/whisper_lib/normalizers/english.json
+++ /dev/null
@@ -1,1741 +0,0 @@
-{
- "accessorise": "accessorize",
- "accessorised": "accessorized",
- "accessorises": "accessorizes",
- "accessorising": "accessorizing",
- "acclimatisation": "acclimatization",
- "acclimatise": "acclimatize",
- "acclimatised": "acclimatized",
- "acclimatises": "acclimatizes",
- "acclimatising": "acclimatizing",
- "accoutrements": "accouterments",
- "aeon": "eon",
- "aeons": "eons",
- "aerogramme": "aerogram",
- "aerogrammes": "aerograms",
- "aeroplane": "airplane",
- "aeroplanes": "airplanes",
- "aesthete": "esthete",
- "aesthetes": "esthetes",
- "aesthetic": "esthetic",
- "aesthetically": "esthetically",
- "aesthetics": "esthetics",
- "aetiology": "etiology",
- "ageing": "aging",
- "aggrandisement": "aggrandizement",
- "agonise": "agonize",
- "agonised": "agonized",
- "agonises": "agonizes",
- "agonising": "agonizing",
- "agonisingly": "agonizingly",
- "almanack": "almanac",
- "almanacks": "almanacs",
- "aluminium": "aluminum",
- "amortisable": "amortizable",
- "amortisation": "amortization",
- "amortisations": "amortizations",
- "amortise": "amortize",
- "amortised": "amortized",
- "amortises": "amortizes",
- "amortising": "amortizing",
- "amphitheatre": "amphitheater",
- "amphitheatres": "amphitheaters",
- "anaemia": "anemia",
- "anaemic": "anemic",
- "anaesthesia": "anesthesia",
- "anaesthetic": "anesthetic",
- "anaesthetics": "anesthetics",
- "anaesthetise": "anesthetize",
- "anaesthetised": "anesthetized",
- "anaesthetises": "anesthetizes",
- "anaesthetising": "anesthetizing",
- "anaesthetist": "anesthetist",
- "anaesthetists": "anesthetists",
- "anaesthetize": "anesthetize",
- "anaesthetized": "anesthetized",
- "anaesthetizes": "anesthetizes",
- "anaesthetizing": "anesthetizing",
- "analogue": "analog",
- "analogues": "analogs",
- "analyse": "analyze",
- "analysed": "analyzed",
- "analyses": "analyzes",
- "analysing": "analyzing",
- "anglicise": "anglicize",
- "anglicised": "anglicized",
- "anglicises": "anglicizes",
- "anglicising": "anglicizing",
- "annualised": "annualized",
- "antagonise": "antagonize",
- "antagonised": "antagonized",
- "antagonises": "antagonizes",
- "antagonising": "antagonizing",
- "apologise": "apologize",
- "apologised": "apologized",
- "apologises": "apologizes",
- "apologising": "apologizing",
- "appal": "appall",
- "appals": "appalls",
- "appetiser": "appetizer",
- "appetisers": "appetizers",
- "appetising": "appetizing",
- "appetisingly": "appetizingly",
- "arbour": "arbor",
- "arbours": "arbors",
- "archeological": "archaeological",
- "archaeologically": "archeologically",
- "archaeologist": "archeologist",
- "archaeologists": "archeologists",
- "archaeology": "archeology</span>",
- "ardour": "ardor",
- "armour": "armor",
- "armoured": "armored",
- "armourer": "armorer",
- "armourers": "armorers",
- "armouries": "armories",
- "armoury": "armory",
- "artefact": "artifact",
- "artefacts": "artifacts",
- "authorise": "authorize",
- "authorised": "authorized",
- "authorises": "authorizes",
- "authorising": "authorizing",
- "axe": "ax",
- "backpedalled": "backpedaled",
- "backpedalling": "backpedaling",
- "bannister": "banister",
- "bannisters": "banisters",
- "baptise": "baptize",
- "baptised": "baptized",
- "baptises": "baptizes",
- "baptising": "baptizing",
- "bastardise": "bastardize",
- "bastardised": "bastardized",
- "bastardises": "bastardizes",
- "bastardising": "bastardizing",
- "battleax": "battleaxe",
- "baulk": "balk",
- "baulked": "balked",
- "baulking": "balking",
- "baulks": "balks",
- "bedevilled": "bedeviled",
- "bedevilling": "bedeviling",
- "behaviour": "behavior",
- "behavioural": "behavioral",
- "behaviourism": "behaviorism",
- "behaviourist": "behaviorist",
- "behaviourists": "behaviorists",
- "behaviours": "behaviors",
- "behove": "behoove",
- "behoved": "behooved",
- "behoves": "behooves",
- "bejewelled": "bejeweled",
- "belabour": "belabor",
- "belaboured": "belabored",
- "belabouring": "belaboring",
- "belabours": "belabors",
- "bevelled": "beveled",
- "bevvies": "bevies",
- "bevvy": "bevy",
- "biassed": "biased",
- "biassing": "biasing",
- "bingeing": "binging",
- "bougainvillaea": "bougainvillea",
- "bougainvillaeas": "bougainvilleas",
- "bowdlerise": "bowdlerize",
- "bowdlerised": "bowdlerized",
- "bowdlerises": "bowdlerizes",
- "bowdlerising": "bowdlerizing",
- "breathalyse": "breathalyze",
- "breathalysed": "breathalyzed",
- "breathalyser": "breathalyzer",
- "breathalysers": "breathalyzers",
- "breathalyses": "breathalyzes",
- "breathalysing": "breathalyzing",
- "brutalise": "brutalize",
- "brutalised": "brutalized",
- "brutalises": "brutalizes",
- "brutalising": "brutalizing",
- "busses": "buses",
- "bussing": "busing",
- "caesarean": "cesarean",
- "caesareans": "cesareans",
- "calibre": "caliber",
- "calibres": "calibers",
- "calliper": "caliper",
- "callipers": "calipers",
- "callisthenics": "calisthenics",
- "canalise": "canalize",
- "canalised": "canalized",
- "canalises": "canalizes",
- "canalising": "canalizing",
- "cancelation": "cancellation",
- "cancelations": "cancellations",
- "cancelled": "canceled",
- "cancelling": "canceling",
- "candour": "candor",
- "cannibalise": "cannibalize",
- "cannibalised": "cannibalized",
- "cannibalises": "cannibalizes",
- "cannibalising": "cannibalizing",
- "canonise": "canonize",
- "canonised": "canonized",
- "canonises": "canonizes",
- "canonising": "canonizing",
- "capitalise": "capitalize",
- "capitalised": "capitalized",
- "capitalises": "capitalizes",
- "capitalising": "capitalizing",
- "caramelise": "caramelize",
- "caramelised": "caramelized",
- "caramelises": "caramelizes",
- "caramelising": "caramelizing",
- "carbonise": "carbonize",
- "carbonised": "carbonized",
- "carbonises": "carbonizes",
- "carbonising": "carbonizing",
- "carolled": "caroled",
- "carolling": "caroling",
- "catalogue": "catalog",
- "catalogued": "cataloged",
- "catalogues": "catalogs",
- "cataloguing": "cataloging",
- "catalyse": "catalyze",
- "catalysed": "catalyzed",
- "catalyses": "catalyzes",
- "catalysing": "catalyzing",
- "categorise": "categorize",
- "categorised": "categorized",
- "categorises": "categorizes",
- "categorising": "categorizing",
- "cauterise": "cauterize",
- "cauterised": "cauterized",
- "cauterises": "cauterizes",
- "cauterising": "cauterizing",
- "cavilled": "caviled",
- "cavilling": "caviling",
- "centigramme": "centigram",
- "centigrammes": "centigrams",
- "centilitre": "centiliter",
- "centilitres": "centiliters",
- "centimetre": "centimeter",
- "centimetres": "centimeters",
- "centralise": "centralize",
- "centralised": "centralized",
- "centralises": "centralizes",
- "centralising": "centralizing",
- "centre": "center",
- "centred": "centered",
- "centrefold": "centerfold",
- "centrefolds": "centerfolds",
- "centrepiece": "centerpiece",
- "centrepieces": "centerpieces",
- "centres": "centers",
- "channelled": "channeled",
- "channelling": "channeling",
- "characterise": "characterize",
- "characterised": "characterized",
- "characterises": "characterizes",
- "characterising": "characterizing",
- "cheque": "check",
- "chequebook": "checkbook",
- "chequebooks": "checkbooks",
- "chequered": "checkered",
- "cheques": "checks",
- "chilli": "chili",
- "chimaera": "chimera",
- "chimaeras": "chimeras",
- "chiselled": "chiseled",
- "chiselling": "chiseling",
- "circularise": "circularize",
- "circularised": "circularized",
- "circularises": "circularizes",
- "circularising": "circularizing",
- "civilise": "civilize",
- "civilised": "civilized",
- "civilises": "civilizes",
- "civilising": "civilizing",
- "clamour": "clamor",
- "clamoured": "clamored",
- "clamouring": "clamoring",
- "clamours": "clamors",
- "clangour": "clangor",
- "clarinettist": "clarinetist",
- "clarinettists": "clarinetists",
- "collectivise": "collectivize",
- "collectivised": "collectivized",
- "collectivises": "collectivizes",
- "collectivising": "collectivizing",
- "colonisation": "colonization",
- "colonise": "colonize",
- "colonised": "colonized",
- "coloniser": "colonizer",
- "colonisers": "colonizers",
- "colonises": "colonizes",
- "colonising": "colonizing",
- "colour": "color",
- "colourant": "colorant",
- "colourants": "colorants",
- "coloured": "colored",
- "coloureds": "coloreds",
- "colourful": "colorful",
- "colourfully": "colorfully",
- "colouring": "coloring",
- "colourize": "colorize",
- "colourized": "colorized",
- "colourizes": "colorizes",
- "colourizing": "colorizing",
- "colourless": "colorless",
- "colours": "colors",
- "commercialise": "commercialize",
- "commercialised": "commercialized",
- "commercialises": "commercializes",
- "commercialising": "commercializing",
- "compartmentalise": "compartmentalize",
- "compartmentalised": "compartmentalized",
- "compartmentalises": "compartmentalizes",
- "compartmentalising": "compartmentalizing",
- "computerise": "computerize",
- "computerised": "computerized",
- "computerises": "computerizes",
- "computerising": "computerizing",
- "conceptualise": "conceptualize",
- "conceptualised": "conceptualized",
- "conceptualises": "conceptualizes",
- "conceptualising": "conceptualizing",
- "connexion": "connection",
- "connexions": "connections",
- "contextualise": "contextualize",
- "contextualised": "contextualized",
- "contextualises": "contextualizes",
- "contextualising": "contextualizing",
- "cosier": "cozier",
- "cosies": "cozies",
- "cosiest": "coziest",
- "cosily": "cozily",
- "cosiness": "coziness",
- "cosy": "cozy",
- "councillor": "councilor",
- "councillors": "councilors",
- "counselled": "counseled",
- "counselling": "counseling",
- "counsellor": "counselor",
- "counsellors": "counselors",
- "crenelated": "crenellated",
- "criminalise": "criminalize",
- "criminalised": "criminalized",
- "criminalises": "criminalizes",
- "criminalising": "criminalizing",
- "criticise": "criticize",
- "criticised": "criticized",
- "criticises": "criticizes",
- "criticising": "criticizing",
- "crueller": "crueler",
- "cruellest": "cruelest",
- "crystallisation": "crystallization",
- "crystallise": "crystallize",
- "crystallised": "crystallized",
- "crystallises": "crystallizes",
- "crystallising": "crystallizing",
- "cudgelled": "cudgeled",
- "cudgelling": "cudgeling",
- "customise": "customize",
- "customised": "customized",
- "customises": "customizes",
- "customising": "customizing",
- "cypher": "cipher",
- "cyphers": "ciphers",
- "decentralisation": "decentralization",
- "decentralise": "decentralize",
- "decentralised": "decentralized",
- "decentralises": "decentralizes",
- "decentralising": "decentralizing",
- "decriminalisation": "decriminalization",
- "decriminalise": "decriminalize",
- "decriminalised": "decriminalized",
- "decriminalises": "decriminalizes",
- "decriminalising": "decriminalizing",
- "defence": "defense",
- "defenceless": "defenseless",
- "defences": "defenses",
- "dehumanisation": "dehumanization",
- "dehumanise": "dehumanize",
- "dehumanised": "dehumanized",
- "dehumanises": "dehumanizes",
- "dehumanising": "dehumanizing",
- "demeanour": "demeanor",
- "demilitarisation": "demilitarization",
- "demilitarise": "demilitarize",
- "demilitarised": "demilitarized",
- "demilitarises": "demilitarizes",
- "demilitarising": "demilitarizing",
- "demobilisation": "demobilization",
- "demobilise": "demobilize",
- "demobilised": "demobilized",
- "demobilises": "demobilizes",
- "demobilising": "demobilizing",
- "democratisation": "democratization",
- "democratise": "democratize",
- "democratised": "democratized",
- "democratises": "democratizes",
- "democratising": "democratizing",
- "demonise": "demonize",
- "demonised": "demonized",
- "demonises": "demonizes",
- "demonising": "demonizing",
- "demoralisation": "demoralization",
- "demoralise": "demoralize",
- "demoralised": "demoralized",
- "demoralises": "demoralizes",
- "demoralising": "demoralizing",
- "denationalisation": "denationalization",
- "denationalise": "denationalize",
- "denationalised": "denationalized",
- "denationalises": "denationalizes",
- "denationalising": "denationalizing",
- "deodorise": "deodorize",
- "deodorised": "deodorized",
- "deodorises": "deodorizes",
- "deodorising": "deodorizing",
- "depersonalise": "depersonalize",
- "depersonalised": "depersonalized",
- "depersonalises": "depersonalizes",
- "depersonalising": "depersonalizing",
- "deputise": "deputize",
- "deputised": "deputized",
- "deputises": "deputizes",
- "deputising": "deputizing",
- "desensitisation": "desensitization",
- "desensitise": "desensitize",
- "desensitised": "desensitized",
- "desensitises": "desensitizes",
- "desensitising": "desensitizing",
- "destabilisation": "destabilization",
- "destabilise": "destabilize",
- "destabilised": "destabilized",
- "destabilises": "destabilizes",
- "destabilising": "destabilizing",
- "dialled": "dialed",
- "dialling": "dialing",
- "dialogue": "dialog",
- "dialogues": "dialogs",
- "diarrhoea": "diarrhea",
- "digitise": "digitize",
- "digitised": "digitized",
- "digitises": "digitizes",
- "digitising": "digitizing",
- "disc": "disk",
- "discolour": "discolor",
- "discoloured": "discolored",
- "discolouring": "discoloring",
- "discolours": "discolors",
- "discs": "disks",
- "disembowelled": "disemboweled",
- "disembowelling": "disemboweling",
- "disfavour": "disfavor",
- "dishevelled": "disheveled",
- "dishonour": "dishonor",
- "dishonourable": "dishonorable",
- "dishonourably": "dishonorably",
- "dishonoured": "dishonored",
- "dishonouring": "dishonoring",
- "dishonours": "dishonors",
- "disorganisation": "disorganization",
- "disorganised": "disorganized",
- "distil": "distill",
- "distils": "distills",
- "dramatisation": "dramatization",
- "dramatisations": "dramatizations",
- "dramatise": "dramatize",
- "dramatised": "dramatized",
- "dramatises": "dramatizes",
- "dramatising": "dramatizing",
- "draught": "draft",
- "draughtboard": "draftboard",
- "draughtboards": "draftboards",
- "draughtier": "draftier",
- "draughtiest": "draftiest",
- "draughts": "drafts",
- "draughtsman": "draftsman",
- "draughtsmanship": "draftsmanship",
- "draughtsmen": "draftsmen",
- "draughtswoman": "draftswoman",
- "draughtswomen": "draftswomen",
- "draughty": "drafty",
- "drivelled": "driveled",
- "drivelling": "driveling",
- "duelled": "dueled",
- "duelling": "dueling",
- "economise": "economize",
- "economised": "economized",
- "economises": "economizes",
- "economising": "economizing",
- "edoema": "edema",
- "editorialise": "editorialize",
- "editorialised": "editorialized",
- "editorialises": "editorializes",
- "editorialising": "editorializing",
- "empathise": "empathize",
- "empathised": "empathized",
- "empathises": "empathizes",
- "empathising": "empathizing",
- "emphasise": "emphasize",
- "emphasised": "emphasized",
- "emphasises": "emphasizes",
- "emphasising": "emphasizing",
- "enamelled": "enameled",
- "enamelling": "enameling",
- "enamoured": "enamored",
- "encyclopaedia": "encyclopedia",
- "encyclopaedias": "encyclopedias",
- "encyclopaedic": "encyclopedic",
- "endeavour": "endeavor",
- "endeavoured": "endeavored",
- "endeavouring": "endeavoring",
- "endeavours": "endeavors",
- "energise": "energize",
- "energised": "energized",
- "energises": "energizes",
- "energising": "energizing",
- "enrol": "enroll",
- "enrols": "enrolls",
- "enthral": "enthrall",
- "enthrals": "enthralls",
- "epaulette": "epaulet",
- "epaulettes": "epaulets",
- "epicentre": "epicenter",
- "epicentres": "epicenters",
- "epilogue": "epilog",
- "epilogues": "epilogs",
- "epitomise": "epitomize",
- "epitomised": "epitomized",
- "epitomises": "epitomizes",
- "epitomising": "epitomizing",
- "equalisation": "equalization",
- "equalise": "equalize",
- "equalised": "equalized",
- "equaliser": "equalizer",
- "equalisers": "equalizers",
- "equalises": "equalizes",
- "equalising": "equalizing",
- "eulogise": "eulogize",
- "eulogised": "eulogized",
- "eulogises": "eulogizes",
- "eulogising": "eulogizing",
- "evangelise": "evangelize",
- "evangelised": "evangelized",
- "evangelises": "evangelizes",
- "evangelising": "evangelizing",
- "exorcise": "exorcize",
- "exorcised": "exorcized",
- "exorcises": "exorcizes",
- "exorcising": "exorcizing",
- "extemporisation": "extemporization",
- "extemporise": "extemporize",
- "extemporised": "extemporized",
- "extemporises": "extemporizes",
- "extemporising": "extemporizing",
- "externalisation": "externalization",
- "externalisations": "externalizations",
- "externalise": "externalize",
- "externalised": "externalized",
- "externalises": "externalizes",
- "externalising": "externalizing",
- "factorise": "factorize",
- "factorised": "factorized",
- "factorises": "factorizes",
- "factorising": "factorizing",
- "faecal": "fecal",
- "faeces": "feces",
- "familiarisation": "familiarization",
- "familiarise": "familiarize",
- "familiarised": "familiarized",
- "familiarises": "familiarizes",
- "familiarising": "familiarizing",
- "fantasise": "fantasize",
- "fantasised": "fantasized",
- "fantasises": "fantasizes",
- "fantasising": "fantasizing",
- "favour": "favor",
- "favourable": "favorable",
- "favourably": "favorably",
- "favoured": "favored",
- "favouring": "favoring",
- "favourite": "favorite",
- "favourites": "favorites",
- "favouritism": "favoritism",
- "favours": "favors",
- "feminise": "feminize",
- "feminised": "feminized",
- "feminises": "feminizes",
- "feminising": "feminizing",
- "fertilisation": "fertilization",
- "fertilise": "fertilize",
- "fertilised": "fertilized",
- "fertiliser": "fertilizer",
- "fertilisers": "fertilizers",
- "fertilises": "fertilizes",
- "fertilising": "fertilizing",
- "fervour": "fervor",
- "fibre": "fiber",
- "fibreglass": "fiberglass",
- "fibres": "fibers",
- "fictionalisation": "fictionalization",
- "fictionalisations": "fictionalizations",
- "fictionalise": "fictionalize",
- "fictionalised": "fictionalized",
- "fictionalises": "fictionalizes",
- "fictionalising": "fictionalizing",
- "fillet": "filet",
- "filleted": "fileted",
- "filleting": "fileting",
- "fillets": "filets",
- "finalisation": "finalization",
- "finalise": "finalize",
- "finalised": "finalized",
- "finalises": "finalizes",
- "finalising": "finalizing",
- "flautist": "flutist",
- "flautists": "flutists",
- "flavour": "flavor",
- "flavoured": "flavored",
- "flavouring": "flavoring",
- "flavourings": "flavorings",
- "flavourless": "flavorless",
- "flavours": "flavors",
- "flavoursome": "flavorsome",
- "flyer / flier": "flier / flyer",
- "foetal": "fetal",
- "foetid": "fetid",
- "foetus": "fetus",
- "foetuses": "fetuses",
- "formalisation": "formalization",
- "formalise": "formalize",
- "formalised": "formalized",
- "formalises": "formalizes",
- "formalising": "formalizing",
- "fossilisation": "fossilization",
- "fossilise": "fossilize",
- "fossilised": "fossilized",
- "fossilises": "fossilizes",
- "fossilising": "fossilizing",
- "fraternisation": "fraternization",
- "fraternise": "fraternize",
- "fraternised": "fraternized",
- "fraternises": "fraternizes",
- "fraternising": "fraternizing",
- "fulfil": "fulfill",
- "fulfilment": "fulfillment",
- "fulfils": "fulfills",
- "funnelled": "funneled",
- "funnelling": "funneling",
- "galvanise": "galvanize",
- "galvanised": "galvanized",
- "galvanises": "galvanizes",
- "galvanising": "galvanizing",
- "gambolled": "gamboled",
- "gambolling": "gamboling",
- "gaol": "jail",
- "gaolbird": "jailbird",
- "gaolbirds": "jailbirds",
- "gaolbreak": "jailbreak",
- "gaolbreaks": "jailbreaks",
- "gaoled": "jailed",
- "gaoler": "jailer",
- "gaolers": "jailers",
- "gaoling": "jailing",
- "gaols": "jails",
- "gasses": "gases",
- "gage": "gauge",
- "gaged": "gauged",
- "gages": "gauges",
- "gaging": "gauging",
- "generalisation": "generalization",
- "generalisations": "generalizations",
- "generalise": "generalize",
- "generalised": "generalized",
- "generalises": "generalizes",
- "generalising": "generalizing",
- "ghettoise": "ghettoize",
- "ghettoised": "ghettoized",
- "ghettoises": "ghettoizes",
- "ghettoising": "ghettoizing",
- "gipsies": "gypsies",
- "glamorise": "glamorize",
- "glamorised": "glamorized",
- "glamorises": "glamorizes",
- "glamorising": "glamorizing",
- "glamor": "glamour",
- "globalisation": "globalization",
- "globalise": "globalize",
- "globalised": "globalized",
- "globalises": "globalizes",
- "globalising": "globalizing",
- "glueing": "gluing",
- "goitre": "goiter",
- "goitres": "goiters",
- "gonorrhoea": "gonorrhea",
- "gramme": "gram",
- "grammes": "grams",
- "gravelled": "graveled",
- "grey": "gray",
- "greyed": "grayed",
- "greying": "graying",
- "greyish": "grayish",
- "greyness": "grayness",
- "greys": "grays",
- "grovelled": "groveled",
- "grovelling": "groveling",
- "groyne": "groin",
- "groynes": "groins",
- "gruelling": "grueling",
- "gruellingly": "gruelingly",
- "gryphon": "griffin",
- "gryphons": "griffins",
- "gynaecological": "gynecological",
- "gynaecologist": "gynecologist",
- "gynaecologists": "gynecologists",
- "gynaecology": "gynecology",
- "haematological": "hematological",
- "haematologist": "hematologist",
- "haematologists": "hematologists",
- "haematology": "hematology",
- "haemoglobin": "hemoglobin",
- "haemophilia": "hemophilia",
- "haemophiliac": "hemophiliac",
- "haemophiliacs": "hemophiliacs",
- "haemorrhage": "hemorrhage",
- "haemorrhaged": "hemorrhaged",
- "haemorrhages": "hemorrhages",
- "haemorrhaging": "hemorrhaging",
- "haemorrhoids": "hemorrhoids",
- "harbour": "harbor",
- "harboured": "harbored",
- "harbouring": "harboring",
- "harbours": "harbors",
- "harmonisation": "harmonization",
- "harmonise": "harmonize",
- "harmonised": "harmonized",
- "harmonises": "harmonizes",
- "harmonising": "harmonizing",
- "homoeopath": "homeopath",
- "homoeopathic": "homeopathic",
- "homoeopaths": "homeopaths",
- "homoeopathy": "homeopathy",
- "homogenise": "homogenize",
- "homogenised": "homogenized",
- "homogenises": "homogenizes",
- "homogenising": "homogenizing",
- "honour": "honor",
- "honourable": "honorable",
- "honourably": "honorably",
- "honoured": "honored",
- "honouring": "honoring",
- "honours": "honors",
- "hospitalisation": "hospitalization",
- "hospitalise": "hospitalize",
- "hospitalised": "hospitalized",
- "hospitalises": "hospitalizes",
- "hospitalising": "hospitalizing",
- "humanise": "humanize",
- "humanised": "humanized",
- "humanises": "humanizes",
- "humanising": "humanizing",
- "humour": "humor",
- "humoured": "humored",
- "humouring": "humoring",
- "humourless": "humorless",
- "humours": "humors",
- "hybridise": "hybridize",
- "hybridised": "hybridized",
- "hybridises": "hybridizes",
- "hybridising": "hybridizing",
- "hypnotise": "hypnotize",
- "hypnotised": "hypnotized",
- "hypnotises": "hypnotizes",
- "hypnotising": "hypnotizing",
- "hypothesise": "hypothesize",
- "hypothesised": "hypothesized",
- "hypothesises": "hypothesizes",
- "hypothesising": "hypothesizing",
- "idealisation": "idealization",
- "idealise": "idealize",
- "idealised": "idealized",
- "idealises": "idealizes",
- "idealising": "idealizing",
- "idolise": "idolize",
- "idolised": "idolized",
- "idolises": "idolizes",
- "idolising": "idolizing",
- "immobilisation": "immobilization",
- "immobilise": "immobilize",
- "immobilised": "immobilized",
- "immobiliser": "immobilizer",
- "immobilisers": "immobilizers",
- "immobilises": "immobilizes",
- "immobilising": "immobilizing",
- "immortalise": "immortalize",
- "immortalised": "immortalized",
- "immortalises": "immortalizes",
- "immortalising": "immortalizing",
- "immunisation": "immunization",
- "immunise": "immunize",
- "immunised": "immunized",
- "immunises": "immunizes",
- "immunising": "immunizing",
- "impanelled": "impaneled",
- "impanelling": "impaneling",
- "imperilled": "imperiled",
- "imperilling": "imperiling",
- "individualise": "individualize",
- "individualised": "individualized",
- "individualises": "individualizes",
- "individualising": "individualizing",
- "industrialise": "industrialize",
- "industrialised": "industrialized",
- "industrialises": "industrializes",
- "industrialising": "industrializing",
- "inflexion": "inflection",
- "inflexions": "inflections",
- "initialise": "initialize",
- "initialised": "initialized",
- "initialises": "initializes",
- "initialising": "initializing",
- "initialled": "initialed",
- "initialling": "initialing",
- "instal": "install",
- "instalment": "installment",
- "instalments": "installments",
- "instals": "installs",
- "instil": "instill",
- "instils": "instills",
- "institutionalisation": "institutionalization",
- "institutionalise": "institutionalize",
- "institutionalised": "institutionalized",
- "institutionalises": "institutionalizes",
- "institutionalising": "institutionalizing",
- "intellectualise": "intellectualize",
- "intellectualised": "intellectualized",
- "intellectualises": "intellectualizes",
- "intellectualising": "intellectualizing",
- "internalisation": "internalization",
- "internalise": "internalize",
- "internalised": "internalized",
- "internalises": "internalizes",
- "internalising": "internalizing",
- "internationalisation": "internationalization",
- "internationalise": "internationalize",
- "internationalised": "internationalized",
- "internationalises": "internationalizes",
- "internationalising": "internationalizing",
- "ionisation": "ionization",
- "ionise": "ionize",
- "ionised": "ionized",
- "ioniser": "ionizer",
- "ionisers": "ionizers",
- "ionises": "ionizes",
- "ionising": "ionizing",
- "italicise": "italicize",
- "italicised": "italicized",
- "italicises": "italicizes",
- "italicising": "italicizing",
- "itemise": "itemize",
- "itemised": "itemized",
- "itemises": "itemizes",
- "itemising": "itemizing",
- "jeopardise": "jeopardize",
- "jeopardised": "jeopardized",
- "jeopardises": "jeopardizes",
- "jeopardising": "jeopardizing",
- "jewelled": "jeweled",
- "jeweller": "jeweler",
- "jewellers": "jewelers",
- "jewellery": "jewelry",
- "judgement": "judgment",
- "kilogramme": "kilogram",
- "kilogrammes": "kilograms",
- "kilometre": "kilometer",
- "kilometres": "kilometers",
- "labelled": "labeled",
- "labelling": "labeling",
- "labour": "labor",
- "laboured": "labored",
- "labourer": "laborer",
- "labourers": "laborers",
- "labouring": "laboring",
- "labours": "labors",
- "lacklustre": "lackluster",
- "legalisation": "legalization",
- "legalise": "legalize",
- "legalised": "legalized",
- "legalises": "legalizes",
- "legalising": "legalizing",
- "legitimise": "legitimize",
- "legitimised": "legitimized",
- "legitimises": "legitimizes",
- "legitimising": "legitimizing",
- "leukaemia": "leukemia",
- "levelled": "leveled",
- "leveller": "leveler",
- "levellers": "levelers",
- "levelling": "leveling",
- "libelled": "libeled",
- "libelling": "libeling",
- "libellous": "libelous",
- "liberalisation": "liberalization",
- "liberalise": "liberalize",
- "liberalised": "liberalized",
- "liberalises": "liberalizes",
- "liberalising": "liberalizing",
- "licence": "license",
- "licenced": "licensed",
- "licences": "licenses",
- "licencing": "licensing",
- "likeable": "likable",
- "lionisation": "lionization",
- "lionise": "lionize",
- "lionised": "lionized",
- "lionises": "lionizes",
- "lionising": "lionizing",
- "liquidise": "liquidize",
- "liquidised": "liquidized",
- "liquidiser": "liquidizer",
- "liquidisers": "liquidizers",
- "liquidises": "liquidizes",
- "liquidising": "liquidizing",
- "litre": "liter",
- "litres": "liters",
- "localise": "localize",
- "localised": "localized",
- "localises": "localizes",
- "localising": "localizing",
- "louvre": "louver",
- "louvred": "louvered",
- "louvres": "louvers",
- "lustre": "luster",
- "magnetise": "magnetize",
- "magnetised": "magnetized",
- "magnetises": "magnetizes",
- "magnetising": "magnetizing",
- "manoeuvrability": "maneuverability",
- "manoeuvrable": "maneuverable",
- "manoeuvre": "maneuver",
- "manoeuvred": "maneuvered",
- "manoeuvres": "maneuvers",
- "manoeuvring": "maneuvering",
- "manoeuvrings": "maneuverings",
- "marginalisation": "marginalization",
- "marginalise": "marginalize",
- "marginalised": "marginalized",
- "marginalises": "marginalizes",
- "marginalising": "marginalizing",
- "marshalled": "marshaled",
- "marshalling": "marshaling",
- "marvelled": "marveled",
- "marvelling": "marveling",
- "marvellous": "marvelous",
- "marvellously": "marvelously",
- "materialisation": "materialization",
- "materialise": "materialize",
- "materialised": "materialized",
- "materialises": "materializes",
- "materialising": "materializing",
- "maximisation": "maximization",
- "maximise": "maximize",
- "maximised": "maximized",
- "maximises": "maximizes",
- "maximising": "maximizing",
- "meagre": "meager",
- "mechanisation": "mechanization",
- "mechanise": "mechanize",
- "mechanised": "mechanized",
- "mechanises": "mechanizes",
- "mechanising": "mechanizing",
- "mediaeval": "medieval",
- "memorialise": "memorialize",
- "memorialised": "memorialized",
- "memorialises": "memorializes",
- "memorialising": "memorializing",
- "memorise": "memorize",
- "memorised": "memorized",
- "memorises": "memorizes",
- "memorising": "memorizing",
- "mesmerise": "mesmerize",
- "mesmerised": "mesmerized",
- "mesmerises": "mesmerizes",
- "mesmerising": "mesmerizing",
- "metabolise": "metabolize",
- "metabolised": "metabolized",
- "metabolises": "metabolizes",
- "metabolising": "metabolizing",
- "metre": "meter",
- "metres": "meters",
- "micrometre": "micrometer",
- "micrometres": "micrometers",
- "militarise": "militarize",
- "militarised": "militarized",
- "militarises": "militarizes",
- "militarising": "militarizing",
- "milligramme": "milligram",
- "milligrammes": "milligrams",
- "millilitre": "milliliter",
- "millilitres": "milliliters",
- "millimetre": "millimeter",
- "millimetres": "millimeters",
- "miniaturisation": "miniaturization",
- "miniaturise": "miniaturize",
- "miniaturised": "miniaturized",
- "miniaturises": "miniaturizes",
- "miniaturising": "miniaturizing",
- "minibusses": "minibuses",
- "minimise": "minimize",
- "minimised": "minimized",
- "minimises": "minimizes",
- "minimising": "minimizing",
- "misbehaviour": "misbehavior",
- "misdemeanour": "misdemeanor",
- "misdemeanours": "misdemeanors",
- "misspelt": "misspelled",
- "mitre": "miter",
- "mitres": "miters",
- "mobilisation": "mobilization",
- "mobilise": "mobilize",
- "mobilised": "mobilized",
- "mobilises": "mobilizes",
- "mobilising": "mobilizing",
- "modelled": "modeled",
- "modeller": "modeler",
- "modellers": "modelers",
- "modelling": "modeling",
- "modernise": "modernize",
- "modernised": "modernized",
- "modernises": "modernizes",
- "modernising": "modernizing",
- "moisturise": "moisturize",
- "moisturised": "moisturized",
- "moisturiser": "moisturizer",
- "moisturisers": "moisturizers",
- "moisturises": "moisturizes",
- "moisturising": "moisturizing",
- "monologue": "monolog",
- "monologues": "monologs",
- "monopolisation": "monopolization",
- "monopolise": "monopolize",
- "monopolised": "monopolized",
- "monopolises": "monopolizes",
- "monopolising": "monopolizing",
- "moralise": "moralize",
- "moralised": "moralized",
- "moralises": "moralizes",
- "moralising": "moralizing",
- "motorised": "motorized",
- "mould": "mold",
- "moulded": "molded",
- "moulder": "molder",
- "mouldered": "moldered",
- "mouldering": "moldering",
- "moulders": "molders",
- "mouldier": "moldier",
- "mouldiest": "moldiest",
- "moulding": "molding",
- "mouldings": "moldings",
- "moulds": "molds",
- "mouldy": "moldy",
- "moult": "molt",
- "moulted": "molted",
- "moulting": "molting",
- "moults": "molts",
- "moustache": "mustache",
- "moustached": "mustached",
- "moustaches": "mustaches",
- "moustachioed": "mustachioed",
- "multicoloured": "multicolored",
- "nationalisation": "nationalization",
- "nationalisations": "nationalizations",
- "nationalise": "nationalize",
- "nationalised": "nationalized",
- "nationalises": "nationalizes",
- "nationalising": "nationalizing",
- "naturalisation": "naturalization",
- "naturalise": "naturalize",
- "naturalised": "naturalized",
- "naturalises": "naturalizes",
- "naturalising": "naturalizing",
- "neighbour": "neighbor",
- "neighbourhood": "neighborhood",
- "neighbourhoods": "neighborhoods",
- "neighbouring": "neighboring",
- "neighbourliness": "neighborliness",
- "neighbourly": "neighborly",
- "neighbours": "neighbors",
- "neutralisation": "neutralization",
- "neutralise": "neutralize",
- "neutralised": "neutralized",
- "neutralises": "neutralizes",
- "neutralising": "neutralizing",
- "normalisation": "normalization",
- "normalise": "normalize",
- "normalised": "normalized",
- "normalises": "normalizes",
- "normalising": "normalizing",
- "odour": "odor",
- "odourless": "odorless",
- "odours": "odors",
- "oesophagus": "esophagus",
- "oesophaguses": "esophaguses",
- "oestrogen": "estrogen",
- "offence": "offense",
- "offences": "offenses",
- "omelette": "omelet",
- "omelettes": "omelets",
- "optimise": "optimize",
- "optimised": "optimized",
- "optimises": "optimizes",
- "optimising": "optimizing",
- "organisation": "organization",
- "organisational": "organizational",
- "organisations": "organizations",
- "organise": "organize",
- "organised": "organized",
- "organiser": "organizer",
- "organisers": "organizers",
- "organises": "organizes",
- "organising": "organizing",
- "orthopaedic": "orthopedic",
- "orthopaedics": "orthopedics",
- "ostracise": "ostracize",
- "ostracised": "ostracized",
- "ostracises": "ostracizes",
- "ostracising": "ostracizing",
- "outmanoeuvre": "outmaneuver",
- "outmanoeuvred": "outmaneuvered",
- "outmanoeuvres": "outmaneuvers",
- "outmanoeuvring": "outmaneuvering",
- "overemphasise": "overemphasize",
- "overemphasised": "overemphasized",
- "overemphasises": "overemphasizes",
- "overemphasising": "overemphasizing",
- "oxidisation": "oxidization",
- "oxidise": "oxidize",
- "oxidised": "oxidized",
- "oxidises": "oxidizes",
- "oxidising": "oxidizing",
- "paederast": "pederast",
- "paederasts": "pederasts",
- "paediatric": "pediatric",
- "paediatrician": "pediatrician",
- "paediatricians": "pediatricians",
- "paediatrics": "pediatrics",
- "paedophile": "pedophile",
- "paedophiles": "pedophiles",
- "paedophilia": "pedophilia",
- "palaeolithic": "paleolithic",
- "palaeontologist": "paleontologist",
- "palaeontologists": "paleontologists",
- "palaeontology": "paleontology",
- "panelled": "paneled",
- "panelling": "paneling",
- "panellist": "panelist",
- "panellists": "panelists",
- "paralyse": "paralyze",
- "paralysed": "paralyzed",
- "paralyses": "paralyzes",
- "paralysing": "paralyzing",
- "parcelled": "parceled",
- "parcelling": "parceling",
- "parlour": "parlor",
- "parlours": "parlors",
- "particularise": "particularize",
- "particularised": "particularized",
- "particularises": "particularizes",
- "particularising": "particularizing",
- "passivisation": "passivization",
- "passivise": "passivize",
- "passivised": "passivized",
- "passivises": "passivizes",
- "passivising": "passivizing",
- "pasteurisation": "pasteurization",
- "pasteurise": "pasteurize",
- "pasteurised": "pasteurized",
- "pasteurises": "pasteurizes",
- "pasteurising": "pasteurizing",
- "patronise": "patronize",
- "patronised": "patronized",
- "patronises": "patronizes",
- "patronising": "patronizing",
- "patronisingly": "patronizingly",
- "pedalled": "pedaled",
- "pedalling": "pedaling",
- "pedestrianisation": "pedestrianization",
- "pedestrianise": "pedestrianize",
- "pedestrianised": "pedestrianized",
- "pedestrianises": "pedestrianizes",
- "pedestrianising": "pedestrianizing",
- "penalise": "penalize",
- "penalised": "penalized",
- "penalises": "penalizes",
- "penalising": "penalizing",
- "pencilled": "penciled",
- "pencilling": "penciling",
- "personalise": "personalize",
- "personalised": "personalized",
- "personalises": "personalizes",
- "personalising": "personalizing",
- "pharmacopoeia": "pharmacopeia",
- "pharmacopoeias": "pharmacopeias",
- "philosophise": "philosophize",
- "philosophised": "philosophized",
- "philosophises": "philosophizes",
- "philosophising": "philosophizing",
- "philtre": "filter",
- "philtres": "filters",
- "phoney": "phony",
- "plagiarise": "plagiarize",
- "plagiarised": "plagiarized",
- "plagiarises": "plagiarizes",
- "plagiarising": "plagiarizing",
- "plough": "plow",
- "ploughed": "plowed",
- "ploughing": "plowing",
- "ploughman": "plowman",
- "ploughmen": "plowmen",
- "ploughs": "plows",
- "ploughshare": "plowshare",
- "ploughshares": "plowshares",
- "polarisation": "polarization",
- "polarise": "polarize",
- "polarised": "polarized",
- "polarises": "polarizes",
- "polarising": "polarizing",
- "politicisation": "politicization",
- "politicise": "politicize",
- "politicised": "politicized",
- "politicises": "politicizes",
- "politicising": "politicizing",
- "popularisation": "popularization",
- "popularise": "popularize",
- "popularised": "popularized",
- "popularises": "popularizes",
- "popularising": "popularizing",
- "pouffe": "pouf",
- "pouffes": "poufs",
- "practise": "practice",
- "practised": "practiced",
- "practises": "practices",
- "practising": "practicing",
- "praesidium": "presidium",
- "praesidiums": "presidiums",
- "pressurisation": "pressurization",
- "pressurise": "pressurize",
- "pressurised": "pressurized",
- "pressurises": "pressurizes",
- "pressurising": "pressurizing",
- "pretence": "pretense",
- "pretences": "pretenses",
- "primaeval": "primeval",
- "prioritisation": "prioritization",
- "prioritise": "prioritize",
- "prioritised": "prioritized",
- "prioritises": "prioritizes",
- "prioritising": "prioritizing",
- "privatisation": "privatization",
- "privatisations": "privatizations",
- "privatise": "privatize",
- "privatised": "privatized",
- "privatises": "privatizes",
- "privatising": "privatizing",
- "professionalisation": "professionalization",
- "professionalise": "professionalize",
- "professionalised": "professionalized",
- "professionalises": "professionalizes",
- "professionalising": "professionalizing",
- "programme": "program",
- "programmes": "programs",
- "prologue": "prolog",
- "prologues": "prologs",
- "propagandise": "propagandize",
- "propagandised": "propagandized",
- "propagandises": "propagandizes",
- "propagandising": "propagandizing",
- "proselytise": "proselytize",
- "proselytised": "proselytized",
- "proselytiser": "proselytizer",
- "proselytisers": "proselytizers",
- "proselytises": "proselytizes",
- "proselytising": "proselytizing",
- "psychoanalyse": "psychoanalyze",
- "psychoanalysed": "psychoanalyzed",
- "psychoanalyses": "psychoanalyzes",
- "psychoanalysing": "psychoanalyzing",
- "publicise": "publicize",
- "publicised": "publicized",
- "publicises": "publicizes",
- "publicising": "publicizing",
- "pulverisation": "pulverization",
- "pulverise": "pulverize",
- "pulverised": "pulverized",
- "pulverises": "pulverizes",
- "pulverising": "pulverizing",
- "pummelled": "pummel",
- "pummelling": "pummeled",
- "pyjama": "pajama",
- "pyjamas": "pajamas",
- "pzazz": "pizzazz",
- "quarrelled": "quarreled",
- "quarrelling": "quarreling",
- "radicalise": "radicalize",
- "radicalised": "radicalized",
- "radicalises": "radicalizes",
- "radicalising": "radicalizing",
- "rancour": "rancor",
- "randomise": "randomize",
- "randomised": "randomized",
- "randomises": "randomizes",
- "randomising": "randomizing",
- "rationalisation": "rationalization",
- "rationalisations": "rationalizations",
- "rationalise": "rationalize",
- "rationalised": "rationalized",
- "rationalises": "rationalizes",
- "rationalising": "rationalizing",
- "ravelled": "raveled",
- "ravelling": "raveling",
- "realisable": "realizable",
- "realisation": "realization",
- "realisations": "realizations",
- "realise": "realize",
- "realised": "realized",
- "realises": "realizes",
- "realising": "realizing",
- "recognisable": "recognizable",
- "recognisably": "recognizably",
- "recognisance": "recognizance",
- "recognise": "recognize",
- "recognised": "recognized",
- "recognises": "recognizes",
- "recognising": "recognizing",
- "reconnoitre": "reconnoiter",
- "reconnoitred": "reconnoitered",
- "reconnoitres": "reconnoiters",
- "reconnoitring": "reconnoitering",
- "refuelled": "refueled",
- "refuelling": "refueling",
- "regularisation": "regularization",
- "regularise": "regularize",
- "regularised": "regularized",
- "regularises": "regularizes",
- "regularising": "regularizing",
- "remodelled": "remodeled",
- "remodelling": "remodeling",
- "remould": "remold",
- "remoulded": "remolded",
- "remoulding": "remolding",
- "remoulds": "remolds",
- "reorganisation": "reorganization",
- "reorganisations": "reorganizations",
- "reorganise": "reorganize",
- "reorganised": "reorganized",
- "reorganises": "reorganizes",
- "reorganising": "reorganizing",
- "revelled": "reveled",
- "reveller": "reveler",
- "revellers": "revelers",
- "revelling": "reveling",
- "revitalise": "revitalize",
- "revitalised": "revitalized",
- "revitalises": "revitalizes",
- "revitalising": "revitalizing",
- "revolutionise": "revolutionize",
- "revolutionised": "revolutionized",
- "revolutionises": "revolutionizes",
- "revolutionising": "revolutionizing",
- "rhapsodise": "rhapsodize",
- "rhapsodised": "rhapsodized",
- "rhapsodises": "rhapsodizes",
- "rhapsodising": "rhapsodizing",
- "rigour": "rigor",
- "rigours": "rigors",
- "ritualised": "ritualized",
- "rivalled": "rivaled",
- "rivalling": "rivaling",
- "romanticise": "romanticize",
- "romanticised": "romanticized",
- "romanticises": "romanticizes",
- "romanticising": "romanticizing",
- "rumour": "rumor",
- "rumoured": "rumored",
- "rumours": "rumors",
- "sabre": "saber",
- "sabres": "sabers",
- "saltpetre": "saltpeter",
- "sanitise": "sanitize",
- "sanitised": "sanitized",
- "sanitises": "sanitizes",
- "sanitising": "sanitizing",
- "satirise": "satirize",
- "satirised": "satirized",
- "satirises": "satirizes",
- "satirising": "satirizing",
- "saviour": "savior",
- "saviours": "saviors",
- "savour": "savor",
- "savoured": "savored",
- "savouries": "savories",
- "savouring": "savoring",
- "savours": "savors",
- "savoury": "savory",
- "scandalise": "scandalize",
- "scandalised": "scandalized",
- "scandalises": "scandalizes",
- "scandalising": "scandalizing",
- "sceptic": "skeptic",
- "sceptical": "skeptical",
- "sceptically": "skeptically",
- "scepticism": "skepticism",
- "sceptics": "skeptics",
- "sceptre": "scepter",
- "sceptres": "scepters",
- "scrutinise": "scrutinize",
- "scrutinised": "scrutinized",
- "scrutinises": "scrutinizes",
- "scrutinising": "scrutinizing",
- "secularisation": "secularization",
- "secularise": "secularize",
- "secularised": "secularized",
- "secularises": "secularizes",
- "secularising": "secularizing",
- "sensationalise": "sensationalize",
- "sensationalised": "sensationalized",
- "sensationalises": "sensationalizes",
- "sensationalising": "sensationalizing",
- "sensitise": "sensitize",
- "sensitised": "sensitized",
- "sensitises": "sensitizes",
- "sensitising": "sensitizing",
- "sentimentalise": "sentimentalize",
- "sentimentalised": "sentimentalized",
- "sentimentalises": "sentimentalizes",
- "sentimentalising": "sentimentalizing",
- "sepulchre": "sepulcher",
- "sepulchres": "sepulchers",
- "serialisation": "serialization",
- "serialisations": "serializations",
- "serialise": "serialize",
- "serialised": "serialized",
- "serialises": "serializes",
- "serialising": "serializing",
- "sermonise": "sermonize",
- "sermonised": "sermonized",
- "sermonises": "sermonizes",
- "sermonising": "sermonizing",
- "sheikh": "sheik",
- "shovelled": "shoveled",
- "shovelling": "shoveling",
- "shrivelled": "shriveled",
- "shrivelling": "shriveling",
- "signalise": "signalize",
- "signalised": "signalized",
- "signalises": "signalizes",
- "signalising": "signalizing",
- "signalled": "signaled",
- "signalling": "signaling",
- "smoulder": "smolder",
- "smouldered": "smoldered",
- "smouldering": "smoldering",
- "smoulders": "smolders",
- "snivelled": "sniveled",
- "snivelling": "sniveling",
- "snorkelled": "snorkeled",
- "snorkelling": "snorkeling",
- "snowplough": "snowplow",
- "snowploughs": "snowplow",
- "socialisation": "socialization",
- "socialise": "socialize",
- "socialised": "socialized",
- "socialises": "socializes",
- "socialising": "socializing",
- "sodomise": "sodomize",
- "sodomised": "sodomized",
- "sodomises": "sodomizes",
- "sodomising": "sodomizing",
- "solemnise": "solemnize",
- "solemnised": "solemnized",
- "solemnises": "solemnizes",
- "solemnising": "solemnizing",
- "sombre": "somber",
- "specialisation": "specialization",
- "specialisations": "specializations",
- "specialise": "specialize",
- "specialised": "specialized",
- "specialises": "specializes",
- "specialising": "specializing",
- "spectre": "specter",
- "spectres": "specters",
- "spiralled": "spiraled",
- "spiralling": "spiraling",
- "splendour": "splendor",
- "splendours": "splendors",
- "squirrelled": "squirreled",
- "squirrelling": "squirreling",
- "stabilisation": "stabilization",
- "stabilise": "stabilize",
- "stabilised": "stabilized",
- "stabiliser": "stabilizer",
- "stabilisers": "stabilizers",
- "stabilises": "stabilizes",
- "stabilising": "stabilizing",
- "standardisation": "standardization",
- "standardise": "standardize",
- "standardised": "standardized",
- "standardises": "standardizes",
- "standardising": "standardizing",
- "stencilled": "stenciled",
- "stencilling": "stenciling",
- "sterilisation": "sterilization",
- "sterilisations": "sterilizations",
- "sterilise": "sterilize",
- "sterilised": "sterilized",
- "steriliser": "sterilizer",
- "sterilisers": "sterilizers",
- "sterilises": "sterilizes",
- "sterilising": "sterilizing",
- "stigmatisation": "stigmatization",
- "stigmatise": "stigmatize",
- "stigmatised": "stigmatized",
- "stigmatises": "stigmatizes",
- "stigmatising": "stigmatizing",
- "storey": "story",
- "storeys": "stories",
- "subsidisation": "subsidization",
- "subsidise": "subsidize",
- "subsidised": "subsidized",
- "subsidiser": "subsidizer",
- "subsidisers": "subsidizers",
- "subsidises": "subsidizes",
- "subsidising": "subsidizing",
- "succour": "succor",
- "succoured": "succored",
- "succouring": "succoring",
- "succours": "succors",
- "sulphate": "sulfate",
- "sulphates": "sulfates",
- "sulphide": "sulfide",
- "sulphides": "sulfides",
- "sulphur": "sulfur",
- "sulphurous": "sulfurous",
- "summarise": "summarize",
- "summarised": "summarized",
- "summarises": "summarizes",
- "summarising": "summarizing",
- "swivelled": "swiveled",
- "swivelling": "swiveling",
- "symbolise": "symbolize",
- "symbolised": "symbolized",
- "symbolises": "symbolizes",
- "symbolising": "symbolizing",
- "sympathise": "sympathize",
- "sympathised": "sympathized",
- "sympathiser": "sympathizer",
- "sympathisers": "sympathizers",
- "sympathises": "sympathizes",
- "sympathising": "sympathizing",
- "synchronisation": "synchronization",
- "synchronise": "synchronize",
- "synchronised": "synchronized",
- "synchronises": "synchronizes",
- "synchronising": "synchronizing",
- "synthesise": "synthesize",
- "synthesised": "synthesized",
- "synthesiser": "synthesizer",
- "synthesisers": "synthesizers",
- "synthesises": "synthesizes",
- "synthesising": "synthesizing",
- "syphon": "siphon",
- "syphoned": "siphoned",
- "syphoning": "siphoning",
- "syphons": "siphons",
- "systematisation": "systematization",
- "systematise": "systematize",
- "systematised": "systematized",
- "systematises": "systematizes",
- "systematising": "systematizing",
- "tantalise": "tantalize",
- "tantalised": "tantalized",
- "tantalises": "tantalizes",
- "tantalising": "tantalizing",
- "tantalisingly": "tantalizingly",
- "tasselled": "tasseled",
- "technicolour": "technicolor",
- "temporise": "temporize",
- "temporised": "temporized",
- "temporises": "temporizes",
- "temporising": "temporizing",
- "tenderise": "tenderize",
- "tenderised": "tenderized",
- "tenderises": "tenderizes",
- "tenderising": "tenderizing",
- "terrorise": "terrorize",
- "terrorised": "terrorized",
- "terrorises": "terrorizes",
- "terrorising": "terrorizing",
- "theatre": "theater",
- "theatregoer": "theatergoer",
- "theatregoers": "theatergoers",
- "theatres": "theaters",
- "theorise": "theorize",
- "theorised": "theorized",
- "theorises": "theorizes",
- "theorising": "theorizing",
- "tonne": "ton",
- "tonnes": "tons",
- "towelled": "toweled",
- "towelling": "toweling",
- "toxaemia": "toxemia",
- "tranquillise": "tranquilize",
- "tranquillised": "tranquilized",
- "tranquilliser": "tranquilizer",
- "tranquillisers": "tranquilizers",
- "tranquillises": "tranquilizes",
- "tranquillising": "tranquilizing",
- "tranquillity": "tranquility",
- "tranquillize": "tranquilize",
- "tranquillized": "tranquilized",
- "tranquillizer": "tranquilizer",
- "tranquillizers": "tranquilizers",
- "tranquillizes": "tranquilizes",
- "tranquillizing": "tranquilizing",
- "tranquilly": "tranquility",
- "transistorised": "transistorized",
- "traumatise": "traumatize",
- "traumatised": "traumatized",
- "traumatises": "traumatizes",
- "traumatising": "traumatizing",
- "travelled": "traveled",
- "traveller": "traveler",
- "travellers": "travelers",
- "travelling": "traveling",
- "travelog": "travelogue",
- "travelogs": "travelogues",
- "trialled": "trialed",
- "trialling": "trialing",
- "tricolour": "tricolor",
- "tricolours": "tricolors",
- "trivialise": "trivialize",
- "trivialised": "trivialized",
- "trivialises": "trivializes",
- "trivialising": "trivializing",
- "tumour": "tumor",
- "tumours": "tumors",
- "tunnelled": "tunneled",
- "tunnelling": "tunneling",
- "tyrannise": "tyrannize",
- "tyrannised": "tyrannized",
- "tyrannises": "tyrannizes",
- "tyrannising": "tyrannizing",
- "tyre": "tire",
- "tyres": "tires",
- "unauthorised": "unauthorized",
- "uncivilised": "uncivilized",
- "underutilised": "underutilized",
- "unequalled": "unequaled",
- "unfavourable": "unfavorable",
- "unfavourably": "unfavorably",
- "unionisation": "unionization",
- "unionise": "unionize",
- "unionised": "unionized",
- "unionises": "unionizes",
- "unionising": "unionizing",
- "unorganised": "unorganized",
- "unravelled": "unraveled",
- "unravelling": "unraveling",
- "unrecognisable": "unrecognizable",
- "unrecognised": "unrecognized",
- "unrivalled": "unrivaled",
- "unsavoury": "unsavory",
- "untrammelled": "untrammeled",
- "urbanisation": "urbanization",
- "urbanise": "urbanize",
- "urbanised": "urbanized",
- "urbanises": "urbanizes",
- "urbanising": "urbanizing",
- "utilisable": "utilizable",
- "utilisation": "utilization",
- "utilise": "utilize",
- "utilised": "utilized",
- "utilises": "utilizes",
- "utilising": "utilizing",
- "valour": "valor",
- "vandalise": "vandalize",
- "vandalised": "vandalized",
- "vandalises": "vandalizes",
- "vandalising": "vandalizing",
- "vaporisation": "vaporization",
- "vaporise": "vaporize",
- "vaporised": "vaporized",
- "vaporises": "vaporizes",
- "vaporising": "vaporizing",
- "vapour": "vapor",
- "vapours": "vapors",
- "verbalise": "verbalize",
- "verbalised": "verbalized",
- "verbalises": "verbalizes",
- "verbalising": "verbalizing",
- "victimisation": "victimization",
- "victimise": "victimize",
- "victimised": "victimized",
- "victimises": "victimizes",
- "victimising": "victimizing",
- "videodisc": "videodisk",
- "videodiscs": "videodisks",
- "vigour": "vigor",
- "visualisation": "visualization",
- "visualisations": "visualizations",
- "visualise": "visualize",
- "visualised": "visualized",
- "visualises": "visualizes",
- "visualising": "visualizing",
- "vocalisation": "vocalization",
- "vocalisations": "vocalizations",
- "vocalise": "vocalize",
- "vocalised": "vocalized",
- "vocalises": "vocalizes",
- "vocalising": "vocalizing",
- "vulcanised": "vulcanized",
- "vulgarisation": "vulgarization",
- "vulgarise": "vulgarize",
- "vulgarised": "vulgarized",
- "vulgarises": "vulgarizes",
- "vulgarising": "vulgarizing",
- "waggon": "wagon",
- "waggons": "wagons",
- "watercolour": "watercolor",
- "watercolours": "watercolors",
- "weaselled": "weaseled",
- "weaselling": "weaseling",
- "westernisation": "westernization",
- "westernise": "westernize",
- "westernised": "westernized",
- "westernises": "westernizes",
- "westernising": "westernizing",
- "womanise": "womanize",
- "womanised": "womanized",
- "womaniser": "womanizer",
- "womanisers": "womanizers",
- "womanises": "womanizes",
- "womanising": "womanizing",
- "woollen": "woolen",
- "woollens": "woolens",
- "woollies": "woolies",
- "woolly": "wooly",
- "worshipped": "worshiped",
- "worshipping": "worshiping",
- "worshipper": "worshiper",
- "yodelled": "yodeled",
- "yodelling": "yodeling",
- "yoghourt": "yogurt",
- "yoghourts": "yogurts",
- "yoghurt": "yogurt",
- "yoghurts": "yogurts",
- "mhm": "hmm",
- "mmm": "hmm"
-}
\ No newline at end of file
diff --git a/funasr/models/sense_voice/whisper_lib/normalizers/english.py b/funasr/models/sense_voice/whisper_lib/normalizers/english.py
deleted file mode 100644
index cb46530..0000000
--- a/funasr/models/sense_voice/whisper_lib/normalizers/english.py
+++ /dev/null
@@ -1,546 +0,0 @@
-import json
-import os
-import re
-from fractions import Fraction
-from typing import Iterator, List, Match, Optional, Union
-
-from more_itertools import windowed
-
-from .basic import remove_symbols_and_diacritics
-
-
-class EnglishNumberNormalizer:
- """
- Convert any spelled-out numbers into arabic numbers, while handling:
-
- - remove any commas
- - keep the suffixes such as: `1960s`, `274th`, `32nd`, etc.
- - spell out currency symbols after the number. e.g. `$20 million` -> `20000000 dollars`
- - spell out `one` and `ones`
- - interpret successive single-digit numbers as nominal: `one oh one` -> `101`
- """
-
- def __init__(self):
- super().__init__()
-
- self.zeros = {"o", "oh", "zero"}
- self.ones = {
- name: i
- for i, name in enumerate(
- [
- "one",
- "two",
- "three",
- "four",
- "five",
- "six",
- "seven",
- "eight",
- "nine",
- "ten",
- "eleven",
- "twelve",
- "thirteen",
- "fourteen",
- "fifteen",
- "sixteen",
- "seventeen",
- "eighteen",
- "nineteen",
- ],
- start=1,
- )
- }
- self.ones_plural = {
- "sixes" if name == "six" else name + "s": (value, "s")
- for name, value in self.ones.items()
- }
- self.ones_ordinal = {
- "zeroth": (0, "th"),
- "first": (1, "st"),
- "second": (2, "nd"),
- "third": (3, "rd"),
- "fifth": (5, "th"),
- "twelfth": (12, "th"),
- **{
- name + ("h" if name.endswith("t") else "th"): (value, "th")
- for name, value in self.ones.items()
- if value > 3 and value != 5 and value != 12
- },
- }
- self.ones_suffixed = {**self.ones_plural, **self.ones_ordinal}
-
- self.tens = {
- "twenty": 20,
- "thirty": 30,
- "forty": 40,
- "fifty": 50,
- "sixty": 60,
- "seventy": 70,
- "eighty": 80,
- "ninety": 90,
- }
- self.tens_plural = {
- name.replace("y", "ies"): (value, "s") for name, value in self.tens.items()
- }
- self.tens_ordinal = {
- name.replace("y", "ieth"): (value, "th") for name, value in self.tens.items()
- }
- self.tens_suffixed = {**self.tens_plural, **self.tens_ordinal}
-
- self.multipliers = {
- "hundred": 100,
- "thousand": 1_000,
- "million": 1_000_000,
- "billion": 1_000_000_000,
- "trillion": 1_000_000_000_000,
- "quadrillion": 1_000_000_000_000_000,
- "quintillion": 1_000_000_000_000_000_000,
- "sextillion": 1_000_000_000_000_000_000_000,
- "septillion": 1_000_000_000_000_000_000_000_000,
- "octillion": 1_000_000_000_000_000_000_000_000_000,
- "nonillion": 1_000_000_000_000_000_000_000_000_000_000,
- "decillion": 1_000_000_000_000_000_000_000_000_000_000_000,
- }
- self.multipliers_plural = {
- name + "s": (value, "s") for name, value in self.multipliers.items()
- }
- self.multipliers_ordinal = {
- name + "th": (value, "th") for name, value in self.multipliers.items()
- }
- self.multipliers_suffixed = {
- **self.multipliers_plural,
- **self.multipliers_ordinal,
- }
- self.decimals = {*self.ones, *self.tens, *self.zeros}
-
- self.preceding_prefixers = {
- "minus": "-",
- "negative": "-",
- "plus": "+",
- "positive": "+",
- }
- self.following_prefixers = {
- "pound": "拢",
- "pounds": "拢",
- "euro": "鈧�",
- "euros": "鈧�",
- "dollar": "$",
- "dollars": "$",
- "cent": "垄",
- "cents": "垄",
- }
- self.prefixes = set(
- list(self.preceding_prefixers.values()) + list(self.following_prefixers.values())
- )
- self.suffixers = {
- "per": {"cent": "%"},
- "percent": "%",
- }
- self.specials = {"and", "double", "triple", "point"}
-
- self.words = set(
- [
- key
- for mapping in [
- self.zeros,
- self.ones,
- self.ones_suffixed,
- self.tens,
- self.tens_suffixed,
- self.multipliers,
- self.multipliers_suffixed,
- self.preceding_prefixers,
- self.following_prefixers,
- self.suffixers,
- self.specials,
- ]
- for key in mapping
- ]
- )
- self.literal_words = {"one", "ones"}
-
- def process_words(self, words: List[str]) -> Iterator[str]:
- prefix: Optional[str] = None
- value: Optional[Union[str, int]] = None
- skip = False
-
- def to_fraction(s: str):
- try:
- return Fraction(s)
- except ValueError:
- return None
-
- def output(result: Union[str, int]):
- nonlocal prefix, value
- result = str(result)
- if prefix is not None:
- result = prefix + result
- value = None
- prefix = None
- return result
-
- if len(words) == 0:
- return
-
- for prev, current, next in windowed([None] + words + [None], 3):
- if skip:
- skip = False
- continue
-
- next_is_numeric = next is not None and re.match(r"^\d+(\.\d+)?$", next)
- has_prefix = current[0] in self.prefixes
- current_without_prefix = current[1:] if has_prefix else current
- if re.match(r"^\d+(\.\d+)?$", current_without_prefix):
- # arabic numbers (potentially with signs and fractions)
- f = to_fraction(current_without_prefix)
- assert f is not None
- if value is not None:
- if isinstance(value, str) and value.endswith("."):
- # concatenate decimals / ip address components
- value = str(value) + str(current)
- continue
- else:
- yield output(value)
-
- prefix = current[0] if has_prefix else prefix
- if f.denominator == 1:
- value = f.numerator # store integers as int
- else:
- value = current_without_prefix
- elif current not in self.words:
- # non-numeric words
- if value is not None:
- yield output(value)
- yield output(current)
- elif current in self.zeros:
- value = str(value or "") + "0"
- elif current in self.ones:
- ones = self.ones[current]
-
- if value is None:
- value = ones
- elif isinstance(value, str) or prev in self.ones:
- if prev in self.tens and ones < 10: # replace the last zero with the digit
- assert value[-1] == "0"
- value = value[:-1] + str(ones)
- else:
- value = str(value) + str(ones)
- elif ones < 10:
- if value % 10 == 0:
- value += ones
- else:
- value = str(value) + str(ones)
- else: # eleven to nineteen
- if value % 100 == 0:
- value += ones
- else:
- value = str(value) + str(ones)
- elif current in self.ones_suffixed:
- # ordinal or cardinal; yield the number right away
- ones, suffix = self.ones_suffixed[current]
- if value is None:
- yield output(str(ones) + suffix)
- elif isinstance(value, str) or prev in self.ones:
- if prev in self.tens and ones < 10:
- assert value[-1] == "0"
- yield output(value[:-1] + str(ones) + suffix)
- else:
- yield output(str(value) + str(ones) + suffix)
- elif ones < 10:
- if value % 10 == 0:
- yield output(str(value + ones) + suffix)
- else:
- yield output(str(value) + str(ones) + suffix)
- else: # eleven to nineteen
- if value % 100 == 0:
- yield output(str(value + ones) + suffix)
- else:
- yield output(str(value) + str(ones) + suffix)
- value = None
- elif current in self.tens:
- tens = self.tens[current]
- if value is None:
- value = tens
- elif isinstance(value, str):
- value = str(value) + str(tens)
- else:
- if value % 100 == 0:
- value += tens
- else:
- value = str(value) + str(tens)
- elif current in self.tens_suffixed:
- # ordinal or cardinal; yield the number right away
- tens, suffix = self.tens_suffixed[current]
- if value is None:
- yield output(str(tens) + suffix)
- elif isinstance(value, str):
- yield output(str(value) + str(tens) + suffix)
- else:
- if value % 100 == 0:
- yield output(str(value + tens) + suffix)
- else:
- yield output(str(value) + str(tens) + suffix)
- elif current in self.multipliers:
- multiplier = self.multipliers[current]
- if value is None:
- value = multiplier
- elif isinstance(value, str) or value == 0:
- f = to_fraction(value)
- p = f * multiplier if f is not None else None
- if f is not None and p.denominator == 1:
- value = p.numerator
- else:
- yield output(value)
- value = multiplier
- else:
- before = value // 1000 * 1000
- residual = value % 1000
- value = before + residual * multiplier
- elif current in self.multipliers_suffixed:
- multiplier, suffix = self.multipliers_suffixed[current]
- if value is None:
- yield output(str(multiplier) + suffix)
- elif isinstance(value, str):
- f = to_fraction(value)
- p = f * multiplier if f is not None else None
- if f is not None and p.denominator == 1:
- yield output(str(p.numerator) + suffix)
- else:
- yield output(value)
- yield output(str(multiplier) + suffix)
- else: # int
- before = value // 1000 * 1000
- residual = value % 1000
- value = before + residual * multiplier
- yield output(str(value) + suffix)
- value = None
- elif current in self.preceding_prefixers:
- # apply prefix (positive, minus, etc.) if it precedes a number
- if value is not None:
- yield output(value)
-
- if next in self.words or next_is_numeric:
- prefix = self.preceding_prefixers[current]
- else:
- yield output(current)
- elif current in self.following_prefixers:
- # apply prefix (dollars, cents, etc.) only after a number
- if value is not None:
- prefix = self.following_prefixers[current]
- yield output(value)
- else:
- yield output(current)
- elif current in self.suffixers:
- # apply suffix symbols (percent -> '%')
- if value is not None:
- suffix = self.suffixers[current]
- if isinstance(suffix, dict):
- if next in suffix:
- yield output(str(value) + suffix[next])
- skip = True
- else:
- yield output(value)
- yield output(current)
- else:
- yield output(str(value) + suffix)
- else:
- yield output(current)
- elif current in self.specials:
- if next not in self.words and not next_is_numeric:
- # apply special handling only if the next word can be numeric
- if value is not None:
- yield output(value)
- yield output(current)
- elif current == "and":
- # ignore "and" after hundreds, thousands, etc.
- if prev not in self.multipliers:
- if value is not None:
- yield output(value)
- yield output(current)
- elif current == "double" or current == "triple":
- if next in self.ones or next in self.zeros:
- repeats = 2 if current == "double" else 3
- ones = self.ones.get(next, 0)
- value = str(value or "") + str(ones) * repeats
- skip = True
- else:
- if value is not None:
- yield output(value)
- yield output(current)
- elif current == "point":
- if next in self.decimals or next_is_numeric:
- value = str(value or "") + "."
- else:
- # should all have been covered at this point
- raise ValueError(f"Unexpected token: {current}")
- else:
- # all should have been covered at this point
- raise ValueError(f"Unexpected token: {current}")
-
- if value is not None:
- yield output(value)
-
- def preprocess(self, s: str):
- # replace "<number> and a half" with "<number> point five"
- results = []
-
- segments = re.split(r"\band\s+a\s+half\b", s)
- for i, segment in enumerate(segments):
- if len(segment.strip()) == 0:
- continue
- if i == len(segments) - 1:
- results.append(segment)
- else:
- results.append(segment)
- last_word = segment.rsplit(maxsplit=2)[-1]
- if last_word in self.decimals or last_word in self.multipliers:
- results.append("point five")
- else:
- results.append("and a half")
-
- s = " ".join(results)
-
- # put a space at number/letter boundary
- s = re.sub(r"([a-z])([0-9])", r"\1 \2", s)
- s = re.sub(r"([0-9])([a-z])", r"\1 \2", s)
-
- # but remove spaces which could be a suffix
- s = re.sub(r"([0-9])\s+(st|nd|rd|th|s)\b", r"\1\2", s)
-
- return s
-
- def postprocess(self, s: str):
- def combine_cents(m: Match):
- try:
- currency = m.group(1)
- integer = m.group(2)
- cents = int(m.group(3))
- return f"{currency}{integer}.{cents:02d}"
- except ValueError:
- return m.string
-
- def extract_cents(m: Match):
- try:
- return f"垄{int(m.group(1))}"
- except ValueError:
- return m.string
-
- # apply currency postprocessing; "$2 and 垄7" -> "$2.07"
- s = re.sub(r"([鈧�$])([0-9]+) (?:and )?垄([0-9]{1,2})\b", combine_cents, s)
- s = re.sub(r"[鈧�$]0.([0-9]{1,2})\b", extract_cents, s)
-
- # write "one(s)" instead of "1(s)", just for the readability
- s = re.sub(r"\b1(s?)\b", r"one\1", s)
-
- return s
-
- def __call__(self, s: str):
- s = self.preprocess(s)
- s = " ".join(word for word in self.process_words(s.split()) if word is not None)
- s = self.postprocess(s)
-
- return s
-
-
-class EnglishSpellingNormalizer:
- """
- Applies British-American spelling mappings as listed in [1].
-
- [1] https://www.tysto.com/uk-us-spelling-list.html
- """
-
- def __init__(self):
- mapping_path = os.path.join(os.path.dirname(__file__), "english.json")
- self.mapping = json.load(open(mapping_path))
-
- def __call__(self, s: str):
- return " ".join(self.mapping.get(word, word) for word in s.split())
-
-
-class EnglishTextNormalizer:
- def __init__(self):
- self.ignore_patterns = r"\b(hmm|mm|mhm|mmm|uh|um)\b"
- self.replacers = {
- # common contractions
- r"\bwon't\b": "will not",
- r"\bcan't\b": "can not",
- r"\blet's\b": "let us",
- r"\bain't\b": "aint",
- r"\by'all\b": "you all",
- r"\bwanna\b": "want to",
- r"\bgotta\b": "got to",
- r"\bgonna\b": "going to",
- r"\bi'ma\b": "i am going to",
- r"\bimma\b": "i am going to",
- r"\bwoulda\b": "would have",
- r"\bcoulda\b": "could have",
- r"\bshoulda\b": "should have",
- r"\bma'am\b": "madam",
- # contractions in titles/prefixes
- r"\bmr\b": "mister ",
- r"\bmrs\b": "missus ",
- r"\bst\b": "saint ",
- r"\bdr\b": "doctor ",
- r"\bprof\b": "professor ",
- r"\bcapt\b": "captain ",
- r"\bgov\b": "governor ",
- r"\bald\b": "alderman ",
- r"\bgen\b": "general ",
- r"\bsen\b": "senator ",
- r"\brep\b": "representative ",
- r"\bpres\b": "president ",
- r"\brev\b": "reverend ",
- r"\bhon\b": "honorable ",
- r"\basst\b": "assistant ",
- r"\bassoc\b": "associate ",
- r"\blt\b": "lieutenant ",
- r"\bcol\b": "colonel ",
- r"\bjr\b": "junior ",
- r"\bsr\b": "senior ",
- r"\besq\b": "esquire ",
- # prefect tenses, ideally it should be any past participles, but it's harder..
- r"'d been\b": " had been",
- r"'s been\b": " has been",
- r"'d gone\b": " had gone",
- r"'s gone\b": " has gone",
- r"'d done\b": " had done", # "'s done" is ambiguous
- r"'s got\b": " has got",
- # general contractions
- r"n't\b": " not",
- r"'re\b": " are",
- r"'s\b": " is",
- r"'d\b": " would",
- r"'ll\b": " will",
- r"'t\b": " not",
- r"'ve\b": " have",
- r"'m\b": " am",
- }
- self.standardize_numbers = EnglishNumberNormalizer()
- self.standardize_spellings = EnglishSpellingNormalizer()
-
- def __call__(self, s: str):
- s = s.lower()
-
- s = re.sub(r"[<\[][^>\]]*[>\]]", "", s) # remove words between brackets
- s = re.sub(r"\(([^)]+?)\)", "", s) # remove words between parenthesis
- s = re.sub(self.ignore_patterns, "", s)
- s = re.sub(r"\s+'", "'", s) # when there's a space before an apostrophe
-
- for pattern, replacement in self.replacers.items():
- s = re.sub(pattern, replacement, s)
-
- s = re.sub(r"(\d),(\d)", r"\1\2", s) # remove commas between digits
- s = re.sub(r"\.([^0-9]|$)", r" \1", s) # remove periods not followed by numbers
- s = remove_symbols_and_diacritics(s, keep=".%$垄鈧�") # keep numeric symbols
-
- s = self.standardize_numbers(s)
- s = self.standardize_spellings(s)
-
- # now remove prefix/suffix symbols that are not preceded/followed by numbers
- s = re.sub(r"[.$垄鈧([^0-9])", r" \1", s)
- s = re.sub(r"([^0-9])%", r"\1 ", s)
-
- s = re.sub(r"\s+", " ", s) # replace any successive whitespaces with a space
-
- return s
diff --git a/funasr/models/sense_voice/whisper_lib/timing.py b/funasr/models/sense_voice/whisper_lib/timing.py
deleted file mode 100644
index ba9cb13..0000000
--- a/funasr/models/sense_voice/whisper_lib/timing.py
+++ /dev/null
@@ -1,362 +0,0 @@
-import itertools
-import subprocess
-import warnings
-from dataclasses import dataclass
-from typing import TYPE_CHECKING, List
-
-import numba
-import numpy as np
-import torch
-import torch.nn.functional as F
-
-from .audio import HOP_LENGTH, SAMPLE_RATE, TOKENS_PER_SECOND
-from .tokenizer import Tokenizer
-
-if TYPE_CHECKING:
- from .model import Whisper
-
-
-def median_filter(x: torch.Tensor, filter_width: int):
- """Apply a median filter of width `filter_width` along the last dimension of `x`"""
- pad_width = filter_width // 2
- if x.shape[-1] <= pad_width:
- # F.pad requires the padding width to be smaller than the input dimension
- return x
-
- if (ndim := x.ndim) <= 2:
- # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
- x = x[None, None, :]
-
- assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
-
- result = None
- x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
- if x.is_cuda:
- try:
- from .triton_ops import median_filter_cuda
-
- result = median_filter_cuda(x, filter_width)
- except (RuntimeError, subprocess.CalledProcessError):
- warnings.warn(
- "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
- "falling back to a slower median kernel implementation..."
- )
-
- if result is None:
- # sort() is faster than torch.median (https://github.com/pytorch/pytorch/issues/51450)
- result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
-
- if ndim <= 2:
- result = result[0, 0]
-
- return result
-
-
-@numba.jit(nopython=True)
-def backtrace(trace: np.ndarray):
- i = trace.shape[0] - 1
- j = trace.shape[1] - 1
- trace[0, :] = 2
- trace[:, 0] = 1
-
- result = []
- while i > 0 or j > 0:
- result.append((i - 1, j - 1))
-
- if trace[i, j] == 0:
- i -= 1
- j -= 1
- elif trace[i, j] == 1:
- i -= 1
- elif trace[i, j] == 2:
- j -= 1
- else:
- raise ValueError("Unexpected trace[i, j]")
-
- result = np.array(result)
- return result[::-1, :].T
-
-
-@numba.jit(nopython=True, parallel=True)
-def dtw_cpu(x: np.ndarray):
- N, M = x.shape
- cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
- trace = -np.ones((N + 1, M + 1), dtype=np.float32)
-
- cost[0, 0] = 0
- for j in range(1, M + 1):
- for i in range(1, N + 1):
- c0 = cost[i - 1, j - 1]
- c1 = cost[i - 1, j]
- c2 = cost[i, j - 1]
-
- if c0 < c1 and c0 < c2:
- c, t = c0, 0
- elif c1 < c0 and c1 < c2:
- c, t = c1, 1
- else:
- c, t = c2, 2
-
- cost[i, j] = x[i - 1, j - 1] + c
- trace[i, j] = t
-
- return backtrace(trace)
-
-
-def dtw_cuda(x, BLOCK_SIZE=1024):
- from .triton_ops import dtw_kernel
-
- M, N = x.shape
- assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
-
- x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
- x_skew = x_skew.T.contiguous()
- cost = torch.ones(N + M + 2, M + 2) * np.inf
- cost[0, 0] = 0
- cost = cost.cuda()
- trace = torch.zeros_like(cost, dtype=torch.int32)
-
- dtw_kernel[(1,)](
- cost,
- trace,
- x_skew,
- x_skew.stride(0),
- cost.stride(0),
- trace.stride(0),
- N,
- M,
- BLOCK_SIZE=BLOCK_SIZE,
- )
-
- trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
- return backtrace(trace.cpu().numpy())
-
-
-def dtw(x: torch.Tensor) -> np.ndarray:
- if x.is_cuda:
- try:
- return dtw_cuda(x)
- except (RuntimeError, subprocess.CalledProcessError):
- warnings.warn(
- "Failed to launch Triton kernels, likely due to missing CUDA toolkit; "
- "falling back to a slower DTW implementation..."
- )
-
- return dtw_cpu(x.double().cpu().numpy())
-
-
-@dataclass
-class WordTiming:
- word: str
- tokens: List[int]
- start: float
- end: float
- probability: float
-
-
-def find_alignment(
- model: "Whisper",
- tokenizer: Tokenizer,
- text_tokens: List[int],
- mel: torch.Tensor,
- num_frames: int,
- *,
- medfilt_width: int = 7,
- qk_scale: float = 1.0,
-) -> List[WordTiming]:
- if len(text_tokens) == 0:
- return []
-
- tokens = torch.tensor(
- [
- *tokenizer.sot_sequence,
- tokenizer.no_timestamps,
- *text_tokens,
- tokenizer.eot,
- ]
- ).to(model.device)
-
- # install hooks on the cross attention layers to retrieve the attention weights
- QKs = [None] * model.dims.n_text_layer
- hooks = [
- block.cross_attn.register_forward_hook(
- lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])
- )
- for i, block in enumerate(model.decoder.blocks)
- ]
-
- with torch.no_grad():
- logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
- sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
- token_probs = sampled_logits.softmax(dim=-1)
- text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens]
- text_token_probs = text_token_probs.tolist()
-
- for hook in hooks:
- hook.remove()
-
- # heads * tokens * frames
- weights = torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])
- weights = weights[:, :, : num_frames // 2]
- weights = (weights * qk_scale).softmax(dim=-1)
- std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
- weights = (weights - mean) / std
- weights = median_filter(weights, medfilt_width)
-
- matrix = weights.mean(axis=0)
- matrix = matrix[len(tokenizer.sot_sequence) : -1]
- text_indices, time_indices = dtw(-matrix)
-
- words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
- if len(word_tokens) <= 1:
- # return on eot only
- # >>> np.pad([], (1, 0))
- # array([0.])
- # This results in crashes when we lookup jump_times with float, like
- # IndexError: arrays used as indices must be of integer (or boolean) type
- return []
- word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
-
- jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
- jump_times = time_indices[jumps] / TOKENS_PER_SECOND
- start_times = jump_times[word_boundaries[:-1]]
- end_times = jump_times[word_boundaries[1:]]
- word_probabilities = [
- np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
- ]
-
- return [
- WordTiming(word, tokens, start, end, probability)
- for word, tokens, start, end, probability in zip(
- words, word_tokens, start_times, end_times, word_probabilities
- )
- ]
-
-
-def merge_punctuations(alignment: List[WordTiming], prepended: str, appended: str):
- # merge prepended punctuations
- i = len(alignment) - 2
- j = len(alignment) - 1
- while i >= 0:
- previous = alignment[i]
- following = alignment[j]
- if previous.word.startswith(" ") and previous.word.strip() in prepended:
- # prepend it to the following word
- following.word = previous.word + following.word
- following.tokens = previous.tokens + following.tokens
- previous.word = ""
- previous.tokens = []
- else:
- j = i
- i -= 1
-
- # merge appended punctuations
- i = 0
- j = 1
- while j < len(alignment):
- previous = alignment[i]
- following = alignment[j]
- if not previous.word.endswith(" ") and following.word in appended:
- # append it to the previous word
- previous.word = previous.word + following.word
- previous.tokens = previous.tokens + following.tokens
- following.word = ""
- following.tokens = []
- else:
- i = j
- j += 1
-
-
-def add_word_timestamps(
- *,
- segments: List[dict],
- model: "Whisper",
- tokenizer: Tokenizer,
- mel: torch.Tensor,
- num_frames: int,
- prepend_punctuations: str = "\"'鈥溌�([{-",
- append_punctuations: str = "\"'.銆�,锛�!锛�?锛�:锛氣��)]}銆�",
- last_speech_timestamp: float,
- **kwargs,
-):
- if len(segments) == 0:
- return
-
- text_tokens_per_segment = [
- [token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments
- ]
-
- text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
- alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
- word_durations = np.array([t.end - t.start for t in alignment])
- word_durations = word_durations[word_durations.nonzero()]
- median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
- median_duration = min(0.7, float(median_duration))
- max_duration = median_duration * 2
-
- # hack: truncate long words at sentence boundaries.
- # a better segmentation algorithm based on VAD should be able to replace this.
- if len(word_durations) > 0:
- sentence_end_marks = ".銆�!锛�?锛�"
- # ensure words at sentence boundaries are not longer than twice the median word duration.
- for i in range(1, len(alignment)):
- if alignment[i].end - alignment[i].start > max_duration:
- if alignment[i].word in sentence_end_marks:
- alignment[i].end = alignment[i].start + max_duration
- elif alignment[i - 1].word in sentence_end_marks:
- alignment[i].start = alignment[i].end - max_duration
-
- merge_punctuations(alignment, prepend_punctuations, append_punctuations)
-
- time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
- word_index = 0
-
- for segment, text_tokens in zip(segments, text_tokens_per_segment):
- saved_tokens = 0
- words = []
-
- while word_index < len(alignment) and saved_tokens < len(text_tokens):
- timing = alignment[word_index]
-
- if timing.word:
- words.append(
- dict(
- word=timing.word,
- start=round(time_offset + timing.start, 2),
- end=round(time_offset + timing.end, 2),
- probability=timing.probability,
- )
- )
-
- saved_tokens += len(timing.tokens)
- word_index += 1
-
- # hack: truncate long words at segment boundaries.
- # a better segmentation algorithm based on VAD should be able to replace this.
- if len(words) > 0:
- # ensure the first and second word after a pause is not longer than
- # twice the median word duration.
- if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
- words[0]["end"] - words[0]["start"] > max_duration
- or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)
- ):
- if len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration:
- boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
- words[0]["end"] = words[1]["start"] = boundary
- words[0]["start"] = max(0, words[0]["end"] - max_duration)
-
- # prefer the segment-level start timestamp if the first word is too long.
- if segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]:
- words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
- else:
- segment["start"] = words[0]["start"]
-
- # prefer the segment-level end timestamp if the last word is too long.
- if segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]:
- words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
- else:
- segment["end"] = words[-1]["end"]
-
- last_speech_timestamp = segment["end"]
-
- segment["words"] = words
diff --git a/funasr/models/sense_voice/whisper_lib/tokenizer.py b/funasr/models/sense_voice/whisper_lib/tokenizer.py
deleted file mode 100644
index 5b276c2..0000000
--- a/funasr/models/sense_voice/whisper_lib/tokenizer.py
+++ /dev/null
@@ -1,456 +0,0 @@
-import base64
-import os
-import string
-from dataclasses import dataclass, field
-from functools import cached_property, lru_cache
-from typing import Dict, List, Optional, Tuple
-
-import tiktoken
-
-# FIX(funasr): sense vocie
-LANGUAGES = {
- "en": "english",
- "zh": "chinese",
- "de": "german",
- "es": "spanish",
- "ru": "russian",
- "ko": "korean",
- "fr": "french",
- "ja": "japanese",
- "pt": "portuguese",
- "tr": "turkish",
- "pl": "polish",
- "ca": "catalan",
- "nl": "dutch",
- "ar": "arabic",
- "sv": "swedish",
- "it": "italian",
- "id": "indonesian",
- "hi": "hindi",
- "fi": "finnish",
- "vi": "vietnamese",
- "he": "hebrew",
- "uk": "ukrainian",
- "el": "greek",
- "ms": "malay",
- "cs": "czech",
- "ro": "romanian",
- "da": "danish",
- "hu": "hungarian",
- "ta": "tamil",
- "no": "norwegian",
- "th": "thai",
- "ur": "urdu",
- "hr": "croatian",
- "bg": "bulgarian",
- "lt": "lithuanian",
- "la": "latin",
- "mi": "maori",
- "ml": "malayalam",
- "cy": "welsh",
- "sk": "slovak",
- "te": "telugu",
- "fa": "persian",
- "lv": "latvian",
- "bn": "bengali",
- "sr": "serbian",
- "az": "azerbaijani",
- "sl": "slovenian",
- "kn": "kannada",
- "et": "estonian",
- "mk": "macedonian",
- "br": "breton",
- "eu": "basque",
- "is": "icelandic",
- "hy": "armenian",
- "ne": "nepali",
- "mn": "mongolian",
- "bs": "bosnian",
- "kk": "kazakh",
- "sq": "albanian",
- "sw": "swahili",
- "gl": "galician",
- "mr": "marathi",
- "pa": "punjabi",
- "si": "sinhala",
- "km": "khmer",
- "sn": "shona",
- "yo": "yoruba",
- "so": "somali",
- "af": "afrikaans",
- "oc": "occitan",
- "ka": "georgian",
- "be": "belarusian",
- "tg": "tajik",
- "sd": "sindhi",
- "gu": "gujarati",
- "am": "amharic",
- "yi": "yiddish",
- "lo": "lao",
- "uz": "uzbek",
- "fo": "faroese",
- "ht": "haitian creole",
- "ps": "pashto",
- "tk": "turkmen",
- "nn": "nynorsk",
- "mt": "maltese",
- "sa": "sanskrit",
- "lb": "luxembourgish",
- "my": "myanmar",
- "bo": "tibetan",
- "tl": "tagalog",
- "mg": "malagasy",
- "as": "assamese",
- "tt": "tatar",
- "haw": "hawaiian",
- "ln": "lingala",
- "ha": "hausa",
- "ba": "bashkir",
- "jw": "javanese",
- "su": "sundanese",
- "yue": "cantonese",
- "minnan": "minnan",
- "wuyu": "wuyu",
- "dialect": "dialect",
- "zh/en": "zh/en",
- "en/zh": "en/zh",
-}
-
-# language code lookup by name, with a few language aliases
-TO_LANGUAGE_CODE = {
- **{language: code for code, language in LANGUAGES.items()},
- "burmese": "my",
- "valencian": "ca",
- "flemish": "nl",
- "haitian": "ht",
- "letzeburgesch": "lb",
- "pushto": "ps",
- "panjabi": "pa",
- "moldavian": "ro",
- "moldovan": "ro",
- "sinhalese": "si",
- "castilian": "es",
- "mandarin": "zh",
-}
-
-# FIX(funasr): sense vocie
-AUDIO_EVENT = {
- "ASR": "ASR",
- "AED": "AED",
- "SER": "SER",
- "Speech": "Speech",
- "/Speech": "/Speech",
- "BGM": "BGM",
- "/BGM": "/BGM",
- "Laughter": "Laughter",
- "/Laughter": "/Laughter",
- "Applause": "Applause",
- "/Applause": "/Applause",
-}
-
-EMOTION = {
- "HAPPY": "HAPPY",
- "SAD": "SAD",
- "ANGRY": "ANGRY",
- "NEUTRAL": "NEUTRAL",
-}
-
-
-@dataclass
-class Tokenizer:
- """A thin wrapper around `tiktoken` providing quick access to special tokens"""
-
- encoding: tiktoken.Encoding
- num_languages: int
- language: Optional[str] = None
- task: Optional[str] = None
- sot_sequence: Tuple[int] = ()
- special_tokens: Dict[str, int] = field(default_factory=dict)
-
- def __post_init__(self):
- for special in self.encoding.special_tokens_set:
- special_token = self.encoding.encode_single_token(special)
- self.special_tokens[special] = special_token
-
- sot: int = self.special_tokens["<|startoftranscript|>"]
- translate: int = self.special_tokens["<|translate|>"]
- transcribe: int = self.special_tokens["<|transcribe|>"]
-
- langs = tuple(LANGUAGES.keys())[: self.num_languages]
- sot_sequence = [sot]
- if self.language is not None:
- if self.language == "nospeech":
- sot_sequence.append(self.no_speech)
- else:
- sot_sequence.append(sot + 1 + langs.index(self.language))
- # if self.language is not None:
- # sot_sequence.append(sot + 1 + langs.index(self.language))
- if self.task is not None:
- task_token: int = transcribe if self.task == "transcribe" else translate
- sot_sequence.append(task_token)
-
- self.sot_sequence = tuple(sot_sequence)
-
- def encode(self, text, **kwargs):
- return self.encoding.encode(text, **kwargs)
-
- def decode(self, token_ids: List[int], **kwargs) -> str:
- token_ids = [t for t in token_ids if t < self.timestamp_begin]
- return self.encoding.decode(token_ids, **kwargs)
-
- def decode_with_timestamps(self, token_ids: List[int], **kwargs) -> str:
- """
- Timestamp tokens are above other special tokens' id range and are ignored by `decode()`.
- This method decodes given tokens with timestamps tokens annotated, e.g. "<|1.08|>".
- """
- return self.encoding.decode(token_ids, **kwargs)
-
- def get_vocab_size(self) -> int:
- return self.encoding.n_vocab
-
- @cached_property
- def eot(self) -> int:
- return self.encoding.eot_token
-
- @cached_property
- def transcribe(self) -> int:
- return self.special_tokens["<|transcribe|>"]
-
- @cached_property
- def translate(self) -> int:
- return self.special_tokens["<|translate|>"]
-
- @cached_property
- def sot(self) -> int:
- return self.special_tokens["<|startoftranscript|>"]
-
- @cached_property
- def sot_sense(self) -> int:
- return self.special_tokens["<|startoftranscript|>"]
-
- @cached_property
- def sot_lm(self) -> int:
- return self.special_tokens["<|startoflm|>"]
-
- @cached_property
- def sot_prev(self) -> int:
- return self.special_tokens["<|startofprev|>"]
-
- @cached_property
- def no_speech(self) -> int:
- return self.special_tokens["<|nospeech|>"]
-
- @cached_property
- def no_timestamps(self) -> int:
- return self.special_tokens["<|notimestamps|>"]
-
- @cached_property
- def timestamp_begin(self) -> int:
- return self.special_tokens["<|0.00|>"]
-
- @cached_property
- def language_token(self) -> int:
- """Returns the token id corresponding to the value of the `language` field"""
- if self.language is None:
- raise ValueError("This tokenizer does not have language token configured")
-
- return self.to_language_token(self.language)
-
- def to_language_token(self, language):
- if token := self.special_tokens.get(f"<|{language}|>", None):
- return token
-
- raise KeyError(f"Language {language} not found in tokenizer.")
-
- @cached_property
- def all_language_tokens(self) -> Tuple[int]:
- result = []
- for token, token_id in self.special_tokens.items():
- if token.strip("<|>") in LANGUAGES:
- result.append(token_id)
- return tuple(result)[: self.num_languages]
-
- @cached_property
- def all_language_codes(self) -> Tuple[str]:
- return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
-
- @cached_property
- def sot_sequence_including_notimestamps(self) -> Tuple[int]:
- return tuple(list(self.sot_sequence) + [self.no_timestamps])
-
- @cached_property
- def non_speech_tokens(self) -> Tuple[int]:
- """
- Returns the list of tokens to suppress in order to avoid any speaker tags or non-speech
- annotations, to prevent sampling texts that are not actually spoken in the audio, e.g.
-
- - 鈾櫔鈾�
- - ( SPEAKING FOREIGN LANGUAGE )
- - [DAVID] Hey there,
-
- keeping basic punctuations like commas, periods, question marks, exclamation points, etc.
- """
- symbols = list('"#()*+/:;<=>@[\\]^_`{|}~銆屻�嶃�庛��')
- symbols += "<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} 鈾櫔 鈾櫔鈾�".split()
-
- # symbols that may be a single token or multiple tokens depending on the tokenizer.
- # In case they're multiple tokens, suppress the first token, which is safe because:
- # These are between U+2640 and U+267F miscellaneous symbols that are okay to suppress
- # in generations, and in the 3-byte UTF-8 representation they share the first two bytes.
- miscellaneous = set("鈾┾櫔鈾櫖鈾櫘鈾�")
- assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
-
- # allow hyphens "-" and single quotes "'" between words, but not at the beginning of a word
- result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
- for symbol in symbols + list(miscellaneous):
- for tokens in [
- self.encoding.encode(symbol),
- self.encoding.encode(" " + symbol),
- ]:
- if len(tokens) == 1 or symbol in miscellaneous:
- result.add(tokens[0])
-
- return tuple(sorted(result))
-
- def split_to_word_tokens(self, tokens: List[int]):
- if self.language in {"zh", "ja", "th", "lo", "my", "yue"}:
- # These languages don't typically use spaces, so it is difficult to split words
- # without morpheme analysis. Here, we instead split words at any
- # position where the tokens are decoded as valid unicode points
- return self.split_tokens_on_unicode(tokens)
-
- return self.split_tokens_on_spaces(tokens)
-
- def split_tokens_on_unicode(self, tokens: List[int]):
- decoded_full = self.decode_with_timestamps(tokens)
- replacement_char = "\ufffd"
-
- words = []
- word_tokens = []
- current_tokens = []
- unicode_offset = 0
-
- for token in tokens:
- current_tokens.append(token)
- decoded = self.decode_with_timestamps(current_tokens)
-
- if (
- replacement_char not in decoded
- or decoded_full[unicode_offset + decoded.index(replacement_char)]
- == replacement_char
- ):
- words.append(decoded)
- word_tokens.append(current_tokens)
- current_tokens = []
- unicode_offset += len(decoded)
-
- return words, word_tokens
-
- def split_tokens_on_spaces(self, tokens: List[int]):
- subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
- words = []
- word_tokens = []
-
- for subword, subword_tokens in zip(subwords, subword_tokens_list):
- special = subword_tokens[0] >= self.eot
- with_space = subword.startswith(" ")
- punctuation = subword.strip() in string.punctuation
- if special or with_space or punctuation or len(words) == 0:
- words.append(subword)
- word_tokens.append(subword_tokens)
- else:
- words[-1] = words[-1] + subword
- word_tokens[-1].extend(subword_tokens)
-
- return words, word_tokens
-
-
-@lru_cache(maxsize=None)
-def get_encoding(name: str = "gpt2", num_languages: int = 99, vocab_path: str = None):
- if vocab_path is None:
- vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
-
- ranks = {
- base64.b64decode(token): int(rank)
- for token, rank in (line.split() for line in open(vocab_path) if line)
- }
- n_vocab = len(ranks)
- special_tokens = {}
-
- if False: # name == "gpt2" or name == "multilingual":
- specials = [
- "<|endoftext|>",
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
- ]
- else:
- specials = [
- "<|endoftext|>",
- "<|startoftranscript|>",
- *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
- *[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
- *[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
- "<|translate|>",
- "<|transcribe|>",
- "<|startoflm|>",
- "<|startofprev|>",
- "<|nospeech|>",
- "<|notimestamps|>",
- *[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 51)],
- *[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
- ]
-
- for token in specials:
- special_tokens[token] = n_vocab
- n_vocab += 1
-
- return tiktoken.Encoding(
- name=os.path.basename(vocab_path),
- explicit_n_vocab=n_vocab,
- pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
- mergeable_ranks=ranks,
- special_tokens=special_tokens,
- )
-
-
-@lru_cache(maxsize=None)
-def get_tokenizer(
- multilingual: bool,
- *,
- num_languages: int = 99,
- language: Optional[str] = None,
- task: Optional[str] = None, # Literal["transcribe", "translate", None]
- encoding_path: Optional[str] = None,
- vocab_path: Optional[str] = None,
-) -> Tokenizer:
- if language is not None:
- language = language.lower()
- if language not in LANGUAGES:
- if language in TO_LANGUAGE_CODE:
- language = TO_LANGUAGE_CODE[language]
- elif language == "nospeech":
- pass
- else:
- raise ValueError(f"Unsupported language: {language}")
-
- if multilingual:
- encoding_name = "multilingual"
- language = language or "en"
- task = task or "transcribe"
- else:
- encoding_name = "gpt2"
- language = None
- task = None
- if encoding_path is not None:
- encoding_name = encoding_path
-
- encoding = get_encoding(name=encoding_name, num_languages=num_languages, vocab_path=vocab_path)
-
- return Tokenizer(encoding=encoding, num_languages=num_languages, language=language, task=task)
diff --git a/funasr/models/sense_voice/whisper_lib/transcribe.py b/funasr/models/sense_voice/whisper_lib/transcribe.py
deleted file mode 100644
index 5c5f49d..0000000
--- a/funasr/models/sense_voice/whisper_lib/transcribe.py
+++ /dev/null
@@ -1,573 +0,0 @@
-import argparse
-import os
-import traceback
-import warnings
-from typing import TYPE_CHECKING, List, Optional, Tuple, Union
-
-import numpy as np
-import torch
-import tqdm
-
-from .audio import (
- FRAMES_PER_SECOND,
- HOP_LENGTH,
- N_FRAMES,
- N_SAMPLES,
- SAMPLE_RATE,
- log_mel_spectrogram,
- pad_or_trim,
-)
-from .decoding import DecodingOptions, DecodingResult
-from .timing import add_word_timestamps
-from .tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
-from .utils import (
- exact_div,
- format_timestamp,
- get_end,
- get_writer,
- make_safe,
- optional_float,
- optional_int,
- str2bool,
-)
-
-if TYPE_CHECKING:
- from .model import Whisper
-
-
-def transcribe(
- model: "Whisper",
- audio: Union[str, np.ndarray, torch.Tensor],
- *,
- verbose: Optional[bool] = None,
- temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
- compression_ratio_threshold: Optional[float] = 2.4,
- logprob_threshold: Optional[float] = -1.0,
- no_speech_threshold: Optional[float] = 0.6,
- condition_on_previous_text: bool = True,
- initial_prompt: Optional[str] = None,
- word_timestamps: bool = False,
- prepend_punctuations: str = "\"'鈥溌�([{-",
- append_punctuations: str = "\"'.銆�,锛�!锛�?锛�:锛氣��)]}銆�",
- clip_timestamps: Union[str, List[float]] = "0",
- hallucination_silence_threshold: Optional[float] = None,
- **decode_options,
-):
- """
- Transcribe an audio file using Whisper
-
- Parameters
- ----------
- model: Whisper
- The Whisper model instance
-
- audio: Union[str, np.ndarray, torch.Tensor]
- The path to the audio file to open, or the audio waveform
-
- verbose: bool
- Whether to display the text being decoded to the console. If True, displays all the details,
- If False, displays minimal details. If None, does not display anything
-
- temperature: Union[float, Tuple[float, ...]]
- Temperature for sampling. It can be a tuple of temperatures, which will be successively used
- upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
-
- compression_ratio_threshold: float
- If the gzip compression ratio is above this value, treat as failed
-
- logprob_threshold: float
- If the average log probability over sampled tokens is below this value, treat as failed
-
- no_speech_threshold: float
- If the no_speech probability is higher than this value AND the average log probability
- over sampled tokens is below `logprob_threshold`, consider the segment as silent
-
- condition_on_previous_text: bool
- if True, the previous output of the model is provided as a prompt for the next window;
- disabling may make the text inconsistent across windows, but the model becomes less prone to
- getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
-
- word_timestamps: bool
- Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
- and include the timestamps for each word in each segment.
-
- prepend_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the next word
-
- append_punctuations: str
- If word_timestamps is True, merge these punctuation symbols with the previous word
-
- initial_prompt: Optional[str]
- Optional text to provide as a prompt for the first window. This can be used to provide, or
- "prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
- to make it more likely to predict those word correctly.
-
- decode_options: dict
- Keyword arguments to construct `DecodingOptions` instances
-
- clip_timestamps: Union[str, List[float]]
- Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process.
- The last end timestamp defaults to the end of the file.
-
- hallucination_silence_threshold: Optional[float]
- When word_timestamps is True, skip silent periods longer than this threshold (in seconds)
- when a possible hallucination is detected
-
- Returns
- -------
- A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
- the spoken language ("language"), which is detected when `decode_options["language"]` is None.
- """
- dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
- if model.device == torch.device("cpu"):
- if torch.cuda.is_available():
- warnings.warn("Performing inference on CPU when CUDA is available")
- if dtype == torch.float16:
- warnings.warn("FP16 is not supported on CPU; using FP32 instead")
- dtype = torch.float32
-
- if dtype == torch.float32:
- decode_options["fp16"] = False
-
- # Pad 30-seconds of silence to the input audio, for slicing
- mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
- content_frames = mel.shape[-1] - N_FRAMES
- content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
-
- if decode_options.get("language", None) is None:
- if not model.is_multilingual:
- decode_options["language"] = "en"
- else:
- if verbose:
- print(
- "Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
- )
- mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
- _, probs = model.detect_language(mel_segment)
- decode_options["language"] = max(probs, key=probs.get)
- if verbose is not None:
- print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
-
- language: str = decode_options["language"]
- task: str = decode_options.get("task", "transcribe")
- tokenizer = get_tokenizer(
- model.is_multilingual,
- num_languages=model.num_languages,
- language=language,
- task=task,
- )
-
- if isinstance(clip_timestamps, str):
- clip_timestamps = [
- float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])
- ]
- seek_points: List[int] = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
- if len(seek_points) == 0:
- seek_points.append(0)
- if len(seek_points) % 2 == 1:
- seek_points.append(content_frames)
- seek_clips: List[Tuple[int, int]] = list(zip(seek_points[::2], seek_points[1::2]))
-
- punctuation = "\"'鈥溌�([{-\"'.銆�,锛�!锛�?锛�:锛氣��)]}銆�"
-
- if word_timestamps and task == "translate":
- warnings.warn("Word-level timestamps on translations may not be reliable.")
-
- def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
- temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
- decode_result = None
-
- for t in temperatures:
- kwargs = {**decode_options}
- if t > 0:
- # disable beam_size and patience when t > 0
- kwargs.pop("beam_size", None)
- kwargs.pop("patience", None)
- else:
- # disable best_of when t == 0
- kwargs.pop("best_of", None)
-
- options = DecodingOptions(**kwargs, temperature=t)
- decode_result = model.decode(segment, options)
-
- needs_fallback = False
- if (
- compression_ratio_threshold is not None
- and decode_result.compression_ratio > compression_ratio_threshold
- ):
- needs_fallback = True # too repetitive
- if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
- needs_fallback = True # average log probability is too low
- if (
- no_speech_threshold is not None
- and decode_result.no_speech_prob > no_speech_threshold
- ):
- needs_fallback = False # silence
- if not needs_fallback:
- break
-
- return decode_result
-
- clip_idx = 0
- seek = seek_clips[clip_idx][0]
- input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
- time_precision = (
- input_stride * HOP_LENGTH / SAMPLE_RATE
- ) # time per output token: 0.02 (seconds)
- all_tokens = []
- all_segments = []
- prompt_reset_since = 0
-
- if initial_prompt is not None:
- initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
- all_tokens.extend(initial_prompt_tokens)
- else:
- initial_prompt_tokens = []
-
- def new_segment(*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult):
- tokens = tokens.tolist()
- text_tokens = [token for token in tokens if token < tokenizer.eot]
- return {
- "seek": seek,
- "start": start,
- "end": end,
- "text": tokenizer.decode(text_tokens),
- "tokens": tokens,
- "temperature": result.temperature,
- "avg_logprob": result.avg_logprob,
- "compression_ratio": result.compression_ratio,
- "no_speech_prob": result.no_speech_prob,
- }
-
- # show the progress bar when verbose is False (if True, transcribed text will be printed)
- with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
- last_speech_timestamp = 0.0
- # NOTE: This loop is obscurely flattened to make the diff readable.
- # A later commit should turn this into a simpler nested loop.
- # for seek_clip_start, seek_clip_end in seek_clips:
- # while seek < seek_clip_end
- while clip_idx < len(seek_clips):
- seek_clip_start, seek_clip_end = seek_clips[clip_idx]
- if seek < seek_clip_start:
- seek = seek_clip_start
- if seek >= seek_clip_end:
- clip_idx += 1
- if clip_idx < len(seek_clips):
- seek = seek_clips[clip_idx][0]
- continue
- time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
- window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
- segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
- mel_segment = mel[:, seek : seek + segment_size]
- segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
- mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
-
- decode_options["prompt"] = all_tokens[prompt_reset_since:]
- result: DecodingResult = decode_with_fallback(mel_segment)
- tokens = torch.tensor(result.tokens)
-
- if no_speech_threshold is not None:
- # no voice activity check
- should_skip = result.no_speech_prob > no_speech_threshold
- if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
- # don't skip if the logprob is high enough, despite the no_speech_prob
- should_skip = False
-
- if should_skip:
- seek += segment_size # fast-forward to the next segment boundary
- continue
-
- previous_seek = seek
- current_segments = []
-
- # anomalous words are very long/short/improbable
- def word_anomaly_score(word: dict) -> float:
- probability = word.get("probability", 0.0)
- duration = word["end"] - word["start"]
- score = 0.0
- if probability < 0.15:
- score += 1.0
- if duration < 0.133:
- score += (0.133 - duration) * 15
- if duration > 2.0:
- score += duration - 2.0
- return score
-
- def is_segment_anomaly(segment: Optional[dict]) -> bool:
- if segment is None or not segment["words"]:
- return False
- words = [w for w in segment["words"] if w["word"] not in punctuation]
- words = words[:8]
- score = sum(word_anomaly_score(w) for w in words)
- return score >= 3 or score + 0.01 >= len(words)
-
- def next_words_segment(segments: List[dict]) -> Optional[dict]:
- return next((s for s in segments if s["words"]), None)
-
- timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
- single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
-
- consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
- consecutive.add_(1)
- if len(consecutive) > 0:
- # if the output contains two consecutive timestamp tokens
- slices = consecutive.tolist()
- if single_timestamp_ending:
- slices.append(len(tokens))
-
- last_slice = 0
- for current_slice in slices:
- sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
- end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
- current_segments.append(
- new_segment(
- start=time_offset + start_timestamp_pos * time_precision,
- end=time_offset + end_timestamp_pos * time_precision,
- tokens=sliced_tokens,
- result=result,
- )
- )
- last_slice = current_slice
-
- if single_timestamp_ending:
- # single timestamp at the end means no speech after the last timestamp.
- seek += segment_size
- else:
- # otherwise, ignore the unfinished segment and seek to the last timestamp
- last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
- seek += last_timestamp_pos * input_stride
- else:
- duration = segment_duration
- timestamps = tokens[timestamp_tokens.nonzero().flatten()]
- if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
- # no consecutive timestamps but it has a timestamp; use the last one.
- last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
- duration = last_timestamp_pos * time_precision
-
- current_segments.append(
- new_segment(
- start=time_offset,
- end=time_offset + duration,
- tokens=tokens,
- result=result,
- )
- )
- seek += segment_size
-
- if word_timestamps:
- add_word_timestamps(
- segments=current_segments,
- model=model,
- tokenizer=tokenizer,
- mel=mel_segment,
- num_frames=segment_size,
- prepend_punctuations=prepend_punctuations,
- append_punctuations=append_punctuations,
- last_speech_timestamp=last_speech_timestamp,
- )
-
- if not single_timestamp_ending:
- last_word_end = get_end(current_segments)
- if last_word_end is not None and last_word_end > time_offset:
- seek = round(last_word_end * FRAMES_PER_SECOND)
-
- # skip silence before possible hallucinations
- if hallucination_silence_threshold is not None:
- threshold = hallucination_silence_threshold
- if not single_timestamp_ending:
- last_word_end = get_end(current_segments)
- if last_word_end is not None and last_word_end > time_offset:
- remaining_duration = window_end_time - last_word_end
- if remaining_duration > threshold:
- seek = round(last_word_end * FRAMES_PER_SECOND)
- else:
- seek = previous_seek + segment_size
-
- # if first segment might be a hallucination, skip leading silence
- first_segment = next_words_segment(current_segments)
- if first_segment is not None and is_segment_anomaly(first_segment):
- gap = first_segment["start"] - time_offset
- if gap > threshold:
- seek = previous_seek + round(gap * FRAMES_PER_SECOND)
- continue
-
- # skip silence before any possible hallucination that is surrounded
- # by silence or more hallucinations
- hal_last_end = last_speech_timestamp
- for si in range(len(current_segments)):
- segment = current_segments[si]
- if not segment["words"]:
- continue
- if is_segment_anomaly(segment):
- next_segment = next_words_segment(current_segments[si + 1 :])
- if next_segment is not None:
- hal_next_start = next_segment["words"][0]["start"]
- else:
- hal_next_start = time_offset + segment_duration
- silence_before = (
- segment["start"] - hal_last_end > threshold
- or segment["start"] < threshold
- or segment["start"] - time_offset < 2.0
- )
- silence_after = (
- hal_next_start - segment["end"] > threshold
- or is_segment_anomaly(next_segment)
- or window_end_time - segment["end"] < 2.0
- )
- if silence_before and silence_after:
- seek = round(
- max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND
- )
- if content_duration - segment["end"] < threshold:
- seek = content_frames
- current_segments[si:] = []
- break
- hal_last_end = segment["end"]
-
- last_word_end = get_end(current_segments)
- if last_word_end is not None:
- last_speech_timestamp = last_word_end
-
- if verbose:
- for segment in current_segments:
- start, end, text = segment["start"], segment["end"], segment["text"]
- line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
- print(make_safe(line))
-
- # if a segment is instantaneous or does not contain text, clear it
- for i, segment in enumerate(current_segments):
- if segment["start"] == segment["end"] or segment["text"].strip() == "":
- segment["text"] = ""
- segment["tokens"] = []
- segment["words"] = []
-
- all_segments.extend(
- [
- {"id": i, **segment}
- for i, segment in enumerate(current_segments, start=len(all_segments))
- ]
- )
- all_tokens.extend(
- [token for segment in current_segments for token in segment["tokens"]]
- )
-
- if not condition_on_previous_text or result.temperature > 0.5:
- # do not feed the prompt tokens if a high temperature was used
- prompt_reset_since = len(all_tokens)
-
- # update progress bar
- pbar.update(min(content_frames, seek) - previous_seek)
-
- return dict(
- text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
- segments=all_segments,
- language=language,
- )
-
-
-def cli():
- from . import available_models
-
- def valid_model_name(name):
- if name in available_models() or os.path.exists(name):
- return name
- raise ValueError(
- f"model should be one of {available_models()} or path to a model checkpoint"
- )
-
- # fmt: off
- parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
- parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
- parser.add_argument("--model", default="small", type=valid_model_name, help="name of the Whisper model to use")
- parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
- parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
- parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
- parser.add_argument("--output_format", "-f", type=str, default="all", choices=["txt", "vtt", "srt", "tsv", "json", "all"], help="format of the output file; if not specified, all available formats will be produced")
- parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
-
- parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
- parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
-
- parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
- parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
- parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
- parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
- parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
-
- parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
- parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
- parser.add_argument("--condition_on_previous_text", type=str2bool, default=True, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
- parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
-
- parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
- parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
- parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
- parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
- parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
- parser.add_argument("--prepend_punctuations", type=str, default="\"\'鈥溌�([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
- parser.add_argument("--append_punctuations", type=str, default="\"\'.銆�,锛�!锛�?锛�:锛氣��)]}銆�", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
- parser.add_argument("--highlight_words", type=str2bool, default=False, help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
- parser.add_argument("--max_line_width", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of characters in a line before breaking the line")
- parser.add_argument("--max_line_count", type=optional_int, default=None, help="(requires --word_timestamps True) the maximum number of lines in a segment")
- parser.add_argument("--max_words_per_line", type=optional_int, default=None, help="(requires --word_timestamps True, no effect with --max_line_width) the maximum number of words in a segment")
- parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
- parser.add_argument("--clip_timestamps", type=str, default="0", help="comma-separated list start,end,start,end,... timestamps (in seconds) of clips to process, where the last end timestamp defaults to the end of the file")
- parser.add_argument("--hallucination_silence_threshold", type=optional_float, help="(requires --word_timestamps True) skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected")
- # fmt: on
-
- args = parser.parse_args().__dict__
- model_name: str = args.pop("model")
- model_dir: str = args.pop("model_dir")
- output_dir: str = args.pop("output_dir")
- output_format: str = args.pop("output_format")
- device: str = args.pop("device")
- os.makedirs(output_dir, exist_ok=True)
-
- if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
- if args["language"] is not None:
- warnings.warn(
- f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
- )
- args["language"] = "en"
-
- temperature = args.pop("temperature")
- if (increment := args.pop("temperature_increment_on_fallback")) is not None:
- temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
- else:
- temperature = [temperature]
-
- if (threads := args.pop("threads")) > 0:
- torch.set_num_threads(threads)
-
- from . import load_model
-
- model = load_model(model_name, device=device, download_root=model_dir)
-
- writer = get_writer(output_format, output_dir)
- word_options = [
- "highlight_words",
- "max_line_count",
- "max_line_width",
- "max_words_per_line",
- ]
- if not args["word_timestamps"]:
- for option in word_options:
- if args[option]:
- parser.error(f"--{option} requires --word_timestamps True")
- if args["max_line_count"] and not args["max_line_width"]:
- warnings.warn("--max_line_count has no effect without --max_line_width")
- if args["max_words_per_line"] and args["max_line_width"]:
- warnings.warn("--max_words_per_line has no effect with --max_line_width")
- writer_args = {arg: args.pop(arg) for arg in word_options}
- for audio_path in args.pop("audio"):
- try:
- result = transcribe(model, audio_path, temperature=temperature, **args)
- writer(result, audio_path, **writer_args)
- except Exception as e:
- traceback.print_exc()
- print(f"Skipping {audio_path} due to {type(e).__name__}: {str(e)}")
-
-
-if __name__ == "__main__":
- cli()
diff --git a/funasr/models/sense_voice/whisper_lib/triton_ops.py b/funasr/models/sense_voice/whisper_lib/triton_ops.py
deleted file mode 100644
index 9919595..0000000
--- a/funasr/models/sense_voice/whisper_lib/triton_ops.py
+++ /dev/null
@@ -1,102 +0,0 @@
-from functools import lru_cache
-
-import numpy as np
-import torch
-
-try:
- import triton
- import triton.language as tl
-except ImportError:
- raise RuntimeError("triton import failed; try `pip install --pre triton`")
-
-
-@triton.jit
-def dtw_kernel(cost, trace, x, x_stride, cost_stride, trace_stride, N, M, BLOCK_SIZE: tl.constexpr):
- offsets = tl.arange(0, BLOCK_SIZE)
- mask = offsets < M
-
- for k in range(1, N + M + 1): # k = i + j
- tl.debug_barrier()
-
- p0 = cost + (k - 1) * cost_stride
- p1 = cost + k * cost_stride
- p2 = cost + k * cost_stride + 1
-
- c0 = tl.load(p0 + offsets, mask=mask)
- c1 = tl.load(p1 + offsets, mask=mask)
- c2 = tl.load(p2 + offsets, mask=mask)
-
- x_row = tl.load(x + (k - 1) * x_stride + offsets, mask=mask, other=0)
- cost_row = x_row + tl.minimum(tl.minimum(c0, c1), c2)
-
- cost_ptr = cost + (k + 1) * cost_stride + 1
- tl.store(cost_ptr + offsets, cost_row, mask=mask)
-
- trace_ptr = trace + (k + 1) * trace_stride + 1
- tl.store(trace_ptr + offsets, 2, mask=mask & (c2 <= c0) & (c2 <= c1))
- tl.store(trace_ptr + offsets, 1, mask=mask & (c1 <= c0) & (c1 <= c2))
- tl.store(trace_ptr + offsets, 0, mask=mask & (c0 <= c1) & (c0 <= c2))
-
-
-@lru_cache(maxsize=None)
-def median_kernel(filter_width: int):
- @triton.jit
- def kernel(y, x, x_stride, y_stride, BLOCK_SIZE: tl.constexpr): # x.shape[-1] == filter_width
- row_idx = tl.program_id(0)
- offsets = tl.arange(0, BLOCK_SIZE)
- mask = offsets < y_stride
-
- x_ptr = x + row_idx * x_stride # noqa: F841
- y_ptr = y + row_idx * y_stride
-
- LOAD_ALL_ROWS_HERE # noqa: F821
-
- BUBBLESORT_HERE # noqa: F821
-
- tl.store(y_ptr + offsets, MIDDLE_ROW_HERE, mask=mask) # noqa: F821
-
- kernel = triton.JITFunction(kernel.fn)
- kernel.src = kernel.src.replace(
- " LOAD_ALL_ROWS_HERE",
- "\n".join(
- [f" row{i} = tl.load(x_ptr + offsets + {i}, mask=mask)" for i in range(filter_width)]
- ),
- )
- kernel.src = kernel.src.replace(
- " BUBBLESORT_HERE",
- "\n\n".join(
- [
- "\n\n".join(
- [
- "\n".join(
- [
- f" smaller = tl.where(row{j} < row{j + 1}, row{j}, row{j + 1})",
- f" larger = tl.where(row{j} > row{j + 1}, row{j}, row{j + 1})",
- f" row{j} = smaller",
- f" row{j + 1} = larger",
- ]
- )
- for j in range(filter_width - i - 1)
- ]
- )
- for i in range(filter_width // 2 + 1)
- ]
- ),
- )
- kernel.src = kernel.src.replace("MIDDLE_ROW_HERE", f"row{filter_width // 2}")
-
- return kernel
-
-
-def median_filter_cuda(x: torch.Tensor, filter_width: int):
- """Apply a median filter of given width along the last dimension of x"""
- slices = x.contiguous().unfold(-1, filter_width, 1)
- grid = np.prod(slices.shape[:-2])
-
- kernel = median_kernel(filter_width)
- y = torch.empty_like(slices[..., 0])
-
- BLOCK_SIZE = 1 << (y.stride(-2) - 1).bit_length()
- kernel[(grid,)](y, x, x.stride(-2), y.stride(-2), BLOCK_SIZE=BLOCK_SIZE)
-
- return y
diff --git a/funasr/models/sense_voice/whisper_lib/utils.py b/funasr/models/sense_voice/whisper_lib/utils.py
deleted file mode 100644
index 5fc6125..0000000
--- a/funasr/models/sense_voice/whisper_lib/utils.py
+++ /dev/null
@@ -1,283 +0,0 @@
-import json
-import os
-import re
-import sys
-import zlib
-from typing import Callable, List, Optional, TextIO
-
-system_encoding = sys.getdefaultencoding()
-
-if system_encoding != "utf-8":
-
- def make_safe(string):
- # replaces any character not representable using the system default encoding with an '?',
- # avoiding UnicodeEncodeError (https://github.com/openai/whisper/discussions/729).
- return string.encode(system_encoding, errors="replace").decode(system_encoding)
-
-else:
-
- def make_safe(string):
- # utf-8 can encode any Unicode code point, so no need to do the round-trip encoding
- return string
-
-
-def exact_div(x, y):
- assert x % y == 0
- return x // y
-
-
-def str2bool(string):
- str2val = {"True": True, "False": False}
- if string in str2val:
- return str2val[string]
- else:
- raise ValueError(f"Expected one of {set(str2val.keys())}, got {string}")
-
-
-def optional_int(string):
- return None if string == "None" else int(string)
-
-
-def optional_float(string):
- return None if string == "None" else float(string)
-
-
-def compression_ratio(text) -> float:
- text_bytes = text.encode("utf-8")
- return len(text_bytes) / len(zlib.compress(text_bytes))
-
-
-def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
- assert seconds >= 0, "non-negative timestamp expected"
- milliseconds = round(seconds * 1000.0)
-
- hours = milliseconds // 3_600_000
- milliseconds -= hours * 3_600_000
-
- minutes = milliseconds // 60_000
- milliseconds -= minutes * 60_000
-
- seconds = milliseconds // 1_000
- milliseconds -= seconds * 1_000
-
- hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
- return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
-
-
-def get_start(segments: List[dict]) -> Optional[float]:
- return next(
- (w["start"] for s in segments for w in s["words"]),
- segments[0]["start"] if segments else None,
- )
-
-
-def get_end(segments: List[dict]) -> Optional[float]:
- return next(
- (w["end"] for s in reversed(segments) for w in reversed(s["words"])),
- segments[-1]["end"] if segments else None,
- )
-
-
-class ResultWriter:
- extension: str
-
- def __init__(self, output_dir: str):
- self.output_dir = output_dir
-
- def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
- audio_basename = os.path.basename(audio_path)
- audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
-
- with open(output_path, "w", encoding="utf-8") as f:
- self.write_result(result, file=f, options=options, **kwargs)
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- raise NotImplementedError
-
-
-class WriteTXT(ResultWriter):
- extension: str = "txt"
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- for segment in result["segments"]:
- print(segment["text"].strip(), file=file, flush=True)
-
-
-class SubtitlesWriter(ResultWriter):
- always_include_hours: bool
- decimal_marker: str
-
- def iterate_result(
- self,
- result: dict,
- options: Optional[dict] = None,
- *,
- max_line_width: Optional[int] = None,
- max_line_count: Optional[int] = None,
- highlight_words: bool = False,
- max_words_per_line: Optional[int] = None,
- ):
- options = options or {}
- max_line_width = max_line_width or options.get("max_line_width")
- max_line_count = max_line_count or options.get("max_line_count")
- highlight_words = highlight_words or options.get("highlight_words", False)
- max_words_per_line = max_words_per_line or options.get("max_words_per_line")
- preserve_segments = max_line_count is None or max_line_width is None
- max_line_width = max_line_width or 1000
- max_words_per_line = max_words_per_line or 1000
-
- def iterate_subtitles():
- line_len = 0
- line_count = 1
- # the next subtitle to yield (a list of word timings with whitespace)
- subtitle: List[dict] = []
- last: float = get_start(result["segments"]) or 0.0
- for segment in result["segments"]:
- chunk_index = 0
- words_count = max_words_per_line
- while chunk_index < len(segment["words"]):
- remaining_words = len(segment["words"]) - chunk_index
- if max_words_per_line > len(segment["words"]) - chunk_index:
- words_count = remaining_words
- for i, original_timing in enumerate(
- segment["words"][chunk_index : chunk_index + words_count]
- ):
- timing = original_timing.copy()
- long_pause = not preserve_segments and timing["start"] - last > 3.0
- has_room = line_len + len(timing["word"]) <= max_line_width
- seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
- if line_len > 0 and has_room and not long_pause and not seg_break:
- # line continuation
- line_len += len(timing["word"])
- else:
- # new line
- timing["word"] = timing["word"].strip()
- if (
- len(subtitle) > 0
- and max_line_count is not None
- and (long_pause or line_count >= max_line_count)
- or seg_break
- ):
- # subtitle break
- yield subtitle
- subtitle = []
- line_count = 1
- elif line_len > 0:
- # line break
- line_count += 1
- timing["word"] = "\n" + timing["word"]
- line_len = len(timing["word"].strip())
- subtitle.append(timing)
- last = timing["start"]
- chunk_index += max_words_per_line
- if len(subtitle) > 0:
- yield subtitle
-
- if len(result["segments"]) > 0 and "words" in result["segments"][0]:
- for subtitle in iterate_subtitles():
- subtitle_start = self.format_timestamp(subtitle[0]["start"])
- subtitle_end = self.format_timestamp(subtitle[-1]["end"])
- subtitle_text = "".join([word["word"] for word in subtitle])
- if highlight_words:
- last = subtitle_start
- all_words = [timing["word"] for timing in subtitle]
- for i, this_word in enumerate(subtitle):
- start = self.format_timestamp(this_word["start"])
- end = self.format_timestamp(this_word["end"])
- if last != start:
- yield last, start, subtitle_text
-
- yield start, end, "".join(
- [
- re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) if j == i else word
- for j, word in enumerate(all_words)
- ]
- )
- last = end
- else:
- yield subtitle_start, subtitle_end, subtitle_text
- else:
- for segment in result["segments"]:
- segment_start = self.format_timestamp(segment["start"])
- segment_end = self.format_timestamp(segment["end"])
- segment_text = segment["text"].strip().replace("-->", "->")
- yield segment_start, segment_end, segment_text
-
- def format_timestamp(self, seconds: float):
- return format_timestamp(
- seconds=seconds,
- always_include_hours=self.always_include_hours,
- decimal_marker=self.decimal_marker,
- )
-
-
-class WriteVTT(SubtitlesWriter):
- extension: str = "vtt"
- always_include_hours: bool = False
- decimal_marker: str = "."
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- print("WEBVTT\n", file=file)
- for start, end, text in self.iterate_result(result, options, **kwargs):
- print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
-
-
-class WriteSRT(SubtitlesWriter):
- extension: str = "srt"
- always_include_hours: bool = True
- decimal_marker: str = ","
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- for i, (start, end, text) in enumerate(
- self.iterate_result(result, options, **kwargs), start=1
- ):
- print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
-
-
-class WriteTSV(ResultWriter):
- """
- Write a transcript to a file in TSV (tab-separated values) format containing lines like:
- <start time in integer milliseconds>\t<end time in integer milliseconds>\t<transcript text>
-
- Using integer milliseconds as start and end times means there's no chance of interference from
- an environment setting a language encoding that causes the decimal in a floating point number
- to appear as a comma; also is faster and more efficient to parse & store, e.g., in C++.
- """
-
- extension: str = "tsv"
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- print("start", "end", "text", sep="\t", file=file)
- for segment in result["segments"]:
- print(round(1000 * segment["start"]), file=file, end="\t")
- print(round(1000 * segment["end"]), file=file, end="\t")
- print(segment["text"].strip().replace("\t", " "), file=file, flush=True)
-
-
-class WriteJSON(ResultWriter):
- extension: str = "json"
-
- def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- json.dump(result, file)
-
-
-def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO, dict], None]:
- writers = {
- "txt": WriteTXT,
- "vtt": WriteVTT,
- "srt": WriteSRT,
- "tsv": WriteTSV,
- "json": WriteJSON,
- }
-
- if output_format == "all":
- all_writers = [writer(output_dir) for writer in writers.values()]
-
- def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
- for writer in all_writers:
- writer(result, file, options, **kwargs)
-
- return write_all
-
- return writers[output_format](output_dir)
diff --git a/funasr/models/sense_voice/whisper_lib/version.py b/funasr/models/sense_voice/whisper_lib/version.py
deleted file mode 100644
index c96dd9c..0000000
--- a/funasr/models/sense_voice/whisper_lib/version.py
+++ /dev/null
@@ -1 +0,0 @@
-__version__ = "20231117"
--
Gitblit v1.9.1