From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/language_model/seq_rnn_lm.py | 9 +++------
1 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/funasr/models/language_model/seq_rnn_lm.py b/funasr/models/language_model/seq_rnn_lm.py
index bef4974..baf63bc 100644
--- a/funasr/models/language_model/seq_rnn_lm.py
+++ b/funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
"""Sequential implementation of Recurrent Neural Network Language Model."""
+
from typing import Tuple
from typing import Union
@@ -37,9 +38,7 @@
self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
if rnn_type in ["LSTM", "GRU"]:
rnn_class = getattr(nn, rnn_type)
- self.rnn = rnn_class(
- ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
- )
+ self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
- raise ValueError(
- "When using the tied flag, nhid must be equal to emsize"
- )
+ raise ValueError("When using the tied flag, nhid must be equal to emsize")
self.decoder.weight = self.encoder.weight
self.rnn_type = rnn_type
--
Gitblit v1.9.1