liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/scama/beam_search.py
@@ -11,7 +11,7 @@
import torch
from funasr.metrics import end_detect
from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
@@ -31,7 +31,6 @@
            score=float(self.score),
            scores={k: float(v) for k, v in self.scores.items()},
        )._asdict()
class BeamSearchScama(torch.nn.Module):
@@ -151,7 +150,8 @@
        return torch.cat((xs, x))
    def score_full(
        self, hyp: Hypothesis,
        self,
        hyp: Hypothesis,
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
@@ -173,7 +173,9 @@
        scores = dict()
        states = dict()
        for k, d in self.full_scorers.items():
            scores[k], states[k] = d.score(hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
            scores[k], states[k] = d.score(
                hyp.yseq, hyp.states[k], x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
            )
        return scores, states
    def score_partial(
@@ -283,7 +285,8 @@
        return new_states
    def search(
        self, running_hyps: List[Hypothesis],
        self,
        running_hyps: List[Hypothesis],
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
@@ -303,7 +306,9 @@
        for hyp in running_hyps:
            # scoring
            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
            scores, states = self.score_full(hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds)
            scores, states = self.score_full(
                hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds
            )
            for k in self.full_scorers:
                weighted_scores += self.weights[k] * scores[k]
            # partial scoring
@@ -327,9 +332,7 @@
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(
                            hyp.scores, scores, j, part_scores, part_j
                        ),
                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
@@ -341,7 +344,8 @@
        return best_hyps
    def forward(
        self, x: torch.Tensor,
        self,
        x: torch.Tensor,
        scama_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        maxlenratio: float = 0.0,
@@ -386,20 +390,26 @@
            mask_enc = None
            if scama_mask is not None:
                token_num_predictor = scama_mask.size(1)
                token_id_slice = min(i, token_num_predictor-1)
                mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
                token_id_slice = min(i, token_num_predictor - 1)
                mask_enc = scama_mask[:, token_id_slice : token_id_slice + 1, :]
                # if mask_enc.size(1) == 0:
                #     mask_enc = scama_mask[:, -2:-1, :]
                #     # mask_enc = torch.zeros_like(mask_enc)
            pre_acoustic_embeds_cur = None
            if pre_acoustic_embeds is not None:
                b, t, d = pre_acoustic_embeds.size()
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(device=pre_acoustic_embeds.device)
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
                    device=pre_acoustic_embeds.device
                )
                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                token_id_slice = min(i, t)
                pre_acoustic_embeds_cur = pre_acoustic_embeds[:, token_id_slice:token_id_slice+1, :]
                pre_acoustic_embeds_cur = pre_acoustic_embeds[
                    :, token_id_slice : token_id_slice + 1, :
                ]
            best = self.search(running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur)
            best = self.search(
                running_hyps, x, x_mask=mask_enc, pre_acoustic_embeds=pre_acoustic_embeds_cur
            )
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
@@ -416,8 +426,7 @@
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
                "there is no N-best results, perform recognition " "again with smaller minlenratio."
            )
            return (
                []
@@ -431,17 +440,13 @@
            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x] for x in best.yseq[1:-1]])
                + "\n"
                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
            )
        return nbest_hyps
@@ -469,15 +474,492 @@
        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
        if self.token_list is not None:
            logging.debug(
                "best hypo: "
                + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
            )
        # add eos in the final loop to avoid that there are no ended hyps
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos))
                for h in running_hyps
                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses
        # (this will be a problem, number of hyps < beam)
        remained_hyps = []
        for hyp in running_hyps:
            if hyp.yseq[-1] == self.eos:
                # e.g., Word LM needs to add final <eos> score
                for k, d in chain(self.full_scorers.items(), self.part_scorers.items()):
                    s = d.final_score(hyp.states[k])
                    hyp.scores[k] += s
                    hyp = hyp._replace(score=hyp.score + self.weights[k] * s)
                ended_hyps.append(hyp)
            else:
                remained_hyps.append(hyp)
        return remained_hyps
class BeamSearchScamaStreaming(torch.nn.Module):
    """Beam search implementation."""
    def __init__(
        self,
        scorers: Dict[str, ScorerInterface],
        weights: Dict[str, float],
        beam_size: int,
        vocab_size: int,
        sos: int,
        eos: int,
        token_list: List[str] = None,
        pre_beam_ratio: float = 1.5,
        pre_beam_score_key: str = None,
    ):
        """Initialize beam search.
        Args:
            scorers (dict[str, ScorerInterface]): Dict of decoder modules
                e.g., Decoder, CTCPrefixScorer, LM
                The scorer will be ignored if it is `None`
            weights (dict[str, float]): Dict of weights for each scorers
                The scorer will be ignored if its weight is 0
            beam_size (int): The number of hypotheses kept during search
            vocab_size (int): The number of vocabulary
            sos (int): Start of sequence id
            eos (int): End of sequence id
            token_list (list[str]): List of tokens for debug log
            pre_beam_score_key (str): key of scores to perform pre-beam search
            pre_beam_ratio (float): beam size in the pre-beam search
                will be `int(pre_beam_ratio * beam_size)`
        """
        super().__init__()
        # set scorers
        self.weights = weights
        self.scorers = dict()
        self.full_scorers = dict()
        self.part_scorers = dict()
        # this module dict is required for recursive cast
        # `self.to(device, dtype)` in `recog.py`
        self.nn_dict = torch.nn.ModuleDict()
        for k, v in scorers.items():
            w = weights.get(k, 0)
            if w == 0 or v is None:
                continue
            assert isinstance(
                v, ScorerInterface
            ), f"{k} ({type(v)}) does not implement ScorerInterface"
            self.scorers[k] = v
            if isinstance(v, PartialScorerInterface):
                self.part_scorers[k] = v
            else:
                self.full_scorers[k] = v
            if isinstance(v, torch.nn.Module):
                self.nn_dict[k] = v
        # set configurations
        self.sos = sos
        self.eos = eos
        self.token_list = token_list
        self.pre_beam_size = int(pre_beam_ratio * beam_size)
        self.beam_size = beam_size
        self.n_vocab = vocab_size
        if (
            pre_beam_score_key is not None
            and pre_beam_score_key != "full"
            and pre_beam_score_key not in self.full_scorers
        ):
            raise KeyError(f"{pre_beam_score_key} is not found in {self.full_scorers}")
        self.pre_beam_score_key = pre_beam_score_key
        self.do_pre_beam = (
            self.pre_beam_score_key is not None
            and self.pre_beam_size < self.n_vocab
            and len(self.part_scorers) > 0
        )
    def init_hyp(self, x) -> List[Hypothesis]:
        """Get an initial hypothesis data.
        Args:
            x (torch.Tensor): The encoder output feature
        Returns:
            Hypothesis: The initial hypothesis.
        """
        init_states = dict()
        init_scores = dict()
        for k, d in self.scorers.items():
            init_states[k] = d.init_state(x)
            init_scores[k] = 0.0
        return [
            Hypothesis(
                score=0.0,
                scores=init_scores,
                states=init_states,
                yseq=torch.tensor([self.sos], device=x.device),
            )
        ]
    @staticmethod
    def append_token(xs: torch.Tensor, x: int) -> torch.Tensor:
        """Append new token to prefix tokens.
        Args:
            xs (torch.Tensor): The prefix token
            x (int): The new token to append
        Returns:
            torch.Tensor: New tensor contains: xs + [x] with xs.dtype and xs.device
        """
        x = torch.tensor([x], dtype=xs.dtype, device=xs.device)
        return torch.cat((xs, x))
    def score_full(
        self,
        hyp: Hypothesis,
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        cache: dict = {},
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.full_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.full_scorers`
                and tensor score values of shape: `(self.n_vocab,)`,
                and state dict that has string keys
                and state values of `self.full_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.full_scorers.items():
            scores[k], states[k] = d.score(
                hyp.yseq,
                hyp.states[k],
                x,
                x_mask=x_mask,
                pre_acoustic_embeds=pre_acoustic_embeds,
                cache=cache,
            )
        return scores, states
    def score_partial(
        self, hyp: Hypothesis, ids: torch.Tensor, x: torch.Tensor
    ) -> Tuple[Dict[str, torch.Tensor], Dict[str, Any]]:
        """Score new hypothesis by `self.part_scorers`.
        Args:
            hyp (Hypothesis): Hypothesis with prefix tokens to score
            ids (torch.Tensor): 1D tensor of new partial tokens to score
            x (torch.Tensor): Corresponding input feature
        Returns:
            Tuple[Dict[str, torch.Tensor], Dict[str, Any]]: Tuple of
                score dict of `hyp` that has string keys of `self.part_scorers`
                and tensor score values of shape: `(len(ids),)`,
                and state dict that has string keys
                and state values of `self.part_scorers`
        """
        scores = dict()
        states = dict()
        for k, d in self.part_scorers.items():
            scores[k], states[k] = d.score_partial(hyp.yseq, ids, hyp.states[k], x)
        return scores, states
    def beam(
        self, weighted_scores: torch.Tensor, ids: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Compute topk full token ids and partial token ids.
        Args:
            weighted_scores (torch.Tensor): The weighted sum scores for each tokens.
            Its shape is `(self.n_vocab,)`.
            ids (torch.Tensor): The partial token ids to compute topk
        Returns:
            Tuple[torch.Tensor, torch.Tensor]:
                The topk full token ids and partial token ids.
                Their shapes are `(self.beam_size,)`
        """
        # no pre beam performed
        if weighted_scores.size(0) == ids.size(0):
            top_ids = weighted_scores.topk(self.beam_size)[1]
            return top_ids, top_ids
        # mask pruned in pre-beam not to select in topk
        tmp = weighted_scores[ids]
        weighted_scores[:] = -float("inf")
        weighted_scores[ids] = tmp
        top_ids = weighted_scores.topk(self.beam_size)[1]
        local_ids = weighted_scores[ids].topk(self.beam_size)[1]
        return top_ids, local_ids
    @staticmethod
    def merge_scores(
        prev_scores: Dict[str, float],
        next_full_scores: Dict[str, torch.Tensor],
        full_idx: int,
        next_part_scores: Dict[str, torch.Tensor],
        part_idx: int,
    ) -> Dict[str, torch.Tensor]:
        """Merge scores for new hypothesis.
        Args:
            prev_scores (Dict[str, float]):
                The previous hypothesis scores by `self.scorers`
            next_full_scores (Dict[str, torch.Tensor]): scores by `self.full_scorers`
            full_idx (int): The next token id for `next_full_scores`
            next_part_scores (Dict[str, torch.Tensor]):
                scores of partial tokens by `self.part_scorers`
            part_idx (int): The new token id for `next_part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are scalar tensors by the scorers.
        """
        new_scores = dict()
        for k, v in next_full_scores.items():
            new_scores[k] = prev_scores[k] + v[full_idx]
        for k, v in next_part_scores.items():
            new_scores[k] = prev_scores[k] + v[part_idx]
        return new_scores
    def merge_states(self, states: Any, part_states: Any, part_idx: int) -> Any:
        """Merge states for new hypothesis.
        Args:
            states: states of `self.full_scorers`
            part_states: states of `self.part_scorers`
            part_idx (int): The new token id for `part_scores`
        Returns:
            Dict[str, torch.Tensor]: The new score dict.
                Its keys are names of `self.full_scorers` and `self.part_scorers`.
                Its values are states of the scorers.
        """
        new_states = dict()
        for k, v in states.items():
            new_states[k] = v
        for k, d in self.part_scorers.items():
            new_states[k] = d.select_state(part_states[k], part_idx)
        return new_states
    def search(
        self,
        running_hyps: List[Hypothesis],
        x: torch.Tensor,
        x_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        cache: dict = {},
    ) -> List[Hypothesis]:
        """Search new tokens for running hypotheses and encoded speech x.
        Args:
            running_hyps (List[Hypothesis]): Running hypotheses on beam
            x (torch.Tensor): Encoded speech feature (T, D)
        Returns:
            List[Hypotheses]: Best sorted hypotheses
        """
        best_hyps = []
        part_ids = torch.arange(self.n_vocab, device=x.device)  # no pre-beam
        for hyp in running_hyps:
            # scoring
            weighted_scores = torch.zeros(self.n_vocab, dtype=x.dtype, device=x.device)
            scores, states = self.score_full(
                hyp, x, x_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds, cache=cache
            )
            for k in self.full_scorers:
                weighted_scores += self.weights[k] * scores[k]
            # partial scoring
            if self.do_pre_beam:
                pre_beam_scores = (
                    weighted_scores
                    if self.pre_beam_score_key == "full"
                    else scores[self.pre_beam_score_key]
                )
                part_ids = torch.topk(pre_beam_scores, self.pre_beam_size)[1]
            part_scores, part_states = self.score_partial(hyp, part_ids, x)
            for k in self.part_scorers:
                weighted_scores[part_ids] += self.weights[k] * part_scores[k]
            # add previous hyp score
            weighted_scores += hyp.score
            # update hyps
            for j, part_j in zip(*self.beam(weighted_scores, part_ids)):
                # will be (2 x beam at most)
                best_hyps.append(
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
            # sort and prune 2 x beam -> beam
            best_hyps = sorted(best_hyps, key=lambda x: x.score, reverse=True)[
                : min(len(best_hyps), self.beam_size)
            ]
        return best_hyps
    def forward(
        self,
        x: torch.Tensor,
        scama_mask: torch.Tensor = None,
        pre_acoustic_embeds: torch.Tensor = None,
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
        maxlen: int = None,
        minlen: int = 0,
        cache: dict = {},
    ) -> List[Hypothesis]:
        """Perform beam search.
        Args:
            x (torch.Tensor): Encoded speech feature (T, D)
            maxlenratio (float): Input length ratio to obtain max output length.
                If maxlenratio=0.0 (default), it uses a end-detect function
                to automatically find maximum hypothesis lengths
                If maxlenratio<0.0, its absolute value is interpreted
                as a constant max output length.
            minlenratio (float): Input length ratio to obtain min output length.
        Returns:
            list[Hypothesis]: N-best decoding results
        """
        if maxlen is None:
            # set length bounds
            if maxlenratio == 0:
                maxlen = x.shape[0]
            elif maxlenratio < 0:
                maxlen = -1 * int(maxlenratio)
            else:
                maxlen = max(1, int(maxlenratio * x.size(0)))
            minlen = int(minlenratio * x.size(0))
        logging.info("decoder input length: " + str(x.shape[0]))
        logging.info("max output length: " + str(maxlen))
        logging.info("min output length: " + str(minlen))
        # main loop of prefix search
        # running_hyps = self.init_hyp(x)
        running_hyps = cache["running_hyps"]
        ended_hyps = []
        for i in range(maxlen):
            logging.debug("position " + str(i))
            mask_enc = None
            # if scama_mask is not None:
            #     token_num_predictor = scama_mask.size(1)
            #     token_id_slice = min(i, token_num_predictor-1)
            #     mask_enc = scama_mask[:, token_id_slice:token_id_slice+1, :]
            #     # if mask_enc.size(1) == 0:
            #     #     mask_enc = scama_mask[:, -2:-1, :]
            #     #     # mask_enc = torch.zeros_like(mask_enc)
            pre_acoustic_embeds_cur = None
            if pre_acoustic_embeds is not None:
                b, t, d = pre_acoustic_embeds.size()
                pad = torch.zeros((b, 1, d), dtype=pre_acoustic_embeds.dtype).to(
                    device=pre_acoustic_embeds.device
                )
                pre_acoustic_embeds = torch.cat((pre_acoustic_embeds, pad), dim=1)
                token_id_slice = min(i, t)
                pre_acoustic_embeds_cur = pre_acoustic_embeds[
                    :, token_id_slice : token_id_slice + 1, :
                ]
            best = self.search(
                running_hyps,
                x,
                x_mask=mask_enc,
                pre_acoustic_embeds=pre_acoustic_embeds_cur,
                cache=cache["decoder"],
            )
            # post process of one iteration
            running_hyps = self.post_process(i, maxlen, maxlenratio, best, ended_hyps)
            # end detection
            if maxlenratio == 0.0 and end_detect([h.asdict() for h in ended_hyps], i):
                logging.info(f"end detected at {i}")
                break
            if len(running_hyps) == 0:
                logging.info("no hypothesis. Finish decoding.")
                break
            else:
                logging.debug(f"remained hypotheses: {len(running_hyps)}")
        nbest_hyps = sorted(ended_hyps, key=lambda x: x.score, reverse=True)
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition " "again with smaller minlenratio."
            )
            return (
                []
                if minlenratio < 0.1
                else self.forward(x, maxlenratio, max(0.0, minlenratio - 0.1))
            )
        # report the best result
        for x in nbest_hyps:
            yseq = "".join([self.token_list[x] for x in x.yseq])
            logging.debug("nbest: y: {}, yseq: {}, score: {}".format(x.yseq, yseq, x.score))
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: " + "".join([self.token_list[x] for x in best.yseq[1:-1]]) + "\n"
            )
        return nbest_hyps
    def post_process(
        self,
        i: int,
        maxlen: int,
        maxlenratio: float,
        running_hyps: List[Hypothesis],
        ended_hyps: List[Hypothesis],
    ) -> List[Hypothesis]:
        """Perform post-processing of beam search iterations.
        Args:
            i (int): The length of hypothesis tokens.
            maxlen (int): The maximum length of tokens in beam search.
            maxlenratio (int): The maximum length ratio in beam search.
            running_hyps (List[Hypothesis]): The running hypotheses in beam search.
            ended_hyps (List[Hypothesis]): The ended hypotheses in beam search.
        Returns:
            List[Hypothesis]: The new running hypotheses.
        """
        logging.debug(f"the number of running hypotheses: {len(running_hyps)}")
        if self.token_list is not None:
            logging.debug(
                "best hypo: " + "".join([self.token_list[x] for x in running_hyps[0].yseq[1:]])
            )
        # add eos in the final loop to avoid that there are no ended hyps
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses