From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/language_model/seq_rnn_lm.py |    9 +++------
 1 files changed, 3 insertions(+), 6 deletions(-)

diff --git a/funasr/models/language_model/seq_rnn_lm.py b/funasr/models/language_model/seq_rnn_lm.py
index bef4974..baf63bc 100644
--- a/funasr/models/language_model/seq_rnn_lm.py
+++ b/funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
 """Sequential implementation of Recurrent Neural Network Language Model."""
+
 from typing import Tuple
 from typing import Union
 
@@ -37,9 +38,7 @@
         self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
         if rnn_type in ["LSTM", "GRU"]:
             rnn_class = getattr(nn, rnn_type)
-            self.rnn = rnn_class(
-                ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
-            )
+            self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
         else:
             try:
                 nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
         # https://arxiv.org/abs/1611.01462
         if tie_weights:
             if nhid != ninp:
-                raise ValueError(
-                    "When using the tied flag, nhid must be equal to emsize"
-                )
+                raise ValueError("When using the tied flag, nhid must be equal to emsize")
             self.decoder.weight = self.encoder.weight
 
         self.rnn_type = rnn_type

--
Gitblit v1.9.1