zhifu gao
2023-05-09 15d5ba7882a1c83b75b3154b69b0a79208b132a1
Merge pull request #479 from alibaba-damo-academy/dev_aky

rnnt bug fix
7个文件已修改
84 ■■■■ 已修改文件
funasr/bin/asr_inference_rnnt.py 19 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/decoder/rnnt_decoder.py 12 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr_transducer.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/conformer_encoder.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/nets_utils.py 35 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/modules/repeat.py 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/asr_inference_rnnt.py
@@ -188,18 +188,15 @@
        self.frontend = frontend
        self.window_size = self.chunk_size + self.right_context
        
        self._ctx = self.asr_model.encoder.get_encoder_input_size(
            self.window_size
        )
        if self.streaming:
            self._ctx = self.asr_model.encoder.get_encoder_input_size(
                self.window_size
            )
       
        #self.last_chunk_length = (
        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
        #) * self.hop_length
        self.last_chunk_length = (
            self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
        )
        self.reset_inference_cache()
            self.last_chunk_length = (
                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
            )
            self.reset_inference_cache()
    def reset_inference_cache(self) -> None:
        """Reset Speech2Text parameters."""
funasr/models/decoder/rnnt_decoder.py
@@ -33,6 +33,7 @@
        dropout_rate: float = 0.0,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
        use_embed_mask: bool = False,
    ) -> None:
        """Construct a RNNDecoder object."""
        super().__init__()
@@ -66,6 +67,15 @@
        self.device = next(self.parameters()).device
        self.score_cache = {}
        self.use_embed_mask = use_embed_mask
        if self.use_embed_mask:
            self._embed_mask = SpecAug(
                time_mask_width_range=3,
                num_time_mask=4,
                apply_freq_mask=False,
                apply_time_warp=False
            )
    
    def forward(
        self,
@@ -88,6 +98,8 @@
            states = self.init_state(labels.size(0))
        dec_embed = self.dropout_embed(self.embed(labels))
        if self.use_embed_mask and self.training:
            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
        dec_out, states = self.rnn_forward(dec_embed, states)
        return dec_out
funasr/models/e2e_asr_transducer.py
@@ -12,7 +12,7 @@
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.decoder.rnnt_decoder import RNNTDecoder
from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.joint_net.joint_network import JointNetwork
from funasr.modules.nets_utils import get_transducer_task_io
from funasr.layers.abs_normalize import AbsNormalize
@@ -62,7 +62,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: Encoder,
        encoder: AbsEncoder,
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
@@ -286,7 +286,7 @@
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # 4. Forward encoder
        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -515,7 +515,7 @@
        frontend: Optional[AbsFrontend],
        specaug: Optional[AbsSpecAug],
        normalize: Optional[AbsNormalize],
        encoder: Encoder,
        encoder: AbsEncoder,
        decoder: RNNTDecoder,
        joint_network: JointNetwork,
        att_decoder: Optional[AbsAttDecoder] = None,
funasr/models/encoder/conformer_encoder.py
@@ -307,7 +307,7 @@
        feed_forward: torch.nn.Module,
        feed_forward_macaron: torch.nn.Module,
        conv_mod: torch.nn.Module,
        norm_class: torch.nn.Module = torch.nn.LayerNorm,
        norm_class: torch.nn.Module = LayerNorm,
        norm_args: Dict = {},
        dropout_rate: float = 0.0,
    ) -> None:
@@ -1145,7 +1145,7 @@
            x = x[:,::self.time_reduction_factor,:]
            olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
        return x, olens
        return x, olens, None
    def simu_chunk_forward(
        self,
funasr/modules/nets_utils.py
@@ -485,14 +485,39 @@
        new_k = k.replace(old_prefix, new_prefix)
        state_dict[new_k] = v
class Swish(torch.nn.Module):
    """Construct an Swish object."""
    """Swish activation definition.
    def forward(self, x):
        """Return Swich activation function."""
        return x * torch.sigmoid(x)
    Swish(x) = (beta * x) * sigmoid(x)
                 where beta = 1 defines standard Swish activation.
    References:
        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
        E-swish variant: https://arxiv.org/abs/1801.07145.
    Args:
        beta: Beta parameter for E-Swish.
                (beta >= 1. If beta < 1, use standard Swish).
        use_builtin: Whether to use PyTorch function if available.
    """
    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
        super().__init__()
        self.beta = beta
        if beta > 1:
            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
        else:
            if use_builtin:
                self.swish = torch.nn.SiLU()
            else:
                self.swish = lambda x: x * torch.sigmoid(x)
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward computation."""
        return self.swish(x)
def get_activation(act):
    """Return activation function."""
funasr/modules/repeat.py
@@ -7,7 +7,7 @@
"""Repeat the same layer definition."""
from typing import Dict, List, Optional
from funasr.modules.layer_norm import LayerNorm
import torch
@@ -48,7 +48,7 @@
        self,
        block_list: List[torch.nn.Module],
        output_size: int,
        norm_class: torch.nn.Module = torch.nn.LayerNorm,
        norm_class: torch.nn.Module = LayerNorm,
    ) -> None:
        """Construct a MultiBlocks object."""
        super().__init__()
funasr/tasks/asr.py
@@ -1684,7 +1684,7 @@
        # 7. Build model
        if encoder.unified_model_training:
        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
            model = UnifiedTransducerModel(
                vocab_size=vocab_size,
                token_list=token_list,