游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/models/decoder/rnnt_decoder.py
@@ -3,7 +3,6 @@
from typing import List, Optional, Tuple
import torch
from typeguard import check_argument_types
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.specaug.specaug import SpecAug
@@ -33,11 +32,11 @@
        dropout_rate: float = 0.0,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
        use_embed_mask: bool = False,
    ) -> None:
        """Construct a RNNDecoder object."""
        super().__init__()
        assert check_argument_types()
        if rnn_type not in ("lstm", "gru"):
            raise ValueError(f"Not supported: rnn_type={rnn_type}")
@@ -66,6 +65,15 @@
        self.device = next(self.parameters()).device
        self.score_cache = {}
        self.use_embed_mask = use_embed_mask
        if self.use_embed_mask:
            self._embed_mask = SpecAug(
                time_mask_width_range=3,
                num_time_mask=4,
                apply_freq_mask=False,
                apply_time_warp=False
            )
    
    def forward(
        self,
@@ -88,6 +96,8 @@
            states = self.init_state(labels.size(0))
        dec_embed = self.dropout_embed(self.embed(labels))
        if self.use_embed_mask and self.training:
            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
        dec_out, states = self.rnn_forward(dec_embed, states)
        return dec_out