aky15
2023-04-12 7d1efe158eda74dc847c397db906f6cb77ac0f84
funasr/models/rnnt_decoder/stateless_decoder.py
File was renamed from funasr/models_transducer/decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
from funasr.models_transducer.beam_search_transducer import Hypothesis
from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
from funasr.modules.beam_search.beam_search_transducer import Hypothesis
from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
        embed_size: int = 256,
        embed_dropout_rate: float = 0.0,
        embed_pad: int = 0,
        use_embed_mask: bool = False,
    ) -> None:
        """Construct a StatelessDecoder object."""
        super().__init__()
@@ -42,14 +41,6 @@
        self.device = next(self.parameters()).device
        self.score_cache = {}
        self.use_embed_mask = use_embed_mask
        if self.use_embed_mask:
            self._embed_mask = SpecAug(
                time_mask_width_range=3,
                num_time_mask=1,
                apply_freq_mask=False,
                apply_time_warp=False
            )
    def forward(
@@ -69,9 +60,6 @@
        """
        dec_embed = self.embed_dropout_rate(self.embed(labels))
        if self.use_embed_mask and self.training:
            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
        return dec_embed
    def score(