From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/language_model/rnn/decoders.py | 247 +++++++++++++++++-------------------------------
1 files changed, 89 insertions(+), 158 deletions(-)
diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py
index a426b51..e7d35e9 100644
--- a/funasr/models/language_model/rnn/decoders.py
+++ b/funasr/models/language_model/rnn/decoders.py
@@ -1,4 +1,5 @@
"""RNN decoder module."""
+
import logging
import math
import random
@@ -15,7 +16,7 @@
from funasr.metrics import end_detect
from funasr.models.transformer.utils.nets_utils import mask_by_length
from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import att_to_numpy
@@ -45,24 +46,24 @@
"""
def __init__(
- self,
- eprojs,
- odim,
- dtype,
- dlayers,
- dunits,
- sos,
- eos,
- att,
- verbose=0,
- char_list=None,
- labeldist=None,
- lsm_weight=0.0,
- sampling_probability=0.0,
- dropout=0.0,
- context_residual=False,
- replace_sos=False,
- num_encs=1,
+ self,
+ eprojs,
+ odim,
+ dtype,
+ dlayers,
+ dunits,
+ sos,
+ eos,
+ att,
+ verbose=0,
+ char_list=None,
+ labeldist=None,
+ lsm_weight=0.0,
+ sampling_probability=0.0,
+ dropout=0.0,
+ context_residual=False,
+ replace_sos=False,
+ num_encs=1,
):
torch.nn.Module.__init__(self)
@@ -76,16 +77,20 @@
self.decoder = torch.nn.ModuleList()
self.dropout_dec = torch.nn.ModuleList()
self.decoder += [
- torch.nn.LSTMCell(dunits + eprojs, dunits)
- if self.dtype == "lstm"
- else torch.nn.GRUCell(dunits + eprojs, dunits)
+ (
+ torch.nn.LSTMCell(dunits + eprojs, dunits)
+ if self.dtype == "lstm"
+ else torch.nn.GRUCell(dunits + eprojs, dunits)
+ )
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
for _ in six.moves.range(1, self.dlayers):
self.decoder += [
- torch.nn.LSTMCell(dunits, dunits)
- if self.dtype == "lstm"
- else torch.nn.GRUCell(dunits, dunits)
+ (
+ torch.nn.LSTMCell(dunits, dunits)
+ if self.dtype == "lstm"
+ else torch.nn.GRUCell(dunits, dunits)
+ )
]
self.dropout_dec += [torch.nn.Dropout(p=dropout)]
# NOTE: dropout is applied only for the vertical connections
@@ -131,9 +136,7 @@
else:
z_list[0] = self.decoder[0](ey, z_prev[0])
for i in six.moves.range(1, self.dlayers):
- z_list[i] = self.decoder[i](
- self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
- )
+ z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i])
return z_list, c_list
def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
@@ -197,9 +200,7 @@
)
)
logging.info(
- self.__class__.__name__
- + " output lengths: "
- + str([y.size(0) for y in ys_out])
+ self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out])
)
# initialization
@@ -280,7 +281,7 @@
ys_hat = y_all.view(batch, olength, -1)
ys_true = ys_out_pad
for (i, y_hat), y_true in zip(
- enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
+ enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
):
if i == MAX_DECODER_OUTPUT:
break
@@ -362,9 +363,7 @@
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
recog_args.weights_ctc_dec
) # normalize
- logging.info(
- "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
- )
+ logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
@@ -450,9 +449,7 @@
hyp["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
- z_list, c_list = self.rnn_forward(
- ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
- )
+ z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"])
# get nbest local scores and their ids
if self.context_residual:
@@ -464,9 +461,7 @@
local_att_scores = F.log_softmax(logits, dim=1)
if rnnlm:
rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
- local_scores = (
- local_att_scores + recog_args.lm_weight * local_lm_scores
- )
+ local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
else:
local_scores = local_att_scores
@@ -482,9 +477,7 @@
ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
)
- local_scores = (1.0 - ctc_weight) * local_att_scores[
- :, local_best_ids[0]
- ]
+ local_scores = (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
if self.num_encs == 1:
local_scores += ctc_weight * torch.from_numpy(
ctc_scores[0] - hyp["ctc_score_prev"][0]
@@ -492,24 +485,16 @@
else:
for idx in range(self.num_encs):
local_scores += (
- ctc_weight
- * weights_ctc_dec[idx]
- * torch.from_numpy(
- ctc_scores[idx] - hyp["ctc_score_prev"][idx]
- )
+ ctc_weight
+ * weights_ctc_dec[idx]
+ * torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx])
)
if rnnlm:
- local_scores += (
- recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
- )
- local_best_scores, joint_best_ids = torch.topk(
- local_scores, beam, dim=1
- )
+ local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
+ local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
local_best_ids = local_best_ids[:, joint_best_ids[0]]
else:
- local_best_scores, local_best_ids = torch.topk(
- local_scores, beam, dim=1
- )
+ local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
for j in six.moves.range(beam):
new_hyp = {}
@@ -519,9 +504,7 @@
if self.num_encs == 1:
new_hyp["a_prev"] = att_w[:]
else:
- new_hyp["a_prev"] = [
- att_w_list[idx][:] for idx in range(self.num_encs + 1)
- ]
+ new_hyp["a_prev"] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
@@ -530,27 +513,22 @@
new_hyp["rnnlm_prev"] = rnnlm_state
if lpz[0] is not None:
new_hyp["ctc_state_prev"] = [
- ctc_states[idx][joint_best_ids[0, j]]
- for idx in range(self.num_encs)
+ ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
]
new_hyp["ctc_score_prev"] = [
- ctc_scores[idx][joint_best_ids[0, j]]
- for idx in range(self.num_encs)
+ ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
]
# will be (2 x beam) hyps at most
hyps_best_kept.append(new_hyp)
- hyps_best_kept = sorted(
- hyps_best_kept, key=lambda x: x["score"], reverse=True
- )[:beam]
+ hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[
+ :beam
+ ]
# sort and get nbest
hyps = hyps_best_kept
logging.debug("number of pruned hypotheses: " + str(len(hyps)))
- logging.debug(
- "best hypo: "
- + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
- )
+ logging.debug("best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]))
# add eos in the final loop to avoid that there are no ended hyps
if i == maxlen - 1:
@@ -569,9 +547,7 @@
if len(hyp["yseq"]) > minlen:
hyp["score"] += (i + 1) * penalty
if rnnlm: # Word LM needs to add final <eos> score
- hyp["score"] += recog_args.lm_weight * rnnlm.final(
- hyp["rnnlm_prev"]
- )
+ hyp["score"] += recog_args.lm_weight * rnnlm.final(hyp["rnnlm_prev"])
ended_hyps.append(hyp)
else:
remained_hyps.append(hyp)
@@ -589,21 +565,18 @@
break
for hyp in hyps:
- logging.debug(
- "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
- )
+ logging.debug("hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]))
logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
- : min(len(ended_hyps), recog_args.nbest)
- ]
+ : min(len(ended_hyps), recog_args.nbest)
+ ]
# check number of hypotheses
if len(nbest_hyps) == 0:
logging.warning(
- "there is no N-best results, "
- "perform recognition again with smaller minlenratio."
+ "there is no N-best results, " "perform recognition again with smaller minlenratio."
)
# should copy because Namespace will be overwritten globally
recog_args = Namespace(**vars(recog_args))
@@ -623,16 +596,16 @@
return nbest_hyps
def recognize_beam_batch(
- self,
- h,
- hlens,
- lpz,
- recog_args,
- char_list,
- rnnlm=None,
- normalize_score=True,
- strm_idx=0,
- lang_ids=None,
+ self,
+ h,
+ hlens,
+ lpz,
+ recog_args,
+ char_list,
+ rnnlm=None,
+ normalize_score=True,
+ strm_idx=0,
+ lang_ids=None,
):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
@@ -667,9 +640,7 @@
weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
recog_args.weights_ctc_dec
) # normalize
- logging.info(
- "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
- )
+ logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
else:
weights_ctc_dec = [1.0]
@@ -686,18 +657,10 @@
logging.info("min output length: " + str(minlen))
# initialization
- c_prev = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- z_prev = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- c_list = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
- z_list = [
- to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
- ]
+ c_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+ z_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+ c_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+ z_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
vscores = to_device(h[0], torch.zeros(batch, beam))
rnnlm_state = None
@@ -716,14 +679,10 @@
if self.replace_sos and recog_args.tgt_lang:
logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
logging.info("<sos> mark: " + recog_args.tgt_lang)
- yseq = [
- [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)
- ]
+ yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
elif lang_ids is not None:
# NOTE: used for evaluation during training
- yseq = [
- [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)
- ]
+ yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
else:
logging.info("<sos> index: " + str(self.sos))
logging.info("<sos> mark: " + char_list[self.sos])
@@ -740,8 +699,7 @@
]
exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
exp_h = [
- h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
- for idx in range(self.num_encs)
+ h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)
]
exp_h = [
exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
@@ -750,9 +708,7 @@
if lpz[0] is not None:
scoring_num = min(
- int(beam * CTC_SCORING_RATIO)
- if att_weight > 0.0 and not lpz[0].is_cuda
- else 0,
+ int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0,
lpz[0].size(-1),
)
ctc_scorer = [
@@ -796,9 +752,7 @@
# attention decoder
z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
if self.context_residual:
- logits = self.output(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- )
+ logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
local_scores = att_weight * F.log_softmax(logits, dim=1)
@@ -812,9 +766,7 @@
if ctc_scorer[0]:
local_scores[:, 0] = self.logzero # avoid choosing blank
part_ids = (
- torch.topk(local_scores, scoring_num, dim=-1)[1]
- if scoring_num > 0
- else None
+ torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None
)
for idx in range(self.num_encs):
att_w = att_w_list[idx]
@@ -823,8 +775,7 @@
yseq, ctc_state[idx], part_ids, att_w_
)
local_scores = (
- local_scores
- + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
+ local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
)
local_scores = local_scores.view(batch, beam, self.odim)
@@ -839,9 +790,7 @@
# global pruning
accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
- accum_odim_ids = (
- torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
- )
+ accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
accum_padded_beam_ids = (
(accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
)
@@ -867,24 +816,16 @@
]
else:
# handle the case of location_recurrent when return is a tuple
- _a_prev_ = torch.index_select(
- att_w_list[idx][0].view(n_bb, -1), 0, vidx
- )
- _h_prev_ = torch.index_select(
- att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
- )
- _c_prev_ = torch.index_select(
- att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
- )
+ _a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
+ _h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
+ _c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
_a_prev = (_a_prev_, (_h_prev_, _c_prev_))
a_prev.append(_a_prev)
z_prev = [
- torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
- for li in range(self.dlayers)
+ torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
]
c_prev = [
- torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
- for li in range(self.dlayers)
+ torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
]
# pick ended hyps
@@ -900,9 +841,7 @@
_vscore = None
if eos_vscores[samp_i, beam_j] > thr[samp_i]:
yk = y_prev[k][:]
- if len(yk) <= min(
- hlens[idx][samp_i] for idx in range(self.num_encs)
- ):
+ if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)):
_vscore = eos_vscores[samp_i][beam_j] + penalty_i
elif i == maxlen - 1:
yk = yseq[k][:]
@@ -910,9 +849,7 @@
if _vscore:
yk.append(self.eos)
if rnnlm:
- _vscore += recog_args.lm_weight * rnnlm.final(
- rnnlm_state, index=k
- )
+ _vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
_score = _vscore.data.cpu().numpy()
ended_hyps[samp_i].append(
{"yseq": yk, "vscore": _vscore, "score": _score}
@@ -938,9 +875,7 @@
torch.cuda.empty_cache()
- dummy_hyps = [
- {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
- ]
+ dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
ended_hyps = [
ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
for samp_i in six.moves.range(batch)
@@ -952,7 +887,7 @@
nbest_hyps = [
sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
- : min(len(ended_hyps[samp_i]), recog_args.nbest)
+ : min(len(ended_hyps[samp_i]), recog_args.nbest)
]
for samp_i in six.moves.range(batch)
]
@@ -1168,13 +1103,9 @@
state["a_prev"][self.num_encs],
)
ey = torch.cat((ey, att_c), dim=1) # utt(1) x (zdim + hdim)
- z_list, c_list = self.rnn_forward(
- ey, z_list, c_list, state["z_prev"], state["c_prev"]
- )
+ z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"])
if self.context_residual:
- logits = self.output(
- torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
- )
+ logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
else:
logits = self.output(self.dropout_dec[-1](z_list[-1]))
logp = F.log_softmax(logits, dim=1).squeeze(0)
--
Gitblit v1.9.1