liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/paraformer/search.py
@@ -1,17 +1,19 @@
from itertools import chain
import logging
from typing import Any
from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Tuple
from typing import Union
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import torch
import logging
from itertools import chain
from typing import Any, Dict, List, NamedTuple, Tuple, Union
from funasr.metrics import end_detect
from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface
from funasr.models.transformer.scorers.scorer_interface import ScorerInterface
from funasr.metrics.common import end_detect
from funasr.models.transformer.scorers.scorer_interface import (
    PartialScorerInterface,
    ScorerInterface,
)
class Hypothesis(NamedTuple):
    """Hypothesis data type."""
@@ -318,9 +320,7 @@
                    Hypothesis(
                        score=weighted_scores[j],
                        yseq=self.append_token(hyp.yseq, j),
                        scores=self.merge_scores(
                            hyp.scores, scores, j, part_scores, part_j
                        ),
                        scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j),
                        states=self.merge_states(states, part_states, part_j),
                    )
                )
@@ -332,7 +332,11 @@
        return best_hyps
    def forward(
        self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0
        self,
        x: torch.Tensor,
        am_scores: torch.Tensor,
        maxlenratio: float = 0.0,
        minlenratio: float = 0.0,
    ) -> List[Hypothesis]:
        """Perform beam search.
@@ -376,8 +380,7 @@
        # check the number of hypotheses reaching to eos
        if len(nbest_hyps) == 0:
            logging.warning(
                "there is no N-best results, perform recognition "
                "again with smaller minlenratio."
                "there is no N-best results, perform recognition " "again with smaller minlenratio."
            )
            return (
                []
@@ -388,17 +391,13 @@
        # report the best result
        best = nbest_hyps[0]
        for k, v in best.scores.items():
            logging.info(
                f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}"
            )
            logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}")
        logging.info(f"total log probability: {best.score:.2f}")
        logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}")
        logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}")
        if self.token_list is not None:
            logging.info(
                "best hypo: "
                + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]])
                + "\n"
                "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n"
            )
        return nbest_hyps
@@ -433,8 +432,7 @@
        if i == maxlen - 1:
            logging.info("adding <eos> in the last position in the loop")
            running_hyps = [
                h._replace(yseq=self.append_token(h.yseq, self.eos))
                for h in running_hyps
                h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps
            ]
        # add ended hypotheses to a final list, and removed them from current hypotheses