From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/language_model/rnn/decoders.py |  247 +++++++++++++++++-------------------------------
 1 files changed, 89 insertions(+), 158 deletions(-)

diff --git a/funasr/models/language_model/rnn/decoders.py b/funasr/models/language_model/rnn/decoders.py
index a426b51..e7d35e9 100644
--- a/funasr/models/language_model/rnn/decoders.py
+++ b/funasr/models/language_model/rnn/decoders.py
@@ -1,4 +1,5 @@
 """RNN decoder module."""
+
 import logging
 import math
 import random
@@ -15,7 +16,7 @@
 from funasr.metrics import end_detect
 from funasr.models.transformer.utils.nets_utils import mask_by_length
 from funasr.models.transformer.utils.nets_utils import pad_list
-from funasr.models.transformer.utils.nets_utils import th_accuracy
+from funasr.metrics.compute_acc import th_accuracy
 from funasr.models.transformer.utils.nets_utils import to_device
 from funasr.models.language_model.rnn.attentions import att_to_numpy
 
@@ -45,24 +46,24 @@
     """
 
     def __init__(
-            self,
-            eprojs,
-            odim,
-            dtype,
-            dlayers,
-            dunits,
-            sos,
-            eos,
-            att,
-            verbose=0,
-            char_list=None,
-            labeldist=None,
-            lsm_weight=0.0,
-            sampling_probability=0.0,
-            dropout=0.0,
-            context_residual=False,
-            replace_sos=False,
-            num_encs=1,
+        self,
+        eprojs,
+        odim,
+        dtype,
+        dlayers,
+        dunits,
+        sos,
+        eos,
+        att,
+        verbose=0,
+        char_list=None,
+        labeldist=None,
+        lsm_weight=0.0,
+        sampling_probability=0.0,
+        dropout=0.0,
+        context_residual=False,
+        replace_sos=False,
+        num_encs=1,
     ):
 
         torch.nn.Module.__init__(self)
@@ -76,16 +77,20 @@
         self.decoder = torch.nn.ModuleList()
         self.dropout_dec = torch.nn.ModuleList()
         self.decoder += [
-            torch.nn.LSTMCell(dunits + eprojs, dunits)
-            if self.dtype == "lstm"
-            else torch.nn.GRUCell(dunits + eprojs, dunits)
+            (
+                torch.nn.LSTMCell(dunits + eprojs, dunits)
+                if self.dtype == "lstm"
+                else torch.nn.GRUCell(dunits + eprojs, dunits)
+            )
         ]
         self.dropout_dec += [torch.nn.Dropout(p=dropout)]
         for _ in six.moves.range(1, self.dlayers):
             self.decoder += [
-                torch.nn.LSTMCell(dunits, dunits)
-                if self.dtype == "lstm"
-                else torch.nn.GRUCell(dunits, dunits)
+                (
+                    torch.nn.LSTMCell(dunits, dunits)
+                    if self.dtype == "lstm"
+                    else torch.nn.GRUCell(dunits, dunits)
+                )
             ]
             self.dropout_dec += [torch.nn.Dropout(p=dropout)]
             # NOTE: dropout is applied only for the vertical connections
@@ -131,9 +136,7 @@
         else:
             z_list[0] = self.decoder[0](ey, z_prev[0])
             for i in six.moves.range(1, self.dlayers):
-                z_list[i] = self.decoder[i](
-                    self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
-                )
+                z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i])
         return z_list, c_list
 
     def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
@@ -197,9 +200,7 @@
                 )
             )
         logging.info(
-            self.__class__.__name__
-            + " output lengths: "
-            + str([y.size(0) for y in ys_out])
+            self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out])
         )
 
         # initialization
@@ -280,7 +281,7 @@
             ys_hat = y_all.view(batch, olength, -1)
             ys_true = ys_out_pad
             for (i, y_hat), y_true in zip(
-                    enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
+                enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
             ):
                 if i == MAX_DECODER_OUTPUT:
                     break
@@ -362,9 +363,7 @@
             weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
                 recog_args.weights_ctc_dec
             )  # normalize
-            logging.info(
-                "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
-            )
+            logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
         else:
             weights_ctc_dec = [1.0]
 
@@ -450,9 +449,7 @@
                         hyp["a_prev"][self.num_encs],
                     )
                 ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
-                z_list, c_list = self.rnn_forward(
-                    ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
-                )
+                z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"])
 
                 # get nbest local scores and their ids
                 if self.context_residual:
@@ -464,9 +461,7 @@
                 local_att_scores = F.log_softmax(logits, dim=1)
                 if rnnlm:
                     rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
-                    local_scores = (
-                            local_att_scores + recog_args.lm_weight * local_lm_scores
-                    )
+                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                 else:
                     local_scores = local_att_scores
 
@@ -482,9 +477,7 @@
                         ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
                             hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
                         )
-                    local_scores = (1.0 - ctc_weight) * local_att_scores[
-                                                        :, local_best_ids[0]
-                                                        ]
+                    local_scores = (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
                     if self.num_encs == 1:
                         local_scores += ctc_weight * torch.from_numpy(
                             ctc_scores[0] - hyp["ctc_score_prev"][0]
@@ -492,24 +485,16 @@
                     else:
                         for idx in range(self.num_encs):
                             local_scores += (
-                                    ctc_weight
-                                    * weights_ctc_dec[idx]
-                                    * torch.from_numpy(
-                                ctc_scores[idx] - hyp["ctc_score_prev"][idx]
-                            )
+                                ctc_weight
+                                * weights_ctc_dec[idx]
+                                * torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx])
                             )
                     if rnnlm:
-                        local_scores += (
-                                recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
-                        )
-                    local_best_scores, joint_best_ids = torch.topk(
-                        local_scores, beam, dim=1
-                    )
+                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
+                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                     local_best_ids = local_best_ids[:, joint_best_ids[0]]
                 else:
-                    local_best_scores, local_best_ids = torch.topk(
-                        local_scores, beam, dim=1
-                    )
+                    local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
 
                 for j in six.moves.range(beam):
                     new_hyp = {}
@@ -519,9 +504,7 @@
                     if self.num_encs == 1:
                         new_hyp["a_prev"] = att_w[:]
                     else:
-                        new_hyp["a_prev"] = [
-                            att_w_list[idx][:] for idx in range(self.num_encs + 1)
-                        ]
+                        new_hyp["a_prev"] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
                     new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
                     new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
                     new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
@@ -530,27 +513,22 @@
                         new_hyp["rnnlm_prev"] = rnnlm_state
                     if lpz[0] is not None:
                         new_hyp["ctc_state_prev"] = [
-                            ctc_states[idx][joint_best_ids[0, j]]
-                            for idx in range(self.num_encs)
+                            ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
                         ]
                         new_hyp["ctc_score_prev"] = [
-                            ctc_scores[idx][joint_best_ids[0, j]]
-                            for idx in range(self.num_encs)
+                            ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
                         ]
                     # will be (2 x beam) hyps at most
                     hyps_best_kept.append(new_hyp)
 
-                hyps_best_kept = sorted(
-                    hyps_best_kept, key=lambda x: x["score"], reverse=True
-                )[:beam]
+                hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[
+                    :beam
+                ]
 
             # sort and get nbest
             hyps = hyps_best_kept
             logging.debug("number of pruned hypotheses: " + str(len(hyps)))
-            logging.debug(
-                "best hypo: "
-                + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
-            )
+            logging.debug("best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]))
 
             # add eos in the final loop to avoid that there are no ended hyps
             if i == maxlen - 1:
@@ -569,9 +547,7 @@
                     if len(hyp["yseq"]) > minlen:
                         hyp["score"] += (i + 1) * penalty
                         if rnnlm:  # Word LM needs to add final <eos> score
-                            hyp["score"] += recog_args.lm_weight * rnnlm.final(
-                                hyp["rnnlm_prev"]
-                            )
+                            hyp["score"] += recog_args.lm_weight * rnnlm.final(hyp["rnnlm_prev"])
                         ended_hyps.append(hyp)
                 else:
                     remained_hyps.append(hyp)
@@ -589,21 +565,18 @@
                 break
 
             for hyp in hyps:
-                logging.debug(
-                    "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
-                )
+                logging.debug("hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]))
 
             logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
 
         nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
-                     : min(len(ended_hyps), recog_args.nbest)
-                     ]
+            : min(len(ended_hyps), recog_args.nbest)
+        ]
 
         # check number of hypotheses
         if len(nbest_hyps) == 0:
             logging.warning(
-                "there is no N-best results, "
-                "perform recognition again with smaller minlenratio."
+                "there is no N-best results, " "perform recognition again with smaller minlenratio."
             )
             # should copy because Namespace will be overwritten globally
             recog_args = Namespace(**vars(recog_args))
@@ -623,16 +596,16 @@
         return nbest_hyps
 
     def recognize_beam_batch(
-            self,
-            h,
-            hlens,
-            lpz,
-            recog_args,
-            char_list,
-            rnnlm=None,
-            normalize_score=True,
-            strm_idx=0,
-            lang_ids=None,
+        self,
+        h,
+        hlens,
+        lpz,
+        recog_args,
+        char_list,
+        rnnlm=None,
+        normalize_score=True,
+        strm_idx=0,
+        lang_ids=None,
     ):
         # to support mutiple encoder asr mode, in single encoder mode,
         # convert torch.Tensor to List of torch.Tensor
@@ -667,9 +640,7 @@
             weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
                 recog_args.weights_ctc_dec
             )  # normalize
-            logging.info(
-                "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
-            )
+            logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
         else:
             weights_ctc_dec = [1.0]
 
@@ -686,18 +657,10 @@
         logging.info("min output length: " + str(minlen))
 
         # initialization
-        c_prev = [
-            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
-        ]
-        z_prev = [
-            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
-        ]
-        c_list = [
-            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
-        ]
-        z_list = [
-            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
-        ]
+        c_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+        z_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+        c_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
+        z_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
         vscores = to_device(h[0], torch.zeros(batch, beam))
 
         rnnlm_state = None
@@ -716,14 +679,10 @@
         if self.replace_sos and recog_args.tgt_lang:
             logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
             logging.info("<sos> mark: " + recog_args.tgt_lang)
-            yseq = [
-                [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)
-            ]
+            yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
         elif lang_ids is not None:
             # NOTE: used for evaluation during training
-            yseq = [
-                [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)
-            ]
+            yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
         else:
             logging.info("<sos> index: " + str(self.sos))
             logging.info("<sos> mark: " + char_list[self.sos])
@@ -740,8 +699,7 @@
         ]
         exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
         exp_h = [
-            h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
-            for idx in range(self.num_encs)
+            h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)
         ]
         exp_h = [
             exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
@@ -750,9 +708,7 @@
 
         if lpz[0] is not None:
             scoring_num = min(
-                int(beam * CTC_SCORING_RATIO)
-                if att_weight > 0.0 and not lpz[0].is_cuda
-                else 0,
+                int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0,
                 lpz[0].size(-1),
             )
             ctc_scorer = [
@@ -796,9 +752,7 @@
             # attention decoder
             z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
             if self.context_residual:
-                logits = self.output(
-                    torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
-                )
+                logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
             else:
                 logits = self.output(self.dropout_dec[-1](z_list[-1]))
             local_scores = att_weight * F.log_softmax(logits, dim=1)
@@ -812,9 +766,7 @@
             if ctc_scorer[0]:
                 local_scores[:, 0] = self.logzero  # avoid choosing blank
                 part_ids = (
-                    torch.topk(local_scores, scoring_num, dim=-1)[1]
-                    if scoring_num > 0
-                    else None
+                    torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None
                 )
                 for idx in range(self.num_encs):
                     att_w = att_w_list[idx]
@@ -823,8 +775,7 @@
                         yseq, ctc_state[idx], part_ids, att_w_
                     )
                     local_scores = (
-                            local_scores
-                            + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
+                        local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
                     )
 
             local_scores = local_scores.view(batch, beam, self.odim)
@@ -839,9 +790,7 @@
 
             # global pruning
             accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
-            accum_odim_ids = (
-                torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
-            )
+            accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
             accum_padded_beam_ids = (
                 (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
             )
@@ -867,24 +816,16 @@
                     ]
                 else:
                     # handle the case of location_recurrent when return is a tuple
-                    _a_prev_ = torch.index_select(
-                        att_w_list[idx][0].view(n_bb, -1), 0, vidx
-                    )
-                    _h_prev_ = torch.index_select(
-                        att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
-                    )
-                    _c_prev_ = torch.index_select(
-                        att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
-                    )
+                    _a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
+                    _h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
+                    _c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
                     _a_prev = (_a_prev_, (_h_prev_, _c_prev_))
                 a_prev.append(_a_prev)
             z_prev = [
-                torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
-                for li in range(self.dlayers)
+                torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
             ]
             c_prev = [
-                torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
-                for li in range(self.dlayers)
+                torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
             ]
 
             # pick ended hyps
@@ -900,9 +841,7 @@
                         _vscore = None
                         if eos_vscores[samp_i, beam_j] > thr[samp_i]:
                             yk = y_prev[k][:]
-                            if len(yk) <= min(
-                                    hlens[idx][samp_i] for idx in range(self.num_encs)
-                            ):
+                            if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)):
                                 _vscore = eos_vscores[samp_i][beam_j] + penalty_i
                         elif i == maxlen - 1:
                             yk = yseq[k][:]
@@ -910,9 +849,7 @@
                         if _vscore:
                             yk.append(self.eos)
                             if rnnlm:
-                                _vscore += recog_args.lm_weight * rnnlm.final(
-                                    rnnlm_state, index=k
-                                )
+                                _vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
                             _score = _vscore.data.cpu().numpy()
                             ended_hyps[samp_i].append(
                                 {"yseq": yk, "vscore": _vscore, "score": _score}
@@ -938,9 +875,7 @@
 
         torch.cuda.empty_cache()
 
-        dummy_hyps = [
-            {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
-        ]
+        dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
         ended_hyps = [
             ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
             for samp_i in six.moves.range(batch)
@@ -952,7 +887,7 @@
 
         nbest_hyps = [
             sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
-            : min(len(ended_hyps[samp_i]), recog_args.nbest)
+                : min(len(ended_hyps[samp_i]), recog_args.nbest)
             ]
             for samp_i in six.moves.range(batch)
         ]
@@ -1168,13 +1103,9 @@
                 state["a_prev"][self.num_encs],
             )
         ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
-        z_list, c_list = self.rnn_forward(
-            ey, z_list, c_list, state["z_prev"], state["c_prev"]
-        )
+        z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"])
         if self.context_residual:
-            logits = self.output(
-                torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
-            )
+            logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
         else:
             logits = self.output(self.dropout_dec[-1](z_list[-1]))
         logp = F.log_softmax(logits, dim=1).squeeze(0)

--
Gitblit v1.9.1