Merge pull request #479 from alibaba-damo-academy/dev_aky
rnnt bug fix
| | |
| | | self.frontend = frontend |
| | | self.window_size = self.chunk_size + self.right_context |
| | | |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | if self.streaming: |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | |
| | | #self.last_chunk_length = ( |
| | | # self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | #) * self.hop_length |
| | | |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | | ) |
| | | self.reset_inference_cache() |
| | | |
| | | def reset_inference_cache(self) -> None: |
| | | """Reset Speech2Text parameters.""" |
| | |
| | | dropout_rate: float = 0.0, |
| | | embed_dropout_rate: float = 0.0, |
| | | embed_pad: int = 0, |
| | | use_embed_mask: bool = False, |
| | | ) -> None: |
| | | """Construct a RNNDecoder object.""" |
| | | super().__init__() |
| | |
| | | |
| | | self.device = next(self.parameters()).device |
| | | self.score_cache = {} |
| | | |
| | | self.use_embed_mask = use_embed_mask |
| | | if self.use_embed_mask: |
| | | self._embed_mask = SpecAug( |
| | | time_mask_width_range=3, |
| | | num_time_mask=4, |
| | | apply_freq_mask=False, |
| | | apply_time_warp=False |
| | | ) |
| | | |
| | | def forward( |
| | | self, |
| | |
| | | states = self.init_state(labels.size(0)) |
| | | |
| | | dec_embed = self.dropout_embed(self.embed(labels)) |
| | | if self.use_embed_mask and self.training: |
| | | dec_embed = self._embed_mask(dec_embed, label_lens)[0] |
| | | dec_out, states = self.rnn_forward(dec_embed, states) |
| | | return dec_out |
| | | |
| | |
| | | from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.decoder.rnnt_decoder import RNNTDecoder |
| | | from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder |
| | | from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.joint_net.joint_network import JointNetwork |
| | | from funasr.modules.nets_utils import get_transducer_task_io |
| | | from funasr.layers.abs_normalize import AbsNormalize |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | encoder: AbsEncoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | |
| | | feats, feats_lengths = self.normalize(feats, feats_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | |
| | | assert encoder_out.size(0) == speech.size(0), ( |
| | | encoder_out.size(), |
| | |
| | | frontend: Optional[AbsFrontend], |
| | | specaug: Optional[AbsSpecAug], |
| | | normalize: Optional[AbsNormalize], |
| | | encoder: Encoder, |
| | | encoder: AbsEncoder, |
| | | decoder: RNNTDecoder, |
| | | joint_network: JointNetwork, |
| | | att_decoder: Optional[AbsAttDecoder] = None, |
| | |
| | | feed_forward: torch.nn.Module, |
| | | feed_forward_macaron: torch.nn.Module, |
| | | conv_mod: torch.nn.Module, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | norm_args: Dict = {}, |
| | | dropout_rate: float = 0.0, |
| | | ) -> None: |
| | |
| | | x = x[:,::self.time_reduction_factor,:] |
| | | olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1 |
| | | |
| | | return x, olens |
| | | return x, olens, None |
| | | |
| | | def simu_chunk_forward( |
| | | self, |
| | |
| | | new_k = k.replace(old_prefix, new_prefix) |
| | | state_dict[new_k] = v |
| | | |
| | | |
| | | class Swish(torch.nn.Module): |
| | | """Construct an Swish object.""" |
| | | """Swish activation definition. |
| | | |
| | | def forward(self, x): |
| | | """Return Swich activation function.""" |
| | | return x * torch.sigmoid(x) |
| | | Swish(x) = (beta * x) * sigmoid(x) |
| | | where beta = 1 defines standard Swish activation. |
| | | |
| | | References: |
| | | https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1. |
| | | E-swish variant: https://arxiv.org/abs/1801.07145. |
| | | |
| | | Args: |
| | | beta: Beta parameter for E-Swish. |
| | | (beta >= 1. If beta < 1, use standard Swish). |
| | | use_builtin: Whether to use PyTorch function if available. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None: |
| | | super().__init__() |
| | | |
| | | self.beta = beta |
| | | |
| | | if beta > 1: |
| | | self.swish = lambda x: (self.beta * x) * torch.sigmoid(x) |
| | | else: |
| | | if use_builtin: |
| | | self.swish = torch.nn.SiLU() |
| | | else: |
| | | self.swish = lambda x: x * torch.sigmoid(x) |
| | | |
| | | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | | """Forward computation.""" |
| | | return self.swish(x) |
| | | |
| | | def get_activation(act): |
| | | """Return activation function.""" |
| | |
| | | """Repeat the same layer definition.""" |
| | | |
| | | from typing import Dict, List, Optional |
| | | |
| | | from funasr.modules.layer_norm import LayerNorm |
| | | import torch |
| | | |
| | | |
| | |
| | | self, |
| | | block_list: List[torch.nn.Module], |
| | | output_size: int, |
| | | norm_class: torch.nn.Module = torch.nn.LayerNorm, |
| | | norm_class: torch.nn.Module = LayerNorm, |
| | | ) -> None: |
| | | """Construct a MultiBlocks object.""" |
| | | super().__init__() |
| | |
| | | |
| | | # 7. Build model |
| | | |
| | | if encoder.unified_model_training: |
| | | if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training: |
| | | model = UnifiedTransducerModel( |
| | | vocab_size=vocab_size, |
| | | token_list=token_list, |