From 15d5ba7882a1c83b75b3154b69b0a79208b132a1 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 09 五月 2023 13:56:58 +0800
Subject: [PATCH] Merge pull request #479 from alibaba-damo-academy/dev_aky

---
 funasr/models/decoder/rnnt_decoder.py |   12 ++++++++++++
 1 files changed, 12 insertions(+), 0 deletions(-)

diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
index 5401ab2..a0fe9ea 100644
--- a/funasr/models/decoder/rnnt_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -33,6 +33,7 @@
         dropout_rate: float = 0.0,
         embed_dropout_rate: float = 0.0,
         embed_pad: int = 0,
+        use_embed_mask: bool = False,
     ) -> None:
         """Construct a RNNDecoder object."""
         super().__init__()
@@ -66,6 +67,15 @@
 
         self.device = next(self.parameters()).device
         self.score_cache = {}
+
+        self.use_embed_mask = use_embed_mask
+        if self.use_embed_mask:
+            self._embed_mask = SpecAug(
+                time_mask_width_range=3,
+                num_time_mask=4,
+                apply_freq_mask=False,
+                apply_time_warp=False
+            )
     
     def forward(
         self,
@@ -88,6 +98,8 @@
             states = self.init_state(labels.size(0))
 
         dec_embed = self.dropout_embed(self.embed(labels))
+        if self.use_embed_mask and self.training:
+            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
         dec_out, states = self.rnn_forward(dec_embed, states)
         return dec_out
 

--
Gitblit v1.9.1