From 256035b6c1fa6115b6f33972ed243eb43f3e4299 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期五, 14 四月 2023 11:38:00 +0800
Subject: [PATCH] rnnt reorg
---
funasr/models/encoder/conformer_encoder.py | 640 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1 files changed, 639 insertions(+), 1 deletions(-)
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..c837cf5 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -8,6 +8,7 @@
from typing import Optional
from typing import Tuple
from typing import Union
+from typing import Dict
import torch
from torch import nn
@@ -18,6 +19,7 @@
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
+ RelPositionMultiHeadedAttentionChunk,
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
from funasr.modules.embedding import (
@@ -25,16 +27,24 @@
ScaledPositionalEncoding, # noqa: H301
RelPositionalEncoding, # noqa: H301
LegacyRelPositionalEncoding, # noqa: H301
+ StreamingRelPositionalEncoding,
)
from funasr.modules.layer_norm import LayerNorm
+from funasr.modules.normalization import get_normalization
from funasr.modules.multi_layer_conv import Conv1dLinear
from funasr.modules.multi_layer_conv import MultiLayeredConv1d
from funasr.modules.nets_utils import get_activation
from funasr.modules.nets_utils import make_pad_mask
+from funasr.modules.nets_utils import (
+ TooShortUttError,
+ check_short_utt,
+ make_chunk_mask,
+ make_source_mask,
+)
from funasr.modules.positionwise_feed_forward import (
PositionwiseFeedForward, # noqa: H301
)
-from funasr.modules.repeat import repeat
+from funasr.modules.repeat import repeat, MultiBlocks
from funasr.modules.subsampling import Conv2dSubsampling
from funasr.modules.subsampling import Conv2dSubsampling2
from funasr.modules.subsampling import Conv2dSubsampling6
@@ -42,6 +52,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
+from funasr.modules.subsampling import StreamingConvInput
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
@@ -275,6 +287,188 @@
return (x, pos_emb), mask
return x, mask
+
+class ChunkEncoderLayer(torch.nn.Module):
+ """Chunk Conformer module definition.
+ Args:
+ block_size: Input/output size.
+ self_att: Self-attention module instance.
+ feed_forward: Feed-forward module instance.
+ feed_forward_macaron: Feed-forward module instance for macaron network.
+ conv_mod: Convolution module instance.
+ norm_class: Normalization module class.
+ norm_args: Normalization module arguments.
+ dropout_rate: Dropout rate.
+ """
+
+ def __init__(
+ self,
+ block_size: int,
+ self_att: torch.nn.Module,
+ feed_forward: torch.nn.Module,
+ feed_forward_macaron: torch.nn.Module,
+ conv_mod: torch.nn.Module,
+ norm_class: torch.nn.Module = torch.nn.LayerNorm,
+ norm_args: Dict = {},
+ dropout_rate: float = 0.0,
+ ) -> None:
+ """Construct a Conformer object."""
+ super().__init__()
+
+ self.self_att = self_att
+
+ self.feed_forward = feed_forward
+ self.feed_forward_macaron = feed_forward_macaron
+ self.feed_forward_scale = 0.5
+
+ self.conv_mod = conv_mod
+
+ self.norm_feed_forward = norm_class(block_size, **norm_args)
+ self.norm_self_att = norm_class(block_size, **norm_args)
+
+ self.norm_macaron = norm_class(block_size, **norm_args)
+ self.norm_conv = norm_class(block_size, **norm_args)
+ self.norm_final = norm_class(block_size, **norm_args)
+
+ self.dropout = torch.nn.Dropout(dropout_rate)
+
+ self.block_size = block_size
+ self.cache = None
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset self-attention and convolution modules cache for streaming.
+ Args:
+ left_context: Number of left frames during chunk-by-chunk inference.
+ device: Device to use for cache tensor.
+ """
+ self.cache = [
+ torch.zeros(
+ (1, left_context, self.block_size),
+ device=device,
+ ),
+ torch.zeros(
+ (
+ 1,
+ self.block_size,
+ self.conv_mod.kernel_size - 1,
+ ),
+ device=device,
+ ),
+ ]
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_mask: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T)
+ chunk_mask: Chunk mask. (T_2, T_2)
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ mask: Source mask. (B, T)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.dropout(
+ self.feed_forward_macaron(x)
+ )
+
+ residual = x
+ x = self.norm_self_att(x)
+ x_q = x
+ x = residual + self.dropout(
+ self.self_att(
+ x_q,
+ x,
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ )
+
+ residual = x
+
+ x = self.norm_conv(x)
+ x, _ = self.conv_mod(x)
+ x = residual + self.dropout(x)
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x))
+
+ x = self.norm_final(x)
+ return x, mask, pos_enc
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ pos_enc: torch.Tensor,
+ mask: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 0,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode chunk of input sequence.
+ Args:
+ x: Conformer input sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ mask: Source mask. (B, T_2)
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Conformer output sequences. (B, T, D_block)
+ pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block)
+ """
+ residual = x
+
+ x = self.norm_macaron(x)
+ x = residual + self.feed_forward_scale * self.feed_forward_macaron(x)
+
+ residual = x
+ x = self.norm_self_att(x)
+ if left_context > 0:
+ key = torch.cat([self.cache[0], x], dim=1)
+ else:
+ key = x
+ val = key
+
+ if right_context > 0:
+ att_cache = key[:, -(left_context + right_context) : -right_context, :]
+ else:
+ att_cache = key[:, -left_context:, :]
+ x = residual + self.self_att(
+ x,
+ key,
+ val,
+ pos_enc,
+ mask,
+ left_context=left_context,
+ )
+
+ residual = x
+ x = self.norm_conv(x)
+ x, conv_cache = self.conv_mod(
+ x, cache=self.cache[1], right_context=right_context
+ )
+ x = residual + x
+ residual = x
+
+ x = self.norm_feed_forward(x)
+ x = residual + self.feed_forward_scale * self.feed_forward(x)
+
+ x = self.norm_final(x)
+ self.cache = [att_cache, conv_cache]
+
+ return x, pos_enc
class ConformerEncoder(AbsEncoder):
@@ -604,3 +798,447 @@
if len(intermediate_outs) > 0:
return (xs_pad, intermediate_outs), olens, None
return xs_pad, olens, None
+
+
+class CausalConvolution(torch.nn.Module):
+ """ConformerConvolution module definition.
+ Args:
+ channels: The number of channels.
+ kernel_size: Size of the convolving kernel.
+ activation: Type of activation function.
+ norm_args: Normalization module arguments.
+ causal: Whether to use causal convolution (set to True if streaming).
+ """
+
+ def __init__(
+ self,
+ channels: int,
+ kernel_size: int,
+ activation: torch.nn.Module = torch.nn.ReLU(),
+ norm_args: Dict = {},
+ causal: bool = False,
+ ) -> None:
+ """Construct an ConformerConvolution object."""
+ super().__init__()
+
+ assert (kernel_size - 1) % 2 == 0
+
+ self.kernel_size = kernel_size
+
+ self.pointwise_conv1 = torch.nn.Conv1d(
+ channels,
+ 2 * channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ if causal:
+ self.lorder = kernel_size - 1
+ padding = 0
+ else:
+ self.lorder = 0
+ padding = (kernel_size - 1) // 2
+
+ self.depthwise_conv = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ groups=channels,
+ )
+ self.norm = torch.nn.BatchNorm1d(channels, **norm_args)
+ self.pointwise_conv2 = torch.nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ )
+
+ self.activation = activation
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ cache: Optional[torch.Tensor] = None,
+ right_context: int = 0,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Compute convolution module.
+ Args:
+ x: ConformerConvolution input sequences. (B, T, D_hidden)
+ cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden)
+ right_context: Number of frames in right context.
+ Returns:
+ x: ConformerConvolution output sequences. (B, T, D_hidden)
+ cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden)
+ """
+ x = self.pointwise_conv1(x.transpose(1, 2))
+ x = torch.nn.functional.glu(x, dim=1)
+
+ if self.lorder > 0:
+ if cache is None:
+ x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0)
+ else:
+ x = torch.cat([cache, x], dim=2)
+
+ if right_context > 0:
+ cache = x[:, :, -(self.lorder + right_context) : -right_context]
+ else:
+ cache = x[:, :, -self.lorder :]
+
+ x = self.depthwise_conv(x)
+ x = self.activation(self.norm(x))
+
+ x = self.pointwise_conv2(x).transpose(1, 2)
+
+ return x, cache
+
+class ConformerChunkEncoder(torch.nn.Module):
+ """Encoder module definition.
+ Args:
+ input_size: Input size.
+ body_conf: Encoder body configuration.
+ input_conf: Encoder input configuration.
+ main_conf: Encoder main configuration.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ embed_vgg_like: bool = False,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 3,
+ macaron_style: bool = False,
+ rel_pos_type: str = "legacy",
+ pos_enc_layer_type: str = "rel_pos",
+ selfattention_layer_type: str = "rel_selfattn",
+ activation_type: str = "swish",
+ use_cnn_module: bool = True,
+ zero_triu: bool = False,
+ norm_type: str = "layer_norm",
+ cnn_module_kernel: int = 31,
+ conv_mod_norm_eps: float = 0.00001,
+ conv_mod_norm_momentum: float = 0.1,
+ simplified_att_score: bool = False,
+ dynamic_chunk_training: bool = False,
+ short_chunk_threshold: float = 0.75,
+ short_chunk_size: int = 25,
+ left_chunk_size: int = 0,
+ time_reduction_factor: int = 1,
+ unified_model_training: bool = False,
+ default_chunk_size: int = 16,
+ jitter_range: int = 4,
+ subsampling_factor: int = 1,
+ **activation_parameters,
+ ) -> None:
+ """Construct an Encoder object."""
+ super().__init__()
+
+ assert check_argument_types()
+
+ self.embed = StreamingConvInput(
+ input_size,
+ output_size,
+ subsampling_factor,
+ vgg_like=embed_vgg_like,
+ output_size=output_size,
+ )
+
+ self.pos_enc = StreamingRelPositionalEncoding(
+ output_size,
+ positional_dropout_rate,
+ )
+
+ activation = get_activation(
+ activation_type, **activation_parameters
+ )
+
+ pos_wise_args = (
+ output_size,
+ linear_units,
+ positional_dropout_rate,
+ activation,
+ )
+
+ conv_mod_norm_args = {
+ "eps": conv_mod_norm_eps,
+ "momentum": conv_mod_norm_momentum,
+ }
+
+ conv_mod_args = (
+ output_size,
+ cnn_module_kernel,
+ activation,
+ conv_mod_norm_args,
+ dynamic_chunk_training or unified_model_training,
+ )
+
+ mult_att_args = (
+ attention_heads,
+ output_size,
+ attention_dropout_rate,
+ simplified_att_score,
+ )
+
+ norm_class, norm_args = get_normalization(
+ norm_type,
+ )
+
+ fn_modules = []
+ for _ in range(num_blocks):
+ module = lambda: ChunkEncoderLayer(
+ output_size,
+ RelPositionMultiHeadedAttentionChunk(*mult_att_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ PositionwiseFeedForward(*pos_wise_args),
+ CausalConvolution(*conv_mod_args),
+ norm_class=norm_class,
+ norm_args=norm_args,
+ dropout_rate=dropout_rate,
+ )
+ fn_modules.append(module)
+
+ self.encoders = MultiBlocks(
+ [fn() for fn in fn_modules],
+ output_size,
+ norm_class=norm_class,
+ norm_args=norm_args,
+ )
+
+ self.output_size = output_size
+
+ self.dynamic_chunk_training = dynamic_chunk_training
+ self.short_chunk_threshold = short_chunk_threshold
+ self.short_chunk_size = short_chunk_size
+ self.left_chunk_size = left_chunk_size
+
+ self.unified_model_training = unified_model_training
+ self.default_chunk_size = default_chunk_size
+ self.jitter_range = jitter_range
+
+ self.time_reduction_factor = time_reduction_factor
+
+ def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ hop_length: Frontend's hop length
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size) * hop_length
+
+ def get_encoder_input_size(self, size: int) -> int:
+ """Return the corresponding number of sample for a given chunk size, in frames.
+ Where size is the number of features frames after applying subsampling.
+ Args:
+ size: Number of frames after subsampling.
+ Returns:
+ : Number of raw samples
+ """
+ return self.embed.get_size_before_subsampling(size)
+
+
+ def reset_streaming_cache(self, left_context: int, device: torch.device) -> None:
+ """Initialize/Reset encoder streaming cache.
+ Args:
+ left_context: Number of frames in left context.
+ device: Device ID.
+ """
+ return self.encoders.reset_streaming_cache(left_context, device)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Encode input sequences.
+ Args:
+ x: Encoder input features. (B, T_in, F)
+ x_len: Encoder input features lengths. (B,)
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ x_len: Encoder outputs lenghts. (B,)
+ """
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ if self.unified_model_training:
+ chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item()
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ x_utt = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=None,
+ )
+ x_chunk = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x_utt = x_utt[:,::self.time_reduction_factor,:]
+ x_chunk = x_chunk[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x_utt, x_chunk, olens
+
+ elif self.dynamic_chunk_training:
+ max_len = x.size(1)
+ chunk_size = torch.randint(1, max_len, (1,)).item()
+
+ if chunk_size > (max_len * self.short_chunk_threshold):
+ chunk_size = max_len
+ else:
+ chunk_size = (chunk_size % self.short_chunk_size) + 1
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+ else:
+ x, mask = self.embed(x, mask, None)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = None
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
+
+ return x, olens
+
+ def simu_chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ short_status, limit_size = check_short_utt(
+ self.embed.subsampling_factor, x.size(1)
+ )
+
+ if short_status:
+ raise TooShortUttError(
+ f"has {x.size(1)} frames and is too short for subsampling "
+ + f"(it needs more than {limit_size} frames), return empty results",
+ x.size(1),
+ limit_size,
+ )
+
+ mask = make_source_mask(x_len)
+
+ x, mask = self.embed(x, mask, chunk_size)
+ pos_enc = self.pos_enc(x)
+ chunk_mask = make_chunk_mask(
+ x.size(1),
+ chunk_size,
+ left_chunk_size=self.left_chunk_size,
+ device=x.device,
+ )
+
+ x = self.encoders(
+ x,
+ pos_enc,
+ mask,
+ chunk_mask=chunk_mask,
+ )
+ olens = mask.eq(0).sum(1)
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+
+ return x
+
+ def chunk_forward(
+ self,
+ x: torch.Tensor,
+ x_len: torch.Tensor,
+ processed_frames: torch.tensor,
+ chunk_size: int = 16,
+ left_context: int = 32,
+ right_context: int = 0,
+ ) -> torch.Tensor:
+ """Encode input sequences as chunks.
+ Args:
+ x: Encoder input features. (1, T_in, F)
+ x_len: Encoder input features lengths. (1,)
+ processed_frames: Number of frames already seen.
+ left_context: Number of frames in left context.
+ right_context: Number of frames in right context.
+ Returns:
+ x: Encoder outputs. (B, T_out, D_enc)
+ """
+ mask = make_source_mask(x_len)
+ x, mask = self.embed(x, mask, None)
+
+ if left_context > 0:
+ processed_mask = (
+ torch.arange(left_context, device=x.device)
+ .view(1, left_context)
+ .flip(1)
+ )
+ processed_mask = processed_mask >= processed_frames
+ mask = torch.cat([processed_mask, mask], dim=1)
+ pos_enc = self.pos_enc(x, left_context=left_context)
+ x = self.encoders.chunk_forward(
+ x,
+ pos_enc,
+ mask,
+ chunk_size=chunk_size,
+ left_context=left_context,
+ right_context=right_context,
+ )
+
+ if right_context > 0:
+ x = x[:, 0:-right_context, :]
+
+ if self.time_reduction_factor > 1:
+ x = x[:,::self.time_reduction_factor,:]
+ return x
--
Gitblit v1.9.1