From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/transducer/rnnt_decoder.py |   40 ++++++++++++++++++----------------------
 1 files changed, 18 insertions(+), 22 deletions(-)

diff --git a/funasr/models/transducer/rnnt_decoder.py b/funasr/models/transducer/rnnt_decoder.py
index 6d35b71..bc1787a 100644
--- a/funasr/models/transducer/rnnt_decoder.py
+++ b/funasr/models/transducer/rnnt_decoder.py
@@ -1,12 +1,17 @@
-"""RNN decoder definition for Transducer models."""
-
-from typing import List, Optional, Tuple
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
 
 import torch
+from typing import List, Optional, Tuple
 
-from funasr.models.transducer.beam_search_transducer import Hypothesis
+from funasr.register import tables
 from funasr.models.specaug.specaug import SpecAug
+from funasr.models.transducer.beam_search_transducer import Hypothesis
 
+
+@tables.register("decoder_classes", "rnnt_decoder")
 class RNNTDecoder(torch.nn.Module):
     """RNN decoder module.
 
@@ -37,7 +42,6 @@
         """Construct a RNNDecoder object."""
         super().__init__()
 
-
         if rnn_type not in ("lstm", "gru"):
             raise ValueError(f"Not supported: rnn_type={rnn_type}")
 
@@ -46,9 +50,7 @@
 
         rnn_class = torch.nn.LSTM if rnn_type == "lstm" else torch.nn.GRU
 
-        self.rnn = torch.nn.ModuleList(
-            [rnn_class(embed_size, hidden_size, 1, batch_first=True)]
-        )
+        self.rnn = torch.nn.ModuleList([rnn_class(embed_size, hidden_size, 1, batch_first=True)])
 
         for _ in range(1, num_layers):
             self.rnn += [rnn_class(hidden_size, hidden_size, 1, batch_first=True)]
@@ -72,9 +74,9 @@
                 time_mask_width_range=3,
                 num_time_mask=4,
                 apply_freq_mask=False,
-                apply_time_warp=False
+                apply_time_warp=False,
             )
-    
+
     def forward(
         self,
         labels: torch.Tensor,
@@ -123,13 +125,11 @@
 
         for layer in range(self.dlayers):
             if self.dtype == "lstm":
-                x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[
-                    layer
-                ](x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1]))
-            else:
-                x, h_next[layer : layer + 1] = self.rnn[layer](
-                    x, hx=h_prev[layer : layer + 1]
+                x, (h_next[layer : layer + 1], c_next[layer : layer + 1]) = self.rnn[layer](
+                    x, hx=(h_prev[layer : layer + 1], c_prev[layer : layer + 1])
                 )
+            else:
+                x, h_next[layer : layer + 1] = self.rnn[layer](x, hx=h_prev[layer : layer + 1])
 
             x = self.dropout_rnn[layer](x)
 
@@ -198,9 +198,7 @@
         """
         self.device = device
 
-    def init_state(
-        self, batch_size: int
-    ) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
+    def init_state(self, batch_size: int) -> Tuple[torch.Tensor, Optional[torch.tensor]]:
         """Initialize decoder states.
 
         Args:
@@ -262,7 +260,5 @@
         """
         return (
             torch.cat([s[0] for s in new_states], dim=1),
-            torch.cat([s[1] for s in new_states], dim=1)
-            if self.dtype == "lstm"
-            else None,
+            torch.cat([s[1] for s in new_states], dim=1) if self.dtype == "lstm" else None,
         )

--
Gitblit v1.9.1