From 7d1efe158eda74dc847c397db906f6cb77ac0f84 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期三, 12 四月 2023 16:49:56 +0800
Subject: [PATCH] rnnt reorg
---
funasr/models/rnnt_decoder/stateless_decoder.py | 16 ++--------------
1 files changed, 2 insertions(+), 14 deletions(-)
diff --git a/funasr/models_transducer/decoder/stateless_decoder.py b/funasr/models/rnnt_decoder/stateless_decoder.py
similarity index 86%
rename from funasr/models_transducer/decoder/stateless_decoder.py
rename to funasr/models/rnnt_decoder/stateless_decoder.py
index 07c8f51..a2e1fc1 100644
--- a/funasr/models_transducer/decoder/stateless_decoder.py
+++ b/funasr/models/rnnt_decoder/stateless_decoder.py
@@ -5,8 +5,8 @@
import torch
from typeguard import check_argument_types
-from funasr.models_transducer.beam_search_transducer import Hypothesis
-from funasr.models_transducer.decoder.abs_decoder import AbsDecoder
+from funasr.modules.beam_search.beam_search_transducer import Hypothesis
+from funasr.models.rnnt_decoder.abs_decoder import AbsDecoder
from funasr.models.specaug.specaug import SpecAug
class StatelessDecoder(AbsDecoder):
@@ -26,7 +26,6 @@
embed_size: int = 256,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
- use_embed_mask: bool = False,
) -> None:
"""Construct a StatelessDecoder object."""
super().__init__()
@@ -42,14 +41,6 @@
self.device = next(self.parameters()).device
self.score_cache = {}
- self.use_embed_mask = use_embed_mask
- if self.use_embed_mask:
- self._embed_mask = SpecAug(
- time_mask_width_range=3,
- num_time_mask=1,
- apply_freq_mask=False,
- apply_time_warp=False
- )
def forward(
@@ -69,9 +60,6 @@
"""
dec_embed = self.embed_dropout_rate(self.embed(labels))
- if self.use_embed_mask and self.training:
- dec_embed = self._embed_mask(dec_embed, label_lens)[0]
-
return dec_embed
def score(
--
Gitblit v1.9.1