| File was renamed from funasr/models_transducer/decoder/stateless_decoder.py |
| | |
| | | import torch |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models_transducer.beam_search_transducer import Hypothesis |
| | | from funasr.models_transducer.decoder.abs_decoder import AbsDecoder |
| | | from funasr.modules.beam_search.beam_search_transducer import Hypothesis |
| | | from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder |
| | | from funasr.models.specaug.specaug import SpecAug |
| | | |
| | | class StatelessDecoder(AbsDecoder): |
| | |
| | | embed_size: int = 256, |
| | | embed_dropout_rate: float = 0.0, |
| | | embed_pad: int = 0, |
| | | use_embed_mask: bool = False, |
| | | ) -> None: |
| | | """Construct a StatelessDecoder object.""" |
| | | super().__init__() |
| | |
| | | self.device = next(self.parameters()).device |
| | | self.score_cache = {} |
| | | |
| | | self.use_embed_mask = use_embed_mask |
| | | if self.use_embed_mask: |
| | | self._embed_mask = SpecAug( |
| | | time_mask_width_range=3, |
| | | num_time_mask=1, |
| | | apply_freq_mask=False, |
| | | apply_time_warp=False |
| | | ) |
| | | |
| | | |
| | | def forward( |
| | |
| | | |
| | | """ |
| | | dec_embed = self.embed_dropout_rate(self.embed(labels)) |
| | | if self.use_embed_mask and self.training: |
| | | dec_embed = self._embed_mask(dec_embed, label_lens)[0] |
| | | |
| | | return dec_embed |
| | | |
| | | def score( |