From 26a2a232a94c4a729733d83e8175a16e3f8db481 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 23 十月 2023 16:42:43 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/build_utils/build_asr_model.py                                                                   |    2 
 funasr/modules/rwkv.py                                                                                  |  145 ++++++
 funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp                                                |   32 +
 funasr/modules/rwkv_feed_forward.py                                                                     |   97 ++++
 funasr/modules/rwkv_attention.py                                                                        |  632 ++++++++++++++++++++++++++
 funasr/utils/whisper_utils/__init__.py                                                                  |    0 
 funasr/models/whisper_models/__init__.py                                                                |    0 
 funasr/runtime/docs/SDK_advanced_guide_offline.md                                                       |    4 
 funasr/runtime/docs/SDK_advanced_guide_offline_en.md                                                    |    2 
 funasr/runtime/docs/SDK_advanced_guide_offline_zh.md                                                    |    4 
 funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java |   31 
 funasr/modules/rwkv_subsampling.py                                                                      |  190 +++++++
 funasr/runtime/docs/SDK_advanced_guide_online_zh.md                                                     |    7 
 funasr/models/encoder/rwkv_encoder.py                                                                   |  155 ++++++
 funasr/runtime/docs/SDK_advanced_guide_online.md                                                        |    4 
 funasr/runtime/websocket/bin/websocket-server.cpp                                                       |   27 
 funasr/runtime/websocket/hotwords.txt                                                                   |    2 
 funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py                                    |    2 
 funasr/runtime/run_server.sh                                                                            |    7 
 funasr/runtime/run_server_2pass.sh                                                                      |    7 
 funasr/runtime/websocket/bin/funasr-wss-server.cpp                                                      |   32 +
 funasr/runtime/websocket/bin/websocket-server-2pass.cpp                                                 |   28 
 22 files changed, 1,373 insertions(+), 37 deletions(-)

diff --git a/funasr/build_utils/build_asr_model.py b/funasr/build_utils/build_asr_model.py
index 5e93444..fd47bd3 100644
--- a/funasr/build_utils/build_asr_model.py
+++ b/funasr/build_utils/build_asr_model.py
@@ -42,6 +42,7 @@
 from funasr.models.encoder.branchformer_encoder import BranchformerEncoder
 from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder
 from funasr.models.encoder.transformer_encoder import TransformerEncoder
+from funasr.models.encoder.rwkv_encoder import RWKVEncoder
 from funasr.models.frontend.default import DefaultFrontend
 from funasr.models.frontend.default import MultiChannelFrontend
 from funasr.models.frontend.fused import FusedFrontends
@@ -119,6 +120,7 @@
         e_branchformer=EBranchformerEncoder,
         mfcca_enc=MFCCAEncoder,
         chunk_conformer=ConformerChunkEncoder,
+        rwkv=RWKVEncoder,
     ),
     default="rnn",
 )
diff --git a/funasr/models/encoder/rwkv_encoder.py b/funasr/models/encoder/rwkv_encoder.py
new file mode 100644
index 0000000..8a33520
--- /dev/null
+++ b/funasr/models/encoder/rwkv_encoder.py
@@ -0,0 +1,155 @@
+"""RWKV encoder definition for Transducer models."""
+
+import math
+from typing import Dict, List, Optional, Tuple
+
+import torch
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.modules.rwkv import RWKV
+from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.rwkv_subsampling import RWKVConvInput
+from funasr.modules.nets_utils import make_source_mask
+
+class RWKVEncoder(AbsEncoder):
+    """RWKV encoder module.
+
+    Based on https://arxiv.org/pdf/2305.13048.pdf.
+
+    Args:
+        vocab_size: Vocabulary size.
+        output_size: Input/Output size.
+        context_size: Context size for WKV computation.
+        linear_size: FeedForward hidden size.
+        attention_size: SelfAttention hidden size.
+        normalization_type: Normalization layer type.
+        normalization_args: Normalization layer arguments.
+        num_blocks: Number of RWKV blocks.
+        embed_dropout_rate: Dropout rate for embedding layer.
+        att_dropout_rate: Dropout rate for the attention module.
+        ffn_dropout_rate: Dropout rate for the feed-forward module.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        output_size: int = 512,
+        context_size: int = 1024,
+        linear_size: Optional[int] = None,
+        attention_size: Optional[int] = None,
+        num_blocks: int = 4,
+        att_dropout_rate: float = 0.0,
+        ffn_dropout_rate: float = 0.0,
+        dropout_rate: float = 0.0,
+        subsampling_factor: int =4,
+        time_reduction_factor: int = 1,
+        kernel: int = 3,
+    ) -> None:
+        """Construct a RWKVEncoder object."""
+        super().__init__()
+
+        self.embed = RWKVConvInput(
+            input_size,
+            [output_size//4, output_size//2, output_size],
+            subsampling_factor,
+            conv_kernel_size=kernel,
+            output_size=output_size,
+        )
+
+        self.subsampling_factor = subsampling_factor
+
+        linear_size = output_size * 4 if linear_size is None else linear_size
+        attention_size = output_size if attention_size is None else attention_size
+        
+        self.rwkv_blocks = torch.nn.ModuleList(
+            [
+                RWKV(
+                    output_size,
+                    linear_size,
+                    attention_size,
+                    context_size,
+                    block_id,
+                    num_blocks,
+                    att_dropout_rate=att_dropout_rate,
+                    ffn_dropout_rate=ffn_dropout_rate,
+                    dropout_rate=dropout_rate,
+                )
+                for block_id in range(num_blocks)
+            ]
+        )
+
+        self.embed_norm = LayerNorm(output_size)
+        self.final_norm = LayerNorm(output_size)
+
+        self._output_size = output_size
+        self.context_size = context_size
+
+        self.num_blocks = num_blocks
+        self.time_reduction_factor = time_reduction_factor
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(self, x: torch.Tensor, x_len) -> torch.Tensor:
+        """Encode source label sequences.
+
+        Args:
+            x: Encoder input sequences. (B, L)
+
+        Returns:
+            out: Encoder output sequences. (B, U, D)
+
+        """
+        _, length, _ = x.size()
+
+        assert (
+            length <= self.context_size * self.subsampling_factor
+        ), "Context size is too short for current length: %d versus %d" % (
+            length,
+            self.context_size * self.subsampling_factor,
+        )
+        mask = make_source_mask(x_len).to(x.device)
+        x, mask = self.embed(x, mask, None)
+        x = self.embed_norm(x)
+        olens = mask.eq(0).sum(1)
+
+        for block in self.rwkv_blocks:
+            x, _ = block(x)
+        # for streaming inference
+        # xs_pad = self.rwkv_infer(xs_pad)
+
+        x = self.final_norm(x)
+
+        if self.time_reduction_factor > 1:
+            x = x[:,::self.time_reduction_factor,:]
+            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+        return x, olens, None
+
+    def rwkv_infer(self, xs_pad):
+
+        batch_size = xs_pad.shape[0]
+
+        hidden_sizes = [
+            self._output_size for i in range(5)
+        ]
+
+        state = [
+            torch.zeros(
+                (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks),
+                dtype=torch.float32,
+                device=self.device,
+            )
+            for i in range(5)
+        ]
+
+        state[4] -= 1e-30
+
+        xs_out = []
+        for t in range(xs_pad.shape[1]):
+            x_t = xs_pad[:,t,:]
+            for idx, block in enumerate(self.rwkv_blocks):
+                x_t, state = block(x_t, state=state)
+            xs_out.append(x_t)
+        xs_out = torch.stack(xs_out, dim=1)
+        return xs_out
diff --git a/funasr/models/whisper_models/__init__.py b/funasr/models/whisper_models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/whisper_models/__init__.py
diff --git a/funasr/modules/rwkv.py b/funasr/modules/rwkv.py
new file mode 100644
index 0000000..f020828
--- /dev/null
+++ b/funasr/modules/rwkv.py
@@ -0,0 +1,145 @@
+"""Receptance Weighted Key Value (RWKV) block definition.
+
+Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
+
+"""
+
+from typing import Dict, Optional, Tuple
+
+import torch
+
+from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention
+from funasr.modules.rwkv_feed_forward import FeedForward
+from funasr.modules.layer_norm import LayerNorm
+
+class RWKV(torch.nn.Module):
+    """RWKV module.
+
+    Args:
+        size: Input/Output size.
+        linear_size: Feed-forward hidden size.
+        attention_size: SelfAttention hidden size.
+        context_size: Context size for WKV computation.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+        normalization_class: Normalization layer class.
+        normalization_args: Normalization layer arguments.
+        att_dropout_rate: Dropout rate for the attention module.
+        ffn_dropout_rate: Dropout rate for the feed-forward module.
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        linear_size: int,
+        attention_size: int,
+        context_size: int,
+        block_id: int,
+        num_blocks: int,
+        att_dropout_rate: float = 0.0,
+        ffn_dropout_rate: float = 0.0,
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a RWKV object."""
+        super().__init__()
+
+        self.layer_norm_att = LayerNorm(size)
+        self.layer_norm_ffn = LayerNorm(size)
+
+        self.att = EncoderSelfAttention(
+            size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
+        )
+        self.dropout_att = torch.nn.Dropout(p=dropout_rate)
+
+        self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
+        self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        state: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Compute receptance weighted key value.
+
+        Args:
+            x: RWKV input sequences. (B, L, size)
+            state: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+        Returns:
+            x: RWKV output sequences. (B, L, size)
+            x: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+        """
+        att, state = self.att(self.layer_norm_att(x), state=state)
+        x = x + self.dropout_att(att)
+        ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
+        x = x + self.dropout_ffn(ffn)
+        return x, state
+
+class RWKVDecoderLayer(torch.nn.Module):
+    """RWKV module.
+
+    Args:
+        size: Input/Output size.
+        linear_size: Feed-forward hidden size.
+        attention_size: SelfAttention hidden size.
+        context_size: Context size for WKV computation.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+        normalization_class: Normalization layer class.
+        normalization_args: Normalization layer arguments.
+        att_dropout_rate: Dropout rate for the attention module.
+        ffn_dropout_rate: Dropout rate for the feed-forward module.
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        linear_size: int,
+        attention_size: int,
+        context_size: int,
+        block_id: int,
+        num_blocks: int,
+        att_dropout_rate: float = 0.0,
+        ffn_dropout_rate: float = 0.0,
+        dropout_rate: float = 0.0,
+    ) -> None:
+        """Construct a RWKV object."""
+        super().__init__()
+
+        self.layer_norm_att = LayerNorm(size)
+        self.layer_norm_ffn = LayerNorm(size)
+
+        self.att = DecoderSelfAttention(
+            size, attention_size, context_size, block_id, att_dropout_rate, num_blocks
+        )
+        self.dropout_att = torch.nn.Dropout(p=dropout_rate)
+
+        self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks)
+        self.dropout_ffn = torch.nn.Dropout(p=dropout_rate)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        state: Optional[torch.Tensor] = None,
+    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+        """Compute receptance weighted key value.
+
+        Args:
+            x: RWKV input sequences. (B, L, size)
+            state: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+        Returns:
+            x: RWKV output sequences. (B, L, size)
+            x: Decoder hidden states. [5 x (B, D_att/size, N)]
+
+        """
+        att, state = self.att(self.layer_norm_att(x), state=state)
+        x = x + self.dropout_att(att)
+
+        ffn, state = self.ffn(self.layer_norm_ffn(x), state=state)
+        x = x + self.dropout_ffn(ffn)
+
+        return x, state
diff --git a/funasr/modules/rwkv_attention.py b/funasr/modules/rwkv_attention.py
new file mode 100644
index 0000000..f0c7da3
--- /dev/null
+++ b/funasr/modules/rwkv_attention.py
@@ -0,0 +1,632 @@
+"""Attention (time mixing) modules for RWKV block.
+
+Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py.
+
+Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
+
+"""  # noqa
+
+import math
+from importlib.util import find_spec
+from pathlib import Path
+from typing import List, Optional, Tuple, Union
+
+import torch
+
+wkv_kernel_encoder = None
+wkv_kernel_decoder = None
+
+class WKVLinearAttentionEncoder(torch.autograd.Function):
+    """WKVLinearAttention function definition."""
+
+    @staticmethod
+    def forward(
+        ctx,
+        time_decay: torch.Tensor,
+        time_first: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.tensor,
+    ) -> torch.Tensor:
+        """WKVLinearAttention function forward pass.
+
+        Args:
+            time_decay: Channel-wise time decay vector. (D_att)
+            time_first: Channel-wise time first vector. (D_att)
+            key: Key tensor. (B, U, D_att)
+            value: Value tensor. (B, U, D_att)
+
+        Returns:
+            out: Weighted Key-Value tensor. (B, U, D_att)
+
+        """
+        batch, length, dim = key.size()
+
+        assert length <= wkv_kernel_encoder.context_size, (
+            f"Cannot process key of length {length} while context_size "
+            f"is ({wkv_kernel_encoder.context_size}). Limit should be increased."
+        )
+
+        assert batch * dim % min(dim, 32) == 0, (
+            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
+            f"{min(dim, 32)}"
+        )
+
+        ctx.input_dtype = key.dtype
+
+        time_decay = -torch.exp(time_decay.float().contiguous())
+        time_first = time_first.float().contiguous()
+
+        key = key.float().contiguous()
+        value = value.float().contiguous()
+
+        out = torch.empty_like(key, memory_format=torch.contiguous_format)
+
+        wkv_kernel_encoder.forward(time_decay, time_first, key, value, out)
+        ctx.save_for_backward(time_decay, time_first, key, value, out)
+
+        return out
+
+    @staticmethod
+    def backward(
+        ctx, grad_output: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        """WKVLinearAttention function backward pass.
+
+        Args:
+            grad_output: Output gradient. (B, U, D_att)
+
+        Returns:
+            grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
+            grad_time_first: Gradient for channel-wise time first vector. (D_att)
+            grad_key: Gradient for key tensor. (B, U, D_att)
+            grad_value: Gradient for value tensor. (B, U, D_att)
+
+        """
+        time_decay, time_first, key, value, output = ctx.saved_tensors
+        grad_dtype = ctx.input_dtype
+
+        batch, _, dim = key.size()
+
+        grad_time_decay = torch.empty(
+            (batch, dim),
+            memory_format=torch.contiguous_format,
+            dtype=time_decay.dtype,
+            device=time_decay.device,
+        )
+
+        grad_time_first = torch.empty(
+            (batch, dim),
+            memory_format=torch.contiguous_format,
+            dtype=time_decay.dtype,
+            device=time_decay.device,
+        )
+
+        grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
+        grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
+
+        wkv_kernel_encoder.backward(
+            time_decay,
+            time_first,
+            key,
+            value,
+            output,
+            grad_output.contiguous(),
+            grad_time_decay,
+            grad_time_first,
+            grad_key,
+            grad_value,
+        )
+
+        grad_time_decay = torch.sum(grad_time_decay, dim=0)
+        grad_time_first = torch.sum(grad_time_first, dim=0)
+
+        return (
+            grad_time_decay,
+            grad_time_first,
+            grad_key,
+            grad_value,
+        )
+
+class WKVLinearAttentionDecoder(torch.autograd.Function):
+    """WKVLinearAttention function definition."""
+
+    @staticmethod
+    def forward(
+        ctx,
+        time_decay: torch.Tensor,
+        time_first: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.tensor,
+    ) -> torch.Tensor:
+        """WKVLinearAttention function forward pass.
+
+        Args:
+            time_decay: Channel-wise time decay vector. (D_att)
+            time_first: Channel-wise time first vector. (D_att)
+            key: Key tensor. (B, U, D_att)
+            value: Value tensor. (B, U, D_att)
+
+        Returns:
+            out: Weighted Key-Value tensor. (B, U, D_att)
+
+        """
+        batch, length, dim = key.size()
+
+        assert length <= wkv_kernel_decoder.context_size, (
+            f"Cannot process key of length {length} while context_size "
+            f"is ({wkv_kernel.context_size}). Limit should be increased."
+        )
+
+        assert batch * dim % min(dim, 32) == 0, (
+            f"batch size ({batch}) by dimension ({dim}) should be a multiple of "
+            f"{min(dim, 32)}"
+        )
+
+        ctx.input_dtype = key.dtype
+
+        time_decay = -torch.exp(time_decay.float().contiguous())
+        time_first = time_first.float().contiguous()
+
+        key = key.float().contiguous()
+        value = value.float().contiguous()
+
+        out = torch.empty_like(key, memory_format=torch.contiguous_format)
+
+        wkv_kernel_decoder.forward(time_decay, time_first, key, value, out)
+        ctx.save_for_backward(time_decay, time_first, key, value, out)
+
+        return out
+
+    @staticmethod
+    def backward(
+        ctx, grad_output: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+        """WKVLinearAttention function backward pass.
+
+        Args:
+            grad_output: Output gradient. (B, U, D_att)
+
+        Returns:
+            grad_time_decay: Gradient for channel-wise time decay vector. (D_att)
+            grad_time_first: Gradient for channel-wise time first vector. (D_att)
+            grad_key: Gradient for key tensor. (B, U, D_att)
+            grad_value: Gradient for value tensor. (B, U, D_att)
+
+        """
+        time_decay, time_first, key, value, output = ctx.saved_tensors
+        grad_dtype = ctx.input_dtype
+
+        batch, _, dim = key.size()
+
+        grad_time_decay = torch.empty(
+            (batch, dim),
+            memory_format=torch.contiguous_format,
+            dtype=time_decay.dtype,
+            device=time_decay.device,
+        )
+
+        grad_time_first = torch.empty(
+            (batch, dim),
+            memory_format=torch.contiguous_format,
+            dtype=time_decay.dtype,
+            device=time_decay.device,
+        )
+
+        grad_key = torch.empty_like(key, memory_format=torch.contiguous_format)
+        grad_value = torch.empty_like(value, memory_format=torch.contiguous_format)
+
+        wkv_kernel_decoder.backward(
+            time_decay,
+            time_first,
+            key,
+            value,
+            output,
+            grad_output.contiguous(),
+            grad_time_decay,
+            grad_time_first,
+            grad_key,
+            grad_value,
+        )
+
+        grad_time_decay = torch.sum(grad_time_decay, dim=0)
+        grad_time_first = torch.sum(grad_time_first, dim=0)
+
+        return (
+            grad_time_decay,
+            grad_time_first,
+            grad_key,
+            grad_value,
+        )
+
+def load_encoder_wkv_kernel(context_size: int) -> None:
+    """Load WKV CUDA kernel.
+
+    Args:
+        context_size: Context size.
+
+    """
+    from torch.utils.cpp_extension import load
+
+    global wkv_kernel_encoder
+
+    if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size:
+        return
+
+    if find_spec("ninja") is None:
+        raise ImportError(
+            "Ninja package was not found. WKV kernel module can't be loaded "
+            "for training. Please, 'pip install ninja' in your environment."
+        )
+
+    if not torch.cuda.is_available():
+        raise ImportError(
+            "CUDA is currently a requirement for WKV kernel loading. "
+            "Please set your devices properly and launch again."
+        )
+
+    kernel_folder = Path(__file__).resolve().parent / "cuda_encoder"
+    kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
+
+    kernel_cflags = [
+        "-res-usage",
+        "--maxrregcount 60",
+        "--use_fast_math",
+        "-O3",
+        "-Xptxas -O3",
+        f"-DTmax={context_size}",
+    ]
+    wkv_kernel_encoder = load(
+        name=f"encoder_wkv_{context_size}",
+        sources=kernel_files,
+        verbose=True,
+        extra_cuda_cflags=kernel_cflags,
+    )
+    wkv_kernel_encoder.context_size = context_size
+
+def load_decoder_wkv_kernel(context_size: int) -> None:
+    """Load WKV CUDA kernel.
+
+    Args:
+        context_size: Context size.
+
+    """
+    from torch.utils.cpp_extension import load
+
+    global wkv_kernel_decoder
+
+    if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size:
+        return
+
+    if find_spec("ninja") is None:
+        raise ImportError(
+            "Ninja package was not found. WKV kernel module can't be loaded "
+            "for training. Please, 'pip install ninja' in your environment."
+        )
+
+    if not torch.cuda.is_available():
+        raise ImportError(
+            "CUDA is currently a requirement for WKV kernel loading. "
+            "Please set your devices properly and launch again."
+        )
+
+    kernel_folder = Path(__file__).resolve().parent / "cuda_decoder"
+    kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]]
+
+    kernel_cflags = [
+        "-res-usage",
+        "--maxrregcount 60",
+        "--use_fast_math",
+        "-O3",
+        "-Xptxas -O3",
+        f"-DTmax={context_size}",
+    ]
+    wkv_kernel_decoder = load(
+        name=f"decoder_wkv_{context_size}",
+        sources=kernel_files,
+        verbose=True,
+        extra_cuda_cflags=kernel_cflags,
+    )
+    wkv_kernel_decoder.context_size = context_size
+
+class SelfAttention(torch.nn.Module):
+    """SelfAttention module definition.
+
+    Args:
+        size: Input/Output size.
+        attention_size: Attention hidden size.
+        context_size: Context size for WKV kernel.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        attention_size: int,
+        block_id: int,
+        dropout_rate: float,
+        num_blocks: int,
+    ) -> None:
+        """Construct a SelfAttention object."""
+        super().__init__()
+        self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
+
+        self.time_decay = torch.nn.Parameter(torch.empty(attention_size))
+        self.time_first = torch.nn.Parameter(torch.empty(attention_size))
+
+        self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
+        self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size))
+        self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
+
+        self.proj_key = torch.nn.Linear(size, attention_size, bias=True)
+        self.proj_value = torch.nn.Linear(size, attention_size, bias=True)
+        self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True)
+
+        self.proj_output = torch.nn.Linear(attention_size, size, bias=True)
+
+        self.block_id = block_id
+
+        self.reset_parameters(size, attention_size, block_id, num_blocks)
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+    def reset_parameters(
+        self, size: int, attention_size: int, block_id: int, num_blocks: int
+    ) -> None:
+        """Reset module parameters.
+
+        Args:
+            size: Block size.
+            attention_size: Attention hidden size.
+            block_id: Block index.
+            num_blocks: Number of blocks in the architecture.
+
+        """
+        ratio_0_to_1 = block_id / (num_blocks - 1)
+        ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
+
+        time_weight = torch.ones(1, 1, size)
+
+        for i in range(size):
+            time_weight[0, 0, i] = i / size
+
+        decay_speed = [
+            -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1)
+            for h in range(attention_size)
+        ]
+        decay_speed = torch.tensor(
+            decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device
+        )
+
+        zigzag = (
+            torch.tensor(
+                [(i + 1) % 3 - 1 for i in range(attention_size)],
+                dtype=self.time_first.dtype,
+                device=self.time_first.device,
+            )
+            * 0.5
+        )
+
+        with torch.no_grad():
+            self.time_decay.data = decay_speed
+            self.time_first.data = torch.ones_like(
+                self.time_first * math.log(0.3) + zigzag
+            )
+
+            self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
+            self.time_mix_value.data = (
+                torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1
+            )
+            self.time_mix_receptance.data = torch.pow(
+                time_weight, 0.5 * ratio_1_to_almost0
+            )
+
+    @torch.no_grad()
+    def wkv_linear_attention(
+        self,
+        time_decay: torch.Tensor,
+        time_first: torch.Tensor,
+        key: torch.Tensor,
+        value: torch.Tensor,
+        state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
+    ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
+        """Compute WKV with state (i.e.: for inference).
+
+        Args:
+            time_decay: Channel-wise time decay vector. (D_att)
+            time_first: Channel-wise time first vector. (D_att)
+            key: Key tensor. (B, 1, D_att)
+            value: Value tensor. (B, 1, D_att)
+            state: Decoder hidden states. [3 x (B, D_att)]
+
+        Returns:
+            output: Weighted Key-Value. (B, 1, D_att)
+            state: Decoder hidden states. [3 x (B, 1, D_att)]
+
+        """
+        num_state, den_state, max_state = state
+
+        max_for_output = torch.maximum(max_state, (time_first + key))
+
+        e1 = torch.exp(max_state - max_for_output)
+        e2 = torch.exp((time_first + key) - max_for_output)
+
+        numerator = e1 * num_state + e2 * value
+        denominator = e1 * den_state + e2
+
+        max_for_state = torch.maximum(key, (max_state + time_decay))
+
+        e1 = torch.exp((max_state + time_decay) - max_for_state)
+        e2 = torch.exp(key - max_for_state)
+
+        wkv = numerator / denominator
+
+        state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state]
+
+        return wkv, state
+
+
+class DecoderSelfAttention(SelfAttention):
+    """SelfAttention module definition.
+
+    Args:
+        size: Input/Output size.
+        attention_size: Attention hidden size.
+        context_size: Context size for WKV kernel.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        attention_size: int,
+        context_size: int,
+        block_id: int,
+        dropout_rate: float,
+        num_blocks: int,
+    ) -> None:
+        """Construct a SelfAttention object."""
+        super().__init__(
+            size,
+            attention_size,
+            block_id,
+            dropout_rate,
+            num_blocks
+        )
+        load_decoder_wkv_kernel(context_size)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        state: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+        """Compute time mixing.
+
+        Args:
+            x: SelfAttention input sequences. (B, U, size)
+            state: Decoder hidden states. [5 x (B, 1, D_att, N)]
+
+        Returns:
+            x: SelfAttention output sequences. (B, U, size)
+
+        """
+        shifted_x = (
+            self.time_shift(x) if state is None else state[1][..., self.block_id]
+        )
+
+        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+        value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
+        receptance = x * self.time_mix_receptance + shifted_x * (
+            1 - self.time_mix_receptance
+        )
+
+        key = self.proj_key(key)
+        value = self.proj_value(value)
+        receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+        if state is not None:
+            state[1][..., self.block_id] = x
+
+            wkv, att_state = self.wkv_linear_attention(
+                self.time_decay,
+                self.time_first,
+                key,
+                value,
+                tuple(s[..., self.block_id] for s in state[2:]),
+            )
+
+            state[2][..., self.block_id] = att_state[0]
+            state[3][..., self.block_id] = att_state[1]
+            state[4][..., self.block_id] = att_state[2]
+        else:
+            wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value)
+
+        wkv = self.dropout(wkv)
+        x = self.proj_output(receptance * wkv)
+
+        return x, state
+
+class EncoderSelfAttention(SelfAttention):
+    """SelfAttention module definition.
+
+    Args:
+        size: Input/Output size.
+        attention_size: Attention hidden size.
+        context_size: Context size for WKV kernel.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+
+    """
+
+    def __init__(
+        self,
+        size: int,
+        attention_size: int,
+        context_size: int,
+        block_id: int,
+        dropout_rate: float,
+        num_blocks: int,
+    ) -> None:
+        """Construct a SelfAttention object."""
+        super().__init__(
+            size,
+            attention_size,
+            block_id,
+            dropout_rate,
+            num_blocks
+        )
+        load_encoder_wkv_kernel(context_size)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        state: Optional[List[torch.Tensor]] = None,
+    ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+        """Compute time mixing.
+
+        Args:
+            x: SelfAttention input sequences. (B, U, size)
+            state: Decoder hidden states. [5 x (B, 1, D_att, N)]
+
+        Returns:
+            x: SelfAttention output sequences. (B, U, size)
+
+        """
+        shifted_x = (
+            self.time_shift(x) if state is None else state[1][..., self.block_id]
+        )
+
+        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+        value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value)
+        receptance = x * self.time_mix_receptance + shifted_x * (
+            1 - self.time_mix_receptance
+        )
+
+        key = self.proj_key(key)
+        value = self.proj_value(value)
+        receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+        if state is not None:
+            state[1][..., self.block_id] = x
+
+            wkv, att_state = self.wkv_linear_attention(
+                self.time_decay,
+                self.time_first,
+                key,
+                value,
+                tuple(s[..., self.block_id] for s in state[2:]),
+            )
+
+            state[2][..., self.block_id] = att_state[0]
+            state[3][..., self.block_id] = att_state[1]
+            state[4][..., self.block_id] = att_state[2]
+        else:
+            wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value)
+
+        wkv = self.dropout(wkv)
+        x = self.proj_output(receptance * wkv)
+
+        return x, state
+
diff --git a/funasr/modules/rwkv_feed_forward.py b/funasr/modules/rwkv_feed_forward.py
new file mode 100644
index 0000000..ddb4285
--- /dev/null
+++ b/funasr/modules/rwkv_feed_forward.py
@@ -0,0 +1,97 @@
+"""Feed-forward (channel mixing) module for RWKV block.
+
+Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py
+
+Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py.
+
+"""  # noqa
+
+from typing import List, Optional, Tuple
+
+import torch
+
+
+class FeedForward(torch.nn.Module):
+    """FeedForward module definition.
+
+    Args:
+        size: Input/Output size.
+        hidden_size: Hidden size.
+        block_id: Block index.
+        num_blocks: Number of blocks in the architecture.
+
+    """
+
+    def __init__(
+        self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int
+    ) -> None:
+        """Construct a FeedForward object."""
+        super().__init__()
+
+        self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1))
+
+        self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size))
+        self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size))
+
+        self.proj_key = torch.nn.Linear(size, hidden_size, bias=True)
+        self.proj_value = torch.nn.Linear(hidden_size, size, bias=True)
+        self.proj_receptance = torch.nn.Linear(size, size, bias=True)
+
+        self.block_id = block_id
+
+        self.reset_parameters(size, block_id, num_blocks)
+        self.dropout = torch.nn.Dropout(p=dropout_rate)
+
+    def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None:
+        """Reset module parameters.
+
+        Args:
+            size: Block size.
+            block_id: Block index.
+            num_blocks: Number of blocks in the architecture.
+
+        """
+        ratio_1_to_almost0 = 1.0 - (block_id / num_blocks)
+
+        time_weight = torch.ones(1, 1, size)
+
+        for i in range(size):
+            time_weight[0, 0, i] = i / size
+
+        with torch.no_grad():
+            self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0)
+            self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0)
+
+    def forward(
+        self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None
+    ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
+        """Compute channel mixing.
+
+        Args:
+            x: FeedForward input sequences. (B, U, size)
+            state: Decoder hidden state. [5 x (B, 1, size, N)]
+
+        Returns:
+            x: FeedForward output sequences. (B, U, size)
+            state: Decoder hidden state. [5 x (B, 1, size, N)]
+
+        """
+        shifted_x = (
+            self.time_shift(x) if state is None else state[0][..., self.block_id]
+        )
+
+        key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key)
+        receptance = x * self.time_mix_receptance + shifted_x * (
+            1 - self.time_mix_receptance
+        )
+
+        key = torch.square(torch.relu(self.proj_key(key)))
+        value = self.proj_value(self.dropout(key))
+        receptance = torch.sigmoid(self.proj_receptance(receptance))
+
+        if state is not None:
+            state[0][..., self.block_id] = x
+
+        x = receptance * value
+
+        return x, state
diff --git a/funasr/modules/rwkv_subsampling.py b/funasr/modules/rwkv_subsampling.py
new file mode 100644
index 0000000..4277093
--- /dev/null
+++ b/funasr/modules/rwkv_subsampling.py
@@ -0,0 +1,190 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Subsampling layer definition."""
+import numpy as np
+import torch
+import torch.nn.functional as F
+from funasr.modules.embedding import PositionalEncoding
+import logging
+from funasr.modules.streaming_utils.utils import sequence_mask
+from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len
+from typing import Optional, Tuple, Union
+import math
+
+class TooShortUttError(Exception):
+    """Raised when the utt is too short for subsampling.
+
+    Args:
+        message (str): Message for error catch
+        actual_size (int): the short size that cannot pass the subsampling
+        limit (int): the limit size for subsampling
+
+    """
+
+    def __init__(self, message, actual_size, limit):
+        """Construct a TooShortUttError for error handler."""
+        super().__init__(message)
+        self.actual_size = actual_size
+        self.limit = limit
+
+
+def check_short_utt(ins, size):
+    """Check if the utterance is too short for subsampling."""
+    if isinstance(ins, Conv2dSubsampling2) and size < 3:
+        return True, 3
+    if isinstance(ins, Conv2dSubsampling) and size < 7:
+        return True, 7
+    if isinstance(ins, Conv2dSubsampling6) and size < 11:
+        return True, 11
+    if isinstance(ins, Conv2dSubsampling8) and size < 15:
+        return True, 15
+    return False, -1
+
+
+class RWKVConvInput(torch.nn.Module):
+    """Streaming ConvInput module definition.
+    Args:
+        input_size: Input size.
+        conv_size: Convolution size.
+        subsampling_factor: Subsampling factor.
+        output_size: Block output dimension.
+    """
+
+    def __init__(
+        self,
+        input_size: int,
+        conv_size: Union[int, Tuple],
+        subsampling_factor: int = 4,
+        conv_kernel_size: int = 3,
+        output_size: Optional[int] = None,
+    ) -> None:
+        """Construct a ConvInput object."""
+        super().__init__()
+        if subsampling_factor == 1:
+            conv_size1, conv_size2, conv_size3 = conv_size
+
+            self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+            )
+
+            output_proj = conv_size3 * ((input_size // 2) // 2)
+
+            self.subsampling_factor = 1
+
+            self.stride_1 = 1
+
+            self.create_new_mask = self.create_new_vgg_mask
+
+        else:
+            conv_size1, conv_size2, conv_size3 = conv_size
+
+            kernel_1 = int(subsampling_factor / 2)
+
+            self.conv = torch.nn.Sequential(
+                    torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+                    torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2),
+                    torch.nn.ReLU(),
+            )
+
+            output_proj = conv_size3 * ((input_size // 2) // 2)
+
+            self.subsampling_factor = subsampling_factor
+
+            self.create_new_mask = self.create_new_vgg_mask
+
+            self.stride_1 = kernel_1
+
+        self.min_frame_length = 7
+
+        if output_size is not None:
+            self.output = torch.nn.Linear(output_proj, output_size)
+            self.output_size = output_size
+        else:
+            self.output = None
+            self.output_size = output_proj
+
+    def forward(
+        self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor]
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Encode input sequences.
+        Args:
+            x: ConvInput input sequences. (B, T, D_feats)
+            mask: Mask of input sequences. (B, 1, T)
+        Returns:
+            x: ConvInput output sequences. (B, sub(T), D_out)
+            mask: Mask of output sequences. (B, 1, sub(T))
+        """
+        if mask is not None:
+            mask = self.create_new_mask(mask)
+            olens = max(mask.eq(0).sum(1))
+
+        b, t, f = x.size()
+        x = x.unsqueeze(1) # (b. 1. t. f)
+
+        if chunk_size is not None:
+            max_input_length = int(
+                chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) ))
+            )
+            x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x)
+            x = list(x)
+            x = torch.stack(x, dim=0)
+            N_chunks = max_input_length // ( chunk_size * self.subsampling_factor)
+            x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f)
+
+        x = self.conv(x)
+
+        _, c, _, f = x.size()
+        if chunk_size is not None:
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:]
+        else:
+            x = x.transpose(1, 2).contiguous().view(b, -1, c * f)
+
+        if self.output is not None:
+            x = self.output(x)
+
+        return x, mask[:,:olens][:,:x.size(1)]
+
+    def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor:
+        """Create a new mask for VGG output sequences.
+        Args:
+            mask: Mask of input sequences. (B, T)
+        Returns:
+            mask: Mask of output sequences. (B, sub(T))
+        """
+        if self.subsampling_factor > 1:
+            return mask[:, ::2][:, ::self.stride_1]
+        else:
+            return mask 
+
+    def get_size_before_subsampling(self, size: int) -> int:
+        """Return the original size before subsampling for a given size.
+        Args:
+            size: Number of frames after subsampling.
+        Returns:
+            : Number of frames before subsampling.
+        """
+        return size * self.subsampling_factor
diff --git a/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java b/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java
index be14bd3..f45877c 100644
--- a/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java
+++ b/funasr/runtime/android/AndroidClient/app/src/main/java/com/yeyupiaoling/androidclient/MainActivity.java
@@ -39,6 +39,8 @@
     public static final String TAG = MainActivity.class.getSimpleName();
     // WebSocket鍦板潃
     public String ASR_HOST = "";
+    // 瀹樻柟WebSocket鍦板潃
+    public static final String DEFAULT_HOST = "wss://101.37.77.25:10088";
     // 鍙戦�佺殑JSON鏁版嵁
     public static final String MODE = "2pass";
     public static final String CHUNK_SIZE = "5, 10, 5";
@@ -61,7 +63,6 @@
     // 鎺т欢
     private Button recordBtn;
     private TextView resultText;
-    private WebSocket webSocket;
 
     @SuppressLint("ClickableViewAccessibility")
     @Override
@@ -106,8 +107,8 @@
             ASR_HOST = uri;
         }
         // 璇诲彇鐑瘝
-        String hotWords = sharedPreferences.getString("hotwords", "");
-        if (!hotWords.equals("")) {
+        String hotWords = sharedPreferences.getString("hotwords", null);
+        if (hotWords != null) {
             this.hotWords = hotWords;
         }
     }
@@ -150,6 +151,14 @@
                 editor.apply();
             }
         });
+        builder.setNeutralButton("浣跨敤瀹樻柟鏈嶅姟", (dialog, id) -> {
+            ASR_HOST = DEFAULT_HOST;
+            input.setText(DEFAULT_HOST);
+            Toast.makeText(MainActivity.this, "WebSocket鍦板潃锛�" + ASR_HOST, Toast.LENGTH_SHORT).show();
+            SharedPreferences.Editor editor = sharedPreferences.edit();
+            editor.putString("uri", ASR_HOST);
+            editor.apply();
+        });
         AlertDialog dialog = builder.create();
         dialog.show();
     }
@@ -166,12 +175,10 @@
         builder.setView(view);
         builder.setPositiveButton("纭畾", (dialog, id) -> {
             String hotwords = input.getText().toString();
-            if (!hotwords.equals("")) {
-                this.hotWords = hotwords;
-                SharedPreferences.Editor editor = sharedPreferences.edit();
-                editor.putString("hotwords", hotwords);
-                editor.apply();
-            }
+            this.hotWords = hotwords;
+            SharedPreferences.Editor editor = sharedPreferences.edit();
+            editor.putString("hotwords", hotwords);
+            editor.apply();
         });
         AlertDialog dialog = builder.create();
         dialog.show();
@@ -225,7 +232,7 @@
         Request request = new Request.Builder()
                 .url(ASR_HOST)
                 .build();
-        webSocket = client.newWebSocket(request, new WebSocketListener() {
+        WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() {
 
             @Override
             public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) {
@@ -311,7 +318,9 @@
             obj.put("chunk_size", array);
             obj.put("chunk_interval", CHUNK_INTERVAL);
             obj.put("wav_name", "default");
-            obj.put("hotwords", hotWords);
+            if (!hotWords.equals("")) {
+                obj.put("hotwords", hotWords);
+            }
             obj.put("wav_format", "pcm");
             obj.put("is_speaking", isSpeaking);
             return obj.toString();
diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline.md b/funasr/runtime/docs/SDK_advanced_guide_offline.md
index 0348308..43a69cd 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_offline.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_offline.md
@@ -83,7 +83,8 @@
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotword ../../hotwords.txt > log.out 2>&1 &
  ```
 
 Introduction to run_server.sh parameters: 
@@ -102,6 +103,7 @@
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl锛宻et 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
+--hotword   Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
 ```
 
 ### Shutting Down the FunASR Service
diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline_en.md b/funasr/runtime/docs/SDK_advanced_guide_offline_en.md
index c2599ec..b829e67 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_offline_en.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_offline_en.md
@@ -79,7 +79,7 @@
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key
+  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
  ```
 
 Introduction to run_server.sh parameters: 
diff --git a/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md b/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md
index c631097..ee65501 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -165,7 +165,8 @@
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotword ../../hotwords.txt > log.out 2>&1 &
  ```
 **run_server.sh鍛戒护鍙傛暟浠嬬粛**
 ```text
@@ -182,6 +183,7 @@
 --io-thread-num  鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
 --certfile  ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt锛屽鏋滈渶瑕佸叧闂璼sl锛屽弬鏁拌缃负0
 --keyfile   ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
+--hotword   鐑瘝鏂囦欢璺緞锛屾瘡涓�涓儹璇嶄竴琛岋紝濡傛灉瀹㈡埛绔彁渚涚儹璇嶏紝鍒欎笌瀹㈡埛绔彁渚涚殑鐑瘝鍚堝苟涓�璧蜂娇鐢ㄣ�傞粯璁や负锛�../../hotwords.txt
 ```
 
 ### 鍏抽棴FunASR鏈嶅姟
diff --git a/funasr/runtime/docs/SDK_advanced_guide_online.md b/funasr/runtime/docs/SDK_advanced_guide_online.md
index 17fb891..ddc02cf 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_online.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_online.md
@@ -72,7 +72,8 @@
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotword ../../hotwords.txt > log.out 2>&1 &
 
 # If you want to close ssl锛宲lease add锛�--certfile 0
 # If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model:
@@ -97,6 +98,7 @@
 --io-thread-num: Number of IO threads that the server starts. Default is 1.
 --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl锛宻et 0
 --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. 
+--hotword   Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt
 ```
 
 ### Shutting Down the FunASR Service
diff --git a/funasr/runtime/docs/SDK_advanced_guide_online_zh.md b/funasr/runtime/docs/SDK_advanced_guide_online_zh.md
index 232701e..902ae7a 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_online_zh.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_online_zh.md
@@ -31,7 +31,8 @@
   --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
   --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx  \
   --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \
-  --itn-dir thuduj12/fst_itn_zh > log.out 2>&1 &
+  --itn-dir thuduj12/fst_itn_zh \
+  --hotwordsfile ../../hotwords.txt > log.out 2>&1 &
 
 # 濡傛灉鎮ㄦ兂鍏抽棴ssl锛屽鍔犲弬鏁帮細--certfile 0
 # 濡傛灉鎮ㄦ兂浣跨敤鏃堕棿鎴虫垨鑰呯儹璇嶆ā鍨嬭繘琛岄儴缃诧紝璇疯缃�--model-dir涓哄搴旀ā鍨嬶細
@@ -80,7 +81,8 @@
   --io-thread-num  8 \
   --port 10095 \
   --certfile  ../../../ssl_key/server.crt \
-  --keyfile ../../../ssl_key/server.key > log.out 2>&1 &
+  --keyfile ../../../ssl_key/server.key \
+  --hotword ../../hotwords.txt > log.out 2>&1 &
  ```
 **run_server_2pass.sh鍛戒护鍙傛暟浠嬬粛**
 ```text
@@ -98,6 +100,7 @@
 --io-thread-num  鏈嶅姟绔惎鍔ㄧ殑IO绾跨▼鏁帮紝榛樿涓� 1
 --certfile  ssl鐨勮瘉涔︽枃浠讹紝榛樿涓猴細../../../ssl_key/server.crt锛屽鏋滈渶瑕佸叧闂璼sl锛屽弬鏁拌缃负0
 --keyfile   ssl鐨勫瘑閽ユ枃浠讹紝榛樿涓猴細../../../ssl_key/server.key
+--hotword   鐑瘝鏂囦欢璺緞锛屾瘡涓�涓儹璇嶄竴琛岋紝濡傛灉瀹㈡埛绔彁渚涚儹璇嶏紝鍒欎笌瀹㈡埛绔彁渚涚殑鐑瘝鍚堝苟涓�璧蜂娇鐢ㄣ�傞粯璁や负锛�../../hotwords.txt
 ```
 
 ### 鍏抽棴FunASR鏈嶅姟
diff --git a/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py b/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
index 3a01812..7d0060c 100644
--- a/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
+++ b/funasr/runtime/python/libtorch/funasr_torch/utils/timestamp_utils.py
@@ -3,7 +3,7 @@
 
 def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5):
     if not len(char_list):
-        return []
+        return '', []
     START_END_THRESHOLD = 5
     MAX_TOKEN_DURATION = 30
     TIME_RATE = 10.0 * 6 / 1000 / 3  #  3 times upsampled
diff --git a/funasr/runtime/run_server.sh b/funasr/runtime/run_server.sh
index 6869fd9..f75f159 100644
--- a/funasr/runtime/run_server.sh
+++ b/funasr/runtime/run_server.sh
@@ -9,6 +9,7 @@
 port=10095
 certfile="../../../ssl_key/server.crt"
 keyfile="../../../ssl_key/server.key"
+hotwordsfile="../../hotwords.txt"
 
 . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
 
@@ -24,7 +25,8 @@
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  "" \
-  --keyfile ""
+  --keyfile "" \
+  --hotwordsfile ${hotwordsfile}
 else
 ./funasr-wss-server  \
   --download-model-dir ${download_model_dir} \
@@ -36,5 +38,6 @@
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  ${certfile} \
-  --keyfile ${keyfile}
+  --keyfile ${keyfile} \
+  --hotwordsfile ${hotwordsfile}
 fi
diff --git a/funasr/runtime/run_server_2pass.sh b/funasr/runtime/run_server_2pass.sh
index 63c2041..941064c 100644
--- a/funasr/runtime/run_server_2pass.sh
+++ b/funasr/runtime/run_server_2pass.sh
@@ -10,6 +10,7 @@
 port=10095
 certfile="../../../ssl_key/server.crt"
 keyfile="../../../ssl_key/server.key"
+hotwordsfile="../../hotwords.txt"
 
 . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1;
 
@@ -26,7 +27,8 @@
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  "" \
-  --keyfile ""
+  --keyfile "" \
+  --hotwordsfile ${hotwordsfile}
 else
 ./funasr-wss-server-2pass  \
   --download-model-dir ${download_model_dir} \
@@ -39,5 +41,6 @@
   --io-thread-num  ${io_thread_num} \
   --port ${port} \
   --certfile  ${certfile} \
-  --keyfile ${keyfile}
+  --keyfile ${keyfile} \
+  --hotwordsfile ${hotwordsfile}
 fi
diff --git a/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp b/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp
index 1f8b632..1c87957 100644
--- a/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp
+++ b/funasr/runtime/websocket/bin/funasr-wss-server-2pass.cpp
@@ -14,6 +14,9 @@
 #include <unistd.h>
 #include "websocket-server-2pass.h"
 
+#include <fstream>
+std::string hotwords = "";
+
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
               std::map<std::string, std::string>& model_path) {
@@ -108,6 +111,15 @@
         "default: ../../../ssl_key/server.key, path of keyfile for WSS "
         "connection",
         false, "../../../ssl_key/server.key", "string");
+
+    TCLAP::ValueArg<std::string> hotwordsfile(
+        "", "hotword",
+        "default: ../../hotwords.txt, path of hotwordsfile"
+        "connection",
+        false, "../../hotwords.txt", "string");
+
+    // add file
+    cmd.add(hotwordsfile);
 
     cmd.add(certfile);
     cmd.add(keyfile);
@@ -417,6 +429,21 @@
     std::string s_certfile = certfile.getValue();
     std::string s_keyfile = keyfile.getValue();
 
+    std::string s_hotwordsfile = hotwordsfile.getValue();
+    std::string line;
+    std::ifstream file(s_hotwordsfile);
+    LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
+
+    if (file.is_open()) {
+        while (getline(file, line)) {
+            hotwords += line+HOTWORD_SEP;
+        }
+        LOG(INFO) << "hotwords: " << hotwords;
+        file.close();
+    } else {
+        LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
+    }
+
     bool is_ssl = false;
     if (!s_certfile.empty()) {
       is_ssl = true;
@@ -460,8 +487,7 @@
       websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
     }
 
-    std::cout << "asr model init finished. listen on port:" << s_port
-              << std::endl;
+    LOG(INFO) << "asr model init finished. listen on port:" << s_port;
 
     // Start the ASIO network io_service run loop
     std::vector<std::thread> ts;
@@ -480,7 +506,7 @@
     }
 
   } catch (std::exception const& e) {
-    std::cerr << "Error: " << e.what() << std::endl;
+    LOG(ERROR) << "Error: " << e.what();
   }
 
   return 0;
diff --git a/funasr/runtime/websocket/bin/funasr-wss-server.cpp b/funasr/runtime/websocket/bin/funasr-wss-server.cpp
index 55ce07b..b571dbe 100644
--- a/funasr/runtime/websocket/bin/funasr-wss-server.cpp
+++ b/funasr/runtime/websocket/bin/funasr-wss-server.cpp
@@ -13,6 +13,9 @@
 #include "websocket-server.h"
 #include <unistd.h>
 
+#include <fstream>
+std::string hotwords = "";
+
 using namespace std;
 void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
               std::map<std::string, std::string>& model_path) {
@@ -94,6 +97,15 @@
     TCLAP::ValueArg<std::string> keyfile("", "keyfile", 
         "default: ../../../ssl_key/server.key, path of keyfile for WSS connection", 
         false, "../../../ssl_key/server.key", "string");
+
+    TCLAP::ValueArg<std::string> hotwordsfile(
+        "", "hotword",
+        "default: ../../hotwords.txt, path of hotwordsfile"
+        "connection",
+        false, "../../hotwords.txt", "string");
+
+    // add file
+    cmd.add(hotwordsfile);
 
     cmd.add(certfile);
     cmd.add(keyfile);
@@ -331,6 +343,21 @@
     std::string s_certfile = certfile.getValue();
     std::string s_keyfile = keyfile.getValue();
 
+    std::string s_hotwordsfile = hotwordsfile.getValue();
+    std::string line;
+    std::ifstream file(s_hotwordsfile);
+    LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile;
+
+    if (file.is_open()) {
+        while (getline(file, line)) {
+            hotwords += line+HOTWORD_SEP;
+        }
+        LOG(INFO) << "hotwords: " << hotwords;
+        file.close();
+    } else {
+        LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile;
+    }
+
     bool is_ssl = false;
     if (!s_certfile.empty()) {
       is_ssl = true;
@@ -374,8 +401,7 @@
       websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
     }
 
-    std::cout << "asr model init finished. listen on port:" << s_port
-              << std::endl;
+    LOG(INFO) << "asr model init finished. listen on port:" << s_port;
 
     // Start the ASIO network io_service run loop
     std::vector<std::thread> ts;
@@ -394,7 +420,7 @@
     }
 
   } catch (std::exception const& e) {
-    std::cerr << "Error: " << e.what() << std::endl;
+    LOG(ERROR) << "Error: " << e.what();
   }
 
   return 0;
diff --git a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
index 107be40..9e0668f 100644
--- a/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
+++ b/funasr/runtime/websocket/bin/websocket-server-2pass.cpp
@@ -15,7 +15,9 @@
 #include <thread>
 #include <utility>
 #include <vector>
-#include <chrono>
+
+extern std::string hotwords;
+
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
                                          std::string& s_certfile,
@@ -354,7 +356,14 @@
   unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
   switch (msg->get_opcode()) {
     case websocketpp::frame::opcode::text: {
-      nlohmann::json jsonresult = nlohmann::json::parse(payload);
+      nlohmann::json jsonresult;
+      try{
+        jsonresult = nlohmann::json::parse(payload);
+      }catch (std::exception const &e)
+      {
+        LOG(ERROR)<<e.what();
+        break;
+      }
 
       if (jsonresult.contains("wav_name")) {
         msg_data->msg["wav_name"] = jsonresult["wav_name"];
@@ -370,17 +379,26 @@
           msg_data->msg["hotwords"] = jsonresult["hotwords"];
           if (!msg_data->msg["hotwords"].empty()) {
             std::string hw = msg_data->msg["hotwords"];
-            LOG(INFO)<<"hotwords: " << hw;
-            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+            hw = hw + " " + hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
           }
-        }else{
+        } else {
+          if (hotwords.empty()) {
             std::string hw = "";
             LOG(INFO)<<"hotwords: " << hw;
             std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }else {
+            std::string hw = hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS);
+            msg_data->hotwords_embedding =
+                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }
         }
       }
       if (jsonresult.contains("audio_fs")) {
diff --git a/funasr/runtime/websocket/bin/websocket-server.cpp b/funasr/runtime/websocket/bin/websocket-server.cpp
index da1ffa5..134f5fa 100644
--- a/funasr/runtime/websocket/bin/websocket-server.cpp
+++ b/funasr/runtime/websocket/bin/websocket-server.cpp
@@ -16,6 +16,8 @@
 #include <utility>
 #include <vector>
 
+extern std::string hotwords;
+
 context_ptr WebSocketServer::on_tls_init(tls_mode mode,
                                          websocketpp::connection_hdl hdl,
                                          std::string& s_certfile,
@@ -254,7 +256,15 @@
   unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
   switch (msg->get_opcode()) {
     case websocketpp::frame::opcode::text: {
-      nlohmann::json jsonresult = nlohmann::json::parse(payload);
+      nlohmann::json jsonresult;
+      try{
+        jsonresult = nlohmann::json::parse(payload);
+      }catch (std::exception const &e)
+      {
+        LOG(ERROR)<<e.what();
+        break;
+      }
+
       if (jsonresult["wav_name"] != nullptr) {
         msg_data->msg["wav_name"] = jsonresult["wav_name"];
       }
@@ -266,17 +276,26 @@
           msg_data->msg["hotwords"] = jsonresult["hotwords"];
           if (!msg_data->msg["hotwords"].empty()) {
             std::string hw = msg_data->msg["hotwords"];
-            LOG(INFO)<<"hotwords: " << hw;
-            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            hw = hw + " " + hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
           }
-        }else{
+        } else {
+          if (hotwords.empty()) {
             std::string hw = "";
             LOG(INFO)<<"hotwords: " << hw;
             std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
             msg_data->hotwords_embedding =
                 std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }else {
+            std::string hw = hotwords;
+            LOG(INFO) << "hotwords: " << hw;
+            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
+            msg_data->hotwords_embedding =
+                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
+          }
         }
       }
       if (jsonresult.contains("audio_fs")) {
diff --git a/funasr/runtime/websocket/hotwords.txt b/funasr/runtime/websocket/hotwords.txt
new file mode 100644
index 0000000..6179cbc
--- /dev/null
+++ b/funasr/runtime/websocket/hotwords.txt
@@ -0,0 +1,2 @@
+闃块噷宸村反
+閫氫箟瀹為獙瀹�
diff --git a/funasr/utils/whisper_utils/__init__.py b/funasr/utils/whisper_utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/utils/whisper_utils/__init__.py

--
Gitblit v1.9.1