From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/language_model/rnn/encoders.py | 26 +++++++-------------------
1 files changed, 7 insertions(+), 19 deletions(-)
diff --git a/funasr/models/language_model/rnn/encoders.py b/funasr/models/language_model/rnn/encoders.py
index 819585b..fff3766 100644
--- a/funasr/models/language_model/rnn/encoders.py
+++ b/funasr/models/language_model/rnn/encoders.py
@@ -7,7 +7,7 @@
from torch.nn.utils.rnn import pack_padded_sequence
from torch.nn.utils.rnn import pad_packed_sequence
-from funasr.metrics import get_vgg2l_odim
+from funasr.metrics.common import get_vgg2l_odim
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.nets_utils import to_device
@@ -34,9 +34,7 @@
inputdim = hdim
RNN = torch.nn.LSTM if "lstm" in typ else torch.nn.GRU
- rnn = RNN(
- inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True
- )
+ rnn = RNN(inputdim, cdim, num_layers=1, bidirectional=bidir, batch_first=True)
setattr(self, "%s%d" % ("birnn" if bidir else "rnn", i), rnn)
@@ -72,9 +70,7 @@
rnn.flatten_parameters()
if prev_state is not None and rnn.bidirectional:
prev_state = reset_backward_rnn_state(prev_state)
- ys, states = rnn(
- xs_pack, hx=None if prev_state is None else prev_state[layer]
- )
+ ys, states = rnn(xs_pack, hx=None if prev_state is None else prev_state[layer])
elayer_states.append(states)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
@@ -155,9 +151,7 @@
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad, ilens = pad_packed_sequence(ys, batch_first=True)
# (sum _utt frame_utt) x dim
- projected = torch.tanh(
- self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2)))
- )
+ projected = torch.tanh(self.l_last(ys_pad.contiguous().view(-1, ys_pad.size(2))))
xs_pad = projected.view(ys_pad.size(0), ys_pad.size(1), -1)
return xs_pad, ilens, states # x: utt list of frame x dim
@@ -225,9 +219,7 @@
else:
ilens = np.array(ilens, dtype=np.float32)
ilens = np.array(np.ceil(ilens / 2), dtype=np.int64)
- ilens = np.array(
- np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64
- ).tolist()
+ ilens = np.array(np.ceil(np.array(ilens, dtype=np.float32) / 2), dtype=np.int64).tolist()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad = xs_pad.transpose(1, 2)
@@ -250,9 +242,7 @@
:param int in_channel: number of input channels
"""
- def __init__(
- self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1
- ):
+ def __init__(self, etype, idim, elayers, eunits, eprojs, subsample, dropout, in_channel=1):
super(Encoder, self).__init__()
typ = etype.lstrip("vgg").rstrip("p")
if typ not in ["lstm", "gru", "blstm", "bgru"]:
@@ -367,6 +357,4 @@
enc_list.append(enc)
return enc_list
else:
- raise ValueError(
- "Number of encoders needs to be more than one. {}".format(num_encs)
- )
+ raise ValueError("Number of encoders needs to be more than one. {}".format(num_encs))
--
Gitblit v1.9.1