From 6e69d784e4814c3dbe35e8f70c6cf4b920c8b20b Mon Sep 17 00:00:00 2001
From: 天地 <tiandiweizun@gmail.com>
Date: 星期三, 19 三月 2025 23:10:13 +0800
Subject: [PATCH] 1. bug fix:list(mean)和list(var),由于mean和var是numpy,导致写入到文件的格式错误,参考上面的话,大概率是list(mean.tolist()),其实外层list没有必要 (#2437)
---
funasr/models/language_model/seq_rnn_lm.py | 9 +++------
1 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/funasr/models/language_model/seq_rnn_lm.py b/funasr/models/language_model/seq_rnn_lm.py
index bef4974..baf63bc 100644
--- a/funasr/models/language_model/seq_rnn_lm.py
+++ b/funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
"""Sequential implementation of Recurrent Neural Network Language Model."""
+
from typing import Tuple
from typing import Union
@@ -37,9 +38,7 @@
self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
if rnn_type in ["LSTM", "GRU"]:
rnn_class = getattr(nn, rnn_type)
- self.rnn = rnn_class(
- ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
- )
+ self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
- raise ValueError(
- "When using the tied flag, nhid must be equal to emsize"
- )
+ raise ValueError("When using the tied flag, nhid must be equal to emsize")
self.decoder.weight = self.encoder.weight
self.rnn_type = rnn_type
--
Gitblit v1.9.1