From 1d1ef01b4e23630a99a3be7e9d1dce9550a793e9 Mon Sep 17 00:00:00 2001
From: yhliang <68215459+yhliang-aslp@users.noreply.github.com>
Date: 星期四, 11 五月 2023 16:26:24 +0800
Subject: [PATCH] Merge branch 'main' into dev_smohan
---
funasr/models/decoder/rnnt_decoder.py | 12 ++++++++++++
1 files changed, 12 insertions(+), 0 deletions(-)
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
index 5401ab2..a0fe9ea 100644
--- a/funasr/models/decoder/rnnt_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -33,6 +33,7 @@
dropout_rate: float = 0.0,
embed_dropout_rate: float = 0.0,
embed_pad: int = 0,
+ use_embed_mask: bool = False,
) -> None:
"""Construct a RNNDecoder object."""
super().__init__()
@@ -66,6 +67,15 @@
self.device = next(self.parameters()).device
self.score_cache = {}
+
+ self.use_embed_mask = use_embed_mask
+ if self.use_embed_mask:
+ self._embed_mask = SpecAug(
+ time_mask_width_range=3,
+ num_time_mask=4,
+ apply_freq_mask=False,
+ apply_time_warp=False
+ )
def forward(
self,
@@ -88,6 +98,8 @@
states = self.init_state(labels.size(0))
dec_embed = self.dropout_embed(self.embed(labels))
+ if self.use_embed_mask and self.training:
+ dec_embed = self._embed_mask(dec_embed, label_lens)[0]
dec_out, states = self.rnn_forward(dec_embed, states)
return dec_out
--
Gitblit v1.9.1