| | |
| | | from itertools import chain |
| | | import logging |
| | | from typing import Any |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import NamedTuple |
| | | from typing import Tuple |
| | | from typing import Union |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import torch |
| | | import logging |
| | | from itertools import chain |
| | | from typing import Any, Dict, List, NamedTuple, Tuple, Union |
| | | |
| | | from funasr.metrics import end_detect |
| | | from funasr.models.transformer.scorers.scorer_interface import PartialScorerInterface |
| | | from funasr.models.transformer.scorers.scorer_interface import ScorerInterface |
| | | from funasr.metrics.common import end_detect |
| | | from funasr.models.transformer.scorers.scorer_interface import ( |
| | | PartialScorerInterface, |
| | | ScorerInterface, |
| | | ) |
| | | |
| | | |
| | | class Hypothesis(NamedTuple): |
| | | """Hypothesis data type.""" |
| | |
| | | Hypothesis( |
| | | score=weighted_scores[j], |
| | | yseq=self.append_token(hyp.yseq, j), |
| | | scores=self.merge_scores( |
| | | hyp.scores, scores, j, part_scores, part_j |
| | | ), |
| | | scores=self.merge_scores(hyp.scores, scores, j, part_scores, part_j), |
| | | states=self.merge_states(states, part_states, part_j), |
| | | ) |
| | | ) |
| | |
| | | return best_hyps |
| | | |
| | | def forward( |
| | | self, x: torch.Tensor, am_scores: torch.Tensor, maxlenratio: float = 0.0, minlenratio: float = 0.0 |
| | | self, |
| | | x: torch.Tensor, |
| | | am_scores: torch.Tensor, |
| | | maxlenratio: float = 0.0, |
| | | minlenratio: float = 0.0, |
| | | ) -> List[Hypothesis]: |
| | | """Perform beam search. |
| | | |
| | |
| | | # check the number of hypotheses reaching to eos |
| | | if len(nbest_hyps) == 0: |
| | | logging.warning( |
| | | "there is no N-best results, perform recognition " |
| | | "again with smaller minlenratio." |
| | | "there is no N-best results, perform recognition " "again with smaller minlenratio." |
| | | ) |
| | | return ( |
| | | [] |
| | |
| | | # report the best result |
| | | best = nbest_hyps[0] |
| | | for k, v in best.scores.items(): |
| | | logging.info( |
| | | f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}" |
| | | ) |
| | | logging.info(f"{v:6.2f} * {self.weights[k]:3} = {v * self.weights[k]:6.2f} for {k}") |
| | | logging.info(f"total log probability: {best.score:.2f}") |
| | | logging.info(f"normalized log probability: {best.score / len(best.yseq):.2f}") |
| | | logging.info(f"total number of ended hypotheses: {len(nbest_hyps)}") |
| | | if self.token_list is not None: |
| | | logging.info( |
| | | "best hypo: " |
| | | + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) |
| | | + "\n" |
| | | "best hypo: " + "".join([self.token_list[x.item()] for x in best.yseq[1:-1]]) + "\n" |
| | | ) |
| | | return nbest_hyps |
| | | |
| | |
| | | if i == maxlen - 1: |
| | | logging.info("adding <eos> in the last position in the loop") |
| | | running_hyps = [ |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) |
| | | for h in running_hyps |
| | | h._replace(yseq=self.append_token(h.yseq, self.eos)) for h in running_hyps |
| | | ] |
| | | |
| | | # add ended hypotheses to a final list, and removed them from current hypotheses |