liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/language_model/rnn/decoders.py
@@ -1,4 +1,5 @@
"""RNN decoder module."""
import logging
import math
import random
@@ -15,7 +16,7 @@
from funasr.metrics import end_detect
from funasr.models.transformer.utils.nets_utils import mask_by_length
from funasr.models.transformer.utils.nets_utils import pad_list
from funasr.models.transformer.utils.nets_utils import th_accuracy
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.transformer.utils.nets_utils import to_device
from funasr.models.language_model.rnn.attentions import att_to_numpy
@@ -45,24 +46,24 @@
    """
    def __init__(
            self,
            eprojs,
            odim,
            dtype,
            dlayers,
            dunits,
            sos,
            eos,
            att,
            verbose=0,
            char_list=None,
            labeldist=None,
            lsm_weight=0.0,
            sampling_probability=0.0,
            dropout=0.0,
            context_residual=False,
            replace_sos=False,
            num_encs=1,
        self,
        eprojs,
        odim,
        dtype,
        dlayers,
        dunits,
        sos,
        eos,
        att,
        verbose=0,
        char_list=None,
        labeldist=None,
        lsm_weight=0.0,
        sampling_probability=0.0,
        dropout=0.0,
        context_residual=False,
        replace_sos=False,
        num_encs=1,
    ):
        torch.nn.Module.__init__(self)
@@ -76,16 +77,20 @@
        self.decoder = torch.nn.ModuleList()
        self.dropout_dec = torch.nn.ModuleList()
        self.decoder += [
            torch.nn.LSTMCell(dunits + eprojs, dunits)
            if self.dtype == "lstm"
            else torch.nn.GRUCell(dunits + eprojs, dunits)
            (
                torch.nn.LSTMCell(dunits + eprojs, dunits)
                if self.dtype == "lstm"
                else torch.nn.GRUCell(dunits + eprojs, dunits)
            )
        ]
        self.dropout_dec += [torch.nn.Dropout(p=dropout)]
        for _ in six.moves.range(1, self.dlayers):
            self.decoder += [
                torch.nn.LSTMCell(dunits, dunits)
                if self.dtype == "lstm"
                else torch.nn.GRUCell(dunits, dunits)
                (
                    torch.nn.LSTMCell(dunits, dunits)
                    if self.dtype == "lstm"
                    else torch.nn.GRUCell(dunits, dunits)
                )
            ]
            self.dropout_dec += [torch.nn.Dropout(p=dropout)]
            # NOTE: dropout is applied only for the vertical connections
@@ -131,9 +136,7 @@
        else:
            z_list[0] = self.decoder[0](ey, z_prev[0])
            for i in six.moves.range(1, self.dlayers):
                z_list[i] = self.decoder[i](
                    self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i]
                )
                z_list[i] = self.decoder[i](self.dropout_dec[i - 1](z_list[i - 1]), z_prev[i])
        return z_list, c_list
    def forward(self, hs_pad, hlens, ys_pad, strm_idx=0, lang_ids=None):
@@ -197,9 +200,7 @@
                )
            )
        logging.info(
            self.__class__.__name__
            + " output lengths: "
            + str([y.size(0) for y in ys_out])
            self.__class__.__name__ + " output lengths: " + str([y.size(0) for y in ys_out])
        )
        # initialization
@@ -280,7 +281,7 @@
            ys_hat = y_all.view(batch, olength, -1)
            ys_true = ys_out_pad
            for (i, y_hat), y_true in zip(
                    enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
                enumerate(ys_hat.detach().cpu().numpy()), ys_true.detach().cpu().numpy()
            ):
                if i == MAX_DECODER_OUTPUT:
                    break
@@ -362,9 +363,7 @@
            weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
                recog_args.weights_ctc_dec
            )  # normalize
            logging.info(
                "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
            )
            logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
        else:
            weights_ctc_dec = [1.0]
@@ -450,9 +449,7 @@
                        hyp["a_prev"][self.num_encs],
                    )
                ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
                z_list, c_list = self.rnn_forward(
                    ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"]
                )
                z_list, c_list = self.rnn_forward(ey, z_list, c_list, hyp["z_prev"], hyp["c_prev"])
                # get nbest local scores and their ids
                if self.context_residual:
@@ -464,9 +461,7 @@
                local_att_scores = F.log_softmax(logits, dim=1)
                if rnnlm:
                    rnnlm_state, local_lm_scores = rnnlm.predict(hyp["rnnlm_prev"], vy)
                    local_scores = (
                            local_att_scores + recog_args.lm_weight * local_lm_scores
                    )
                    local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores
                else:
                    local_scores = local_att_scores
@@ -482,9 +477,7 @@
                        ctc_scores[idx], ctc_states[idx] = ctc_prefix_score[idx](
                            hyp["yseq"], local_best_ids[0], hyp["ctc_state_prev"][idx]
                        )
                    local_scores = (1.0 - ctc_weight) * local_att_scores[
                                                        :, local_best_ids[0]
                                                        ]
                    local_scores = (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]]
                    if self.num_encs == 1:
                        local_scores += ctc_weight * torch.from_numpy(
                            ctc_scores[0] - hyp["ctc_score_prev"][0]
@@ -492,24 +485,16 @@
                    else:
                        for idx in range(self.num_encs):
                            local_scores += (
                                    ctc_weight
                                    * weights_ctc_dec[idx]
                                    * torch.from_numpy(
                                ctc_scores[idx] - hyp["ctc_score_prev"][idx]
                            )
                                ctc_weight
                                * weights_ctc_dec[idx]
                                * torch.from_numpy(ctc_scores[idx] - hyp["ctc_score_prev"][idx])
                            )
                    if rnnlm:
                        local_scores += (
                                recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                        )
                    local_best_scores, joint_best_ids = torch.topk(
                        local_scores, beam, dim=1
                    )
                        local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[0]]
                    local_best_scores, joint_best_ids = torch.topk(local_scores, beam, dim=1)
                    local_best_ids = local_best_ids[:, joint_best_ids[0]]
                else:
                    local_best_scores, local_best_ids = torch.topk(
                        local_scores, beam, dim=1
                    )
                    local_best_scores, local_best_ids = torch.topk(local_scores, beam, dim=1)
                for j in six.moves.range(beam):
                    new_hyp = {}
@@ -519,9 +504,7 @@
                    if self.num_encs == 1:
                        new_hyp["a_prev"] = att_w[:]
                    else:
                        new_hyp["a_prev"] = [
                            att_w_list[idx][:] for idx in range(self.num_encs + 1)
                        ]
                        new_hyp["a_prev"] = [att_w_list[idx][:] for idx in range(self.num_encs + 1)]
                    new_hyp["score"] = hyp["score"] + local_best_scores[0, j]
                    new_hyp["yseq"] = [0] * (1 + len(hyp["yseq"]))
                    new_hyp["yseq"][: len(hyp["yseq"])] = hyp["yseq"]
@@ -530,27 +513,22 @@
                        new_hyp["rnnlm_prev"] = rnnlm_state
                    if lpz[0] is not None:
                        new_hyp["ctc_state_prev"] = [
                            ctc_states[idx][joint_best_ids[0, j]]
                            for idx in range(self.num_encs)
                            ctc_states[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
                        ]
                        new_hyp["ctc_score_prev"] = [
                            ctc_scores[idx][joint_best_ids[0, j]]
                            for idx in range(self.num_encs)
                            ctc_scores[idx][joint_best_ids[0, j]] for idx in range(self.num_encs)
                        ]
                    # will be (2 x beam) hyps at most
                    hyps_best_kept.append(new_hyp)
                hyps_best_kept = sorted(
                    hyps_best_kept, key=lambda x: x["score"], reverse=True
                )[:beam]
                hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x["score"], reverse=True)[
                    :beam
                ]
            # sort and get nbest
            hyps = hyps_best_kept
            logging.debug("number of pruned hypotheses: " + str(len(hyps)))
            logging.debug(
                "best hypo: "
                + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]])
            )
            logging.debug("best hypo: " + "".join([char_list[int(x)] for x in hyps[0]["yseq"][1:]]))
            # add eos in the final loop to avoid that there are no ended hyps
            if i == maxlen - 1:
@@ -569,9 +547,7 @@
                    if len(hyp["yseq"]) > minlen:
                        hyp["score"] += (i + 1) * penalty
                        if rnnlm:  # Word LM needs to add final <eos> score
                            hyp["score"] += recog_args.lm_weight * rnnlm.final(
                                hyp["rnnlm_prev"]
                            )
                            hyp["score"] += recog_args.lm_weight * rnnlm.final(hyp["rnnlm_prev"])
                        ended_hyps.append(hyp)
                else:
                    remained_hyps.append(hyp)
@@ -589,21 +565,18 @@
                break
            for hyp in hyps:
                logging.debug(
                    "hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]])
                )
                logging.debug("hypo: " + "".join([char_list[int(x)] for x in hyp["yseq"][1:]]))
            logging.debug("number of ended hypotheses: " + str(len(ended_hyps)))
        nbest_hyps = sorted(ended_hyps, key=lambda x: x["score"], reverse=True)[
                     : min(len(ended_hyps), recog_args.nbest)
                     ]
            : min(len(ended_hyps), recog_args.nbest)
        ]
        # check number of hypotheses
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, "
                "perform recognition again with smaller minlenratio."
                "there is no N-best results, " "perform recognition again with smaller minlenratio."
            )
            # should copy because Namespace will be overwritten globally
            recog_args = Namespace(**vars(recog_args))
@@ -623,16 +596,16 @@
        return nbest_hyps
    def recognize_beam_batch(
            self,
            h,
            hlens,
            lpz,
            recog_args,
            char_list,
            rnnlm=None,
            normalize_score=True,
            strm_idx=0,
            lang_ids=None,
        self,
        h,
        hlens,
        lpz,
        recog_args,
        char_list,
        rnnlm=None,
        normalize_score=True,
        strm_idx=0,
        lang_ids=None,
    ):
        # to support mutiple encoder asr mode, in single encoder mode,
        # convert torch.Tensor to List of torch.Tensor
@@ -667,9 +640,7 @@
            weights_ctc_dec = recog_args.weights_ctc_dec / np.sum(
                recog_args.weights_ctc_dec
            )  # normalize
            logging.info(
                "ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec])
            )
            logging.info("ctc weights (decoding): " + " ".join([str(x) for x in weights_ctc_dec]))
        else:
            weights_ctc_dec = [1.0]
@@ -686,18 +657,10 @@
        logging.info("min output length: " + str(minlen))
        # initialization
        c_prev = [
            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
        ]
        z_prev = [
            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
        ]
        c_list = [
            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
        ]
        z_list = [
            to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)
        ]
        c_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        z_prev = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        c_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        z_list = [to_device(h[0], torch.zeros(n_bb, self.dunits)) for _ in range(self.dlayers)]
        vscores = to_device(h[0], torch.zeros(batch, beam))
        rnnlm_state = None
@@ -716,14 +679,10 @@
        if self.replace_sos and recog_args.tgt_lang:
            logging.info("<sos> index: " + str(char_list.index(recog_args.tgt_lang)))
            logging.info("<sos> mark: " + recog_args.tgt_lang)
            yseq = [
                [char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)
            ]
            yseq = [[char_list.index(recog_args.tgt_lang)] for _ in six.moves.range(n_bb)]
        elif lang_ids is not None:
            # NOTE: used for evaluation during training
            yseq = [
                [lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)
            ]
            yseq = [[lang_ids[b // recog_args.beam_size]] for b in six.moves.range(n_bb)]
        else:
            logging.info("<sos> index: " + str(self.sos))
            logging.info("<sos> mark: " + char_list[self.sos])
@@ -740,8 +699,7 @@
        ]
        exp_hlens = [exp_hlens[idx].view(-1).tolist() for idx in range(self.num_encs)]
        exp_h = [
            h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous()
            for idx in range(self.num_encs)
            h[idx].unsqueeze(1).repeat(1, beam, 1, 1).contiguous() for idx in range(self.num_encs)
        ]
        exp_h = [
            exp_h[idx].view(n_bb, h[idx].size()[1], h[idx].size()[2])
@@ -750,9 +708,7 @@
        if lpz[0] is not None:
            scoring_num = min(
                int(beam * CTC_SCORING_RATIO)
                if att_weight > 0.0 and not lpz[0].is_cuda
                else 0,
                int(beam * CTC_SCORING_RATIO) if att_weight > 0.0 and not lpz[0].is_cuda else 0,
                lpz[0].size(-1),
            )
            ctc_scorer = [
@@ -796,9 +752,7 @@
            # attention decoder
            z_list, c_list = self.rnn_forward(ey, z_list, c_list, z_prev, c_prev)
            if self.context_residual:
                logits = self.output(
                    torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
                )
                logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
            else:
                logits = self.output(self.dropout_dec[-1](z_list[-1]))
            local_scores = att_weight * F.log_softmax(logits, dim=1)
@@ -812,9 +766,7 @@
            if ctc_scorer[0]:
                local_scores[:, 0] = self.logzero  # avoid choosing blank
                part_ids = (
                    torch.topk(local_scores, scoring_num, dim=-1)[1]
                    if scoring_num > 0
                    else None
                    torch.topk(local_scores, scoring_num, dim=-1)[1] if scoring_num > 0 else None
                )
                for idx in range(self.num_encs):
                    att_w = att_w_list[idx]
@@ -823,8 +775,7 @@
                        yseq, ctc_state[idx], part_ids, att_w_
                    )
                    local_scores = (
                            local_scores
                            + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
                        local_scores + ctc_weight * weights_ctc_dec[idx] * local_ctc_scores
                    )
            local_scores = local_scores.view(batch, beam, self.odim)
@@ -839,9 +790,7 @@
            # global pruning
            accum_best_scores, accum_best_ids = torch.topk(vscores, beam, 1)
            accum_odim_ids = (
                torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
            )
            accum_odim_ids = torch.fmod(accum_best_ids, self.odim).view(-1).data.cpu().tolist()
            accum_padded_beam_ids = (
                (accum_best_ids // self.odim + pad_b).view(-1).data.cpu().tolist()
            )
@@ -867,24 +816,16 @@
                    ]
                else:
                    # handle the case of location_recurrent when return is a tuple
                    _a_prev_ = torch.index_select(
                        att_w_list[idx][0].view(n_bb, -1), 0, vidx
                    )
                    _h_prev_ = torch.index_select(
                        att_w_list[idx][1][0].view(n_bb, -1), 0, vidx
                    )
                    _c_prev_ = torch.index_select(
                        att_w_list[idx][1][1].view(n_bb, -1), 0, vidx
                    )
                    _a_prev_ = torch.index_select(att_w_list[idx][0].view(n_bb, -1), 0, vidx)
                    _h_prev_ = torch.index_select(att_w_list[idx][1][0].view(n_bb, -1), 0, vidx)
                    _c_prev_ = torch.index_select(att_w_list[idx][1][1].view(n_bb, -1), 0, vidx)
                    _a_prev = (_a_prev_, (_h_prev_, _c_prev_))
                a_prev.append(_a_prev)
            z_prev = [
                torch.index_select(z_list[li].view(n_bb, -1), 0, vidx)
                for li in range(self.dlayers)
                torch.index_select(z_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
            ]
            c_prev = [
                torch.index_select(c_list[li].view(n_bb, -1), 0, vidx)
                for li in range(self.dlayers)
                torch.index_select(c_list[li].view(n_bb, -1), 0, vidx) for li in range(self.dlayers)
            ]
            # pick ended hyps
@@ -900,9 +841,7 @@
                        _vscore = None
                        if eos_vscores[samp_i, beam_j] > thr[samp_i]:
                            yk = y_prev[k][:]
                            if len(yk) <= min(
                                    hlens[idx][samp_i] for idx in range(self.num_encs)
                            ):
                            if len(yk) <= min(hlens[idx][samp_i] for idx in range(self.num_encs)):
                                _vscore = eos_vscores[samp_i][beam_j] + penalty_i
                        elif i == maxlen - 1:
                            yk = yseq[k][:]
@@ -910,9 +849,7 @@
                        if _vscore:
                            yk.append(self.eos)
                            if rnnlm:
                                _vscore += recog_args.lm_weight * rnnlm.final(
                                    rnnlm_state, index=k
                                )
                                _vscore += recog_args.lm_weight * rnnlm.final(rnnlm_state, index=k)
                            _score = _vscore.data.cpu().numpy()
                            ended_hyps[samp_i].append(
                                {"yseq": yk, "vscore": _vscore, "score": _score}
@@ -938,9 +875,7 @@
        torch.cuda.empty_cache()
        dummy_hyps = [
            {"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}
        ]
        dummy_hyps = [{"yseq": [self.sos, self.eos], "score": np.array([-float("inf")])}]
        ended_hyps = [
            ended_hyps[samp_i] if len(ended_hyps[samp_i]) != 0 else dummy_hyps
            for samp_i in six.moves.range(batch)
@@ -952,7 +887,7 @@
        nbest_hyps = [
            sorted(ended_hyps[samp_i], key=lambda x: x["score"], reverse=True)[
            : min(len(ended_hyps[samp_i]), recog_args.nbest)
                : min(len(ended_hyps[samp_i]), recog_args.nbest)
            ]
            for samp_i in six.moves.range(batch)
        ]
@@ -1168,13 +1103,9 @@
                state["a_prev"][self.num_encs],
            )
        ey = torch.cat((ey, att_c), dim=1)  # utt(1) x (zdim + hdim)
        z_list, c_list = self.rnn_forward(
            ey, z_list, c_list, state["z_prev"], state["c_prev"]
        )
        z_list, c_list = self.rnn_forward(ey, z_list, c_list, state["z_prev"], state["c_prev"])
        if self.context_residual:
            logits = self.output(
                torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1)
            )
            logits = self.output(torch.cat((self.dropout_dec[-1](z_list[-1]), att_c), dim=-1))
        else:
            logits = self.output(self.dropout_dec[-1](z_list[-1]))
        logp = F.log_softmax(logits, dim=1).squeeze(0)