Merge branch 'main' of github.com:alibaba-damo-academy/FunASR
add
| | |
| | | from funasr.models.encoder.branchformer_encoder import BranchformerEncoder |
| | | from funasr.models.encoder.e_branchformer_encoder import EBranchformerEncoder |
| | | from funasr.models.encoder.transformer_encoder import TransformerEncoder |
| | | from funasr.models.encoder.rwkv_encoder import RWKVEncoder |
| | | from funasr.models.frontend.default import DefaultFrontend |
| | | from funasr.models.frontend.default import MultiChannelFrontend |
| | | from funasr.models.frontend.fused import FusedFrontends |
| | |
| | | e_branchformer=EBranchformerEncoder, |
| | | mfcca_enc=MFCCAEncoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | rwkv=RWKVEncoder, |
| | | ), |
| | | default="rnn", |
| | | ) |
| New file |
| | |
| | | """RWKV encoder definition for Transducer models.""" |
| | | |
| | | import math |
| | | from typing import Dict, List, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.modules.rwkv import RWKV |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.rwkv_subsampling import RWKVConvInput |
| | | from funasr.modules.nets_utils import make_source_mask |
| | | |
| | | class RWKVEncoder(AbsEncoder): |
| | | """RWKV encoder module. |
| | | |
| | | Based on https://arxiv.org/pdf/2305.13048.pdf. |
| | | |
| | | Args: |
| | | vocab_size: Vocabulary size. |
| | | output_size: Input/Output size. |
| | | context_size: Context size for WKV computation. |
| | | linear_size: FeedForward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | normalization_type: Normalization layer type. |
| | | normalization_args: Normalization layer arguments. |
| | | num_blocks: Number of RWKV blocks. |
| | | embed_dropout_rate: Dropout rate for embedding layer. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 512, |
| | | context_size: int = 1024, |
| | | linear_size: Optional[int] = None, |
| | | attention_size: Optional[int] = None, |
| | | num_blocks: int = 4, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | subsampling_factor: int =4, |
| | | time_reduction_factor: int = 1, |
| | | kernel: int = 3, |
| | | ) -> None: |
| | | """Construct a RWKVEncoder object.""" |
| | | super().__init__() |
| | | |
| | | self.embed = RWKVConvInput( |
| | | input_size, |
| | | [output_size//4, output_size//2, output_size], |
| | | subsampling_factor, |
| | | conv_kernel_size=kernel, |
| | | output_size=output_size, |
| | | ) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | linear_size = output_size * 4 if linear_size is None else linear_size |
| | | attention_size = output_size if attention_size is None else attention_size |
| | | |
| | | self.rwkv_blocks = torch.nn.ModuleList( |
| | | [ |
| | | RWKV( |
| | | output_size, |
| | | linear_size, |
| | | attention_size, |
| | | context_size, |
| | | block_id, |
| | | num_blocks, |
| | | att_dropout_rate=att_dropout_rate, |
| | | ffn_dropout_rate=ffn_dropout_rate, |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | for block_id in range(num_blocks) |
| | | ] |
| | | ) |
| | | |
| | | self.embed_norm = LayerNorm(output_size) |
| | | self.final_norm = LayerNorm(output_size) |
| | | |
| | | self._output_size = output_size |
| | | self.context_size = context_size |
| | | |
| | | self.num_blocks = num_blocks |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward(self, x: torch.Tensor, x_len) -> torch.Tensor: |
| | | """Encode source label sequences. |
| | | |
| | | Args: |
| | | x: Encoder input sequences. (B, L) |
| | | |
| | | Returns: |
| | | out: Encoder output sequences. (B, U, D) |
| | | |
| | | """ |
| | | _, length, _ = x.size() |
| | | |
| | | assert ( |
| | | length <= self.context_size * self.subsampling_factor |
| | | ), "Context size is too short for current length: %d versus %d" % ( |
| | | length, |
| | | self.context_size * self.subsampling_factor, |
| | | ) |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | x, mask = self.embed(x, mask, None) |
| | | x = self.embed_norm(x) |
| | | olens = mask.eq(0).sum(1) |
| | | |
| | | for block in self.rwkv_blocks: |
| | | x, _ = block(x) |
| | | # for streaming inference |
| | | # xs_pad = self.rwkv_infer(xs_pad) |
| | | |
| | | x = self.final_norm(x) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens, None |
| | | |
| | | def rwkv_infer(self, xs_pad): |
| | | |
| | | batch_size = xs_pad.shape[0] |
| | | |
| | | hidden_sizes = [ |
| | | self._output_size for i in range(5) |
| | | ] |
| | | |
| | | state = [ |
| | | torch.zeros( |
| | | (batch_size, 1, hidden_sizes[i], self.num_rwkv_blocks), |
| | | dtype=torch.float32, |
| | | device=self.device, |
| | | ) |
| | | for i in range(5) |
| | | ] |
| | | |
| | | state[4] -= 1e-30 |
| | | |
| | | xs_out = [] |
| | | for t in range(xs_pad.shape[1]): |
| | | x_t = xs_pad[:,t,:] |
| | | for idx, block in enumerate(self.rwkv_blocks): |
| | | x_t, state = block(x_t, state=state) |
| | | xs_out.append(x_t) |
| | | xs_out = torch.stack(xs_out, dim=1) |
| | | return xs_out |
| New file |
| | |
| | | """Receptance Weighted Key Value (RWKV) block definition. |
| | | |
| | | Based/modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py |
| | | |
| | | """ |
| | | |
| | | from typing import Dict, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | from funasr.modules.rwkv_attention import EncoderSelfAttention, DecoderSelfAttention |
| | | from funasr.modules.rwkv_feed_forward import FeedForward |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | |
| | | class RWKV(torch.nn.Module): |
| | | """RWKV module. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | linear_size: Feed-forward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | context_size: Context size for WKV computation. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | normalization_class: Normalization layer class. |
| | | normalization_args: Normalization layer arguments. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | linear_size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | num_blocks: int, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a RWKV object.""" |
| | | super().__init__() |
| | | |
| | | self.layer_norm_att = LayerNorm(size) |
| | | self.layer_norm_ffn = LayerNorm(size) |
| | | |
| | | self.att = EncoderSelfAttention( |
| | | size, attention_size, context_size, block_id, att_dropout_rate, num_blocks |
| | | ) |
| | | self.dropout_att = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) |
| | | self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Compute receptance weighted key value. |
| | | |
| | | Args: |
| | | x: RWKV input sequences. (B, L, size) |
| | | state: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | Returns: |
| | | x: RWKV output sequences. (B, L, size) |
| | | x: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | """ |
| | | att, state = self.att(self.layer_norm_att(x), state=state) |
| | | x = x + self.dropout_att(att) |
| | | ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) |
| | | x = x + self.dropout_ffn(ffn) |
| | | return x, state |
| | | |
| | | class RWKVDecoderLayer(torch.nn.Module): |
| | | """RWKV module. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | linear_size: Feed-forward hidden size. |
| | | attention_size: SelfAttention hidden size. |
| | | context_size: Context size for WKV computation. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | normalization_class: Normalization layer class. |
| | | normalization_args: Normalization layer arguments. |
| | | att_dropout_rate: Dropout rate for the attention module. |
| | | ffn_dropout_rate: Dropout rate for the feed-forward module. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | linear_size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | num_blocks: int, |
| | | att_dropout_rate: float = 0.0, |
| | | ffn_dropout_rate: float = 0.0, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a RWKV object.""" |
| | | super().__init__() |
| | | |
| | | self.layer_norm_att = LayerNorm(size) |
| | | self.layer_norm_ffn = LayerNorm(size) |
| | | |
| | | self.att = DecoderSelfAttention( |
| | | size, attention_size, context_size, block_id, att_dropout_rate, num_blocks |
| | | ) |
| | | self.dropout_att = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.ffn = FeedForward(size, linear_size, block_id, ffn_dropout_rate, num_blocks) |
| | | self.dropout_ffn = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: |
| | | """Compute receptance weighted key value. |
| | | |
| | | Args: |
| | | x: RWKV input sequences. (B, L, size) |
| | | state: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | Returns: |
| | | x: RWKV output sequences. (B, L, size) |
| | | x: Decoder hidden states. [5 x (B, D_att/size, N)] |
| | | |
| | | """ |
| | | att, state = self.att(self.layer_norm_att(x), state=state) |
| | | x = x + self.dropout_att(att) |
| | | |
| | | ffn, state = self.ffn(self.layer_norm_ffn(x), state=state) |
| | | x = x + self.dropout_ffn(ffn) |
| | | |
| | | return x, state |
| New file |
| | |
| | | """Attention (time mixing) modules for RWKV block. |
| | | |
| | | Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py. |
| | | |
| | | Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. |
| | | |
| | | """ # noqa |
| | | |
| | | import math |
| | | from importlib.util import find_spec |
| | | from pathlib import Path |
| | | from typing import List, Optional, Tuple, Union |
| | | |
| | | import torch |
| | | |
| | | wkv_kernel_encoder = None |
| | | wkv_kernel_decoder = None |
| | | |
| | | class WKVLinearAttentionEncoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | | @staticmethod |
| | | def forward( |
| | | ctx, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.tensor, |
| | | ) -> torch.Tensor: |
| | | """WKVLinearAttention function forward pass. |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, U, D_att) |
| | | value: Value tensor. (B, U, D_att) |
| | | |
| | | Returns: |
| | | out: Weighted Key-Value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | batch, length, dim = key.size() |
| | | |
| | | assert length <= wkv_kernel_encoder.context_size, ( |
| | | f"Cannot process key of length {length} while context_size " |
| | | f"is ({wkv_kernel_encoder.context_size}). Limit should be increased." |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | | |
| | | time_decay = -torch.exp(time_decay.float().contiguous()) |
| | | time_first = time_first.float().contiguous() |
| | | |
| | | key = key.float().contiguous() |
| | | value = value.float().contiguous() |
| | | |
| | | out = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_encoder.forward(time_decay, time_first, key, value, out) |
| | | ctx.save_for_backward(time_decay, time_first, key, value, out) |
| | | |
| | | return out |
| | | |
| | | @staticmethod |
| | | def backward( |
| | | ctx, grad_output: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """WKVLinearAttention function backward pass. |
| | | |
| | | Args: |
| | | grad_output: Output gradient. (B, U, D_att) |
| | | |
| | | Returns: |
| | | grad_time_decay: Gradient for channel-wise time decay vector. (D_att) |
| | | grad_time_first: Gradient for channel-wise time first vector. (D_att) |
| | | grad_key: Gradient for key tensor. (B, U, D_att) |
| | | grad_value: Gradient for value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | time_decay, time_first, key, value, output = ctx.saved_tensors |
| | | grad_dtype = ctx.input_dtype |
| | | |
| | | batch, _, dim = key.size() |
| | | |
| | | grad_time_decay = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_time_first = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_encoder.backward( |
| | | time_decay, |
| | | time_first, |
| | | key, |
| | | value, |
| | | output, |
| | | grad_output.contiguous(), |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | grad_time_decay = torch.sum(grad_time_decay, dim=0) |
| | | grad_time_first = torch.sum(grad_time_first, dim=0) |
| | | |
| | | return ( |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | class WKVLinearAttentionDecoder(torch.autograd.Function): |
| | | """WKVLinearAttention function definition.""" |
| | | |
| | | @staticmethod |
| | | def forward( |
| | | ctx, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.tensor, |
| | | ) -> torch.Tensor: |
| | | """WKVLinearAttention function forward pass. |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, U, D_att) |
| | | value: Value tensor. (B, U, D_att) |
| | | |
| | | Returns: |
| | | out: Weighted Key-Value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | batch, length, dim = key.size() |
| | | |
| | | assert length <= wkv_kernel_decoder.context_size, ( |
| | | f"Cannot process key of length {length} while context_size " |
| | | f"is ({wkv_kernel.context_size}). Limit should be increased." |
| | | ) |
| | | |
| | | assert batch * dim % min(dim, 32) == 0, ( |
| | | f"batch size ({batch}) by dimension ({dim}) should be a multiple of " |
| | | f"{min(dim, 32)}" |
| | | ) |
| | | |
| | | ctx.input_dtype = key.dtype |
| | | |
| | | time_decay = -torch.exp(time_decay.float().contiguous()) |
| | | time_first = time_first.float().contiguous() |
| | | |
| | | key = key.float().contiguous() |
| | | value = value.float().contiguous() |
| | | |
| | | out = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_decoder.forward(time_decay, time_first, key, value, out) |
| | | ctx.save_for_backward(time_decay, time_first, key, value, out) |
| | | |
| | | return out |
| | | |
| | | @staticmethod |
| | | def backward( |
| | | ctx, grad_output: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """WKVLinearAttention function backward pass. |
| | | |
| | | Args: |
| | | grad_output: Output gradient. (B, U, D_att) |
| | | |
| | | Returns: |
| | | grad_time_decay: Gradient for channel-wise time decay vector. (D_att) |
| | | grad_time_first: Gradient for channel-wise time first vector. (D_att) |
| | | grad_key: Gradient for key tensor. (B, U, D_att) |
| | | grad_value: Gradient for value tensor. (B, U, D_att) |
| | | |
| | | """ |
| | | time_decay, time_first, key, value, output = ctx.saved_tensors |
| | | grad_dtype = ctx.input_dtype |
| | | |
| | | batch, _, dim = key.size() |
| | | |
| | | grad_time_decay = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_time_first = torch.empty( |
| | | (batch, dim), |
| | | memory_format=torch.contiguous_format, |
| | | dtype=time_decay.dtype, |
| | | device=time_decay.device, |
| | | ) |
| | | |
| | | grad_key = torch.empty_like(key, memory_format=torch.contiguous_format) |
| | | grad_value = torch.empty_like(value, memory_format=torch.contiguous_format) |
| | | |
| | | wkv_kernel_decoder.backward( |
| | | time_decay, |
| | | time_first, |
| | | key, |
| | | value, |
| | | output, |
| | | grad_output.contiguous(), |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | grad_time_decay = torch.sum(grad_time_decay, dim=0) |
| | | grad_time_first = torch.sum(grad_time_first, dim=0) |
| | | |
| | | return ( |
| | | grad_time_decay, |
| | | grad_time_first, |
| | | grad_key, |
| | | grad_value, |
| | | ) |
| | | |
| | | def load_encoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | | Args: |
| | | context_size: Context size. |
| | | |
| | | """ |
| | | from torch.utils.cpp_extension import load |
| | | |
| | | global wkv_kernel_encoder |
| | | |
| | | if wkv_kernel_encoder is not None and wkv_kernel_encoder.context_size == context_size: |
| | | return |
| | | |
| | | if find_spec("ninja") is None: |
| | | raise ImportError( |
| | | "Ninja package was not found. WKV kernel module can't be loaded " |
| | | "for training. Please, 'pip install ninja' in your environment." |
| | | ) |
| | | |
| | | if not torch.cuda.is_available(): |
| | | raise ImportError( |
| | | "CUDA is currently a requirement for WKV kernel loading. " |
| | | "Please set your devices properly and launch again." |
| | | ) |
| | | |
| | | kernel_folder = Path(__file__).resolve().parent / "cuda_encoder" |
| | | kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] |
| | | |
| | | kernel_cflags = [ |
| | | "-res-usage", |
| | | "--maxrregcount 60", |
| | | "--use_fast_math", |
| | | "-O3", |
| | | "-Xptxas -O3", |
| | | f"-DTmax={context_size}", |
| | | ] |
| | | wkv_kernel_encoder = load( |
| | | name=f"encoder_wkv_{context_size}", |
| | | sources=kernel_files, |
| | | verbose=True, |
| | | extra_cuda_cflags=kernel_cflags, |
| | | ) |
| | | wkv_kernel_encoder.context_size = context_size |
| | | |
| | | def load_decoder_wkv_kernel(context_size: int) -> None: |
| | | """Load WKV CUDA kernel. |
| | | |
| | | Args: |
| | | context_size: Context size. |
| | | |
| | | """ |
| | | from torch.utils.cpp_extension import load |
| | | |
| | | global wkv_kernel_decoder |
| | | |
| | | if wkv_kernel_decoder is not None and wkv_kernel_decoder.context_size == context_size: |
| | | return |
| | | |
| | | if find_spec("ninja") is None: |
| | | raise ImportError( |
| | | "Ninja package was not found. WKV kernel module can't be loaded " |
| | | "for training. Please, 'pip install ninja' in your environment." |
| | | ) |
| | | |
| | | if not torch.cuda.is_available(): |
| | | raise ImportError( |
| | | "CUDA is currently a requirement for WKV kernel loading. " |
| | | "Please set your devices properly and launch again." |
| | | ) |
| | | |
| | | kernel_folder = Path(__file__).resolve().parent / "cuda_decoder" |
| | | kernel_files = [kernel_folder / f for f in ["wkv_op.cpp", "wkv_cuda.cu"]] |
| | | |
| | | kernel_cflags = [ |
| | | "-res-usage", |
| | | "--maxrregcount 60", |
| | | "--use_fast_math", |
| | | "-O3", |
| | | "-Xptxas -O3", |
| | | f"-DTmax={context_size}", |
| | | ] |
| | | wkv_kernel_decoder = load( |
| | | name=f"decoder_wkv_{context_size}", |
| | | sources=kernel_files, |
| | | verbose=True, |
| | | extra_cuda_cflags=kernel_cflags, |
| | | ) |
| | | wkv_kernel_decoder.context_size = context_size |
| | | |
| | | class SelfAttention(torch.nn.Module): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__() |
| | | self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | self.time_decay = torch.nn.Parameter(torch.empty(attention_size)) |
| | | self.time_first = torch.nn.Parameter(torch.empty(attention_size)) |
| | | |
| | | self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_value = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | |
| | | self.proj_key = torch.nn.Linear(size, attention_size, bias=True) |
| | | self.proj_value = torch.nn.Linear(size, attention_size, bias=True) |
| | | self.proj_receptance = torch.nn.Linear(size, attention_size, bias=True) |
| | | |
| | | self.proj_output = torch.nn.Linear(attention_size, size, bias=True) |
| | | |
| | | self.block_id = block_id |
| | | |
| | | self.reset_parameters(size, attention_size, block_id, num_blocks) |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def reset_parameters( |
| | | self, size: int, attention_size: int, block_id: int, num_blocks: int |
| | | ) -> None: |
| | | """Reset module parameters. |
| | | |
| | | Args: |
| | | size: Block size. |
| | | attention_size: Attention hidden size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | ratio_0_to_1 = block_id / (num_blocks - 1) |
| | | ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) |
| | | |
| | | time_weight = torch.ones(1, 1, size) |
| | | |
| | | for i in range(size): |
| | | time_weight[0, 0, i] = i / size |
| | | |
| | | decay_speed = [ |
| | | -5 + 8 * (h / (attention_size - 1)) ** (0.7 + 1.3 * ratio_0_to_1) |
| | | for h in range(attention_size) |
| | | ] |
| | | decay_speed = torch.tensor( |
| | | decay_speed, dtype=self.time_decay.dtype, device=self.time_decay.device |
| | | ) |
| | | |
| | | zigzag = ( |
| | | torch.tensor( |
| | | [(i + 1) % 3 - 1 for i in range(attention_size)], |
| | | dtype=self.time_first.dtype, |
| | | device=self.time_first.device, |
| | | ) |
| | | * 0.5 |
| | | ) |
| | | |
| | | with torch.no_grad(): |
| | | self.time_decay.data = decay_speed |
| | | self.time_first.data = torch.ones_like( |
| | | self.time_first * math.log(0.3) + zigzag |
| | | ) |
| | | |
| | | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | self.time_mix_value.data = ( |
| | | torch.pow(time_weight, ratio_1_to_almost0) + 0.3 * ratio_0_to_1 |
| | | ) |
| | | self.time_mix_receptance.data = torch.pow( |
| | | time_weight, 0.5 * ratio_1_to_almost0 |
| | | ) |
| | | |
| | | @torch.no_grad() |
| | | def wkv_linear_attention( |
| | | self, |
| | | time_decay: torch.Tensor, |
| | | time_first: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.Tensor, |
| | | state: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], |
| | | ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]: |
| | | """Compute WKV with state (i.e.: for inference). |
| | | |
| | | Args: |
| | | time_decay: Channel-wise time decay vector. (D_att) |
| | | time_first: Channel-wise time first vector. (D_att) |
| | | key: Key tensor. (B, 1, D_att) |
| | | value: Value tensor. (B, 1, D_att) |
| | | state: Decoder hidden states. [3 x (B, D_att)] |
| | | |
| | | Returns: |
| | | output: Weighted Key-Value. (B, 1, D_att) |
| | | state: Decoder hidden states. [3 x (B, 1, D_att)] |
| | | |
| | | """ |
| | | num_state, den_state, max_state = state |
| | | |
| | | max_for_output = torch.maximum(max_state, (time_first + key)) |
| | | |
| | | e1 = torch.exp(max_state - max_for_output) |
| | | e2 = torch.exp((time_first + key) - max_for_output) |
| | | |
| | | numerator = e1 * num_state + e2 * value |
| | | denominator = e1 * den_state + e2 |
| | | |
| | | max_for_state = torch.maximum(key, (max_state + time_decay)) |
| | | |
| | | e1 = torch.exp((max_state + time_decay) - max_for_state) |
| | | e2 = torch.exp(key - max_for_state) |
| | | |
| | | wkv = numerator / denominator |
| | | |
| | | state = [e1 * num_state + e2 * value, e1 * den_state + e2, max_for_state] |
| | | |
| | | return wkv, state |
| | | |
| | | |
| | | class DecoderSelfAttention(SelfAttention): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_decoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[List[torch.Tensor]] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute time mixing. |
| | | |
| | | Args: |
| | | x: SelfAttention input sequences. (B, U, size) |
| | | state: Decoder hidden states. [5 x (B, 1, D_att, N)] |
| | | |
| | | Returns: |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[1][..., self.block_id] = x |
| | | |
| | | wkv, att_state = self.wkv_linear_attention( |
| | | self.time_decay, |
| | | self.time_first, |
| | | key, |
| | | value, |
| | | tuple(s[..., self.block_id] for s in state[2:]), |
| | | ) |
| | | |
| | | state[2][..., self.block_id] = att_state[0] |
| | | state[3][..., self.block_id] = att_state[1] |
| | | state[4][..., self.block_id] = att_state[2] |
| | | else: |
| | | wkv = WKVLinearAttentionDecoder.apply(self.time_decay, self.time_first, key, value) |
| | | |
| | | wkv = self.dropout(wkv) |
| | | x = self.proj_output(receptance * wkv) |
| | | |
| | | return x, state |
| | | |
| | | class EncoderSelfAttention(SelfAttention): |
| | | """SelfAttention module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | attention_size: Attention hidden size. |
| | | context_size: Context size for WKV kernel. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size: int, |
| | | attention_size: int, |
| | | context_size: int, |
| | | block_id: int, |
| | | dropout_rate: float, |
| | | num_blocks: int, |
| | | ) -> None: |
| | | """Construct a SelfAttention object.""" |
| | | super().__init__( |
| | | size, |
| | | attention_size, |
| | | block_id, |
| | | dropout_rate, |
| | | num_blocks |
| | | ) |
| | | load_encoder_wkv_kernel(context_size) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | state: Optional[List[torch.Tensor]] = None, |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute time mixing. |
| | | |
| | | Args: |
| | | x: SelfAttention input sequences. (B, U, size) |
| | | state: Decoder hidden states. [5 x (B, 1, D_att, N)] |
| | | |
| | | Returns: |
| | | x: SelfAttention output sequences. (B, U, size) |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[1][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | value = x * self.time_mix_value + shifted_x * (1 - self.time_mix_value) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = self.proj_key(key) |
| | | value = self.proj_value(value) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[1][..., self.block_id] = x |
| | | |
| | | wkv, att_state = self.wkv_linear_attention( |
| | | self.time_decay, |
| | | self.time_first, |
| | | key, |
| | | value, |
| | | tuple(s[..., self.block_id] for s in state[2:]), |
| | | ) |
| | | |
| | | state[2][..., self.block_id] = att_state[0] |
| | | state[3][..., self.block_id] = att_state[1] |
| | | state[4][..., self.block_id] = att_state[2] |
| | | else: |
| | | wkv = WKVLinearAttentionEncoder.apply(self.time_decay, self.time_first, key, value) |
| | | |
| | | wkv = self.dropout(wkv) |
| | | x = self.proj_output(receptance * wkv) |
| | | |
| | | return x, state |
| | | |
| New file |
| | |
| | | """Feed-forward (channel mixing) module for RWKV block. |
| | | |
| | | Based/Modified from https://github.com/BlinkDL/RWKV-LM/blob/main/RWKV-v4/src/model.py |
| | | |
| | | Some variables are renamed according to https://github.com/huggingface/transformers/blob/main/src/transformers/models/rwkv/modeling_rwkv.py. |
| | | |
| | | """ # noqa |
| | | |
| | | from typing import List, Optional, Tuple |
| | | |
| | | import torch |
| | | |
| | | |
| | | class FeedForward(torch.nn.Module): |
| | | """FeedForward module definition. |
| | | |
| | | Args: |
| | | size: Input/Output size. |
| | | hidden_size: Hidden size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | |
| | | def __init__( |
| | | self, size: int, hidden_size: int, block_id: int, dropout_rate: float, num_blocks: int |
| | | ) -> None: |
| | | """Construct a FeedForward object.""" |
| | | super().__init__() |
| | | |
| | | self.time_shift = torch.nn.ZeroPad2d((0, 0, 1, -1)) |
| | | |
| | | self.time_mix_key = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | self.time_mix_receptance = torch.nn.Parameter(torch.empty(1, 1, size)) |
| | | |
| | | self.proj_key = torch.nn.Linear(size, hidden_size, bias=True) |
| | | self.proj_value = torch.nn.Linear(hidden_size, size, bias=True) |
| | | self.proj_receptance = torch.nn.Linear(size, size, bias=True) |
| | | |
| | | self.block_id = block_id |
| | | |
| | | self.reset_parameters(size, block_id, num_blocks) |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | def reset_parameters(self, size: int, block_id: int, num_blocks: int) -> None: |
| | | """Reset module parameters. |
| | | |
| | | Args: |
| | | size: Block size. |
| | | block_id: Block index. |
| | | num_blocks: Number of blocks in the architecture. |
| | | |
| | | """ |
| | | ratio_1_to_almost0 = 1.0 - (block_id / num_blocks) |
| | | |
| | | time_weight = torch.ones(1, 1, size) |
| | | |
| | | for i in range(size): |
| | | time_weight[0, 0, i] = i / size |
| | | |
| | | with torch.no_grad(): |
| | | self.time_mix_key.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | self.time_mix_receptance.data = torch.pow(time_weight, ratio_1_to_almost0) |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, state: Optional[List[torch.Tensor]] = None |
| | | ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]: |
| | | """Compute channel mixing. |
| | | |
| | | Args: |
| | | x: FeedForward input sequences. (B, U, size) |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | Returns: |
| | | x: FeedForward output sequences. (B, U, size) |
| | | state: Decoder hidden state. [5 x (B, 1, size, N)] |
| | | |
| | | """ |
| | | shifted_x = ( |
| | | self.time_shift(x) if state is None else state[0][..., self.block_id] |
| | | ) |
| | | |
| | | key = x * self.time_mix_key + shifted_x * (1 - self.time_mix_key) |
| | | receptance = x * self.time_mix_receptance + shifted_x * ( |
| | | 1 - self.time_mix_receptance |
| | | ) |
| | | |
| | | key = torch.square(torch.relu(self.proj_key(key))) |
| | | value = self.proj_value(self.dropout(key)) |
| | | receptance = torch.sigmoid(self.proj_receptance(receptance)) |
| | | |
| | | if state is not None: |
| | | state[0][..., self.block_id] = x |
| | | |
| | | x = receptance * value |
| | | |
| | | return x, state |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | |
| | | # Copyright 2019 Shigeki Karita |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Subsampling layer definition.""" |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn.functional as F |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | import logging |
| | | from funasr.modules.streaming_utils.utils import sequence_mask |
| | | from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len |
| | | from typing import Optional, Tuple, Union |
| | | import math |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | | Args: |
| | | message (str): Message for error catch |
| | | actual_size (int): the short size that cannot pass the subsampling |
| | | limit (int): the limit size for subsampling |
| | | |
| | | """ |
| | | |
| | | def __init__(self, message, actual_size, limit): |
| | | """Construct a TooShortUttError for error handler.""" |
| | | super().__init__(message) |
| | | self.actual_size = actual_size |
| | | self.limit = limit |
| | | |
| | | |
| | | def check_short_utt(ins, size): |
| | | """Check if the utterance is too short for subsampling.""" |
| | | if isinstance(ins, Conv2dSubsampling2) and size < 3: |
| | | return True, 3 |
| | | if isinstance(ins, Conv2dSubsampling) and size < 7: |
| | | return True, 7 |
| | | if isinstance(ins, Conv2dSubsampling6) and size < 11: |
| | | return True, 11 |
| | | if isinstance(ins, Conv2dSubsampling8) and size < 15: |
| | | return True, 15 |
| | | return False, -1 |
| | | |
| | | |
| | | class RWKVConvInput(torch.nn.Module): |
| | | """Streaming ConvInput module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | conv_size: Convolution size. |
| | | subsampling_factor: Subsampling factor. |
| | | output_size: Block output dimension. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | conv_size: Union[int, Tuple], |
| | | subsampling_factor: int = 4, |
| | | conv_kernel_size: int = 3, |
| | | output_size: Optional[int] = None, |
| | | ) -> None: |
| | | """Construct a ConvInput object.""" |
| | | super().__init__() |
| | | if subsampling_factor == 1: |
| | | conv_size1, conv_size2, conv_size3 = conv_size |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=[1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size3 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = 1 |
| | | |
| | | self.stride_1 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | else: |
| | | conv_size1, conv_size2, conv_size3 = conv_size |
| | | |
| | | kernel_1 = int(subsampling_factor / 2) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, conv_kernel_size, stride=[kernel_1, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, conv_kernel_size, stride=[2, 2], padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size3, conv_size3, conv_kernel_size, stride=1, padding=(conv_kernel_size-1)//2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size3 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | self.stride_1 = kernel_1 |
| | | |
| | | self.min_frame_length = 7 |
| | | |
| | | if output_size is not None: |
| | | self.output = torch.nn.Linear(output_proj, output_size) |
| | | self.output_size = output_size |
| | | else: |
| | | self.output = None |
| | | self.output_size = output_proj |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: ConvInput input sequences. (B, T, D_feats) |
| | | mask: Mask of input sequences. (B, 1, T) |
| | | Returns: |
| | | x: ConvInput output sequences. (B, sub(T), D_out) |
| | | mask: Mask of output sequences. (B, 1, sub(T)) |
| | | """ |
| | | if mask is not None: |
| | | mask = self.create_new_mask(mask) |
| | | olens = max(mask.eq(0).sum(1)) |
| | | |
| | | b, t, f = x.size() |
| | | x = x.unsqueeze(1) # (b. 1. t. f) |
| | | |
| | | if chunk_size is not None: |
| | | max_input_length = int( |
| | | chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) |
| | | ) |
| | | x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) |
| | | x = list(x) |
| | | x = torch.stack(x, dim=0) |
| | | N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) |
| | | x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) |
| | | |
| | | x = self.conv(x) |
| | | |
| | | _, c, _, f = x.size() |
| | | if chunk_size is not None: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] |
| | | else: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f) |
| | | |
| | | if self.output is not None: |
| | | x = self.output(x) |
| | | |
| | | return x, mask[:,:olens][:,:x.size(1)] |
| | | |
| | | def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create a new mask for VGG output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | return mask[:, ::2][:, ::self.stride_1] |
| | | else: |
| | | return mask |
| | | |
| | | def get_size_before_subsampling(self, size: int) -> int: |
| | | """Return the original size before subsampling for a given size. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of frames before subsampling. |
| | | """ |
| | | return size * self.subsampling_factor |
| | |
| | | public static final String TAG = MainActivity.class.getSimpleName(); |
| | | // WebSocket地址 |
| | | public String ASR_HOST = ""; |
| | | // 官方WebSocket地址 |
| | | public static final String DEFAULT_HOST = "wss://101.37.77.25:10088"; |
| | | // 发送的JSON数据 |
| | | public static final String MODE = "2pass"; |
| | | public static final String CHUNK_SIZE = "5, 10, 5"; |
| | |
| | | // 控件 |
| | | private Button recordBtn; |
| | | private TextView resultText; |
| | | private WebSocket webSocket; |
| | | |
| | | @SuppressLint("ClickableViewAccessibility") |
| | | @Override |
| | |
| | | ASR_HOST = uri; |
| | | } |
| | | // 读取热词 |
| | | String hotWords = sharedPreferences.getString("hotwords", ""); |
| | | if (!hotWords.equals("")) { |
| | | String hotWords = sharedPreferences.getString("hotwords", null); |
| | | if (hotWords != null) { |
| | | this.hotWords = hotWords; |
| | | } |
| | | } |
| | |
| | | editor.apply(); |
| | | } |
| | | }); |
| | | builder.setNeutralButton("使用官方服务", (dialog, id) -> { |
| | | ASR_HOST = DEFAULT_HOST; |
| | | input.setText(DEFAULT_HOST); |
| | | Toast.makeText(MainActivity.this, "WebSocket地址:" + ASR_HOST, Toast.LENGTH_SHORT).show(); |
| | | SharedPreferences.Editor editor = sharedPreferences.edit(); |
| | | editor.putString("uri", ASR_HOST); |
| | | editor.apply(); |
| | | }); |
| | | AlertDialog dialog = builder.create(); |
| | | dialog.show(); |
| | | } |
| | |
| | | builder.setView(view); |
| | | builder.setPositiveButton("确定", (dialog, id) -> { |
| | | String hotwords = input.getText().toString(); |
| | | if (!hotwords.equals("")) { |
| | | this.hotWords = hotwords; |
| | | SharedPreferences.Editor editor = sharedPreferences.edit(); |
| | | editor.putString("hotwords", hotwords); |
| | | editor.apply(); |
| | | } |
| | | this.hotWords = hotwords; |
| | | SharedPreferences.Editor editor = sharedPreferences.edit(); |
| | | editor.putString("hotwords", hotwords); |
| | | editor.apply(); |
| | | }); |
| | | AlertDialog dialog = builder.create(); |
| | | dialog.show(); |
| | |
| | | Request request = new Request.Builder() |
| | | .url(ASR_HOST) |
| | | .build(); |
| | | webSocket = client.newWebSocket(request, new WebSocketListener() { |
| | | WebSocket webSocket = client.newWebSocket(request, new WebSocketListener() { |
| | | |
| | | @Override |
| | | public void onOpen(@NonNull WebSocket webSocket, @NonNull Response response) { |
| | |
| | | obj.put("chunk_size", array); |
| | | obj.put("chunk_interval", CHUNK_INTERVAL); |
| | | obj.put("wav_name", "default"); |
| | | obj.put("hotwords", hotWords); |
| | | if (!hotWords.equals("")) { |
| | | obj.put("hotwords", hotWords); |
| | | } |
| | | obj.put("wav_format", "pcm"); |
| | | obj.put("is_speaking", isSpeaking); |
| | | return obj.toString(); |
| | |
| | | --io-thread-num 8 \ |
| | | --port 10095 \ |
| | | --certfile ../../../ssl_key/server.crt \ |
| | | --keyfile ../../../ssl_key/server.key > log.out 2>&1 & |
| | | --keyfile ../../../ssl_key/server.key \ |
| | | --hotword ../../hotwords.txt > log.out 2>&1 & |
| | | ``` |
| | | |
| | | Introduction to run_server.sh parameters: |
| | |
| | | --io-thread-num: Number of IO threads that the server starts. Default is 1. |
| | | --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0 |
| | | --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. |
| | | --hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt |
| | | ``` |
| | | |
| | | ### Shutting Down the FunASR Service |
| | |
| | | --io-thread-num 8 \ |
| | | --port 10095 \ |
| | | --certfile ../../../ssl_key/server.crt \ |
| | | --keyfile ../../../ssl_key/server.key |
| | | --keyfile ../../../ssl_key/server.key > log.out 2>&1 & |
| | | ``` |
| | | |
| | | Introduction to run_server.sh parameters: |
| | |
| | | --io-thread-num 8 \ |
| | | --port 10095 \ |
| | | --certfile ../../../ssl_key/server.crt \ |
| | | --keyfile ../../../ssl_key/server.key > log.out 2>&1 & |
| | | --keyfile ../../../ssl_key/server.key \ |
| | | --hotword ../../hotwords.txt > log.out 2>&1 & |
| | | ``` |
| | | **run_server.sh命令参数介绍** |
| | | ```text |
| | |
| | | --io-thread-num 服务端启动的IO线程数,默认为 1 |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0 |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key |
| | | --hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt |
| | | ``` |
| | | |
| | | ### 关闭FunASR服务 |
| | |
| | | --io-thread-num 8 \ |
| | | --port 10095 \ |
| | | --certfile ../../../ssl_key/server.crt \ |
| | | --keyfile ../../../ssl_key/server.key > log.out 2>&1 & |
| | | --keyfile ../../../ssl_key/server.key \ |
| | | --hotword ../../hotwords.txt > log.out 2>&1 & |
| | | |
| | | # If you want to close ssl,please add:--certfile 0 |
| | | # If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model: |
| | |
| | | --io-thread-num: Number of IO threads that the server starts. Default is 1. |
| | | --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set 0 |
| | | --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. |
| | | --hotword Hotword file path, one line for each hot word, if the client provides hot words, then combined with the hot words provided by the client. Default is ../../hotwords.txt |
| | | ``` |
| | | |
| | | ### Shutting Down the FunASR Service |
| | |
| | | --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \ |
| | | --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \ |
| | | --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx \ |
| | | --itn-dir thuduj12/fst_itn_zh > log.out 2>&1 & |
| | | --itn-dir thuduj12/fst_itn_zh \ |
| | | --hotwordsfile ../../hotwords.txt > log.out 2>&1 & |
| | | |
| | | # 如果您想关闭ssl,增加参数:--certfile 0 |
| | | # 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型: |
| | |
| | | --io-thread-num 8 \ |
| | | --port 10095 \ |
| | | --certfile ../../../ssl_key/server.crt \ |
| | | --keyfile ../../../ssl_key/server.key > log.out 2>&1 & |
| | | --keyfile ../../../ssl_key/server.key \ |
| | | --hotword ../../hotwords.txt > log.out 2>&1 & |
| | | ``` |
| | | **run_server_2pass.sh命令参数介绍** |
| | | ```text |
| | |
| | | --io-thread-num 服务端启动的IO线程数,默认为 1 |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为0 |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key |
| | | --hotword 热词文件路径,每一个热词一行,如果客户端提供热词,则与客户端提供的热词合并一起使用。默认为:../../hotwords.txt |
| | | ``` |
| | | |
| | | ### 关闭FunASR服务 |
| | |
| | | |
| | | def time_stamp_lfr6_onnx(us_cif_peak, char_list, begin_time=0.0, total_offset=-1.5): |
| | | if not len(char_list): |
| | | return [] |
| | | return '', [] |
| | | START_END_THRESHOLD = 5 |
| | | MAX_TOKEN_DURATION = 30 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | |
| | | port=10095 |
| | | certfile="../../../ssl_key/server.crt" |
| | | keyfile="../../../ssl_key/server.key" |
| | | hotwordsfile="../../hotwords.txt" |
| | | |
| | | . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1; |
| | | |
| | |
| | | --io-thread-num ${io_thread_num} \ |
| | | --port ${port} \ |
| | | --certfile "" \ |
| | | --keyfile "" |
| | | --keyfile "" \ |
| | | --hotwordsfile ${hotwordsfile} |
| | | else |
| | | ./funasr-wss-server \ |
| | | --download-model-dir ${download_model_dir} \ |
| | |
| | | --io-thread-num ${io_thread_num} \ |
| | | --port ${port} \ |
| | | --certfile ${certfile} \ |
| | | --keyfile ${keyfile} |
| | | --keyfile ${keyfile} \ |
| | | --hotwordsfile ${hotwordsfile} |
| | | fi |
| | |
| | | port=10095 |
| | | certfile="../../../ssl_key/server.crt" |
| | | keyfile="../../../ssl_key/server.key" |
| | | hotwordsfile="../../hotwords.txt" |
| | | |
| | | . ../../egs/aishell/transformer/utils/parse_options.sh || exit 1; |
| | | |
| | |
| | | --io-thread-num ${io_thread_num} \ |
| | | --port ${port} \ |
| | | --certfile "" \ |
| | | --keyfile "" |
| | | --keyfile "" \ |
| | | --hotwordsfile ${hotwordsfile} |
| | | else |
| | | ./funasr-wss-server-2pass \ |
| | | --download-model-dir ${download_model_dir} \ |
| | |
| | | --io-thread-num ${io_thread_num} \ |
| | | --port ${port} \ |
| | | --certfile ${certfile} \ |
| | | --keyfile ${keyfile} |
| | | --keyfile ${keyfile} \ |
| | | --hotwordsfile ${hotwordsfile} |
| | | fi |
| | |
| | | #include <unistd.h> |
| | | #include "websocket-server-2pass.h" |
| | | |
| | | #include <fstream> |
| | | std::string hotwords = ""; |
| | | |
| | | using namespace std; |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, |
| | | std::map<std::string, std::string>& model_path) { |
| | |
| | | "default: ../../../ssl_key/server.key, path of keyfile for WSS " |
| | | "connection", |
| | | false, "../../../ssl_key/server.key", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> hotwordsfile( |
| | | "", "hotword", |
| | | "default: ../../hotwords.txt, path of hotwordsfile" |
| | | "connection", |
| | | false, "../../hotwords.txt", "string"); |
| | | |
| | | // add file |
| | | cmd.add(hotwordsfile); |
| | | |
| | | cmd.add(certfile); |
| | | cmd.add(keyfile); |
| | |
| | | std::string s_certfile = certfile.getValue(); |
| | | std::string s_keyfile = keyfile.getValue(); |
| | | |
| | | std::string s_hotwordsfile = hotwordsfile.getValue(); |
| | | std::string line; |
| | | std::ifstream file(s_hotwordsfile); |
| | | LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile; |
| | | |
| | | if (file.is_open()) { |
| | | while (getline(file, line)) { |
| | | hotwords += line+HOTWORD_SEP; |
| | | } |
| | | LOG(INFO) << "hotwords: " << hotwords; |
| | | file.close(); |
| | | } else { |
| | | LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile; |
| | | } |
| | | |
| | | bool is_ssl = false; |
| | | if (!s_certfile.empty()) { |
| | | is_ssl = true; |
| | |
| | | websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model |
| | | } |
| | | |
| | | std::cout << "asr model init finished. listen on port:" << s_port |
| | | << std::endl; |
| | | LOG(INFO) << "asr model init finished. listen on port:" << s_port; |
| | | |
| | | // Start the ASIO network io_service run loop |
| | | std::vector<std::thread> ts; |
| | |
| | | } |
| | | |
| | | } catch (std::exception const& e) { |
| | | std::cerr << "Error: " << e.what() << std::endl; |
| | | LOG(ERROR) << "Error: " << e.what(); |
| | | } |
| | | |
| | | return 0; |
| | |
| | | #include "websocket-server.h" |
| | | #include <unistd.h> |
| | | |
| | | #include <fstream> |
| | | std::string hotwords = ""; |
| | | |
| | | using namespace std; |
| | | void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, |
| | | std::map<std::string, std::string>& model_path) { |
| | |
| | | TCLAP::ValueArg<std::string> keyfile("", "keyfile", |
| | | "default: ../../../ssl_key/server.key, path of keyfile for WSS connection", |
| | | false, "../../../ssl_key/server.key", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> hotwordsfile( |
| | | "", "hotword", |
| | | "default: ../../hotwords.txt, path of hotwordsfile" |
| | | "connection", |
| | | false, "../../hotwords.txt", "string"); |
| | | |
| | | // add file |
| | | cmd.add(hotwordsfile); |
| | | |
| | | cmd.add(certfile); |
| | | cmd.add(keyfile); |
| | |
| | | std::string s_certfile = certfile.getValue(); |
| | | std::string s_keyfile = keyfile.getValue(); |
| | | |
| | | std::string s_hotwordsfile = hotwordsfile.getValue(); |
| | | std::string line; |
| | | std::ifstream file(s_hotwordsfile); |
| | | LOG(INFO) << "hotwordsfile path: " << s_hotwordsfile; |
| | | |
| | | if (file.is_open()) { |
| | | while (getline(file, line)) { |
| | | hotwords += line+HOTWORD_SEP; |
| | | } |
| | | LOG(INFO) << "hotwords: " << hotwords; |
| | | file.close(); |
| | | } else { |
| | | LOG(ERROR) << "Unable to open hotwords file: " << s_hotwordsfile; |
| | | } |
| | | |
| | | bool is_ssl = false; |
| | | if (!s_certfile.empty()) { |
| | | is_ssl = true; |
| | |
| | | websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model |
| | | } |
| | | |
| | | std::cout << "asr model init finished. listen on port:" << s_port |
| | | << std::endl; |
| | | LOG(INFO) << "asr model init finished. listen on port:" << s_port; |
| | | |
| | | // Start the ASIO network io_service run loop |
| | | std::vector<std::thread> ts; |
| | |
| | | } |
| | | |
| | | } catch (std::exception const& e) { |
| | | std::cerr << "Error: " << e.what() << std::endl; |
| | | LOG(ERROR) << "Error: " << e.what(); |
| | | } |
| | | |
| | | return 0; |
| | |
| | | #include <thread> |
| | | #include <utility> |
| | | #include <vector> |
| | | #include <chrono> |
| | | |
| | | extern std::string hotwords; |
| | | |
| | | context_ptr WebSocketServer::on_tls_init(tls_mode mode, |
| | | websocketpp::connection_hdl hdl, |
| | | std::string& s_certfile, |
| | |
| | | unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection |
| | | switch (msg->get_opcode()) { |
| | | case websocketpp::frame::opcode::text: { |
| | | nlohmann::json jsonresult = nlohmann::json::parse(payload); |
| | | nlohmann::json jsonresult; |
| | | try{ |
| | | jsonresult = nlohmann::json::parse(payload); |
| | | }catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR)<<e.what(); |
| | | break; |
| | | } |
| | | |
| | | if (jsonresult.contains("wav_name")) { |
| | | msg_data->msg["wav_name"] = jsonresult["wav_name"]; |
| | |
| | | msg_data->msg["hotwords"] = jsonresult["hotwords"]; |
| | | if (!msg_data->msg["hotwords"].empty()) { |
| | | std::string hw = msg_data->msg["hotwords"]; |
| | | LOG(INFO)<<"hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); |
| | | hw = hw + " " + hotwords; |
| | | LOG(INFO) << "hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | } |
| | | }else{ |
| | | } else { |
| | | if (hotwords.empty()) { |
| | | std::string hw = ""; |
| | | LOG(INFO)<<"hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | }else { |
| | | std::string hw = hotwords; |
| | | LOG(INFO) << "hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(tpass_handle, hw, ASR_TWO_PASS); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | } |
| | | } |
| | | } |
| | | if (jsonresult.contains("audio_fs")) { |
| | |
| | | #include <utility> |
| | | #include <vector> |
| | | |
| | | extern std::string hotwords; |
| | | |
| | | context_ptr WebSocketServer::on_tls_init(tls_mode mode, |
| | | websocketpp::connection_hdl hdl, |
| | | std::string& s_certfile, |
| | |
| | | unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection |
| | | switch (msg->get_opcode()) { |
| | | case websocketpp::frame::opcode::text: { |
| | | nlohmann::json jsonresult = nlohmann::json::parse(payload); |
| | | nlohmann::json jsonresult; |
| | | try{ |
| | | jsonresult = nlohmann::json::parse(payload); |
| | | }catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR)<<e.what(); |
| | | break; |
| | | } |
| | | |
| | | if (jsonresult["wav_name"] != nullptr) { |
| | | msg_data->msg["wav_name"] = jsonresult["wav_name"]; |
| | | } |
| | |
| | | msg_data->msg["hotwords"] = jsonresult["hotwords"]; |
| | | if (!msg_data->msg["hotwords"].empty()) { |
| | | std::string hw = msg_data->msg["hotwords"]; |
| | | LOG(INFO)<<"hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); |
| | | hw = hw + " " + hotwords; |
| | | LOG(INFO) << "hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hw); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | } |
| | | }else{ |
| | | } else { |
| | | if (hotwords.empty()) { |
| | | std::string hw = ""; |
| | | LOG(INFO)<<"hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | }else { |
| | | std::string hw = hotwords; |
| | | LOG(INFO) << "hotwords: " << hw; |
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw); |
| | | msg_data->hotwords_embedding = |
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding); |
| | | } |
| | | } |
| | | } |
| | | if (jsonresult.contains("audio_fs")) { |