From 3df109adfccedeb134dea4ba2ea9a2da89872048 Mon Sep 17 00:00:00 2001
From: Isuxiz Slidder <48672727+Isuxiz@users.noreply.github.com>
Date: 星期一, 31 三月 2025 17:51:52 +0800
Subject: [PATCH] Update model.py to fix "IndexError: index 1 is out of bounds for dimension 1 with size 0" (#2454)
---
funasr/models/language_model/seq_rnn_lm.py | 9 +++------
1 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/funasr/models/language_model/seq_rnn_lm.py b/funasr/models/language_model/seq_rnn_lm.py
index bef4974..baf63bc 100644
--- a/funasr/models/language_model/seq_rnn_lm.py
+++ b/funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
"""Sequential implementation of Recurrent Neural Network Language Model."""
+
from typing import Tuple
from typing import Union
@@ -37,9 +38,7 @@
self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
if rnn_type in ["LSTM", "GRU"]:
rnn_class = getattr(nn, rnn_type)
- self.rnn = rnn_class(
- ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
- )
+ self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
- raise ValueError(
- "When using the tied flag, nhid must be equal to emsize"
- )
+ raise ValueError("When using the tied flag, nhid must be equal to emsize")
self.decoder.weight = self.encoder.weight
self.rnn_type = rnn_type
--
Gitblit v1.9.1