8个文件已修改
1 文件已重命名
14个文件已删除
| | |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder |
| | | from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Dict |
| | | |
| | | import torch |
| | | from torch import nn |
| | |
| | | from funasr.modules.attention import ( |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | ) |
| | | from funasr.modules.embedding import ( |
| | |
| | | ScaledPositionalEncoding, # noqa: H301 |
| | | RelPositionalEncoding, # noqa: H301 |
| | | LegacyRelPositionalEncoding, # noqa: H301 |
| | | StreamingRelPositionalEncoding, |
| | | ) |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | from funasr.modules.normalization import get_normalization |
| | | from funasr.modules.multi_layer_conv import Conv1dLinear |
| | | from funasr.modules.multi_layer_conv import MultiLayeredConv1d |
| | | from funasr.modules.nets_utils import get_activation |
| | | from funasr.modules.nets_utils import make_pad_mask |
| | | from funasr.modules.nets_utils import ( |
| | | TooShortUttError, |
| | | check_short_utt, |
| | | make_chunk_mask, |
| | | make_source_mask, |
| | | ) |
| | | from funasr.modules.positionwise_feed_forward import ( |
| | | PositionwiseFeedForward, # noqa: H301 |
| | | ) |
| | | from funasr.modules.repeat import repeat |
| | | from funasr.modules.repeat import repeat, MultiBlocks |
| | | from funasr.modules.subsampling import Conv2dSubsampling |
| | | from funasr.modules.subsampling import Conv2dSubsampling2 |
| | | from funasr.modules.subsampling import Conv2dSubsampling6 |
| | |
| | | from funasr.modules.subsampling import TooShortUttError |
| | | from funasr.modules.subsampling import check_short_utt |
| | | from funasr.modules.subsampling import Conv2dSubsamplingPad |
| | | from funasr.modules.subsampling import StreamingConvInput |
| | | |
| | | class ConvolutionModule(nn.Module): |
| | | """ConvolutionModule in Conformer model. |
| | | |
| | |
| | | return (x, pos_emb), mask |
| | | |
| | | return x, mask |
| | | |
| | | class ChunkEncoderLayer(torch.nn.Module): |
| | | """Chunk Conformer module definition. |
| | | Args: |
| | | block_size: Input/output size. |
| | | self_att: Self-attention module instance. |
| | | feed_forward: Feed-forward module instance. |
| | | feed_forward_macaron: Feed-forward module instance for macaron network. |
| | | conv_mod: Convolution module instance. |
| | | norm_class: Normalization module class. |
| | | norm_args: Normalization module arguments. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | block_size: int, |
| | | self_att: torch.nn.Module, |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a Conformer object.""" |
| | | super().__init__() |
| | | |
| | | self.self_att = self_att |
| | | |
| | | self.feed_forward = feed_forward |
| | | self.feed_forward_macaron = feed_forward_macaron |
| | | self.feed_forward_scale = 0.5 |
| | | |
| | | self.conv_mod = conv_mod |
| | | |
| | | self.norm_feed_forward = norm_class(block_size, **norm_args) |
| | | self.norm_self_att = norm_class(block_size, **norm_args) |
| | | |
| | | self.norm_macaron = norm_class(block_size, **norm_args) |
| | | self.norm_conv = norm_class(block_size, **norm_args) |
| | | self.norm_final = norm_class(block_size, **norm_args) |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | self.block_size = block_size |
| | | self.cache = None |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset self-attention and convolution modules cache for streaming. |
| | | Args: |
| | | left_context: Number of left frames during chunk-by-chunk inference. |
| | | device: Device to use for cache tensor. |
| | | """ |
| | | self.cache = [ |
| | | torch.zeros( |
| | | (1, left_context, self.block_size), |
| | | device=device, |
| | | ), |
| | | torch.zeros( |
| | | ( |
| | | 1, |
| | | self.block_size, |
| | | self.conv_mod.kernel_size - 1, |
| | | ), |
| | | device=device, |
| | | ), |
| | | ] |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T) |
| | | chunk_mask: Chunk mask. (T_2, T_2) |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | mask: Source mask. (B, T) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.dropout( |
| | | self.feed_forward_macaron(x) |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | x_q = x |
| | | x = residual + self.dropout( |
| | | self.self_att( |
| | | x_q, |
| | | x, |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | ) |
| | | |
| | | residual = x |
| | | |
| | | x = self.norm_conv(x) |
| | | x, _ = self.conv_mod(x) |
| | | x = residual + self.dropout(x) |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) |
| | | |
| | | x = self.norm_final(x) |
| | | return x, mask, pos_enc |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 0, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode chunk of input sequence. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T_2) |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | if left_context > 0: |
| | | key = torch.cat([self.cache[0], x], dim=1) |
| | | else: |
| | | key = x |
| | | val = key |
| | | |
| | | if right_context > 0: |
| | | att_cache = key[:, -(left_context + right_context) : -right_context, :] |
| | | else: |
| | | att_cache = key[:, -left_context:, :] |
| | | x = residual + self.self_att( |
| | | x, |
| | | key, |
| | | val, |
| | | pos_enc, |
| | | mask, |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_conv(x) |
| | | x, conv_cache = self.conv_mod( |
| | | x, cache=self.cache[1], right_context=right_context |
| | | ) |
| | | x = residual + x |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward(x) |
| | | |
| | | x = self.norm_final(x) |
| | | self.cache = [att_cache, conv_cache] |
| | | |
| | | return x, pos_enc |
| | | |
| | | |
| | | class ConformerEncoder(AbsEncoder): |
| | |
| | | if len(intermediate_outs) > 0: |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | |
| | | class CausalConvolution(torch.nn.Module): |
| | | """ConformerConvolution module definition. |
| | | Args: |
| | | channels: The number of channels. |
| | | kernel_size: Size of the convolving kernel. |
| | | activation: Type of activation function. |
| | | norm_args: Normalization module arguments. |
| | | causal: Whether to use causal convolution (set to True if streaming). |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | channels: int, |
| | | kernel_size: int, |
| | | activation: torch.nn.Module = torch.nn.ReLU(), |
| | | norm_args: Dict = {}, |
| | | causal: bool = False, |
| | | ) -> None: |
| | | """Construct an ConformerConvolution object.""" |
| | | super().__init__() |
| | | |
| | | assert (kernel_size - 1) % 2 == 0 |
| | | |
| | | self.kernel_size = kernel_size |
| | | |
| | | self.pointwise_conv1 = torch.nn.Conv1d( |
| | | channels, |
| | | 2 * channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | if causal: |
| | | self.lorder = kernel_size - 1 |
| | | padding = 0 |
| | | else: |
| | | self.lorder = 0 |
| | | padding = (kernel_size - 1) // 2 |
| | | |
| | | self.depthwise_conv = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size, |
| | | stride=1, |
| | | padding=padding, |
| | | groups=channels, |
| | | ) |
| | | self.norm = torch.nn.BatchNorm1d(channels, **norm_args) |
| | | self.pointwise_conv2 = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | self.activation = activation |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | cache: Optional[torch.Tensor] = None, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute convolution module. |
| | | Args: |
| | | x: ConformerConvolution input sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: ConformerConvolution output sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) |
| | | """ |
| | | x = self.pointwise_conv1(x.transpose(1, 2)) |
| | | x = torch.nn.functional.glu(x, dim=1) |
| | | |
| | | if self.lorder > 0: |
| | | if cache is None: |
| | | x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) |
| | | else: |
| | | x = torch.cat([cache, x], dim=2) |
| | | |
| | | if right_context > 0: |
| | | cache = x[:, :, -(self.lorder + right_context) : -right_context] |
| | | else: |
| | | cache = x[:, :, -self.lorder :] |
| | | |
| | | x = self.depthwise_conv(x) |
| | | x = self.activation(self.norm(x)) |
| | | |
| | | x = self.pointwise_conv2(x).transpose(1, 2) |
| | | |
| | | return x, cache |
| | | |
| | | class ConformerChunkEncoder(torch.nn.Module): |
| | | """Encoder module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | body_conf: Encoder body configuration. |
| | | input_conf: Encoder input configuration. |
| | | main_conf: Encoder main configuration. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | embed_vgg_like: bool = False, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | norm_type: str = "layer_norm", |
| | | cnn_module_kernel: int = 31, |
| | | conv_mod_norm_eps: float = 0.00001, |
| | | conv_mod_norm_momentum: float = 0.1, |
| | | simplified_att_score: bool = False, |
| | | dynamic_chunk_training: bool = False, |
| | | short_chunk_threshold: float = 0.75, |
| | | short_chunk_size: int = 25, |
| | | left_chunk_size: int = 0, |
| | | time_reduction_factor: int = 1, |
| | | unified_model_training: bool = False, |
| | | default_chunk_size: int = 16, |
| | | jitter_range: int = 4, |
| | | subsampling_factor: int = 1, |
| | | **activation_parameters, |
| | | ) -> None: |
| | | """Construct an Encoder object.""" |
| | | super().__init__() |
| | | |
| | | assert check_argument_types() |
| | | |
| | | self.embed = StreamingConvInput( |
| | | input_size, |
| | | output_size, |
| | | subsampling_factor, |
| | | vgg_like=embed_vgg_like, |
| | | output_size=output_size, |
| | | ) |
| | | |
| | | self.pos_enc = StreamingRelPositionalEncoding( |
| | | output_size, |
| | | positional_dropout_rate, |
| | | ) |
| | | |
| | | activation = get_activation( |
| | | activation_type, **activation_parameters |
| | | ) |
| | | |
| | | pos_wise_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positional_dropout_rate, |
| | | activation, |
| | | ) |
| | | |
| | | conv_mod_norm_args = { |
| | | "eps": conv_mod_norm_eps, |
| | | "momentum": conv_mod_norm_momentum, |
| | | } |
| | | |
| | | conv_mod_args = ( |
| | | output_size, |
| | | cnn_module_kernel, |
| | | activation, |
| | | conv_mod_norm_args, |
| | | dynamic_chunk_training or unified_model_training, |
| | | ) |
| | | |
| | | mult_att_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | simplified_att_score, |
| | | ) |
| | | |
| | | norm_class, norm_args = get_normalization( |
| | | norm_type, |
| | | ) |
| | | |
| | | fn_modules = [] |
| | | for _ in range(num_blocks): |
| | | module = lambda: ChunkEncoderLayer( |
| | | output_size, |
| | | RelPositionMultiHeadedAttentionChunk(*mult_att_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | CausalConvolution(*conv_mod_args), |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | fn_modules.append(module) |
| | | |
| | | self.encoders = MultiBlocks( |
| | | [fn() for fn in fn_modules], |
| | | output_size, |
| | | norm_class=norm_class, |
| | | norm_args=norm_args, |
| | | ) |
| | | |
| | | self.output_size = output_size |
| | | |
| | | self.dynamic_chunk_training = dynamic_chunk_training |
| | | self.short_chunk_threshold = short_chunk_threshold |
| | | self.short_chunk_size = short_chunk_size |
| | | self.left_chunk_size = left_chunk_size |
| | | |
| | | self.unified_model_training = unified_model_training |
| | | self.default_chunk_size = default_chunk_size |
| | | self.jitter_range = jitter_range |
| | | |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | hop_length: Frontend's hop length |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) * hop_length |
| | | |
| | | def get_encoder_input_size(self, size: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) |
| | | |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset encoder streaming cache. |
| | | Args: |
| | | left_context: Number of frames in left context. |
| | | device: Device ID. |
| | | """ |
| | | return self.encoders.reset_streaming_cache(left_context, device) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Encoder input features. (B, T_in, F) |
| | | x_len: Encoder input features lengths. (B,) |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | x_len: Encoder outputs lenghts. (B,) |
| | | """ |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | |
| | | if self.unified_model_training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | x_utt = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=None, |
| | | ) |
| | | x_chunk = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | | x_chunk = x_chunk[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x_utt, x_chunk, olens |
| | | |
| | | elif self.dynamic_chunk_training: |
| | | max_len = x.size(1) |
| | | chunk_size = torch.randint(1, max_len, (1,)).item() |
| | | |
| | | if chunk_size > (max_len * self.short_chunk_threshold): |
| | | chunk_size = max_len |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | else: |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = None |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | |
| | | def simu_chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | |
| | | return x |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | processed_frames: torch.tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Encode input sequences as chunks. |
| | | Args: |
| | | x: Encoder input features. (1, T_in, F) |
| | | x_len: Encoder input features lengths. (1,) |
| | | processed_frames: Number of frames already seen. |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | """ |
| | | mask = make_source_mask(x_len) |
| | | x, mask = self.embed(x, mask, None) |
| | | |
| | | if left_context > 0: |
| | | processed_mask = ( |
| | | torch.arange(left_context, device=x.device) |
| | | .view(1, left_context) |
| | | .flip(1) |
| | | ) |
| | | processed_mask = processed_mask >= processed_frames |
| | | mask = torch.cat([processed_mask, mask], dim=1) |
| | | pos_enc = self.pos_enc(x, left_context=left_context) |
| | | x = self.encoders.chunk_forward( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | |
| | | if right_context > 0: |
| | | x = x[:, 0:-right_context, :] |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | return x |
| | |
| | | import numpy |
| | | import torch |
| | | from torch import nn |
| | | |
| | | from typing import Optional, Tuple |
| | | |
| | | class MultiHeadedAttention(nn.Module): |
| | | """Multi-Head Attention layer. |
| | |
| | | scores = torch.matmul(q_h, k_h.transpose(-2, -1)) |
| | | att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder) |
| | | return att_outs |
| | | |
| | | class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): |
| | | """RelPositionMultiHeadedAttention definition. |
| | | Args: |
| | | num_heads: Number of attention heads. |
| | | embed_size: Embedding size. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | num_heads: int, |
| | | embed_size: int, |
| | | dropout_rate: float = 0.0, |
| | | simplified_attention_score: bool = False, |
| | | ) -> None: |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | |
| | | self.d_k = embed_size // num_heads |
| | | self.num_heads = num_heads |
| | | |
| | | assert self.d_k * num_heads == embed_size, ( |
| | | "embed_size (%d) must be divisible by num_heads (%d)", |
| | | (embed_size, num_heads), |
| | | ) |
| | | |
| | | self.linear_q = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_k = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_v = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | self.linear_out = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | if simplified_attention_score: |
| | | self.linear_pos = torch.nn.Linear(embed_size, num_heads) |
| | | |
| | | self.compute_att_score = self.compute_simplified_attention_score |
| | | else: |
| | | self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) |
| | | |
| | | self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_v) |
| | | |
| | | self.compute_att_score = self.compute_attention_score |
| | | |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | self.attn = None |
| | | |
| | | def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: |
| | | """Compute relative positional encoding. |
| | | Args: |
| | | x: Input sequence. (B, H, T_1, 2 * T_1 - 1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | x: Output sequence. (B, H, T_1, T_2) |
| | | """ |
| | | batch_size, n_heads, time1, n = x.shape |
| | | time2 = time1 + left_context |
| | | |
| | | batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() |
| | | |
| | | return x.as_strided( |
| | | (batch_size, n_heads, time1, time2), |
| | | (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), |
| | | storage_offset=(n_stride * (time1 - 1)), |
| | | ) |
| | | |
| | | def compute_simplified_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Simplified attention score computation. |
| | | Reference: https://github.com/k2-fsa/icefall/pull/458 |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | pos_enc = self.linear_pos(pos_enc) |
| | | |
| | | matrix_ac = torch.matmul(query, key.transpose(2, 3)) |
| | | |
| | | matrix_bd = self.rel_shift( |
| | | pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def compute_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Attention score computation. |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) |
| | | |
| | | query = query.transpose(1, 2) |
| | | q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) |
| | | q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) |
| | | |
| | | matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) |
| | | |
| | | matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) |
| | | matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def forward_qkv( |
| | | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Transform query, key and value. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | v: Value tensor. (B, T_2, size) |
| | | Returns: |
| | | q: Transformed query tensor. (B, H, T_1, d_k) |
| | | k: Transformed key tensor. (B, H, T_2, d_k) |
| | | v: Transformed value tensor. (B, H, T_2, d_k) |
| | | """ |
| | | n_batch = query.size(0) |
| | | |
| | | q = ( |
| | | self.linear_q(query) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | k = ( |
| | | self.linear_k(key) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | v = ( |
| | | self.linear_v(value) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | |
| | | return q, k, v |
| | | |
| | | def forward_attention( |
| | | self, |
| | | value: torch.Tensor, |
| | | scores: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> torch.Tensor: |
| | | """Compute attention context vector. |
| | | Args: |
| | | value: Transformed value. (B, H, T_2, d_k) |
| | | scores: Attention score. (B, H, T_1, T_2) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | Returns: |
| | | attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) |
| | | """ |
| | | batch_size = scores.size(0) |
| | | mask = mask.unsqueeze(1).unsqueeze(2) |
| | | if chunk_mask is not None: |
| | | mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask |
| | | scores = scores.masked_fill(mask, float("-inf")) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
| | | |
| | | attn_output = self.dropout(self.attn) |
| | | attn_output = torch.matmul(attn_output, value) |
| | | |
| | | attn_output = self.linear_out( |
| | | attn_output.transpose(1, 2) |
| | | .contiguous() |
| | | .view(batch_size, -1, self.num_heads * self.d_k) |
| | | ) |
| | | |
| | | return attn_output |
| | | |
| | | def forward( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Compute scaled dot product attention with rel. positional encoding. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | value: Value tensor. (B, T_2, size) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Output tensor. (B, T_1, H * d_k) |
| | | """ |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) |
| | | return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) |
| | |
| | | outputs = F.pad(outputs, (pad_left, pad_right)) |
| | | outputs = outputs.transpose(1,2) |
| | | return outputs |
| | | |
| | | |
| | | class StreamingRelPositionalEncoding(torch.nn.Module): |
| | | """Relative positional encoding. |
| | | Args: |
| | | size: Module size. |
| | | max_len: Maximum input length. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, size: int, dropout_rate: float = 0.0, max_len: int = 5000 |
| | | ) -> None: |
| | | """Construct a RelativePositionalEncoding object.""" |
| | | super().__init__() |
| | | |
| | | self.size = size |
| | | |
| | | self.pe = None |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | |
| | | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| | | self._register_load_state_dict_pre_hook(_pre_hook) |
| | | |
| | | def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: |
| | | """Reset positional encoding. |
| | | Args: |
| | | x: Input sequences. (B, T, ?) |
| | | left_context: Number of frames in left context. |
| | | """ |
| | | time1 = x.size(1) + left_context |
| | | |
| | | if self.pe is not None: |
| | | if self.pe.size(1) >= time1 * 2 - 1: |
| | | if self.pe.dtype != x.dtype or self.pe.device != x.device: |
| | | self.pe = self.pe.to(device=x.device, dtype=x.dtype) |
| | | return |
| | | |
| | | pe_positive = torch.zeros(time1, self.size) |
| | | pe_negative = torch.zeros(time1, self.size) |
| | | |
| | | position = torch.arange(0, time1, dtype=torch.float32).unsqueeze(1) |
| | | div_term = torch.exp( |
| | | torch.arange(0, self.size, 2, dtype=torch.float32) |
| | | * -(math.log(10000.0) / self.size) |
| | | ) |
| | | |
| | | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | | |
| | | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| | | pe_negative = pe_negative[1:].unsqueeze(0) |
| | | |
| | | self.pe = torch.cat([pe_positive, pe_negative], dim=1).to( |
| | | dtype=x.dtype, device=x.device |
| | | ) |
| | | |
| | | def forward(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: |
| | | """Compute positional encoding. |
| | | Args: |
| | | x: Input sequences. (B, T, ?) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), ?) |
| | | """ |
| | | self.extend_pe(x, left_context=left_context) |
| | | |
| | | time1 = x.size(1) + left_context |
| | | |
| | | pos_enc = self.pe[ |
| | | :, self.pe.size(1) // 2 - time1 + 1 : self.pe.size(1) // 2 + x.size(1) |
| | | ] |
| | | pos_enc = self.dropout(pos_enc) |
| | | |
| | | return pos_enc |
| | |
| | | |
| | | """Repeat the same layer definition.""" |
| | | |
| | | from typing import Dict, List, Optional |
| | | |
| | | import torch |
| | | |
| | | |
| | |
| | | |
| | | """ |
| | | return MultiSequential(*[fn(n) for n in range(N)]) |
| | | |
| | | |
| | | class MultiBlocks(torch.nn.Module): |
| | | """MultiBlocks definition. |
| | | Args: |
| | | block_list: Individual blocks of the encoder architecture. |
| | | output_size: Architecture output size. |
| | | norm_class: Normalization module class. |
| | | norm_args: Normalization module arguments. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | block_list: List[torch.nn.Module], |
| | | output_size: int, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_args: Optional[Dict] = None, |
| | | ) -> None: |
| | | """Construct a MultiBlocks object.""" |
| | | super().__init__() |
| | | |
| | | self.blocks = torch.nn.ModuleList(block_list) |
| | | self.norm_blocks = norm_class(output_size, **norm_args) |
| | | |
| | | self.num_blocks = len(block_list) |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset encoder streaming cache. |
| | | Args: |
| | | left_context: Number of left frames during chunk-by-chunk inference. |
| | | device: Device to use for cache tensor. |
| | | """ |
| | | for idx in range(self.num_blocks): |
| | | self.blocks[idx].reset_streaming_cache(left_context, device) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> torch.Tensor: |
| | | """Forward each block of the encoder architecture. |
| | | Args: |
| | | x: MultiBlocks input sequences. (B, T, D_block_1) |
| | | pos_enc: Positional embedding sequences. |
| | | mask: Source mask. (B, T) |
| | | chunk_mask: Chunk mask. (T_2, T_2) |
| | | Returns: |
| | | x: Output sequences. (B, T, D_block_N) |
| | | """ |
| | | for block_index, block in enumerate(self.blocks): |
| | | x, mask, pos_enc = block(x, pos_enc, mask, chunk_mask=chunk_mask) |
| | | |
| | | x = self.norm_blocks(x) |
| | | |
| | | return x |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_size: int = 0, |
| | | left_context: int = 0, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Forward each block of the encoder architecture. |
| | | Args: |
| | | x: MultiBlocks input sequences. (B, T, D_block_1) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_att) |
| | | mask: Source mask. (B, T_2) |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: MultiBlocks output sequences. (B, T, D_block_N) |
| | | """ |
| | | for block_idx, block in enumerate(self.blocks): |
| | | x, pos_enc = block.chunk_forward( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | |
| | | x = self.norm_blocks(x) |
| | | |
| | | return x |
| | |
| | | from funasr.modules.embedding import PositionalEncoding |
| | | import logging |
| | | from funasr.modules.streaming_utils.utils import sequence_mask |
| | | from funasr.modules.nets_utils import sub_factor_to_params, pad_to_len |
| | | from typing import Optional, Tuple, Union |
| | | import math |
| | | |
| | | class TooShortUttError(Exception): |
| | | """Raised when the utt is too short for subsampling. |
| | | |
| | |
| | | var_dict_tf[name_tf].shape)) |
| | | return var_dict_torch_update |
| | | |
| | | class StreamingConvInput(torch.nn.Module): |
| | | """Streaming ConvInput module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | conv_size: Convolution size. |
| | | subsampling_factor: Subsampling factor. |
| | | vgg_like: Whether to use a VGG-like network. |
| | | output_size: Block output dimension. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | conv_size: Union[int, Tuple], |
| | | subsampling_factor: int = 4, |
| | | vgg_like: bool = True, |
| | | output_size: Optional[int] = None, |
| | | ) -> None: |
| | | """Construct a ConvInput object.""" |
| | | super().__init__() |
| | | if vgg_like: |
| | | if subsampling_factor == 1: |
| | | conv_size1, conv_size2 = conv_size |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((1, 2)), |
| | | ) |
| | | |
| | | output_proj = conv_size2 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = 1 |
| | | |
| | | self.stride_1 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | else: |
| | | conv_size1, conv_size2 = conv_size |
| | | |
| | | kernel_1 = int(subsampling_factor / 2) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size1, conv_size1, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((kernel_1, 2)), |
| | | torch.nn.Conv2d(conv_size1, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size2, conv_size2, 3, stride=1, padding=1), |
| | | torch.nn.ReLU(), |
| | | torch.nn.MaxPool2d((2, 2)), |
| | | ) |
| | | |
| | | output_proj = conv_size2 * ((input_size // 2) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | |
| | | self.create_new_mask = self.create_new_vgg_mask |
| | | |
| | | self.stride_1 = kernel_1 |
| | | |
| | | else: |
| | | if subsampling_factor == 1: |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size, conv_size, 3, [1,2], [1,0]), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size * (((input_size - 1) // 2 - 1) // 2) |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | self.kernel_2 = 3 |
| | | self.stride_2 = 1 |
| | | |
| | | self.create_new_mask = self.create_new_conv2d_mask |
| | | |
| | | else: |
| | | kernel_2, stride_2, conv_2_output_size = sub_factor_to_params( |
| | | subsampling_factor, |
| | | input_size, |
| | | ) |
| | | |
| | | self.conv = torch.nn.Sequential( |
| | | torch.nn.Conv2d(1, conv_size, 3, 2), |
| | | torch.nn.ReLU(), |
| | | torch.nn.Conv2d(conv_size, conv_size, kernel_2, stride_2), |
| | | torch.nn.ReLU(), |
| | | ) |
| | | |
| | | output_proj = conv_size * conv_2_output_size |
| | | |
| | | self.subsampling_factor = subsampling_factor |
| | | self.kernel_2 = kernel_2 |
| | | self.stride_2 = stride_2 |
| | | |
| | | self.create_new_mask = self.create_new_conv2d_mask |
| | | |
| | | self.vgg_like = vgg_like |
| | | self.min_frame_length = 7 |
| | | |
| | | if output_size is not None: |
| | | self.output = torch.nn.Linear(output_proj, output_size) |
| | | self.output_size = output_size |
| | | else: |
| | | self.output = None |
| | | self.output_size = output_proj |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, mask: Optional[torch.Tensor], chunk_size: Optional[torch.Tensor] |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: ConvInput input sequences. (B, T, D_feats) |
| | | mask: Mask of input sequences. (B, 1, T) |
| | | Returns: |
| | | x: ConvInput output sequences. (B, sub(T), D_out) |
| | | mask: Mask of output sequences. (B, 1, sub(T)) |
| | | """ |
| | | if mask is not None: |
| | | mask = self.create_new_mask(mask) |
| | | olens = max(mask.eq(0).sum(1)) |
| | | |
| | | b, t, f = x.size() |
| | | x = x.unsqueeze(1) # (b. 1. t. f) |
| | | |
| | | if chunk_size is not None: |
| | | max_input_length = int( |
| | | chunk_size * self.subsampling_factor * (math.ceil(float(t) / (chunk_size * self.subsampling_factor) )) |
| | | ) |
| | | x = map(lambda inputs: pad_to_len(inputs, max_input_length, 1), x) |
| | | x = list(x) |
| | | x = torch.stack(x, dim=0) |
| | | N_chunks = max_input_length // ( chunk_size * self.subsampling_factor) |
| | | x = x.view(b * N_chunks, 1, chunk_size * self.subsampling_factor, f) |
| | | |
| | | x = self.conv(x) |
| | | |
| | | _, c, _, f = x.size() |
| | | if chunk_size is not None: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f)[:,:olens,:] |
| | | else: |
| | | x = x.transpose(1, 2).contiguous().view(b, -1, c * f) |
| | | |
| | | if self.output is not None: |
| | | x = self.output(x) |
| | | |
| | | return x, mask[:,:olens][:,:x.size(1)] |
| | | |
| | | def create_new_vgg_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create a new mask for VGG output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | vgg1_t_len = mask.size(1) - (mask.size(1) % (self.subsampling_factor // 2 )) |
| | | mask = mask[:, :vgg1_t_len][:, ::self.subsampling_factor // 2] |
| | | |
| | | vgg2_t_len = mask.size(1) - (mask.size(1) % 2) |
| | | mask = mask[:, :vgg2_t_len][:, ::2] |
| | | else: |
| | | mask = mask |
| | | |
| | | return mask |
| | | |
| | | def create_new_conv2d_mask(self, mask: torch.Tensor) -> torch.Tensor: |
| | | """Create new conformer mask for Conv2d output sequences. |
| | | Args: |
| | | mask: Mask of input sequences. (B, T) |
| | | Returns: |
| | | mask: Mask of output sequences. (B, sub(T)) |
| | | """ |
| | | if self.subsampling_factor > 1: |
| | | return mask[:, :-2:2][:, : -(self.kernel_2 - 1) : self.stride_2] |
| | | else: |
| | | return mask |
| | | |
| | | def get_size_before_subsampling(self, size: int) -> int: |
| | | """Return the original size before subsampling for a given size. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of frames before subsampling. |
| | | """ |
| | | return size * self.subsampling_factor |
| | |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.rnnt_decoder.rnn_decoder import RNNDecoder |
| | | from funasr.models.rnnt_decoder.stateless_decoder import StatelessDecoder |
| | | from funasr.models.encoder.chunk_encoder import ChunkEncoder as Encoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder |
| | | from funasr.models.e2e_transducer import TransducerModel |
| | | from funasr.models.e2e_transducer_unified import UnifiedTransducerModel |
| | | from funasr.models.joint_network import JointNetwork |
| | |
| | | encoder_choices = ClassChoices( |
| | | "encoder", |
| | | classes=dict( |
| | | encoder=Encoder, |
| | | chunk_conformer=ConformerChunkEncoder, |
| | | ), |
| | | default="encoder", |
| | | default="chunk_conformer", |
| | | ) |
| | | |
| | | decoder_choices = ClassChoices( |