From 77045e7bb78d4b8a82f96130f9d84e356a32d5c5 Mon Sep 17 00:00:00 2001
From: aky15 <ankeyu.aky@11.17.44.249>
Date: 星期二, 09 五月 2023 11:16:07 +0800
Subject: [PATCH] rnnt bug fix

---
 funasr/modules/nets_utils.py               |   35 +++++++++++++++--
 funasr/models/encoder/conformer_encoder.py |    4 +-
 funasr/bin/asr_inference_rnnt.py           |   19 ++++-----
 funasr/tasks/asr.py                        |    2 
 funasr/models/decoder/rnnt_decoder.py      |   12 ++++++
 funasr/modules/repeat.py                   |    4 +-
 funasr/models/e2e_asr_transducer.py        |    8 ++--
 7 files changed, 59 insertions(+), 25 deletions(-)

diff --git a/funasr/bin/asr_inference_rnnt.py b/funasr/bin/asr_inference_rnnt.py
index d964643..bd36907 100644
--- a/funasr/bin/asr_inference_rnnt.py
+++ b/funasr/bin/asr_inference_rnnt.py
@@ -188,18 +188,15 @@
         self.frontend = frontend
         self.window_size = self.chunk_size + self.right_context
         
-        self._ctx = self.asr_model.encoder.get_encoder_input_size(
-            self.window_size
-        )
+        if self.streaming:
+            self._ctx = self.asr_model.encoder.get_encoder_input_size(
+                self.window_size
+            )
        
-        #self.last_chunk_length = (
-        #    self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        #) * self.hop_length
-
-        self.last_chunk_length = (
-            self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
-        )
-        self.reset_inference_cache()
+            self.last_chunk_length = (
+                self.asr_model.encoder.embed.min_frame_length + self.right_context + 1
+            )
+            self.reset_inference_cache()
 
     def reset_inference_cache(self) -> None:
         """Reset Speech2Text parameters."""
diff --git a/funasr/models/decoder/rnnt_decoder.py b/funasr/models/decoder/rnnt_decoder.py
index 5401ab2..a0fe9ea 100644
--- a/funasr/models/decoder/rnnt_decoder.py
+++ b/funasr/models/decoder/rnnt_decoder.py
@@ -33,6 +33,7 @@
         dropout_rate: float = 0.0,
         embed_dropout_rate: float = 0.0,
         embed_pad: int = 0,
+        use_embed_mask: bool = False,
     ) -> None:
         """Construct a RNNDecoder object."""
         super().__init__()
@@ -66,6 +67,15 @@
 
         self.device = next(self.parameters()).device
         self.score_cache = {}
+
+        self.use_embed_mask = use_embed_mask
+        if self.use_embed_mask:
+            self._embed_mask = SpecAug(
+                time_mask_width_range=3,
+                num_time_mask=4,
+                apply_freq_mask=False,
+                apply_time_warp=False
+            )
     
     def forward(
         self,
@@ -88,6 +98,8 @@
             states = self.init_state(labels.size(0))
 
         dec_embed = self.dropout_embed(self.embed(labels))
+        if self.use_embed_mask and self.training:
+            dec_embed = self._embed_mask(dec_embed, label_lens)[0]
         dec_out, states = self.rnn_forward(dec_embed, states)
         return dec_out
 
diff --git a/funasr/models/e2e_asr_transducer.py b/funasr/models/e2e_asr_transducer.py
index f8ba0f0..a5aaa6c 100644
--- a/funasr/models/e2e_asr_transducer.py
+++ b/funasr/models/e2e_asr_transducer.py
@@ -12,7 +12,7 @@
 from funasr.models.specaug.abs_specaug import AbsSpecAug
 from funasr.models.decoder.rnnt_decoder import RNNTDecoder
 from funasr.models.decoder.abs_decoder import AbsDecoder as AbsAttDecoder
-from funasr.models.encoder.conformer_encoder import ConformerChunkEncoder as Encoder
+from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.models.joint_net.joint_network import JointNetwork
 from funasr.modules.nets_utils import get_transducer_task_io
 from funasr.layers.abs_normalize import AbsNormalize
@@ -62,7 +62,7 @@
         frontend: Optional[AbsFrontend],
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
-        encoder: Encoder,
+        encoder: AbsEncoder,
         decoder: RNNTDecoder,
         joint_network: JointNetwork,
         att_decoder: Optional[AbsAttDecoder] = None,
@@ -286,7 +286,7 @@
                 feats, feats_lengths = self.normalize(feats, feats_lengths)
 
         # 4. Forward encoder
-        encoder_out, encoder_out_lens = self.encoder(feats, feats_lengths)
+        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -515,7 +515,7 @@
         frontend: Optional[AbsFrontend],
         specaug: Optional[AbsSpecAug],
         normalize: Optional[AbsNormalize],
-        encoder: Encoder,
+        encoder: AbsEncoder,
         decoder: RNNTDecoder,
         joint_network: JointNetwork,
         att_decoder: Optional[AbsAttDecoder] = None,
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 9777cee..434f2a4 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -307,7 +307,7 @@
         feed_forward: torch.nn.Module,
         feed_forward_macaron: torch.nn.Module,
         conv_mod: torch.nn.Module,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_class: torch.nn.Module = LayerNorm,
         norm_args: Dict = {},
         dropout_rate: float = 0.0,
     ) -> None:
@@ -1145,7 +1145,7 @@
             x = x[:,::self.time_reduction_factor,:]
             olens = torch.floor_divide(olens-1, self.time_reduction_factor) + 1
 
-        return x, olens
+        return x, olens, None
 
     def simu_chunk_forward(
         self,
diff --git a/funasr/modules/nets_utils.py b/funasr/modules/nets_utils.py
index 10df124..397a5c4 100644
--- a/funasr/modules/nets_utils.py
+++ b/funasr/modules/nets_utils.py
@@ -485,14 +485,39 @@
         new_k = k.replace(old_prefix, new_prefix)
         state_dict[new_k] = v
 
-
 class Swish(torch.nn.Module):
-    """Construct an Swish object."""
+    """Swish activation definition.
 
-    def forward(self, x):
-        """Return Swich activation function."""
-        return x * torch.sigmoid(x)
+    Swish(x) = (beta * x) * sigmoid(x)
+                 where beta = 1 defines standard Swish activation.
 
+    References:
+        https://arxiv.org/abs/2108.12943 / https://arxiv.org/abs/1710.05941v1.
+        E-swish variant: https://arxiv.org/abs/1801.07145.
+
+    Args:
+        beta: Beta parameter for E-Swish.
+                (beta >= 1. If beta < 1, use standard Swish).
+        use_builtin: Whether to use PyTorch function if available.
+
+    """
+
+    def __init__(self, beta: float = 1.0, use_builtin: bool = False) -> None:
+        super().__init__()
+
+        self.beta = beta
+
+        if beta > 1:
+            self.swish = lambda x: (self.beta * x) * torch.sigmoid(x)
+        else:
+            if use_builtin:
+                self.swish = torch.nn.SiLU()
+            else:
+                self.swish = lambda x: x * torch.sigmoid(x)
+
+    def forward(self, x: torch.Tensor) -> torch.Tensor:
+        """Forward computation."""
+        return self.swish(x)
 
 def get_activation(act):
     """Return activation function."""
diff --git a/funasr/modules/repeat.py b/funasr/modules/repeat.py
index 2b2dac8..ff1e182 100644
--- a/funasr/modules/repeat.py
+++ b/funasr/modules/repeat.py
@@ -7,7 +7,7 @@
 """Repeat the same layer definition."""
 
 from typing import Dict, List, Optional
-
+from funasr.modules.layer_norm import LayerNorm
 import torch
 
 
@@ -48,7 +48,7 @@
         self,
         block_list: List[torch.nn.Module],
         output_size: int,
-        norm_class: torch.nn.Module = torch.nn.LayerNorm,
+        norm_class: torch.nn.Module = LayerNorm,
     ) -> None:
         """Construct a MultiBlocks object."""
         super().__init__()
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index 87db05c..a64b9e7 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -1682,7 +1682,7 @@
 
         # 7. Build model
 
-        if encoder.unified_model_training:
+        if hasattr(encoder, 'unified_model_training') and encoder.unified_model_training:
             model = UnifiedTransducerModel(
                 vocab_size=vocab_size,
                 token_list=token_list,

--
Gitblit v1.9.1