support transducer model inference
| | |
| | | MultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttention, # noqa: H301 |
| | | LegacyRelPositionMultiHeadedAttention, # noqa: H301 |
| | | RelPositionMultiHeadedAttentionChunk, |
| | | ) |
| | | from funasr.models.transformer.embedding import ( |
| | | PositionalEncoding, # noqa: H301 |
| | |
| | | return (xs_pad, intermediate_outs), olens, None |
| | | return xs_pad, olens, None |
| | | |
| | | |
| | | class CausalConvolution(torch.nn.Module): |
| | | """ConformerConvolution module definition. |
| | | Args: |
| | | channels: The number of channels. |
| | | kernel_size: Size of the convolving kernel. |
| | | activation: Type of activation function. |
| | | norm_args: Normalization module arguments. |
| | | causal: Whether to use causal convolution (set to True if streaming). |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | channels: int, |
| | | kernel_size: int, |
| | | activation: torch.nn.Module = torch.nn.ReLU(), |
| | | norm_args: Dict = {}, |
| | | causal: bool = False, |
| | | ) -> None: |
| | | """Construct an ConformerConvolution object.""" |
| | | super().__init__() |
| | | |
| | | assert (kernel_size - 1) % 2 == 0 |
| | | |
| | | self.kernel_size = kernel_size |
| | | |
| | | self.pointwise_conv1 = torch.nn.Conv1d( |
| | | channels, |
| | | 2 * channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | if causal: |
| | | self.lorder = kernel_size - 1 |
| | | padding = 0 |
| | | else: |
| | | self.lorder = 0 |
| | | padding = (kernel_size - 1) // 2 |
| | | |
| | | self.depthwise_conv = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size, |
| | | stride=1, |
| | | padding=padding, |
| | | groups=channels, |
| | | ) |
| | | self.norm = torch.nn.BatchNorm1d(channels, **norm_args) |
| | | self.pointwise_conv2 = torch.nn.Conv1d( |
| | | channels, |
| | | channels, |
| | | kernel_size=1, |
| | | stride=1, |
| | | padding=0, |
| | | ) |
| | | |
| | | self.activation = activation |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | cache: Optional[torch.Tensor] = None, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Compute convolution module. |
| | | Args: |
| | | x: ConformerConvolution input sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution input cache. (1, conv_kernel, D_hidden) |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: ConformerConvolution output sequences. (B, T, D_hidden) |
| | | cache: ConformerConvolution output cache. (1, conv_kernel, D_hidden) |
| | | """ |
| | | x = self.pointwise_conv1(x.transpose(1, 2)) |
| | | x = torch.nn.functional.glu(x, dim=1) |
| | | |
| | | if self.lorder > 0: |
| | | if cache is None: |
| | | x = torch.nn.functional.pad(x, (self.lorder, 0), "constant", 0.0) |
| | | else: |
| | | x = torch.cat([cache, x], dim=2) |
| | | |
| | | if right_context > 0: |
| | | cache = x[:, :, -(self.lorder + right_context) : -right_context] |
| | | else: |
| | | cache = x[:, :, -self.lorder :] |
| | | |
| | | x = self.depthwise_conv(x) |
| | | x = self.activation(self.norm(x)) |
| | | |
| | | x = self.pointwise_conv2(x).transpose(1, 2) |
| | | |
| | | return x, cache |
| | | |
| | | class ChunkEncoderLayer(torch.nn.Module): |
| | | """Chunk Conformer module definition. |
| | | Args: |
| | | block_size: Input/output size. |
| | | self_att: Self-attention module instance. |
| | | feed_forward: Feed-forward module instance. |
| | | feed_forward_macaron: Feed-forward module instance for macaron network. |
| | | conv_mod: Convolution module instance. |
| | | norm_class: Normalization module class. |
| | | norm_args: Normalization module arguments. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | block_size: int, |
| | | self_att: torch.nn.Module, |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | | """Construct a Conformer object.""" |
| | | super().__init__() |
| | | |
| | | self.self_att = self_att |
| | | |
| | | self.feed_forward = feed_forward |
| | | self.feed_forward_macaron = feed_forward_macaron |
| | | self.feed_forward_scale = 0.5 |
| | | |
| | | self.conv_mod = conv_mod |
| | | |
| | | self.norm_feed_forward = norm_class(block_size, **norm_args) |
| | | self.norm_self_att = norm_class(block_size, **norm_args) |
| | | |
| | | self.norm_macaron = norm_class(block_size, **norm_args) |
| | | self.norm_conv = norm_class(block_size, **norm_args) |
| | | self.norm_final = norm_class(block_size, **norm_args) |
| | | |
| | | self.dropout = torch.nn.Dropout(dropout_rate) |
| | | |
| | | self.block_size = block_size |
| | | self.cache = None |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset self-attention and convolution modules cache for streaming. |
| | | Args: |
| | | left_context: Number of left frames during chunk-by-chunk inference. |
| | | device: Device to use for cache tensor. |
| | | """ |
| | | self.cache = [ |
| | | torch.zeros( |
| | | (1, left_context, self.block_size), |
| | | device=device, |
| | | ), |
| | | torch.zeros( |
| | | ( |
| | | 1, |
| | | self.block_size, |
| | | self.conv_mod.kernel_size - 1, |
| | | ), |
| | | device=device, |
| | | ), |
| | | ] |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T) |
| | | chunk_mask: Chunk mask. (T_2, T_2) |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | mask: Source mask. (B, T) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.dropout( |
| | | self.feed_forward_macaron(x) |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | x_q = x |
| | | x = residual + self.dropout( |
| | | self.self_att( |
| | | x_q, |
| | | x, |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | ) |
| | | |
| | | residual = x |
| | | |
| | | x = self.norm_conv(x) |
| | | x, _ = self.conv_mod(x) |
| | | x = residual + self.dropout(x) |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.dropout(self.feed_forward(x)) |
| | | |
| | | x = self.norm_final(x) |
| | | return x, mask, pos_enc |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 0, |
| | | right_context: int = 0, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode chunk of input sequence. |
| | | Args: |
| | | x: Conformer input sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | mask: Source mask. (B, T_2) |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Conformer output sequences. (B, T, D_block) |
| | | pos_enc: Positional embedding sequences. (B, 2 * (T - 1), D_block) |
| | | """ |
| | | residual = x |
| | | |
| | | x = self.norm_macaron(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward_macaron(x) |
| | | |
| | | residual = x |
| | | x = self.norm_self_att(x) |
| | | if left_context > 0: |
| | | key = torch.cat([self.cache[0], x], dim=1) |
| | | else: |
| | | key = x |
| | | val = key |
| | | |
| | | if right_context > 0: |
| | | att_cache = key[:, -(left_context + right_context) : -right_context, :] |
| | | else: |
| | | att_cache = key[:, -left_context:, :] |
| | | x = residual + self.self_att( |
| | | x, |
| | | key, |
| | | val, |
| | | pos_enc, |
| | | mask, |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | residual = x |
| | | x = self.norm_conv(x) |
| | | x, conv_cache = self.conv_mod( |
| | | x, cache=self.cache[1], right_context=right_context |
| | | ) |
| | | x = residual + x |
| | | residual = x |
| | | |
| | | x = self.norm_feed_forward(x) |
| | | x = residual + self.feed_forward_scale * self.feed_forward(x) |
| | | |
| | | x = self.norm_final(x) |
| | | self.cache = [att_cache, conv_cache] |
| | | |
| | | return x, pos_enc |
| | | |
| | | @tables.register("encoder_classes", "ChunkConformerEncoder") |
| | | class ConformerChunkEncoder(torch.nn.Module): |
| | | """Encoder module definition. |
| | | Args: |
| | | input_size: Input size. |
| | | body_conf: Encoder body configuration. |
| | | input_conf: Encoder input configuration. |
| | | main_conf: Encoder main configuration. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | embed_vgg_like: bool = False, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | positionwise_layer_type: str = "linear", |
| | | positionwise_conv_kernel_size: int = 3, |
| | | macaron_style: bool = False, |
| | | rel_pos_type: str = "legacy", |
| | | pos_enc_layer_type: str = "rel_pos", |
| | | selfattention_layer_type: str = "rel_selfattn", |
| | | activation_type: str = "swish", |
| | | use_cnn_module: bool = True, |
| | | zero_triu: bool = False, |
| | | norm_type: str = "layer_norm", |
| | | cnn_module_kernel: int = 31, |
| | | conv_mod_norm_eps: float = 0.00001, |
| | | conv_mod_norm_momentum: float = 0.1, |
| | | simplified_att_score: bool = False, |
| | | dynamic_chunk_training: bool = False, |
| | | short_chunk_threshold: float = 0.75, |
| | | short_chunk_size: int = 25, |
| | | left_chunk_size: int = 0, |
| | | time_reduction_factor: int = 1, |
| | | unified_model_training: bool = False, |
| | | default_chunk_size: int = 16, |
| | | jitter_range: int = 4, |
| | | subsampling_factor: int = 1, |
| | | ) -> None: |
| | | """Construct an Encoder object.""" |
| | | super().__init__() |
| | | |
| | | |
| | | self.embed = StreamingConvInput( |
| | | input_size=input_size, |
| | | conv_size=output_size, |
| | | subsampling_factor=subsampling_factor, |
| | | vgg_like=embed_vgg_like, |
| | | output_size=output_size, |
| | | ) |
| | | |
| | | self.pos_enc = StreamingRelPositionalEncoding( |
| | | output_size, |
| | | positional_dropout_rate, |
| | | ) |
| | | |
| | | activation = get_activation( |
| | | activation_type |
| | | ) |
| | | |
| | | pos_wise_args = ( |
| | | output_size, |
| | | linear_units, |
| | | positional_dropout_rate, |
| | | activation, |
| | | ) |
| | | |
| | | conv_mod_norm_args = { |
| | | "eps": conv_mod_norm_eps, |
| | | "momentum": conv_mod_norm_momentum, |
| | | } |
| | | |
| | | conv_mod_args = ( |
| | | output_size, |
| | | cnn_module_kernel, |
| | | activation, |
| | | conv_mod_norm_args, |
| | | dynamic_chunk_training or unified_model_training, |
| | | ) |
| | | |
| | | mult_att_args = ( |
| | | attention_heads, |
| | | output_size, |
| | | attention_dropout_rate, |
| | | simplified_att_score, |
| | | ) |
| | | |
| | | |
| | | fn_modules = [] |
| | | for _ in range(num_blocks): |
| | | module = lambda: ChunkEncoderLayer( |
| | | output_size, |
| | | RelPositionMultiHeadedAttentionChunk(*mult_att_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | PositionwiseFeedForward(*pos_wise_args), |
| | | CausalConvolution(*conv_mod_args), |
| | | dropout_rate=dropout_rate, |
| | | ) |
| | | fn_modules.append(module) |
| | | |
| | | self.encoders = MultiBlocks( |
| | | [fn() for fn in fn_modules], |
| | | output_size, |
| | | ) |
| | | |
| | | self._output_size = output_size |
| | | |
| | | self.dynamic_chunk_training = dynamic_chunk_training |
| | | self.short_chunk_threshold = short_chunk_threshold |
| | | self.short_chunk_size = short_chunk_size |
| | | self.left_chunk_size = left_chunk_size |
| | | |
| | | self.unified_model_training = unified_model_training |
| | | self.default_chunk_size = default_chunk_size |
| | | self.jitter_range = jitter_range |
| | | |
| | | self.time_reduction_factor = time_reduction_factor |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def get_encoder_input_raw_size(self, size: int, hop_length: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | hop_length: Frontend's hop length |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) * hop_length |
| | | |
| | | def get_encoder_input_size(self, size: int) -> int: |
| | | """Return the corresponding number of sample for a given chunk size, in frames. |
| | | Where size is the number of features frames after applying subsampling. |
| | | Args: |
| | | size: Number of frames after subsampling. |
| | | Returns: |
| | | : Number of raw samples |
| | | """ |
| | | return self.embed.get_size_before_subsampling(size) |
| | | |
| | | |
| | | def reset_streaming_cache(self, left_context: int, device: torch.device) -> None: |
| | | """Initialize/Reset encoder streaming cache. |
| | | Args: |
| | | left_context: Number of frames in left context. |
| | | device: Device ID. |
| | | """ |
| | | return self.encoders.reset_streaming_cache(left_context, device) |
| | | |
| | | def forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Encoder input features. (B, T_in, F) |
| | | x_len: Encoder input features lengths. (B,) |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | x_len: Encoder outputs lenghts. (B,) |
| | | """ |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | |
| | | if self.unified_model_training: |
| | | if self.training: |
| | | chunk_size = self.default_chunk_size + torch.randint(-self.jitter_range, self.jitter_range+1, (1,)).item() |
| | | else: |
| | | chunk_size = self.default_chunk_size |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | x_utt = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=None, |
| | | ) |
| | | x_chunk = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | | x_chunk = x_chunk[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x_utt, x_chunk, olens |
| | | |
| | | elif self.dynamic_chunk_training: |
| | | max_len = x.size(1) |
| | | if self.training: |
| | | chunk_size = torch.randint(1, max_len, (1,)).item() |
| | | |
| | | if chunk_size > (max_len * self.short_chunk_threshold): |
| | | chunk_size = max_len |
| | | else: |
| | | chunk_size = (chunk_size % self.short_chunk_size) + 1 |
| | | else: |
| | | chunk_size = self.default_chunk_size |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | else: |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = None |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens, None |
| | | |
| | | def full_utt_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Encode input sequences. |
| | | Args: |
| | | x: Encoder input features. (B, T_in, F) |
| | | x_len: Encoder input features lengths. (B,) |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | x_len: Encoder outputs lenghts. (B,) |
| | | """ |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len).to(x.device) |
| | | x, mask = self.embed(x, mask, None) |
| | | pos_enc = self.pos_enc(x) |
| | | x_utt = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=None, |
| | | ) |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x_utt = x_utt[:,::self.time_reduction_factor,:] |
| | | return x_utt |
| | | |
| | | def simu_chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | short_status, limit_size = check_short_utt( |
| | | self.embed.subsampling_factor, x.size(1) |
| | | ) |
| | | |
| | | if short_status: |
| | | raise TooShortUttError( |
| | | f"has {x.size(1)} frames and is too short for subsampling " |
| | | + f"(it needs more than {limit_size} frames), return empty results", |
| | | x.size(1), |
| | | limit_size, |
| | | ) |
| | | |
| | | mask = make_source_mask(x_len) |
| | | |
| | | x, mask = self.embed(x, mask, chunk_size) |
| | | pos_enc = self.pos_enc(x) |
| | | chunk_mask = make_chunk_mask( |
| | | x.size(1), |
| | | chunk_size, |
| | | left_chunk_size=self.left_chunk_size, |
| | | device=x.device, |
| | | ) |
| | | |
| | | x = self.encoders( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_mask=chunk_mask, |
| | | ) |
| | | olens = mask.eq(0).sum(1) |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | |
| | | return x |
| | | |
| | | def chunk_forward( |
| | | self, |
| | | x: torch.Tensor, |
| | | x_len: torch.Tensor, |
| | | processed_frames: torch.tensor, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Encode input sequences as chunks. |
| | | Args: |
| | | x: Encoder input features. (1, T_in, F) |
| | | x_len: Encoder input features lengths. (1,) |
| | | processed_frames: Number of frames already seen. |
| | | left_context: Number of frames in left context. |
| | | right_context: Number of frames in right context. |
| | | Returns: |
| | | x: Encoder outputs. (B, T_out, D_enc) |
| | | """ |
| | | mask = make_source_mask(x_len) |
| | | x, mask = self.embed(x, mask, None) |
| | | |
| | | if left_context > 0: |
| | | processed_mask = ( |
| | | torch.arange(left_context, device=x.device) |
| | | .view(1, left_context) |
| | | .flip(1) |
| | | ) |
| | | processed_mask = processed_mask >= processed_frames |
| | | mask = torch.cat([processed_mask, mask], dim=1) |
| | | pos_enc = self.pos_enc(x, left_context=left_context) |
| | | x = self.encoders.chunk_forward( |
| | | x, |
| | | pos_enc, |
| | | mask, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | | ) |
| | | |
| | | if right_context > 0: |
| | | x = x[:, 0:-right_context, :] |
| | | |
| | | if self.time_reduction_factor > 1: |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | return x |
| | |
| | | """Search algorithms for Transducer models.""" |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | import numpy as np |
| | | from dataclasses import dataclass |
| | | from typing import Any, Dict, List, Optional, Tuple, Union |
| | | |
| | | import numpy as np |
| | | import torch |
| | | |
| | | from funasr.models.transducer.joint_network import JointNetwork |
| | | |
| | |
| | | """Transducer joint network implementation.""" |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.transformer.utils.nets_utils import get_activation |
| | | |
| | | |
| | | @tables.register("joint_network_classes", "joint_network") |
| | | class JointNetwork(torch.nn.Module): |
| | | """Transducer joint network module. |
| | | |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import time |
| | | import torch |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from typing import Dict, Optional, Tuple |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import tempfile |
| | | import codecs |
| | | import requests |
| | | import re |
| | | import copy |
| | | import torch |
| | | import torch.nn as nn |
| | | import random |
| | | import numpy as np |
| | | import time |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | # from funasr.models.ctc import CTC |
| | | # from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | # from funasr.models.e2e_asr_common import ErrorCalculator |
| | | # from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | # from funasr.frontends.abs_frontend import AbsFrontend |
| | | # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | # from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | # from funasr.models.base_model import FunASRModel |
| | | # from funasr.models.paraformer.cif_predictor import CifPredictorV3 |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | |
| | | from funasr.models.model_class_factory import * |
| | | from funasr.register import tables |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | from funasr.models.transformer.utils.nets_utils import get_transducer_task_io |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.models.transducer.beam_search_transducer import BeamSearchTransducer |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.transformer.utils.nets_utils import get_transducer_task_io |
| | | |
| | | |
| | | class Transducer(nn.Module): |
| | | """ESPnet2ASRTransducerModel module definition.""" |
| | | |
| | | |
| | | @tables.register("model_classes", "Transducer") |
| | | class Transducer(torch.nn.Module): |
| | | def __init__( |
| | | self, |
| | | frontend: Optional[str] = None, |
| | |
| | | |
| | | super().__init__() |
| | | |
| | | if frontend is not None: |
| | | frontend_class = frontend_classes.get_class(frontend) |
| | | frontend = frontend_class(**frontend_conf) |
| | | if specaug is not None: |
| | | specaug_class = specaug_classes.get_class(specaug) |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = normalize_classes.get_class(normalize) |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = encoder_classes.get_class(encoder) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | decoder_class = decoder_classes.get_class(decoder) |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **decoder_conf, |
| | | ) |
| | | decoder_output_size = decoder.output_size |
| | | |
| | | joint_network_class = joint_network_classes.get_class(decoder) |
| | | joint_network_class = tables.joint_network_classes.get(joint_network) |
| | | joint_network = joint_network_class( |
| | | vocab_size, |
| | | encoder_output_size, |
| | | decoder_output_size, |
| | | **joint_network_conf, |
| | | ) |
| | | |
| | | |
| | | self.criterion_transducer = None |
| | | self.error_calculator = None |
| | |
| | | self.decoder = decoder |
| | | self.joint_network = joint_network |
| | | |
| | | |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | # |
| | | # if report_cer or report_wer: |
| | | # self.error_calculator = ErrorCalculator( |
| | | # token_list, sym_space, sym_blank, report_cer, report_wer |
| | | # ) |
| | | # |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | self.ctc = None |
| | | self.ctc_weight = 0.0 |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | |
| | | # Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | speech, speech_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.transformer.search import BeamSearch |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight"), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.0), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 2), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=self.sos, |
| | | eos=self.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", |
| | | beam_search = BeamSearchTransducer( |
| | | self.decoder, |
| | | self.joint_network, |
| | | kwargs.get("beam_size", 2), |
| | | nbest=1, |
| | | ) |
| | | # beam_search.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | # for scorer in scorers.values(): |
| | |
| | | # scorer.to(device=kwargs.get("device", "cpu"), dtype=getattr(torch, kwargs.get("dtype", "float32"))).eval() |
| | | self.beam_search = beam_search |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in: list, |
| | | data_lengths: list=None, |
| | | key: list=None, |
| | |
| | | # init beamsearch |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | # if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | |
| | | encoder_out = encoder_out[0] |
| | | |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) |
| | | ) |
| | | |
| | | nbest_hyps = self.beam_search(encoder_out[0], is_final=True) |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | |
| | | results = [] |
| | | b, n, d = encoder_out.size() |
| | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | token_int = hyp.yseq#[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | token_int = hyp.yseq#[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | |
| | | import random |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import numpy as np |
| | | import torch |
| | | import random |
| | | import numpy as np |
| | | import torch.nn as nn |
| | | import torch.nn.functional as F |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.utils.nets_utils import to_device |
| | | from funasr.models.language_model.rnn.attentions import initial_att |
| | |
| | | ) |
| | | return att_list |
| | | |
| | | |
| | | @tables.register("decoder_classes", "rnn_decoder") |
| | | class RNNDecoder(nn.Module): |
| | | def __init__( |
| | | self, |
| | |
| | | """RNN decoder definition for Transducer models.""" |
| | | |
| | | from typing import List, Optional, Tuple |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | from typing import List, Optional, Tuple |
| | | |
| | | from funasr.models.transducer.beam_search_transducer import Hypothesis |
| | | from funasr.register import tables |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | from funasr.models.transducer.beam_search_transducer import Hypothesis |
| | | |
| | | |
| | | @tables.register("decoder_classes", "rnnt_decoder") |
| | | class RNNTDecoder(torch.nn.Module): |
| | | """RNN decoder module. |
| | | |
| | |
| | | return self.forward_attention(v, scores, mask) |
| | | |
| | | |
| | | class RelPositionMultiHeadedAttentionChunk(torch.nn.Module): |
| | | """RelPositionMultiHeadedAttention definition. |
| | | Args: |
| | | num_heads: Number of attention heads. |
| | | embed_size: Embedding size. |
| | | dropout_rate: Dropout rate. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | num_heads: int, |
| | | embed_size: int, |
| | | dropout_rate: float = 0.0, |
| | | simplified_attention_score: bool = False, |
| | | ) -> None: |
| | | """Construct an MultiHeadedAttention object.""" |
| | | super().__init__() |
| | | |
| | | self.d_k = embed_size // num_heads |
| | | self.num_heads = num_heads |
| | | |
| | | assert self.d_k * num_heads == embed_size, ( |
| | | "embed_size (%d) must be divisible by num_heads (%d)", |
| | | (embed_size, num_heads), |
| | | ) |
| | | |
| | | self.linear_q = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_k = torch.nn.Linear(embed_size, embed_size) |
| | | self.linear_v = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | self.linear_out = torch.nn.Linear(embed_size, embed_size) |
| | | |
| | | if simplified_attention_score: |
| | | self.linear_pos = torch.nn.Linear(embed_size, num_heads) |
| | | |
| | | self.compute_att_score = self.compute_simplified_attention_score |
| | | else: |
| | | self.linear_pos = torch.nn.Linear(embed_size, embed_size, bias=False) |
| | | |
| | | self.pos_bias_u = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | self.pos_bias_v = torch.nn.Parameter(torch.Tensor(num_heads, self.d_k)) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_u) |
| | | torch.nn.init.xavier_uniform_(self.pos_bias_v) |
| | | |
| | | self.compute_att_score = self.compute_attention_score |
| | | |
| | | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | | self.attn = None |
| | | |
| | | def rel_shift(self, x: torch.Tensor, left_context: int = 0) -> torch.Tensor: |
| | | """Compute relative positional encoding. |
| | | Args: |
| | | x: Input sequence. (B, H, T_1, 2 * T_1 - 1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | x: Output sequence. (B, H, T_1, T_2) |
| | | """ |
| | | batch_size, n_heads, time1, n = x.shape |
| | | time2 = time1 + left_context |
| | | |
| | | batch_stride, n_heads_stride, time1_stride, n_stride = x.stride() |
| | | |
| | | return x.as_strided( |
| | | (batch_size, n_heads, time1, time2), |
| | | (batch_stride, n_heads_stride, time1_stride - n_stride, n_stride), |
| | | storage_offset=(n_stride * (time1 - 1)), |
| | | ) |
| | | |
| | | def compute_simplified_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Simplified attention score computation. |
| | | Reference: https://github.com/k2-fsa/icefall/pull/458 |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | pos_enc = self.linear_pos(pos_enc) |
| | | |
| | | matrix_ac = torch.matmul(query, key.transpose(2, 3)) |
| | | |
| | | matrix_bd = self.rel_shift( |
| | | pos_enc.transpose(1, 2).unsqueeze(2).repeat(1, 1, query.size(2), 1), |
| | | left_context=left_context, |
| | | ) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def compute_attention_score( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Attention score computation. |
| | | Args: |
| | | query: Transformed query tensor. (B, H, T_1, d_k) |
| | | key: Transformed key tensor. (B, H, T_2, d_k) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Attention score. (B, H, T_1, T_2) |
| | | """ |
| | | p = self.linear_pos(pos_enc).view(pos_enc.size(0), -1, self.num_heads, self.d_k) |
| | | |
| | | query = query.transpose(1, 2) |
| | | q_with_bias_u = (query + self.pos_bias_u).transpose(1, 2) |
| | | q_with_bias_v = (query + self.pos_bias_v).transpose(1, 2) |
| | | |
| | | matrix_ac = torch.matmul(q_with_bias_u, key.transpose(-2, -1)) |
| | | |
| | | matrix_bd = torch.matmul(q_with_bias_v, p.permute(0, 2, 3, 1)) |
| | | matrix_bd = self.rel_shift(matrix_bd, left_context=left_context) |
| | | |
| | | return (matrix_ac + matrix_bd) / math.sqrt(self.d_k) |
| | | |
| | | def forward_qkv( |
| | | self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| | | """Transform query, key and value. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | v: Value tensor. (B, T_2, size) |
| | | Returns: |
| | | q: Transformed query tensor. (B, H, T_1, d_k) |
| | | k: Transformed key tensor. (B, H, T_2, d_k) |
| | | v: Transformed value tensor. (B, H, T_2, d_k) |
| | | """ |
| | | n_batch = query.size(0) |
| | | |
| | | q = ( |
| | | self.linear_q(query) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | k = ( |
| | | self.linear_k(key) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | v = ( |
| | | self.linear_v(value) |
| | | .view(n_batch, -1, self.num_heads, self.d_k) |
| | | .transpose(1, 2) |
| | | ) |
| | | |
| | | return q, k, v |
| | | |
| | | def forward_attention( |
| | | self, |
| | | value: torch.Tensor, |
| | | scores: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | ) -> torch.Tensor: |
| | | """Compute attention context vector. |
| | | Args: |
| | | value: Transformed value. (B, H, T_2, d_k) |
| | | scores: Attention score. (B, H, T_1, T_2) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | Returns: |
| | | attn_output: Transformed value weighted by attention score. (B, T_1, H * d_k) |
| | | """ |
| | | batch_size = scores.size(0) |
| | | mask = mask.unsqueeze(1).unsqueeze(2) |
| | | if chunk_mask is not None: |
| | | mask = chunk_mask.unsqueeze(0).unsqueeze(1) | mask |
| | | scores = scores.masked_fill(mask, float("-inf")) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill(mask, 0.0) |
| | | |
| | | attn_output = self.dropout(self.attn) |
| | | attn_output = torch.matmul(attn_output, value) |
| | | |
| | | attn_output = self.linear_out( |
| | | attn_output.transpose(1, 2) |
| | | .contiguous() |
| | | .view(batch_size, -1, self.num_heads * self.d_k) |
| | | ) |
| | | |
| | | return attn_output |
| | | |
| | | def forward( |
| | | self, |
| | | query: torch.Tensor, |
| | | key: torch.Tensor, |
| | | value: torch.Tensor, |
| | | pos_enc: torch.Tensor, |
| | | mask: torch.Tensor, |
| | | chunk_mask: Optional[torch.Tensor] = None, |
| | | left_context: int = 0, |
| | | ) -> torch.Tensor: |
| | | """Compute scaled dot product attention with rel. positional encoding. |
| | | Args: |
| | | query: Query tensor. (B, T_1, size) |
| | | key: Key tensor. (B, T_2, size) |
| | | value: Value tensor. (B, T_2, size) |
| | | pos_enc: Positional embedding tensor. (B, 2 * T_1 - 1, size) |
| | | mask: Source mask. (B, T_2) |
| | | chunk_mask: Chunk mask. (T_1, T_1) |
| | | left_context: Number of frames in left context. |
| | | Returns: |
| | | : Output tensor. (B, T_1, H * d_k) |
| | | """ |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = self.compute_att_score(q, k, pos_enc, left_context=left_context) |
| | | return self.forward_attention(v, scores, mask, chunk_mask=chunk_mask) |
| | | |