From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg

---
 funasr/models/rnnt_decoder/stateless_decoder.py |   16 ++--------------
 1 files changed, 2 insertions(+), 14 deletions(-)

diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py
similarity index 86%
rename from funasr/models_transducer/decoder/stateless_decoder.py
rename to funasr/models/rnnt_decoder/stateless_decoder.py
index 07c8f51..a2e1fc1 100644
--- a/funasr/models_transducer/decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
 import torch
 from typeguard import check_argument_types
 
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
 from funasr.models.specaug.specaug import SpecAug
 
 class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
         embed_size: int = 256,
         embed_dropout_rate: float = 0.0,
         embed_pad: int = 0,
-        use_embed_mask: bool = False,
     ) -> None:
         """Construct a StatelessDecoder object."""
         super().__init__()
@@ -42,14 +41,6 @@
         self.device = next(self.parameters()).device
         self.score_cache = {}
 
-        self.use_embed_mask = use_embed_mask
-        if self.use_embed_mask:
-            self._embed_mask = SpecAug(
-                time_mask_width_range=3,
-                num_time_mask=1,
-                apply_freq_mask=False,
-                apply_time_warp=False
-            )
 
 
     def forward(
@@ -69,9 +60,6 @@
 
         """
         dec_embed = self.embed_dropout_rate(self.embed(labels))
-        if self.use_embed_mask and self.training:
-            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
-
         return dec_embed
 
     def score(

--
Gitblit v1.9.1