From dfcc5d47587d3e793cbfec2e9509c0e9a9e1732c Mon Sep 17 00:00:00 2001
From: Dogvane Huang <dogvane@gmail.com>
Date: 星期二, 02 七月 2024 12:24:13 +0800
Subject: [PATCH] fix c# demo project to new onnx model files (#1689)
---
funasr/models/language_model/seq_rnn_lm.py | 9 +++------
1 files changed, 3 insertions(+), 6 deletions(-)
diff --git a/funasr/models/language_model/seq_rnn_lm.py b/funasr/models/language_model/seq_rnn_lm.py
index bef4974..baf63bc 100644
--- a/funasr/models/language_model/seq_rnn_lm.py
+++ b/funasr/models/language_model/seq_rnn_lm.py
@@ -1,4 +1,5 @@
"""Sequential implementation of Recurrent Neural Network Language Model."""
+
from typing import Tuple
from typing import Union
@@ -37,9 +38,7 @@
self.encoder = nn.Embedding(vocab_size, ninp, padding_idx=ignore_id)
if rnn_type in ["LSTM", "GRU"]:
rnn_class = getattr(nn, rnn_type)
- self.rnn = rnn_class(
- ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True
- )
+ self.rnn = rnn_class(ninp, nhid, nlayers, dropout=dropout_rate, batch_first=True)
else:
try:
nonlinearity = {"RNN_TANH": "tanh", "RNN_RELU": "relu"}[rnn_type]
@@ -67,9 +66,7 @@
# https://arxiv.org/abs/1611.01462
if tie_weights:
if nhid != ninp:
- raise ValueError(
- "When using the tied flag, nhid must be equal to emsize"
- )
+ raise ValueError("When using the tied flag, nhid must be equal to emsize")
self.decoder.weight = self.encoder.weight
self.rnn_type = rnn_type
--
Gitblit v1.9.1