liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/uniasr/beam_search.py
@@ -33,7 +33,6 @@
        )._asdict()
class BeamSearchScama(torch.nn.Module):
    """Beam search implementation."""
@@ -151,7 +150,8 @@
        return torch.cat((xs, x))
    def score_full(
        self, hyp: Hypothesis,
        self,
        hyp: Hypothesis,
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
@@ -173,7 +173,9 @@
        scores = dict()
        states = dict()
        for k, d in self.full_scorers.items():
            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
            scores[k], states[k] = d.score(
                hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
            )
        return scores, states
    def score_partial(
@@ -283,7 +285,8 @@
        return new_states
    def search(
        self, running_hyps: List[Hypothesis],
        self,
        running_hyps: List[Hypothesis],
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
@@ -303,7 +306,9 @@
        for hyp in running_hyps:
            # scoring
            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
            scores, states = self.score_full(
                hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
            )
            for k in self.full_scorers:
                weighted_scores += self.weights[k] * scores[k]
            # partial scoring
@@ -327,9 +332,7 @@
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(
                            hyp.scores, scores, j, part_scores, part_j
                        ),
                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
@@ -341,7 +344,8 @@
        return best_hyps
    def forward(
        self, x: torch.Tensor,
        self,
        x: torch.Tensor,
        scama_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        maxlenratio: float = 0.0,
@@ -386,20 +390,26 @@
            mask_enc = None
            if scama_mask is not None:
                token_num_predictor = scama_mask.size(1)
                token_id_slice = min(i, token_num_predictor-1)
                mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
                token_id_slice = min(i, token_num_predictor - 1)
                mask_enc = scama_mask[:, token_id_slice : token_id_slice + 1, :]
                # if mask_enc.size(1) == 0:
                #     mask_enc = scama_mask[:, -2:-1, :]
                #     # mask_enc = torch.zeros_like(mask_enc)
            pre_acoustic_embeds_cur = None
            if pre_acoustic_embeds is not None:
                b, t, d = pre_acoustic_embeds.size()
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
                    device=pre_acoustic_embeds.device
                )
                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                token_id_slice = min(i, t)
                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
                pre_acoustic_embeds_cur = pre_acoustic_embeds[
                    :, token_id_slice : token_id_slice + 1, :
                ]
            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
            best = self.search(
                running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur
            )
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
@@ -416,8 +426,7 @@
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
                "there is no N-best results, perform recognition " "again with smaller minlenratio."
            )
            return (
                []
@@ -431,17 +440,13 @@
            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
                + "\n"
                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
            )
        return nbest_hyps
@@ -469,15 +474,13 @@
        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
        if self.token_list is not None:
            logging.debug(
                "best hypo: "
                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
            )
        # add eos in the final loop to avoid that there are no ended hyps
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos))
                for h in running_hyps
                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses